first commit
This commit is contained in:
+131
@@ -0,0 +1,131 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
import os
|
||||
import sys
|
||||
import uvicorn
|
||||
import json
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from src import config
|
||||
from src.predict import load_model, predict_fraud, predict_batch
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Fraud Detection API",
|
||||
description="API for detecting fraudulent transactions",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Load the model at startup
|
||||
model = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global model
|
||||
model = load_model()
|
||||
if model is None:
|
||||
print("Warning: Model could not be loaded. API will not function correctly.")
|
||||
|
||||
|
||||
# Define request and response models
|
||||
class TransactionRequest(BaseModel):
|
||||
trans_date_trans_time: str = Field(..., description="Transaction timestamp")
|
||||
cc_num: str = Field(..., description="Credit card number")
|
||||
merchant: str = Field(..., description="Merchant name")
|
||||
category: str = Field(..., description="Transaction category")
|
||||
amt: float = Field(..., description="Transaction amount")
|
||||
first: str = Field(..., description="Cardholder first name")
|
||||
last: str = Field(..., description="Cardholder last name")
|
||||
gender: str = Field(..., description="Cardholder gender")
|
||||
street: str = Field(..., description="Cardholder street address")
|
||||
city: str = Field(..., description="Cardholder city")
|
||||
state: str = Field(..., description="Cardholder state")
|
||||
zip: str = Field(..., description="Cardholder ZIP code")
|
||||
lat: float = Field(..., description="Cardholder latitude")
|
||||
long: float = Field(..., description="Cardholder longitude")
|
||||
city_pop: int = Field(..., description="City population")
|
||||
job: str = Field(..., description="Cardholder job")
|
||||
dob: str = Field(..., description="Cardholder date of birth")
|
||||
trans_num: str = Field(..., description="Transaction number")
|
||||
unix_time: int = Field(..., description="Unix timestamp")
|
||||
merch_lat: float = Field(..., description="Merchant latitude")
|
||||
merch_long: float = Field(..., description="Merchant longitude")
|
||||
|
||||
|
||||
class PredictionResponse(BaseModel):
|
||||
is_fraud: bool = Field(..., description="Fraud prediction (True/False)")
|
||||
fraud_probability: float = Field(..., description="Probability of fraud")
|
||||
risk_level: str = Field(..., description="Risk level (low/medium/high)")
|
||||
|
||||
|
||||
class BatchPredictionRequest(BaseModel):
|
||||
transactions: List[TransactionRequest] = Field(..., description="List of transactions")
|
||||
|
||||
|
||||
class BatchPredictionResponse(BaseModel):
|
||||
predictions: List[PredictionResponse] = Field(..., description="List of predictions")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Fraud Detection API"}
|
||||
|
||||
|
||||
@app.post("/predict", response_model=PredictionResponse)
|
||||
async def predict(transaction: TransactionRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
# Convert Pydantic model to dictionary
|
||||
transaction_dict = transaction.model_dump()
|
||||
|
||||
# Make prediction
|
||||
result = predict_fraud(model, transaction_dict)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/predict/batch", response_model=BatchPredictionResponse)
|
||||
async def predict_multiple(request: BatchPredictionRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
# Convert Pydantic models to dictionaries
|
||||
transactions_dict = [transaction.model_dump() for transaction in request.transactions]
|
||||
|
||||
# Make predictions
|
||||
results = predict_batch(model, transactions_dict)
|
||||
|
||||
return {"predictions": results}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "model_loaded": model is not None}
|
||||
|
||||
|
||||
@app.get("/model-info")
|
||||
async def model_info():
|
||||
try:
|
||||
with open(config.MODEL_METADATA_PATH, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
return metadata
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Model metadata not found")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the API server"""
|
||||
uvicorn.run(
|
||||
"src.api.app:app",
|
||||
host=config.API_HOST,
|
||||
port=config.API_PORT,
|
||||
reload=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import sys
|
||||
import joblib
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from src import config
|
||||
from src.predict import predict_fraud, predict_batch
|
||||
|
||||
|
||||
class FraudDetectionModel:
|
||||
"""
|
||||
Class to handle model loading and inference
|
||||
"""
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Load the trained model
|
||||
"""
|
||||
try:
|
||||
self.model = joblib.load(config.MODEL_PATH)
|
||||
print(f"Model loaded successfully from {config.MODEL_PATH}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
print(f"Model file not found at {config.MODEL_PATH}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {str(e)}")
|
||||
return False
|
||||
|
||||
def predict(self, transaction_data):
|
||||
"""
|
||||
Predict fraud for a single transaction
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model not loaded")
|
||||
|
||||
return predict_fraud(self.model, transaction_data)
|
||||
|
||||
def predict_batch(self, transactions_data):
|
||||
"""
|
||||
Predict fraud for multiple transactions
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model not loaded")
|
||||
|
||||
return predict_batch(self.model, transactions_data)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
model_instance = None
|
||||
|
||||
|
||||
def get_model_instance():
|
||||
"""
|
||||
Get or create the model instance
|
||||
"""
|
||||
global model_instance
|
||||
if model_instance is None:
|
||||
model_instance = FraudDetectionModel()
|
||||
return model_instance
|
||||
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Base directories
|
||||
BASE_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
DATA_DIR = BASE_DIR / 'data'
|
||||
RAW_DATA_DIR = DATA_DIR / 'raw'
|
||||
PROCESSED_DATA_DIR = DATA_DIR / 'processed'
|
||||
MODELS_DIR = BASE_DIR / 'models'
|
||||
|
||||
# Data files
|
||||
TRAIN_DATA_PATH = RAW_DATA_DIR / 'fraudTrain.csv'
|
||||
TEST_DATA_PATH = RAW_DATA_DIR / 'fraudTest.csv'
|
||||
|
||||
# Processed data files
|
||||
PROCESSED_TRAIN_DATA_PATH = PROCESSED_DATA_DIR / 'processed_train.csv'
|
||||
PROCESSED_TEST_DATA_PATH = PROCESSED_DATA_DIR / 'processed_test.csv'
|
||||
|
||||
# Model files
|
||||
MODEL_PATH = MODELS_DIR / 'fraud_model.pkl'
|
||||
MODEL_METADATA_PATH = MODELS_DIR / 'model_metadata.json'
|
||||
|
||||
# API settings
|
||||
API_HOST = '0.0.0.0'
|
||||
API_PORT = 8000
|
||||
|
||||
# Web UI settings
|
||||
WEB_HOST = '0.0.0.0'
|
||||
WEB_PORT = 8501
|
||||
@@ -0,0 +1,185 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.pipeline import Pipeline
|
||||
# datetime is used implicitly when working with pandas datetime objects
|
||||
from geopy.distance import geodesic
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from src import config
|
||||
|
||||
|
||||
def load_data(file_path):
|
||||
"""
|
||||
Load data from CSV file
|
||||
"""
|
||||
return pd.read_csv(file_path)
|
||||
|
||||
|
||||
def calculate_distance(row):
|
||||
"""
|
||||
Calculate the distance between the cardholder and merchant in kilometers
|
||||
"""
|
||||
try:
|
||||
cardholder_coords = (row['lat'], row['long'])
|
||||
merchant_coords = (row['merch_lat'], row['merch_long'])
|
||||
return geodesic(cardholder_coords, merchant_coords).kilometers
|
||||
except:
|
||||
return np.nan
|
||||
|
||||
|
||||
def extract_time_features(df):
|
||||
"""
|
||||
Extract time-based features from the transaction timestamp
|
||||
"""
|
||||
# Convert to datetime
|
||||
df['trans_date_trans_time'] = pd.to_datetime(df['trans_date_trans_time'])
|
||||
|
||||
# Extract features
|
||||
df['hour'] = df['trans_date_trans_time'].dt.hour
|
||||
df['day'] = df['trans_date_trans_time'].dt.day
|
||||
df['weekday'] = df['trans_date_trans_time'].dt.weekday
|
||||
df['month'] = df['trans_date_trans_time'].dt.month
|
||||
df['year'] = df['trans_date_trans_time'].dt.year
|
||||
|
||||
# Create is_weekend feature
|
||||
df['is_weekend'] = df['weekday'].apply(lambda x: 1 if x >= 5 else 0)
|
||||
|
||||
# Create time of day categories
|
||||
df['time_of_day'] = df['hour'].apply(lambda x:
|
||||
'night' if 0 <= x < 6 else
|
||||
'morning' if 6 <= x < 12 else
|
||||
'afternoon' if 12 <= x < 18 else
|
||||
'evening')
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def calculate_age(df):
|
||||
"""
|
||||
Calculate age of the cardholder based on date of birth
|
||||
"""
|
||||
# Convert to datetime
|
||||
df['dob'] = pd.to_datetime(df['dob'])
|
||||
|
||||
# Calculate age at the time of transaction
|
||||
df['age'] = df.apply(lambda row: (row['trans_date_trans_time'].year - row['dob'].year) -
|
||||
((row['trans_date_trans_time'].month, row['trans_date_trans_time'].day) <
|
||||
(row['dob'].month, row['dob'].day)), axis=1)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def preprocess_data(df, is_training=True):
|
||||
"""
|
||||
Preprocess the data for model training or prediction
|
||||
"""
|
||||
# Make a copy to avoid modifying the original dataframe
|
||||
df_processed = df.copy()
|
||||
|
||||
# Handle missing values
|
||||
for col in df_processed.columns:
|
||||
if df_processed[col].dtype == 'object':
|
||||
# Fix for FutureWarning - avoid chained assignment with inplace=True
|
||||
df_processed[col] = df_processed[col].fillna('unknown')
|
||||
else:
|
||||
# Fix for FutureWarning - avoid chained assignment with inplace=True
|
||||
df_processed[col] = df_processed[col].fillna(df_processed[col].median())
|
||||
|
||||
# Extract time features
|
||||
df_processed = extract_time_features(df_processed)
|
||||
|
||||
# Calculate age
|
||||
df_processed = calculate_age(df_processed)
|
||||
|
||||
# Calculate distance between cardholder and merchant
|
||||
df_processed['distance_km'] = df_processed.apply(calculate_distance, axis=1)
|
||||
|
||||
# Create feature for transaction amount relative to average for that category
|
||||
if is_training:
|
||||
category_avg = df_processed.groupby('category')['amt'].mean().to_dict()
|
||||
else:
|
||||
# Load the category averages from the training data
|
||||
# This would be stored during training
|
||||
category_avg = pd.read_csv(config.PROCESSED_DATA_DIR / 'category_avg.csv').set_index('category')['amt'].to_dict()
|
||||
|
||||
df_processed['amt_to_category_avg'] = df_processed.apply(
|
||||
lambda row: row['amt'] / category_avg.get(row['category'], 1), axis=1)
|
||||
|
||||
# Select features for model
|
||||
feature_cols = [
|
||||
'amt', 'distance_km', 'age', 'hour', 'day', 'weekday', 'month',
|
||||
'is_weekend', 'amt_to_category_avg', 'city_pop', 'category', 'time_of_day'
|
||||
]
|
||||
|
||||
# For training data, save the category averages
|
||||
if is_training:
|
||||
pd.DataFrame(list(category_avg.items()), columns=['category', 'amt']).to_csv(
|
||||
config.PROCESSED_DATA_DIR / 'category_avg.csv', index=False)
|
||||
|
||||
# Return the processed data with selected features
|
||||
return df_processed[feature_cols + (['is_fraud'] if 'is_fraud' in df_processed.columns else [])]
|
||||
|
||||
|
||||
def get_preprocessing_pipeline():
|
||||
"""
|
||||
Create a preprocessing pipeline for numerical and categorical features
|
||||
"""
|
||||
# Define numerical and categorical features
|
||||
numerical_features = [
|
||||
'amt', 'distance_km', 'age', 'hour', 'day', 'weekday', 'month',
|
||||
'amt_to_category_avg', 'city_pop'
|
||||
]
|
||||
|
||||
categorical_features = ['category', 'time_of_day']
|
||||
binary_features = ['is_weekend']
|
||||
|
||||
# Create preprocessing pipelines
|
||||
numerical_transformer = Pipeline(steps=[
|
||||
('scaler', StandardScaler())
|
||||
])
|
||||
|
||||
categorical_transformer = Pipeline(steps=[
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
||||
])
|
||||
|
||||
# Combine preprocessing steps
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numerical_transformer, numerical_features),
|
||||
('cat', categorical_transformer, categorical_features),
|
||||
('bin', 'passthrough', binary_features)
|
||||
])
|
||||
|
||||
return preprocessor
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to preprocess the data and save it
|
||||
"""
|
||||
print("Loading training data...")
|
||||
train_data = load_data(config.TRAIN_DATA_PATH)
|
||||
|
||||
print("Loading test data...")
|
||||
test_data = load_data(config.TEST_DATA_PATH)
|
||||
|
||||
print("Preprocessing training data...")
|
||||
processed_train = preprocess_data(train_data, is_training=True)
|
||||
|
||||
print("Preprocessing test data...")
|
||||
processed_test = preprocess_data(test_data, is_training=False)
|
||||
|
||||
print("Saving processed data...")
|
||||
processed_train.to_csv(config.PROCESSED_TRAIN_DATA_PATH, index=False)
|
||||
processed_test.to_csv(config.PROCESSED_TEST_DATA_PATH, index=False)
|
||||
|
||||
print("Data preprocessing completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,176 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import joblib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from src import config
|
||||
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Load the trained model
|
||||
"""
|
||||
try:
|
||||
model = joblib.load(config.MODEL_PATH)
|
||||
return model
|
||||
except FileNotFoundError:
|
||||
print(f"Model file not found at {config.MODEL_PATH}")
|
||||
return None
|
||||
|
||||
|
||||
def load_test_data():
|
||||
"""
|
||||
Load the processed test data
|
||||
"""
|
||||
try:
|
||||
test_data = pd.read_csv(config.PROCESSED_TEST_DATA_PATH)
|
||||
return test_data
|
||||
except FileNotFoundError:
|
||||
print(f"Test data file not found at {config.PROCESSED_TEST_DATA_PATH}")
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_model(model, test_data):
|
||||
"""
|
||||
Evaluate the model on test data
|
||||
"""
|
||||
if 'is_fraud' not in test_data.columns:
|
||||
print("Target variable 'is_fraud' not found in test data")
|
||||
return None
|
||||
|
||||
# Split features and target
|
||||
X_test = test_data.drop('is_fraud', axis=1)
|
||||
y_test = test_data['is_fraud']
|
||||
|
||||
# Make predictions
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)[:, 1] # Probability of positive class
|
||||
|
||||
# Calculate metrics
|
||||
metrics = {
|
||||
'accuracy': accuracy_score(y_test, y_pred),
|
||||
'precision': precision_score(y_test, y_pred),
|
||||
'recall': recall_score(y_test, y_pred),
|
||||
'f1': f1_score(y_test, y_pred),
|
||||
'confusion_matrix': confusion_matrix(y_test, y_pred).tolist()
|
||||
}
|
||||
|
||||
# Print metrics
|
||||
print("Test Set Metrics:")
|
||||
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||
print(f"Precision: {metrics['precision']:.4f}")
|
||||
print(f"Recall: {metrics['recall']:.4f}")
|
||||
print(f"F1 Score: {metrics['f1']:.4f}")
|
||||
print("Confusion Matrix:")
|
||||
print(metrics['confusion_matrix'])
|
||||
|
||||
# Plot ROC curve
|
||||
plot_roc_curve(y_test, y_prob)
|
||||
|
||||
# Plot Precision-Recall curve
|
||||
plot_precision_recall_curve(y_test, y_prob)
|
||||
|
||||
# Plot confusion matrix
|
||||
plot_confusion_matrix(y_test, y_pred)
|
||||
|
||||
return metrics, y_pred, y_prob
|
||||
|
||||
|
||||
def plot_roc_curve(y_true, y_prob):
|
||||
"""
|
||||
Plot ROC curve
|
||||
"""
|
||||
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
|
||||
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('Receiver Operating Characteristic (ROC) Curve')
|
||||
plt.legend(loc="lower right")
|
||||
plt.savefig(os.path.join(config.MODELS_DIR, 'roc_curve.png'))
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_precision_recall_curve(y_true, y_prob):
|
||||
"""
|
||||
Plot Precision-Recall curve
|
||||
"""
|
||||
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
||||
avg_precision = average_precision_score(y_true, y_prob)
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(recall, precision, color='blue', lw=2, label=f'Precision-Recall curve (AP = {avg_precision:.2f})')
|
||||
plt.xlabel('Recall')
|
||||
plt.ylabel('Precision')
|
||||
plt.title('Precision-Recall Curve')
|
||||
plt.legend(loc="lower left")
|
||||
plt.savefig(os.path.join(config.MODELS_DIR, 'precision_recall_curve.png'))
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred):
|
||||
"""
|
||||
Plot confusion matrix
|
||||
"""
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
plt.figure(figsize=(8, 6))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
|
||||
plt.xlabel('Predicted')
|
||||
plt.ylabel('True')
|
||||
plt.title('Confusion Matrix')
|
||||
plt.savefig(os.path.join(config.MODELS_DIR, 'confusion_matrix.png'))
|
||||
plt.close()
|
||||
|
||||
|
||||
def save_evaluation_results(metrics):
|
||||
"""
|
||||
Save evaluation results to a file
|
||||
"""
|
||||
results_path = os.path.join(config.MODELS_DIR, 'evaluation_results.json')
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(metrics, f, indent=4)
|
||||
|
||||
print(f"Evaluation results saved to {results_path}")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to evaluate the model
|
||||
"""
|
||||
# Load the model
|
||||
print("Loading the model...")
|
||||
model = load_model()
|
||||
if model is None:
|
||||
return
|
||||
|
||||
# Load test data
|
||||
print("Loading test data...")
|
||||
test_data = load_test_data()
|
||||
if test_data is None:
|
||||
return
|
||||
|
||||
# Evaluate the model
|
||||
print("Evaluating the model...")
|
||||
metrics, _, _ = evaluate_model(model, test_data)
|
||||
|
||||
# Save evaluation results
|
||||
save_evaluation_results(metrics)
|
||||
|
||||
print("Model evaluation completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,259 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import joblib
|
||||
import os
|
||||
import sys
|
||||
from sklearn.model_selection import train_test_split, GridSearchCV
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from imblearn.over_sampling import SMOTE
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from src import config
|
||||
from src.data_preprocessing import get_preprocessing_pipeline
|
||||
|
||||
|
||||
def train_model(X_train, y_train, X_val=None, y_val=None, use_smote=True):
|
||||
"""
|
||||
Train a model on the given data
|
||||
"""
|
||||
# Get preprocessing pipeline
|
||||
preprocessor = get_preprocessing_pipeline()
|
||||
|
||||
# Handle categorical features before SMOTE
|
||||
print("Preprocessing data...")
|
||||
# Identify categorical columns
|
||||
categorical_cols = X_train.select_dtypes(include=['object', 'category']).columns.tolist()
|
||||
print(f"Categorical columns: {categorical_cols}")
|
||||
|
||||
if use_smote and categorical_cols:
|
||||
# We need to preprocess categorical features before applying SMOTE
|
||||
print("Preprocessing categorical features for SMOTE...")
|
||||
# Create a preprocessing pipeline just for categorical features
|
||||
cat_preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('cat', OneHotEncoder(handle_unknown='ignore'), categorical_cols)
|
||||
],
|
||||
remainder='passthrough'
|
||||
)
|
||||
|
||||
# Apply preprocessing to training data
|
||||
X_train_processed = cat_preprocessor.fit_transform(X_train)
|
||||
|
||||
# Apply SMOTE to the preprocessed data
|
||||
print("Applying SMOTE to handle class imbalance...")
|
||||
smote = SMOTE(random_state=42)
|
||||
X_train_resampled, y_train_resampled = smote.fit_resample(X_train_processed, y_train)
|
||||
|
||||
# For the final pipeline, we'll use the original data and let the full preprocessor handle it
|
||||
X_train_for_pipeline, y_train_for_pipeline = X_train, y_train
|
||||
elif use_smote:
|
||||
# If no categorical features, apply SMOTE directly
|
||||
print("Applying SMOTE to handle class imbalance...")
|
||||
smote = SMOTE(random_state=42)
|
||||
X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)
|
||||
X_train_for_pipeline, y_train_for_pipeline = X_train_resampled, y_train_resampled
|
||||
else:
|
||||
# No SMOTE, use original data
|
||||
X_train_for_pipeline, y_train_for_pipeline = X_train, y_train
|
||||
|
||||
# Create the full pipeline with preprocessing and model
|
||||
pipeline = Pipeline(steps=[
|
||||
('preprocessor', preprocessor),
|
||||
('classifier', RandomForestClassifier(random_state=42))
|
||||
])
|
||||
|
||||
# Define hyperparameters for grid search
|
||||
param_grid = {
|
||||
'classifier__n_estimators': [100, 200],
|
||||
'classifier__max_depth': [None, 10, 20],
|
||||
'classifier__min_samples_split': [2, 5, 10]
|
||||
}
|
||||
|
||||
# Perform grid search with cross-validation
|
||||
print("Performing grid search with cross-validation...")
|
||||
grid_search = GridSearchCV(pipeline, param_grid, cv=3, scoring='f1', n_jobs=-1)
|
||||
grid_search.fit(X_train_for_pipeline, y_train_for_pipeline)
|
||||
|
||||
# Get the best model
|
||||
best_model = grid_search.best_estimator_
|
||||
print(f"Best parameters: {grid_search.best_params_}")
|
||||
|
||||
# Evaluate on validation set if provided
|
||||
if X_val is not None and y_val is not None:
|
||||
y_pred = best_model.predict(X_val)
|
||||
print("Validation metrics:")
|
||||
print_metrics(y_val, y_pred)
|
||||
|
||||
return best_model, grid_search.best_params_
|
||||
|
||||
|
||||
def print_metrics(y_true, y_pred):
|
||||
"""
|
||||
Print evaluation metrics
|
||||
"""
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
precision = precision_score(y_true, y_pred)
|
||||
recall = recall_score(y_true, y_pred)
|
||||
f1 = f1_score(y_true, y_pred)
|
||||
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Precision: {precision:.4f}")
|
||||
print(f"Recall: {recall:.4f}")
|
||||
print(f"F1 Score: {f1:.4f}")
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
print("Confusion Matrix:")
|
||||
print(cm)
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'confusion_matrix': cm.tolist()
|
||||
}
|
||||
|
||||
|
||||
def plot_feature_importance(model, feature_names):
|
||||
"""
|
||||
Plot feature importance
|
||||
"""
|
||||
# Get feature importance from the model
|
||||
if hasattr(model, 'feature_importances_'):
|
||||
importances = model.feature_importances_
|
||||
else:
|
||||
importances = model.named_steps['classifier'].feature_importances_
|
||||
|
||||
# Get the transformed feature names from the pipeline
|
||||
if hasattr(model, 'named_steps'):
|
||||
# For pipeline models, get the feature names from the preprocessor
|
||||
preprocessor = model.named_steps['preprocessor']
|
||||
# Get the transformed feature names
|
||||
transformed_features = []
|
||||
|
||||
# Handle numerical features (they keep their names)
|
||||
numerical_features = preprocessor.transformers_[0][2] # Numerical features list
|
||||
transformed_features.extend(numerical_features)
|
||||
|
||||
# Handle categorical features (they get expanded with one-hot encoding)
|
||||
categorical_features = preprocessor.transformers_[1][2] # Categorical features list
|
||||
categorical_transformer = preprocessor.transformers_[1][1] # OneHotEncoder
|
||||
if hasattr(categorical_transformer, 'get_feature_names_out'):
|
||||
# For newer scikit-learn versions
|
||||
cat_feature_names = categorical_transformer.get_feature_names_out(categorical_features)
|
||||
else:
|
||||
# For older scikit-learn versions
|
||||
cat_feature_names = categorical_transformer.named_steps['onehot'].get_feature_names(categorical_features)
|
||||
transformed_features.extend(cat_feature_names)
|
||||
|
||||
# Handle binary features (they pass through)
|
||||
binary_features = preprocessor.transformers_[2][2] # Binary features list
|
||||
transformed_features.extend(binary_features)
|
||||
|
||||
# Use the transformed feature names
|
||||
feature_names = transformed_features
|
||||
|
||||
# Make sure the lengths match
|
||||
if len(feature_names) != len(importances):
|
||||
print(f"Warning: Feature names length ({len(feature_names)}) doesn't match importances length ({len(importances)})")
|
||||
# Use generic feature names if lengths don't match
|
||||
feature_names = [f'Feature {i}' for i in range(len(importances))]
|
||||
|
||||
# Create a DataFrame for visualization
|
||||
feature_importance = pd.DataFrame({
|
||||
'Feature': feature_names,
|
||||
'Importance': importances
|
||||
}).sort_values('Importance', ascending=False)
|
||||
|
||||
# Plot
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.barplot(x='Importance', y='Feature', data=feature_importance)
|
||||
plt.title('Feature Importance')
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(config.MODELS_DIR, 'feature_importance.png'))
|
||||
plt.close()
|
||||
|
||||
return feature_importance
|
||||
|
||||
|
||||
def save_model(model, metadata):
|
||||
"""
|
||||
Save the trained model and its metadata
|
||||
"""
|
||||
# Create models directory if it doesn't exist
|
||||
os.makedirs(config.MODELS_DIR, exist_ok=True)
|
||||
|
||||
# Save the model
|
||||
joblib.dump(model, config.MODEL_PATH)
|
||||
|
||||
# Save metadata
|
||||
with open(config.MODEL_METADATA_PATH, 'w') as f:
|
||||
json.dump(metadata, f, indent=4)
|
||||
|
||||
print(f"Model saved to {config.MODEL_PATH}")
|
||||
print(f"Model metadata saved to {config.MODEL_METADATA_PATH}")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to train the model
|
||||
"""
|
||||
# Load processed data
|
||||
print("Loading processed training data...")
|
||||
try:
|
||||
train_data = pd.read_csv(config.PROCESSED_TRAIN_DATA_PATH)
|
||||
except FileNotFoundError:
|
||||
print("Processed training data not found. Please run data_preprocessing.py first.")
|
||||
return
|
||||
|
||||
# Split features and target
|
||||
X = train_data.drop('is_fraud', axis=1)
|
||||
y = train_data['is_fraud']
|
||||
|
||||
# Split into training and validation sets
|
||||
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
|
||||
|
||||
print(f"Training data shape: {X_train.shape}")
|
||||
print(f"Validation data shape: {X_val.shape}")
|
||||
|
||||
# Train the model
|
||||
print("Training the model...")
|
||||
model, best_params = train_model(X_train, y_train, X_val, y_val)
|
||||
|
||||
# Evaluate on validation set
|
||||
print("\nEvaluating on validation set:")
|
||||
y_pred = model.predict(X_val)
|
||||
metrics = print_metrics(y_val, y_pred)
|
||||
|
||||
# Get feature names after preprocessing
|
||||
feature_names = X.columns.tolist()
|
||||
|
||||
# Plot feature importance
|
||||
print("\nPlotting feature importance...")
|
||||
feature_importance = plot_feature_importance(model, feature_names)
|
||||
|
||||
# Save the model and metadata
|
||||
metadata = {
|
||||
'model_type': 'RandomForestClassifier',
|
||||
'best_parameters': best_params,
|
||||
'metrics': metrics,
|
||||
'feature_importance': feature_importance.to_dict(orient='records'),
|
||||
'features': feature_names
|
||||
}
|
||||
|
||||
save_model(model, metadata)
|
||||
|
||||
print("Model training completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+143
@@ -0,0 +1,143 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import joblib
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from src import config
|
||||
from src.data_preprocessing import preprocess_data
|
||||
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Load the trained model
|
||||
"""
|
||||
try:
|
||||
model = joblib.load(config.MODEL_PATH)
|
||||
return model
|
||||
except FileNotFoundError:
|
||||
print(f"Model file not found at {config.MODEL_PATH}")
|
||||
return None
|
||||
|
||||
|
||||
def load_model_metadata():
|
||||
"""
|
||||
Load the model metadata
|
||||
"""
|
||||
try:
|
||||
with open(config.MODEL_METADATA_PATH, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
return metadata
|
||||
except FileNotFoundError:
|
||||
print(f"Model metadata file not found at {config.MODEL_METADATA_PATH}")
|
||||
return None
|
||||
|
||||
|
||||
def prepare_input_data(transaction_data):
|
||||
"""
|
||||
Prepare input data for prediction
|
||||
"""
|
||||
# Convert to DataFrame if it's a dictionary
|
||||
if isinstance(transaction_data, dict):
|
||||
transaction_data = pd.DataFrame([transaction_data])
|
||||
|
||||
# Preprocess the data
|
||||
processed_data = preprocess_data(transaction_data, is_training=False)
|
||||
|
||||
return processed_data
|
||||
|
||||
|
||||
def predict_fraud(model, transaction_data):
|
||||
"""
|
||||
Predict fraud for a transaction
|
||||
"""
|
||||
# Prepare the input data
|
||||
processed_data = prepare_input_data(transaction_data)
|
||||
|
||||
# Make prediction
|
||||
prediction = model.predict(processed_data)[0]
|
||||
probability = model.predict_proba(processed_data)[0, 1] # Probability of fraud
|
||||
|
||||
result = {
|
||||
'is_fraud': bool(prediction),
|
||||
'fraud_probability': float(probability),
|
||||
'risk_level': 'high' if probability > 0.7 else 'medium' if probability > 0.3 else 'low'
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def predict_batch(model, transactions_data):
|
||||
"""
|
||||
Predict fraud for multiple transactions
|
||||
"""
|
||||
# Prepare the input data
|
||||
processed_data = prepare_input_data(transactions_data)
|
||||
|
||||
# Make predictions
|
||||
predictions = model.predict(processed_data)
|
||||
probabilities = model.predict_proba(processed_data)[:, 1] # Probabilities of fraud
|
||||
|
||||
results = []
|
||||
for i in range(len(predictions)):
|
||||
result = {
|
||||
'is_fraud': bool(predictions[i]),
|
||||
'fraud_probability': float(probabilities[i]),
|
||||
'risk_level': 'high' if probabilities[i] > 0.7 else 'medium' if probabilities[i] > 0.3 else 'low'
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function for demonstration
|
||||
"""
|
||||
# Load the model
|
||||
print("Loading the model...")
|
||||
model = load_model()
|
||||
if model is None:
|
||||
return
|
||||
|
||||
# Example transaction data
|
||||
example_transaction = {
|
||||
'trans_date_trans_time': '2019-01-01 00:00:00',
|
||||
'cc_num': '4532315247148429',
|
||||
'merchant': 'fraud_Rippin, Kub and Mann',
|
||||
'category': 'grocery_pos',
|
||||
'amt': 4.97,
|
||||
'first': 'John',
|
||||
'last': 'Doe',
|
||||
'gender': 'M',
|
||||
'street': '123 Main St',
|
||||
'city': 'New York',
|
||||
'state': 'NY',
|
||||
'zip': '10001',
|
||||
'lat': 40.7128,
|
||||
'long': -74.0060,
|
||||
'city_pop': 8336817,
|
||||
'job': 'Data Scientist',
|
||||
'dob': '1980-01-01',
|
||||
'trans_num': 'a795d3a0f8f11f9c45d3a4aa62b5c0f3',
|
||||
'unix_time': 1546300800,
|
||||
'merch_lat': 40.7128,
|
||||
'merch_long': -74.0060
|
||||
}
|
||||
|
||||
# Make prediction
|
||||
print("Making prediction...")
|
||||
result = predict_fraud(model, example_transaction)
|
||||
|
||||
# Print result
|
||||
print("\nPrediction Result:")
|
||||
print(f"Is Fraud: {result['is_fraud']}")
|
||||
print(f"Fraud Probability: {result['fraud_probability']:.4f}")
|
||||
print(f"Risk Level: {result['risk_level']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import requests
|
||||
from flask import Flask, render_template, request, jsonify
|
||||
|
||||
# Add the project root to the path so we can import from src
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from src import config
|
||||
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__)
|
||||
|
||||
# API URL
|
||||
API_URL = f"http://{config.API_HOST}:{config.API_PORT}"
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""
|
||||
Render the main page
|
||||
"""
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
@app.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
"""
|
||||
Handle prediction request
|
||||
"""
|
||||
try:
|
||||
# Get form data
|
||||
transaction_data = {
|
||||
'trans_date_trans_time': request.form.get('trans_date_trans_time'),
|
||||
'cc_num': request.form.get('cc_num'),
|
||||
'merchant': request.form.get('merchant'),
|
||||
'category': request.form.get('category'),
|
||||
'amt': float(request.form.get('amt')),
|
||||
'first': request.form.get('first'),
|
||||
'last': request.form.get('last'),
|
||||
'gender': request.form.get('gender'),
|
||||
'street': request.form.get('street'),
|
||||
'city': request.form.get('city'),
|
||||
'state': request.form.get('state'),
|
||||
'zip': request.form.get('zip'),
|
||||
'lat': float(request.form.get('lat')),
|
||||
'long': float(request.form.get('long')),
|
||||
'city_pop': int(request.form.get('city_pop')),
|
||||
'job': request.form.get('job'),
|
||||
'dob': request.form.get('dob'),
|
||||
'trans_num': request.form.get('trans_num'),
|
||||
'unix_time': int(request.form.get('unix_time')),
|
||||
'merch_lat': float(request.form.get('merch_lat')),
|
||||
'merch_long': float(request.form.get('merch_long'))
|
||||
}
|
||||
|
||||
# Call API
|
||||
response = requests.post(f"{API_URL}/predict", json=transaction_data)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return render_template('result.html', result=result, transaction=transaction_data)
|
||||
else:
|
||||
error_message = f"API Error: {response.status_code} - {response.text}"
|
||||
return render_template('error.html', error=error_message)
|
||||
|
||||
except Exception as e:
|
||||
return render_template('error.html', error=str(e))
|
||||
|
||||
|
||||
@app.route('/api-status')
|
||||
def api_status():
|
||||
"""
|
||||
Check API status
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/health")
|
||||
return jsonify(response.json())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "message": str(e)})
|
||||
|
||||
|
||||
@app.route('/model-info')
|
||||
def model_info():
|
||||
"""
|
||||
Get model information
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/model-info")
|
||||
if response.status_code == 200:
|
||||
return render_template('model_info.html', model_info=response.json())
|
||||
else:
|
||||
error_message = f"API Error: {response.status_code} - {response.text}"
|
||||
return render_template('error.html', error=error_message)
|
||||
except Exception as e:
|
||||
return render_template('error.html', error=str(e))
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run the web server
|
||||
"""
|
||||
app.run(
|
||||
host=config.WEB_HOST,
|
||||
port=config.WEB_PORT,
|
||||
debug=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,69 @@
|
||||
/* Custom styles for Fraud Detection System */
|
||||
|
||||
body {
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
|
||||
.card {
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
border: none;
|
||||
}
|
||||
|
||||
.card-header {
|
||||
background-color: #007bff;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background-color: #007bff;
|
||||
border-color: #007bff;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: #0069d9;
|
||||
border-color: #0062cc;
|
||||
}
|
||||
|
||||
.alert {
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
/* Risk level badges */
|
||||
.badge.bg-danger {
|
||||
font-size: 1rem;
|
||||
padding: 0.5rem 1rem;
|
||||
}
|
||||
|
||||
.badge.bg-warning {
|
||||
font-size: 1rem;
|
||||
padding: 0.5rem 1rem;
|
||||
}
|
||||
|
||||
.badge.bg-success {
|
||||
font-size: 1rem;
|
||||
padding: 0.5rem 1rem;
|
||||
}
|
||||
|
||||
/* Form styling */
|
||||
.form-label {
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.form-control:focus, .form-select:focus {
|
||||
border-color: #80bdff;
|
||||
box-shadow: 0 0 0 0.2rem rgba(0, 123, 255, 0.25);
|
||||
}
|
||||
|
||||
/* Table styling */
|
||||
.table th {
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
|
||||
/* Progress bar for feature importance */
|
||||
.progress {
|
||||
height: 20px;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
transition: width 0.6s ease;
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// JavaScript for Fraud Detection System
|
||||
|
||||
// Function to generate a random transaction number
|
||||
function generateTransactionNumber() {
|
||||
return Math.random().toString(36).substring(2, 15) +
|
||||
Math.random().toString(36).substring(2, 15);
|
||||
}
|
||||
|
||||
// Function to fill form with sample data
|
||||
function fillSampleData() {
|
||||
// Sample transaction data
|
||||
const now = new Date();
|
||||
const formattedDate = now.toISOString().slice(0, 16);
|
||||
|
||||
document.getElementById('trans_date_trans_time').value = formattedDate;
|
||||
document.getElementById('cc_num').value = '4532' + Math.floor(1000000000000 + Math.random() * 9000000000000);
|
||||
document.getElementById('merchant').value = 'Sample Merchant';
|
||||
document.getElementById('category').value = 'shopping_pos';
|
||||
document.getElementById('amt').value = (Math.random() * 1000).toFixed(2);
|
||||
document.getElementById('first').value = 'John';
|
||||
document.getElementById('last').value = 'Doe';
|
||||
document.getElementById('gender').value = 'M';
|
||||
document.getElementById('dob').value = '1980-01-01';
|
||||
document.getElementById('job').value = 'Software Developer';
|
||||
document.getElementById('street').value = '123 Main St';
|
||||
document.getElementById('city').value = 'New York';
|
||||
document.getElementById('state').value = 'NY';
|
||||
document.getElementById('zip').value = '10001';
|
||||
document.getElementById('lat').value = '40.7128';
|
||||
document.getElementById('long').value = '-74.0060';
|
||||
document.getElementById('city_pop').value = '8336817';
|
||||
document.getElementById('merch_lat').value = '40.7128';
|
||||
document.getElementById('merch_long').value = '-74.0060';
|
||||
document.getElementById('trans_num').value = generateTransactionNumber();
|
||||
document.getElementById('unix_time').value = Math.floor(now.getTime() / 1000);
|
||||
}
|
||||
|
||||
// Add event listener when DOM is loaded
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Add sample data button if it exists
|
||||
const sampleDataBtn = document.getElementById('sample-data-btn');
|
||||
if (sampleDataBtn) {
|
||||
sampleDataBtn.addEventListener('click', fillSampleData);
|
||||
}
|
||||
|
||||
// Set current date and time as default for transaction time
|
||||
const transDateTimeInput = document.getElementById('trans_date_trans_time');
|
||||
if (transDateTimeInput) {
|
||||
const now = new Date();
|
||||
const formattedDate = now.toISOString().slice(0, 16);
|
||||
transDateTimeInput.value = formattedDate;
|
||||
}
|
||||
|
||||
// Set current unix time as default
|
||||
const unixTimeInput = document.getElementById('unix_time');
|
||||
if (unixTimeInput && !unixTimeInput.value) {
|
||||
const now = new Date();
|
||||
unixTimeInput.value = Math.floor(now.getTime() / 1000);
|
||||
}
|
||||
|
||||
// Generate random transaction number if empty
|
||||
const transNumInput = document.getElementById('trans_num');
|
||||
if (transNumInput && !transNumInput.value) {
|
||||
transNumInput.value = generateTransactionNumber();
|
||||
}
|
||||
});
|
||||
@@ -0,0 +1,40 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Error - Fraud Detection System</title>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css">
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container mt-5">
|
||||
<div class="row">
|
||||
<div class="col-md-12 text-center mb-4">
|
||||
<h1>Error</h1>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<div class="card">
|
||||
<div class="card-header bg-danger text-white">
|
||||
<h3>An Error Occurred</h3>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="alert alert-danger">
|
||||
<p>{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<div class="text-center mt-4">
|
||||
<a href="/" class="btn btn-primary">Return to Home</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,216 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Fraud Detection System</title>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css">
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container mt-5">
|
||||
<div class="row">
|
||||
<div class="col-md-12 text-center mb-4">
|
||||
<h1>Fraud Detection System</h1>
|
||||
<p class="lead">Enter transaction details to check for potential fraud</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<h3>Transaction Details</h3>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<form action="/predict" method="post">
|
||||
<div class="row">
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="trans_date_trans_time" class="form-label">Transaction Date/Time</label>
|
||||
<input type="datetime-local" class="form-control" id="trans_date_trans_time" name="trans_date_trans_time" required>
|
||||
</div>
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="cc_num" class="form-label">Credit Card Number</label>
|
||||
<input type="text" class="form-control" id="cc_num" name="cc_num" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="merchant" class="form-label">Merchant</label>
|
||||
<input type="text" class="form-control" id="merchant" name="merchant" required>
|
||||
</div>
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="category" class="form-label">Category</label>
|
||||
<select class="form-select" id="category" name="category" required>
|
||||
<option value="">Select Category</option>
|
||||
<option value="grocery_pos">Grocery</option>
|
||||
<option value="shopping_pos">Shopping</option>
|
||||
<option value="food_dining">Food & Dining</option>
|
||||
<option value="entertainment">Entertainment</option>
|
||||
<option value="gas_transport">Gas & Transport</option>
|
||||
<option value="health_fitness">Health & Fitness</option>
|
||||
<option value="travel">Travel</option>
|
||||
<option value="home">Home</option>
|
||||
<option value="kids_pets">Kids & Pets</option>
|
||||
<option value="personal_care">Personal Care</option>
|
||||
<option value="misc_pos">Miscellaneous</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="amt" class="form-label">Amount</label>
|
||||
<input type="number" step="0.01" class="form-control" id="amt" name="amt" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="first" class="form-label">First Name</label>
|
||||
<input type="text" class="form-control" id="first" name="first" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="last" class="form-label">Last Name</label>
|
||||
<input type="text" class="form-control" id="last" name="last" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="gender" class="form-label">Gender</label>
|
||||
<select class="form-select" id="gender" name="gender" required>
|
||||
<option value="">Select Gender</option>
|
||||
<option value="M">Male</option>
|
||||
<option value="F">Female</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="dob" class="form-label">Date of Birth</label>
|
||||
<input type="date" class="form-control" id="dob" name="dob" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="job" class="form-label">Job</label>
|
||||
<input type="text" class="form-control" id="job" name="job" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12 mb-3">
|
||||
<label for="street" class="form-label">Street Address</label>
|
||||
<input type="text" class="form-control" id="street" name="street" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="city" class="form-label">City</label>
|
||||
<input type="text" class="form-control" id="city" name="city" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="state" class="form-label">State</label>
|
||||
<input type="text" class="form-control" id="state" name="state" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="zip" class="form-label">ZIP Code</label>
|
||||
<input type="text" class="form-control" id="zip" name="zip" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="lat" class="form-label">Latitude</label>
|
||||
<input type="number" step="0.000001" class="form-control" id="lat" name="lat" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="long" class="form-label">Longitude</label>
|
||||
<input type="number" step="0.000001" class="form-control" id="long" name="long" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="city_pop" class="form-label">City Population</label>
|
||||
<input type="number" class="form-control" id="city_pop" name="city_pop" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="merch_lat" class="form-label">Merchant Latitude</label>
|
||||
<input type="number" step="0.000001" class="form-control" id="merch_lat" name="merch_lat" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="merch_long" class="form-label">Merchant Longitude</label>
|
||||
<input type="number" step="0.000001" class="form-control" id="merch_long" name="merch_long" required>
|
||||
</div>
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="trans_num" class="form-label">Transaction Number</label>
|
||||
<input type="text" class="form-control" id="trans_num" name="trans_num" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12 mb-3">
|
||||
<label for="unix_time" class="form-label">Unix Time</label>
|
||||
<input type="number" class="form-control" id="unix_time" name="unix_time" required>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12 text-center">
|
||||
<button type="submit" class="btn btn-primary btn-lg">Check for Fraud</button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row mt-4">
|
||||
<div class="col-md-12 text-center">
|
||||
<a href="/model-info" class="btn btn-secondary">View Model Information</a>
|
||||
<button id="check-api" class="btn btn-info">Check API Status</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row mt-3">
|
||||
<div class="col-md-12">
|
||||
<div id="api-status" class="alert alert-info d-none"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
||||
<script src="{{ url_for('static', filename='js/script.js') }}"></script>
|
||||
<script>
|
||||
// Set current date and time as default
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const now = new Date();
|
||||
const formattedDate = now.toISOString().slice(0, 16);
|
||||
document.getElementById('trans_date_trans_time').value = formattedDate;
|
||||
document.getElementById('unix_time').value = Math.floor(now.getTime() / 1000);
|
||||
|
||||
// Check API status button
|
||||
document.getElementById('check-api').addEventListener('click', function() {
|
||||
fetch('/api-status')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
const statusDiv = document.getElementById('api-status');
|
||||
statusDiv.classList.remove('d-none', 'alert-danger', 'alert-success', 'alert-info');
|
||||
|
||||
if (data.status === 'healthy') {
|
||||
statusDiv.classList.add('alert-success');
|
||||
statusDiv.textContent = 'API Status: Healthy' + (data.model_loaded ? ' (Model Loaded)' : ' (Model Not Loaded)');
|
||||
} else {
|
||||
statusDiv.classList.add('alert-danger');
|
||||
statusDiv.textContent = 'API Status: ' + data.status + ' - ' + data.message;
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
const statusDiv = document.getElementById('api-status');
|
||||
statusDiv.classList.remove('d-none', 'alert-info');
|
||||
statusDiv.classList.add('alert-danger');
|
||||
statusDiv.textContent = 'Error connecting to API: ' + error.message;
|
||||
});
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,134 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Model Information - Fraud Detection System</title>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css">
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container mt-5">
|
||||
<div class="row">
|
||||
<div class="col-md-12 text-center mb-4">
|
||||
<h1>Model Information</h1>
|
||||
<p class="lead">Details about the fraud detection model</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<h3>Model Details</h3>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
<div class="col-md-6">
|
||||
<h4>Model Type</h4>
|
||||
<p>{{ model_info.model_type }}</p>
|
||||
|
||||
<h4>Best Parameters</h4>
|
||||
<ul class="list-group">
|
||||
{% for param, value in model_info.best_parameters.items() %}
|
||||
<li class="list-group-item">
|
||||
<strong>{{ param.replace('classifier__', '') }}:</strong> {{ value }}
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="col-md-6">
|
||||
<h4>Performance Metrics</h4>
|
||||
<table class="table table-striped">
|
||||
<tbody>
|
||||
<tr>
|
||||
<th>Accuracy</th>
|
||||
<td>{{ "%.4f"|format(model_info.metrics.accuracy) }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Precision</th>
|
||||
<td>{{ "%.4f"|format(model_info.metrics.precision) }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Recall</th>
|
||||
<td>{{ "%.4f"|format(model_info.metrics.recall) }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>F1 Score</th>
|
||||
<td>{{ "%.4f"|format(model_info.metrics.f1) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<h4>Confusion Matrix</h4>
|
||||
<table class="table table-bordered text-center">
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th>Predicted Negative</th>
|
||||
<th>Predicted Positive</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th>Actual Negative</th>
|
||||
<td class="bg-success text-white">{{ model_info.metrics.confusion_matrix[0][0] }}</td>
|
||||
<td class="bg-warning">{{ model_info.metrics.confusion_matrix[0][1] }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Actual Positive</th>
|
||||
<td class="bg-danger">{{ model_info.metrics.confusion_matrix[1][0] }}</td>
|
||||
<td class="bg-success text-white">{{ model_info.metrics.confusion_matrix[1][1] }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row mt-4">
|
||||
<div class="col-md-12">
|
||||
<h4>Feature Importance</h4>
|
||||
<table class="table table-striped">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Feature</th>
|
||||
<th>Importance</th>
|
||||
<th>Visualization</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for feature in model_info.feature_importance|sort(attribute='Importance', reverse=True) %}
|
||||
<tr>
|
||||
<td>{{ feature.Feature }}</td>
|
||||
<td>{{ "%.4f"|format(feature.Importance) }}</td>
|
||||
<td>
|
||||
<div class="progress">
|
||||
<div class="progress-bar bg-primary" role="progressbar"
|
||||
style="width: {{ feature.Importance * 100 }}%"
|
||||
aria-valuenow="{{ feature.Importance * 100 }}"
|
||||
aria-valuemin="0" aria-valuemax="100">
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row mt-4">
|
||||
<div class="col-md-12 text-center">
|
||||
<a href="/" class="btn btn-primary">Return to Home</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,90 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Fraud Detection Result</title>
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css">
|
||||
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container mt-5">
|
||||
<div class="row">
|
||||
<div class="col-md-12 text-center mb-4">
|
||||
<h1>Fraud Detection Result</h1>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<h3>Prediction Result</h3>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
<div class="col-md-12 text-center mb-4">
|
||||
{% if result.is_fraud %}
|
||||
<div class="alert alert-danger">
|
||||
<h2>Potential Fraud Detected!</h2>
|
||||
<p>Fraud Probability: {{ "%.2f"|format(result.fraud_probability * 100) }}%</p>
|
||||
<p>Risk Level: <span class="badge bg-danger">{{ result.risk_level|upper }}</span></p>
|
||||
</div>
|
||||
{% else %}
|
||||
<div class="alert alert-success">
|
||||
<h2>Transaction Appears Legitimate</h2>
|
||||
<p>Fraud Probability: {{ "%.2f"|format(result.fraud_probability * 100) }}%</p>
|
||||
<p>Risk Level: <span class="badge bg-{{ 'warning' if result.risk_level == 'medium' else 'success' }}">{{ result.risk_level|upper }}</span></p>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<div class="col-md-12">
|
||||
<h4>Transaction Details</h4>
|
||||
<table class="table table-striped">
|
||||
<tbody>
|
||||
<tr>
|
||||
<th>Transaction Date/Time</th>
|
||||
<td>{{ transaction.trans_date_trans_time }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Credit Card Number</th>
|
||||
<td>{{ transaction.cc_num }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Merchant</th>
|
||||
<td>{{ transaction.merchant }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Category</th>
|
||||
<td>{{ transaction.category }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Amount</th>
|
||||
<td>${{ "%.2f"|format(transaction.amt) }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Cardholder</th>
|
||||
<td>{{ transaction.first }} {{ transaction.last }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row mt-4">
|
||||
<div class="col-md-12 text-center">
|
||||
<a href="/" class="btn btn-primary">Check Another Transaction</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user