
import pandas as pd
import numpy as np
import pickle
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (classification_report, confusion_matrix, 
                             roc_auc_score, roc_curve, precision_recall_curve,
                             average_precision_score, accuracy_score)
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# Load Engineered Features
# =============================================================================
feature_data = pd.read_csv('temp_files/engineered_features.csv')

with open('temp_files/feature_columns.txt', 'r') as f:
    feature_cols = [line.strip() for line in f.readlines()]

X = feature_data[feature_cols]
y = feature_data['sla_breached']

print("=" * 70)
print("PHASE 4: MODEL TRAINING & EVALUATION")
print("=" * 70)

# =============================================================================
# Train-Test Split
# =============================================================================
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)

print(f"\nTraining set: {len(X_train):,} samples")
print(f"Test set: {len(X_test):,} samples")
print(f"Train breach rate: {y_train.mean()*100:.1f}%")
print(f"Test breach rate: {y_test.mean()*100:.1f}%")

# =============================================================================
# Scale Features
# =============================================================================
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# =============================================================================
# Baseline Logistic Regression Model
# =============================================================================
print("\n--- Training Baseline Logistic Regression ---")

# Use balanced class weights due to imbalance
baseline_model = LogisticRegression(
    max_iter=1000, 
    class_weight='balanced',
    random_state=42,
    solver='lbfgs'
)

baseline_model.fit(X_train_scaled, y_train)

# Predictions
y_pred = baseline_model.predict(X_test_scaled)
y_prob = baseline_model.predict_proba(X_test_scaled)[:, 1]

# Metrics
accuracy = accuracy_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_prob)
avg_precision = average_precision_score(y_test, y_prob)

print(f"\nBaseline Model Performance:")
print(f"Accuracy: {accuracy:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"Average Precision: {avg_precision:.4f}")
print(f"\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Within SLA', 'Breach']))

# =============================================================================
# Confusion Matrix
# =============================================================================
cm = confusion_matrix(y_test, y_pred)
print(f"\nConfusion Matrix:")
print(cm)

# Plot Confusion Matrix
fig1 = go.Figure(data=go.Heatmap(
    z=cm,
    x=['Predicted: Within SLA', 'Predicted: Breach'],
    y=['Actual: Within SLA', 'Actual: Breach'],
    text=cm,
    texttemplate='%{text}',
    textfont={"size": 18},
    colorscale='Blues',
    showscale=False
))
fig1.update_layout(
    title='Confusion Matrix: SLA Breach Prediction',
    xaxis_title='Predicted',
    yaxis_title='Actual',
    template='plotly_white',
    width=650, height=550
)
fig1.write_html('assets/images/html/model_confusion_matrix.html')
fig1.write_image('assets/images/png/model_confusion_matrix.png', width=650, height=550, scale=2)
print("Saved: model_confusion_matrix")

# =============================================================================
# ROC Curve
# =============================================================================
fpr, tpr, _ = roc_curve(y_test, y_prob)

fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', name=f'Logistic Regression (AUC = {roc_auc:.3f})', 
                            line=dict(color='#2E86AB', width=3)))
fig2.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random Classifier', 
                            line=dict(color='gray', dash='dash', width=2)))
fig2.update_layout(
    title='ROC Curve: SLA Breach Prediction',
    xaxis_title='False Positive Rate', yaxis_title='True Positive Rate',
    template='plotly_white',
    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
    width=700, height=600
)
fig2.write_html('assets/images/html/model_roc_curve.html')
fig2.write_image('assets/images/png/model_roc_curve.png', width=700, height=600, scale=2)
print("Saved: model_roc_curve")

# =============================================================================
# Feature Importance (Coefficients)
# =============================================================================
coef_df = pd.DataFrame({
    'feature': feature_cols,
    'coefficient': baseline_model.coef_[0],
    'abs_coefficient': np.abs(baseline_model.coef_[0])
}).sort_values('abs_coefficient', ascending=False)

print(f"\nTop 20 Most Important Features:")
print(coef_df.head(20)[['feature', 'coefficient']])

# Plot top 20 features
top20 = coef_df.head(20).copy()
fig3 = px.bar(top20, y='feature', x='coefficient', orientation='h',
              title='Top 20 Feature Importance: SLA Breach Prediction',
              color='coefficient', color_continuous_scale='RdBu')
fig3.update_layout(yaxis=dict(autorange="reversed"), template='plotly_white')
fig3.write_html('assets/images/html/model_feature_importance.html')
fig3.write_image('assets/images/png/model_feature_importance.png', width=950, height=700, scale=2)
print("Saved: model_feature_importance")

# =============================================================================
# Save Model
# =============================================================================
model_package = {
    'model': baseline_model,
    'scaler': scaler,
    'feature_columns': feature_cols,
    'metrics': {
        'accuracy': accuracy,
        'roc_auc': roc_auc,
        'avg_precision': avg_precision
    }
}

with open('models/sla_breach_model.pkl', 'wb') as f:
    pickle.dump(model_package, f)

print(f"\nModel saved to: models/sla_breach_model.pkl")

# Save metrics summary
metrics_summary = {
    'Model': 'Logistic Regression',
    'Training_Samples': len(X_train),
    'Test_Samples': len(X_test),
    'Accuracy': round(accuracy, 4),
    'ROC_AUC': round(roc_auc, 4),
    'Avg_Precision': round(avg_precision, 4),
    'Feature_Count': len(feature_cols)
}

metrics_df = pd.DataFrame([metrics_summary])
metrics_df.to_csv('temp_files/model_metrics.csv', index=False)

print("\n" + "=" * 70)
print("MODEL TRAINING & EVALUATION COMPLETE")
print("=" * 70)
