ML Monitoring and Drift Detection
Monitor production ML models with data drift detection, performance tracking, and automated alerting. Includes working Python implementations.
Your model performs well on the test set. You deploy it. Six months later, accuracy has dropped 20% and no one noticed. This is the silent failure mode of production ML—models degrade gradually, without obvious errors, until users complain or business metrics crater.
Monitoring ML systems is fundamentally different from monitoring traditional software. You’re not just watching for crashes and errors. You’re watching for subtle statistical shifts in data and model behavior that indicate degradation.
The Three Types of ML Drift
| Drift Type | What Changes | Detection Method | Business Impact |
|---|---|---|---|
| Data drift | Input feature distributions | Statistical tests (KS, chi-squared) | Model sees unfamiliar patterns |
| Concept drift | Relationship between features and labels | Performance monitoring | What was true is no longer true |
| Model drift | Prediction distribution and confidence | Output monitoring | Model becoming uncertain |
Data Drift Detection
Data drift occurs when production data differs from training data. Use statistical tests to compare distributions.
Kolmogorov-Smirnov Test (Numerical Features)
The KS test compares two distributions by measuring the maximum distance between their cumulative distribution functions.
from scipy import stats
import numpy as np
def detect_numerical_drift(reference: np.ndarray, production: np.ndarray,
p_threshold: float = 0.05) -> dict:
"""
Detect drift in numerical features using KS test.
Args:
reference: Training/validation data distribution
production: Current production data
p_threshold: Significance level (typically 0.05)
Returns:
Dictionary with test results
"""
statistic, p_value = stats.ks_2samp(reference, production)
return {
'test': 'Kolmogorov-Smirnov',
'statistic': statistic,
'p_value': p_value,
'is_drift': p_value < p_threshold,
'severity': 'high' if p_value < 0.001 else 'medium' if p_value < 0.01 else 'low'
}
# Example: Age distribution shifted
reference_age = np.random.normal(35, 10, 1000) # Training data
production_age = np.random.normal(40, 12, 500) # Production data
result = detect_numerical_drift(reference_age, production_age)
print(f"Age drift: p={result['p_value']:.4f} -> {result['severity']}")
Age drift: p=0.0000 -> high
Chi-Squared Test (Categorical Features)
For categorical features, use chi-squared to compare frequency distributions.
def detect_categorical_drift(reference: np.ndarray, production: np.ndarray,
p_threshold: float = 0.05) -> dict:
"""
Detect drift in categorical features using chi-squared test.
"""
# Get all categories from both distributions
all_categories = np.union1d(np.unique(reference), np.unique(production))
# Count occurrences
ref_counts = np.array([np.sum(reference == cat) for cat in all_categories])
prod_counts = np.array([np.sum(production == cat) for cat in all_categories])
# Normalize to frequencies
ref_freq = ref_counts / len(reference)
prod_freq = prod_counts / len(production)
# Calculate expected counts based on reference distribution
expected = ref_freq * len(production) + 1e-10 # Avoid division by zero
observed = prod_freq * len(production)
statistic, p_value = stats.chisquare(observed, expected)
return {
'test': 'Chi-squared',
'statistic': statistic,
'p_value': p_value,
'is_drift': p_value < p_threshold,
'reference_dist': dict(zip(all_categories, ref_freq.round(3))),
'production_dist': dict(zip(all_categories, prod_freq.round(3)))
}
# Example: Category distribution changed
reference_cat = np.random.choice(['A', 'B', 'C'], 1000, p=[0.5, 0.3, 0.2])
production_cat = np.random.choice(['A', 'B', 'C'], 500, p=[0.3, 0.4, 0.3])
result = detect_categorical_drift(reference_cat, production_cat)
print(f"Category drift: p={result['p_value']:.4f}")
print(f" Reference: {result['reference_dist']}")
print(f" Production: {result['production_dist']}")
Category drift: p=0.0000
Reference: {'A': 0.5, 'B': 0.3, 'C': 0.2}
Production: {'A': 0.3, 'B': 0.4, 'C': 0.3} Complete Data Drift Detector
Here’s a complete class that handles both numerical and categorical features:
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
from scipy import stats
@dataclass
class DriftResult:
"""Result of a drift detection test."""
feature: str
test_name: str
statistic: float
p_value: float
is_drift: bool
severity: str
class DataDriftDetector:
"""Detect drift between reference and production data."""
def __init__(self, p_threshold: float = 0.05):
self.p_threshold = p_threshold
self.reference_stats = {}
def fit(self, reference_data: Dict[str, np.ndarray]):
"""Store reference data for comparison."""
for feature, values in reference_data.items():
is_categorical = values.dtype.kind in ['U', 'S', 'O']
self.reference_stats[feature] = {
'values': values,
'is_categorical': is_categorical
}
print(f"Fitted on {len(reference_data)} features")
def detect(self, production_data: Dict[str, np.ndarray]) -> List[DriftResult]:
"""Detect drift for all features."""
results = []
for feature, prod_values in production_data.items():
if feature not in self.reference_stats:
continue
ref_stats = self.reference_stats[feature]
ref_values = ref_stats['values']
if ref_stats['is_categorical']:
result = self._chi_squared_test(feature, ref_values, prod_values)
else:
result = self._ks_test(feature, ref_values, prod_values)
results.append(result)
return results
def _ks_test(self, feature: str, ref: np.ndarray, prod: np.ndarray) -> DriftResult:
stat, p = stats.ks_2samp(ref, prod)
return DriftResult(
feature=feature,
test_name='Kolmogorov-Smirnov',
statistic=stat,
p_value=p,
is_drift=p < self.p_threshold,
severity=self._severity(p)
)
def _chi_squared_test(self, feature: str, ref: np.ndarray, prod: np.ndarray) -> DriftResult:
categories = np.union1d(np.unique(ref), np.unique(prod))
ref_freq = np.array([np.sum(ref == c) for c in categories]) / len(ref)
prod_freq = np.array([np.sum(prod == c) for c in categories]) / len(prod)
expected = ref_freq * len(prod) + 1e-10
observed = prod_freq * len(prod)
stat, p = stats.chisquare(observed, expected)
return DriftResult(
feature=feature,
test_name='Chi-squared',
statistic=stat,
p_value=p,
is_drift=p < self.p_threshold,
severity=self._severity(p)
)
def _severity(self, p: float) -> str:
if p < 0.001: return 'high'
if p < 0.01: return 'medium'
if p < 0.05: return 'low'
return 'none'
Running Data Drift Detection
# Reference data (from training)
reference = {
'age': np.random.normal(35, 10, 1000),
'income': np.random.normal(50000, 15000, 1000),
'category': np.random.choice(['A', 'B', 'C'], 1000, p=[0.5, 0.3, 0.2])
}
# Production data (current)
production = {
'age': np.random.normal(40, 12, 500), # Drift!
'income': np.random.normal(50000, 15000, 500), # No drift
'category': np.random.choice(['A', 'B', 'C'], 500, p=[0.3, 0.4, 0.3]) # Drift!
}
detector = DataDriftDetector(p_threshold=0.05)
detector.fit(reference)
results = detector.detect(production)
for r in results:
status = "DRIFT" if r.is_drift else "OK"
print(f"{r.feature:15} p={r.p_value:.4f} -> {status} ({r.severity})")
Fitted on 3 features age p=0.0000 -> DRIFT (high) income p=0.3475 -> OK (none) category p=0.0000 -> DRIFT (high)
Model Drift Detection
Model drift tracks changes in prediction patterns, not input features. Monitor these signals:
- Prediction distribution shift - Class proportions changing over time
- Confidence degradation - Model becoming less certain
- Low confidence ratio - Percentage of uncertain predictions
class ModelDriftDetector:
"""Detect drift in model predictions and confidence."""
def __init__(self, distribution_threshold: float = 0.1,
confidence_threshold: float = 0.3):
self.distribution_threshold = distribution_threshold
self.confidence_threshold = confidence_threshold
self.reference_dist = None
self.reference_confidence = None
def fit(self, predictions: np.ndarray, confidences: np.ndarray):
"""Store reference prediction statistics."""
unique, counts = np.unique(predictions, return_counts=True)
self.reference_dist = counts / len(predictions)
self.reference_classes = unique
self.reference_confidence = np.mean(confidences)
print(f"Reference distribution: {dict(zip(unique, self.reference_dist.round(3)))}")
print(f"Reference confidence: {self.reference_confidence:.3f}")
def detect(self, predictions: np.ndarray, confidences: np.ndarray) -> dict:
"""Detect drift in production predictions."""
# Current distribution
unique, counts = np.unique(predictions, return_counts=True)
current_dist = counts / len(predictions)
# Jensen-Shannon divergence for distribution comparison
js_div = self._js_divergence(self.reference_dist, current_dist)
# Confidence comparison
current_conf = np.mean(confidences)
conf_drop = self.reference_confidence - current_conf
# Low confidence ratio
low_conf_ratio = np.mean(confidences < 0.6)
return {
'distribution': {
'js_divergence': round(js_div, 4),
'is_drift': js_div > self.distribution_threshold
},
'confidence': {
'reference': round(self.reference_confidence, 3),
'current': round(current_conf, 3),
'drop': round(conf_drop, 3),
'is_drift': conf_drop > self.confidence_threshold
},
'low_confidence_ratio': {
'ratio': round(low_conf_ratio, 3),
'is_concerning': low_conf_ratio > 0.2
}
}
def _js_divergence(self, p: np.ndarray, q: np.ndarray) -> float:
"""Jensen-Shannon divergence between distributions."""
max_len = max(len(p), len(q))
p_pad = np.zeros(max_len)
q_pad = np.zeros(max_len)
p_pad[:len(p)] = p
q_pad[:len(q)] = q
p_pad = (p_pad + 1e-10) / (p_pad + 1e-10).sum()
q_pad = (q_pad + 1e-10) / (q_pad + 1e-10).sum()
m = 0.5 * (p_pad + q_pad)
return 0.5 * stats.entropy(p_pad, m) + 0.5 * stats.entropy(q_pad, m)
Model Drift in Action
# Reference: high confidence, balanced predictions
ref_preds = np.random.choice([0, 1], 1000, p=[0.4, 0.6])
ref_confs = np.random.beta(5, 2, 1000) # High confidence (mean ~0.71)
# Production: distribution shifted, lower confidence
prod_preds = np.random.choice([0, 1], 500, p=[0.6, 0.4])
prod_confs = np.random.beta(2, 2, 500) # Lower confidence (mean ~0.5)
detector = ModelDriftDetector()
detector.fit(ref_preds, ref_confs)
results = detector.detect(prod_preds, prod_confs)
print(f"Distribution drift: JS={results['distribution']['js_divergence']:.4f}")
print(f"Confidence drop: {results['confidence']['drop']:.3f}")
print(f"Low confidence ratio: {results['low_confidence_ratio']['ratio']:.3f}")
Reference distribution: {0: 0.421, 1: 0.579}
Reference confidence: 0.713
Distribution drift: JS=0.0150
Confidence drop: 0.192
Low confidence ratio: 0.596 Performance Monitoring
Track accuracy and latency with alerting when metrics degrade.
class PerformanceMonitor:
"""Monitor model performance with alerting."""
def __init__(self, accuracy_threshold: float = 0.05,
latency_threshold_ms: float = 100):
self.accuracy_threshold = accuracy_threshold
self.latency_threshold_ms = latency_threshold_ms
self.baseline_accuracy = None
self.baseline_latency = None
self.history = []
def set_baseline(self, accuracy: float, latency_p95_ms: float):
"""Set performance baseline."""
self.baseline_accuracy = accuracy
self.baseline_latency = latency_p95_ms
print(f"Baseline: accuracy={accuracy:.3f}, latency_p95={latency_p95_ms:.1f}ms")
def log_batch(self, predictions: np.ndarray, labels: np.ndarray,
latencies_ms: np.ndarray) -> dict:
"""Log batch and check for alerts."""
accuracy = np.mean(predictions == labels)
latency_p95 = np.percentile(latencies_ms, 95)
alerts = []
# Check accuracy
if self.baseline_accuracy:
drop = self.baseline_accuracy - accuracy
if drop > self.accuracy_threshold:
alerts.append({
'type': 'accuracy_degradation',
'severity': 'high' if drop > 0.1 else 'medium',
'message': f"Accuracy dropped {drop:.2%} below baseline"
})
# Check latency
if self.baseline_latency:
increase = latency_p95 - self.baseline_latency
if increase > self.latency_threshold_ms:
alerts.append({
'type': 'latency_spike',
'severity': 'high' if increase > 200 else 'medium',
'message': f"P95 latency increased {increase:.0f}ms"
})
metrics = {
'accuracy': round(accuracy, 4),
'latency_p95_ms': round(latency_p95, 2),
'sample_count': len(predictions),
'alerts': alerts
}
self.history.append(metrics)
return metrics
def get_summary(self) -> dict:
"""Get monitoring summary."""
if not self.history:
return {}
all_alerts = [a for m in self.history for a in m.get('alerts', [])]
return {
'total_batches': len(self.history),
'total_samples': sum(m['sample_count'] for m in self.history),
'avg_accuracy': round(np.mean([m['accuracy'] for m in self.history]), 4),
'total_alerts': len(all_alerts)
}
Simulating Performance Degradation
monitor = PerformanceMonitor(accuracy_threshold=0.05, latency_threshold_ms=50)
monitor.set_baseline(accuracy=0.92, latency_p95_ms=45.0)
# Simulate degrading performance over 5 batches
for i in range(5):
# Performance degrades over time
accuracy_factor = 1.0 - (i * 0.03)
latency_factor = 1.0 + (i * 0.2)
batch_size = 100
preds = np.random.choice([0, 1], batch_size)
labels = preds.copy()
# Introduce errors
n_errors = int(batch_size * (1 - 0.92 * accuracy_factor))
error_idx = np.random.choice(batch_size, n_errors, replace=False)
labels[error_idx] = 1 - labels[error_idx]
latencies = np.random.gamma(4, 10 * latency_factor, batch_size)
metrics = monitor.log_batch(preds, labels, latencies)
alerts = f" ALERTS: {len(metrics['alerts'])}" if metrics['alerts'] else ""
print(f"Batch {i+1}: accuracy={metrics['accuracy']:.3f}, "
f"p95={metrics['latency_p95_ms']:.1f}ms{alerts}")
summary = monitor.get_summary()
print(f"\nTotal alerts: {summary['total_alerts']}")
Baseline: accuracy=0.920, latency_p95=45.0ms Batch 1: accuracy=0.930, p95=82.0ms Batch 2: accuracy=0.900, p95=90.2ms Batch 3: accuracy=0.870, p95=100.8ms ALERTS: 2 Batch 4: accuracy=0.840, p95=121.3ms ALERTS: 2 Batch 5: accuracy=0.810, p95=133.1ms ALERTS: 2 Total alerts: 6
Alerting Integration
CloudWatch Metrics (AWS)
import boto3
cloudwatch = boto3.client('cloudwatch')
def publish_metrics(metrics: dict, model_name: str):
"""Publish metrics to CloudWatch."""
cloudwatch.put_metric_data(
Namespace='MLModels',
MetricData=[
{
'MetricName': 'Accuracy',
'Value': metrics['accuracy'],
'Unit': 'None',
'Dimensions': [{'Name': 'ModelName', 'Value': model_name}]
},
{
'MetricName': 'LatencyP95',
'Value': metrics['latency_p95_ms'],
'Unit': 'Milliseconds',
'Dimensions': [{'Name': 'ModelName', 'Value': model_name}]
},
{
'MetricName': 'DriftScore',
'Value': metrics.get('drift_score', 0),
'Unit': 'None',
'Dimensions': [{'Name': 'ModelName', 'Value': model_name}]
}
]
)
Prometheus Metrics
from prometheus_client import Gauge, Counter, Histogram
# Define metrics
accuracy_gauge = Gauge('model_accuracy', 'Model accuracy', ['model_name'])
latency_histogram = Histogram('model_latency_seconds', 'Inference latency',
['model_name'], buckets=[.01, .025, .05, .1, .25, .5, 1])
drift_counter = Counter('model_drift_detected', 'Drift events', ['model_name', 'feature'])
def record_metrics(model_name: str, accuracy: float, latency: float, drift_features: list):
"""Record Prometheus metrics."""
accuracy_gauge.labels(model_name=model_name).set(accuracy)
latency_histogram.labels(model_name=model_name).observe(latency)
for feature in drift_features:
drift_counter.labels(model_name=model_name, feature=feature).inc()
Retraining Triggers
Set up automatic retraining based on drift signals:
class RetrainingTrigger:
"""Determine when to retrain based on monitoring signals."""
def __init__(self,
accuracy_threshold: float = 0.85,
drift_feature_threshold: int = 2,
consecutive_alerts_threshold: int = 3):
self.accuracy_threshold = accuracy_threshold
self.drift_feature_threshold = drift_feature_threshold
self.consecutive_alerts_threshold = consecutive_alerts_threshold
self.alert_count = 0
def should_retrain(self, metrics: dict, drift_results: list) -> dict:
"""Determine if retraining is needed."""
reasons = []
# Check accuracy below threshold
if metrics['accuracy'] < self.accuracy_threshold:
reasons.append(f"Accuracy {metrics['accuracy']:.2%} below {self.accuracy_threshold:.2%}")
# Check number of features with drift
drifted_features = [r.feature for r in drift_results if r.is_drift]
if len(drifted_features) >= self.drift_feature_threshold:
reasons.append(f"Drift in {len(drifted_features)} features: {drifted_features}")
# Track consecutive alerts
if metrics.get('alerts'):
self.alert_count += 1
else:
self.alert_count = 0
if self.alert_count >= self.consecutive_alerts_threshold:
reasons.append(f"{self.alert_count} consecutive batches with alerts")
return {
'should_retrain': len(reasons) > 0,
'reasons': reasons,
'urgency': 'high' if len(reasons) > 1 else 'medium' if reasons else 'none'
}
Monitoring Dashboard Architecture
┌─────────────────────────────────────────────────────────────┐
│ ML MONITORING STACK │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Model │ │ Data │ │ Performance │ │
│ │ Drift │ │ Drift │ │ Metrics │ │
│ │ Detector │ │ Detector │ │ Collector │ │
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │ │
│ └─────────────────┬┴─────────────────┘ │
│ │ │
│ ┌──────▼──────┐ │
│ │ Metrics │ │
│ │ Aggregator │ │
│ └──────┬──────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ │ │ │ │
│ ┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐ │
│ │CloudWatch │ │Prometheus │ │ Slack │ │
│ │ │ │+ Grafana │ │ Alerts │ │
│ └───────────┘ └───────────┘ └───────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Production Checklist
Before deploying monitoring:
Data Drift
- Store reference distributions from training data
- Choose appropriate statistical tests (KS for numerical, chi-squared for categorical)
- Set alert thresholds based on business impact
- Monitor all features or prioritize based on feature importance
Model Drift
- Track prediction distribution over time
- Monitor confidence/uncertainty metrics
- Compare against baseline from validation set
Performance
- Set accuracy baseline from validation
- Set latency baseline from load testing
- Configure alerting thresholds
- Test alert delivery (don’t let alerts fail silently)
Retraining
- Define retraining triggers
- Automate data collection for retraining
- Set up retraining pipeline
- Implement gradual rollout for retrained models
What’s Next
You’re now monitoring your models in production. Complete the series:
- ML Security - IAM roles, secrets management, VPC configuration
Full Code
Complete monitoring implementation: largo-tutorials/ml-monitoring
Comments
to join the discussion.