diff --git a/src/model_training.py b/src/model_training.py index 169a52e..06bf372 100644 --- a/src/model_training.py +++ b/src/model_training.py @@ -133,6 +133,41 @@ def plot_feature_importance(model, feature_names): else: importances = model.named_steps['classifier'].feature_importances_ + # Get the transformed feature names from the pipeline + if hasattr(model, 'named_steps'): + # For pipeline models, get the feature names from the preprocessor + preprocessor = model.named_steps['preprocessor'] + # Get the transformed feature names + transformed_features = [] + + # Handle numerical features (they keep their names) + numerical_features = preprocessor.transformers_[0][2] # Numerical features list + transformed_features.extend(numerical_features) + + # Handle categorical features (they get expanded with one-hot encoding) + categorical_features = preprocessor.transformers_[1][2] # Categorical features list + categorical_transformer = preprocessor.transformers_[1][1] # OneHotEncoder + if hasattr(categorical_transformer, 'get_feature_names_out'): + # For newer scikit-learn versions + cat_feature_names = categorical_transformer.get_feature_names_out(categorical_features) + else: + # For older scikit-learn versions + cat_feature_names = categorical_transformer.named_steps['onehot'].get_feature_names(categorical_features) + transformed_features.extend(cat_feature_names) + + # Handle binary features (they pass through) + binary_features = preprocessor.transformers_[2][2] # Binary features list + transformed_features.extend(binary_features) + + # Use the transformed feature names + feature_names = transformed_features + + # Make sure the lengths match + if len(feature_names) != len(importances): + print(f"Warning: Feature names length ({len(feature_names)}) doesn't match importances length ({len(importances)})") + # Use generic feature names if lengths don't match + feature_names = [f'Feature {i}' for i in range(len(importances))] + # Create a DataFrame for visualization feature_importance = pd.DataFrame({ 'Feature': feature_names,