From 641fc85209af70350549537897bdcbda4cd7d216 Mon Sep 17 00:00:00 2001 From: Michael Ikehi Date: Wed, 23 Apr 2025 22:45:54 +0100 Subject: [PATCH] Initial project setup for fraud detection system - Define project structure with data, experiments, models, and src directories - Outline key tasks: EDA, feature engineering, model training, API and UI development - Document dataset features and project requirements - Create comprehensive README with implementation roadmap --- src/model_training.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/model_training.py b/src/model_training.py index 169a52e..06bf372 100644 --- a/src/model_training.py +++ b/src/model_training.py @@ -133,6 +133,41 @@ def plot_feature_importance(model, feature_names): else: importances = model.named_steps['classifier'].feature_importances_ + # Get the transformed feature names from the pipeline + if hasattr(model, 'named_steps'): + # For pipeline models, get the feature names from the preprocessor + preprocessor = model.named_steps['preprocessor'] + # Get the transformed feature names + transformed_features = [] + + # Handle numerical features (they keep their names) + numerical_features = preprocessor.transformers_[0][2] # Numerical features list + transformed_features.extend(numerical_features) + + # Handle categorical features (they get expanded with one-hot encoding) + categorical_features = preprocessor.transformers_[1][2] # Categorical features list + categorical_transformer = preprocessor.transformers_[1][1] # OneHotEncoder + if hasattr(categorical_transformer, 'get_feature_names_out'): + # For newer scikit-learn versions + cat_feature_names = categorical_transformer.get_feature_names_out(categorical_features) + else: + # For older scikit-learn versions + cat_feature_names = categorical_transformer.named_steps['onehot'].get_feature_names(categorical_features) + transformed_features.extend(cat_feature_names) + + # Handle binary features (they pass through) + binary_features = preprocessor.transformers_[2][2] # Binary features list + transformed_features.extend(binary_features) + + # Use the transformed feature names + feature_names = transformed_features + + # Make sure the lengths match + if len(feature_names) != len(importances): + print(f"Warning: Feature names length ({len(feature_names)}) doesn't match importances length ({len(importances)})") + # Use generic feature names if lengths don't match + feature_names = [f'Feature {i}' for i in range(len(importances))] + # Create a DataFrame for visualization feature_importance = pd.DataFrame({ 'Feature': feature_names,