run testes on assessments predictions pipeline
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.multioutput import MultiOutputRegressor
|
||||
from src.pipeline.data_preprocessor import DataPreprocessor
|
||||
from src.pipeline.model_trainer import ModelTrainer
|
||||
|
||||
# Set up logging
|
||||
handler = RotatingFileHandler('/root/ds_erp_ai/logs/prediction_pipeline.log', maxBytes=100000, backupCount=3)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Example of DataPreprocessor and ModelTrainer classes from the previous steps
|
||||
class CompanyModelPipeline:
|
||||
def __init__(self, company_ids, input_base_path):
|
||||
self.company_ids = company_ids
|
||||
self.input_base_path = input_base_path
|
||||
|
||||
def run_pipeline(self):
|
||||
for company_id in self.company_ids:
|
||||
try:
|
||||
# Define paths for the company
|
||||
input_path = os.path.join(self.input_base_path, f'{company_id}_raw_data.csv')
|
||||
|
||||
logger.info(f"Starting preprocessing for company {company_id}.")
|
||||
|
||||
# Step 1: Preprocess the data
|
||||
preprocessor = DataPreprocessor(input_path=input_path, company_id=company_id)
|
||||
processed_data_path = preprocessor.run()
|
||||
logger.info(f"Data preprocessing completed for company {company_id}. Processed data saved to {processed_data_path}.")
|
||||
|
||||
# Step 2: Train and save the model
|
||||
model = MultiOutputRegressor(RandomForestRegressor(n_estimators=100, random_state=42))
|
||||
trainer = ModelTrainer(preprocessed_data_path=processed_data_path, company_id=company_id, model=model)
|
||||
model_path, latest_data_path, evaluation_results = trainer.run()
|
||||
|
||||
logger.info(f"Model training and evaluation completed for company {company_id}.")
|
||||
logger.info(f"Model saved to {model_path} and latest data saved to {latest_data_path}.")
|
||||
logger.info(f"Evaluation Results for company {company_id}: {evaluation_results}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while processing company {company_id}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user