2024-09-12 00:01:03 +00:00
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
|
from sklearn.ensemble import RandomForestRegressor
|
|
|
|
|
from sklearn.multioutput import MultiOutputRegressor
|
|
|
|
|
from src.pipeline.data_preprocessor import DataPreprocessor
|
|
|
|
|
from src.pipeline.model_trainer import ModelTrainer
|
|
|
|
|
|
|
|
|
|
# Set up logging
|
|
|
|
|
handler = RotatingFileHandler('/root/ds_erp_ai/logs/prediction_pipeline.log', maxBytes=100000, backupCount=3)
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
# Example of DataPreprocessor and ModelTrainer classes from the previous steps
|
|
|
|
|
class CompanyModelPipeline:
|
|
|
|
|
def __init__(self, company_ids, input_base_path):
|
|
|
|
|
self.company_ids = company_ids
|
|
|
|
|
self.input_base_path = input_base_path
|
|
|
|
|
|
|
|
|
|
def run_pipeline(self):
|
|
|
|
|
for company_id in self.company_ids:
|
|
|
|
|
try:
|
|
|
|
|
# Define paths for the company
|
|
|
|
|
input_path = os.path.join(self.input_base_path, f'{company_id}_raw_data.csv')
|
|
|
|
|
|
|
|
|
|
logger.info(f"Starting preprocessing for company {company_id}.")
|
|
|
|
|
|
2024-09-12 21:36:02 +00:00
|
|
|
# Step 1 : Preprocess the data
|
2024-09-12 00:01:03 +00:00
|
|
|
preprocessor = DataPreprocessor(input_path=input_path, company_id=company_id)
|
|
|
|
|
processed_data_path = preprocessor.run()
|
|
|
|
|
logger.info(f"Data preprocessing completed for company {company_id}. Processed data saved to {processed_data_path}.")
|
|
|
|
|
|
|
|
|
|
# Step 2: Train and save the model
|
|
|
|
|
model = MultiOutputRegressor(RandomForestRegressor(n_estimators=100, random_state=42))
|
|
|
|
|
trainer = ModelTrainer(preprocessed_data_path=processed_data_path, company_id=company_id, model=model)
|
|
|
|
|
model_path, latest_data_path, evaluation_results = trainer.run()
|
|
|
|
|
|
|
|
|
|
logger.info(f"Model training and evaluation completed for company {company_id}.")
|
|
|
|
|
logger.info(f"Model saved to {model_path} and latest data saved to {latest_data_path}.")
|
|
|
|
|
logger.info(f"Evaluation Results for company {company_id}: {evaluation_results}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"An error occurred while processing company {company_id}: {e}")
|
|
|
|
|
|