ML Monitoring and Drift Detection

Deep Dive 35 min read January 24, 2026 |
0

Monitor production ML models with data drift detection, performance tracking, and automated alerting. Includes working Python implementations.

Your model performs well on the test set. You deploy it. Six months later, accuracy has dropped 20% and no one noticed. This is the silent failure mode of production ML—models degrade gradually, without obvious errors, until users complain or business metrics crater.

Monitoring ML systems is fundamentally different from monitoring traditional software. You’re not just watching for crashes and errors. You’re watching for subtle statistical shifts in data and model behavior that indicate degradation.

The Three Types of ML Drift

Drift TypeWhat ChangesDetection MethodBusiness Impact
Data driftInput feature distributionsStatistical tests (KS, chi-squared)Model sees unfamiliar patterns
Concept driftRelationship between features and labelsPerformance monitoringWhat was true is no longer true
Model driftPrediction distribution and confidenceOutput monitoringModel becoming uncertain

Data Drift Detection

Data drift occurs when production data differs from training data. Use statistical tests to compare distributions.

Kolmogorov-Smirnov Test (Numerical Features)

The KS test compares two distributions by measuring the maximum distance between their cumulative distribution functions.

from scipy import stats
import numpy as np

def detect_numerical_drift(reference: np.ndarray, production: np.ndarray,
                           p_threshold: float = 0.05) -> dict:
    """
    Detect drift in numerical features using KS test.

    Args:
        reference: Training/validation data distribution
        production: Current production data
        p_threshold: Significance level (typically 0.05)

    Returns:
        Dictionary with test results
    """
    statistic, p_value = stats.ks_2samp(reference, production)

    return {
        'test': 'Kolmogorov-Smirnov',
        'statistic': statistic,
        'p_value': p_value,
        'is_drift': p_value < p_threshold,
        'severity': 'high' if p_value < 0.001 else 'medium' if p_value < 0.01 else 'low'
    }

# Example: Age distribution shifted
reference_age = np.random.normal(35, 10, 1000)  # Training data
production_age = np.random.normal(40, 12, 500)  # Production data

result = detect_numerical_drift(reference_age, production_age)
print(f"Age drift: p={result['p_value']:.4f} -> {result['severity']}")
Output
Age drift: p=0.0000 -> high

Chi-Squared Test (Categorical Features)

For categorical features, use chi-squared to compare frequency distributions.

def detect_categorical_drift(reference: np.ndarray, production: np.ndarray,
                             p_threshold: float = 0.05) -> dict:
    """
    Detect drift in categorical features using chi-squared test.
    """
    # Get all categories from both distributions
    all_categories = np.union1d(np.unique(reference), np.unique(production))

    # Count occurrences
    ref_counts = np.array([np.sum(reference == cat) for cat in all_categories])
    prod_counts = np.array([np.sum(production == cat) for cat in all_categories])

    # Normalize to frequencies
    ref_freq = ref_counts / len(reference)
    prod_freq = prod_counts / len(production)

    # Calculate expected counts based on reference distribution
    expected = ref_freq * len(production) + 1e-10  # Avoid division by zero
    observed = prod_freq * len(production)

    statistic, p_value = stats.chisquare(observed, expected)

    return {
        'test': 'Chi-squared',
        'statistic': statistic,
        'p_value': p_value,
        'is_drift': p_value < p_threshold,
        'reference_dist': dict(zip(all_categories, ref_freq.round(3))),
        'production_dist': dict(zip(all_categories, prod_freq.round(3)))
    }

# Example: Category distribution changed
reference_cat = np.random.choice(['A', 'B', 'C'], 1000, p=[0.5, 0.3, 0.2])
production_cat = np.random.choice(['A', 'B', 'C'], 500, p=[0.3, 0.4, 0.3])

result = detect_categorical_drift(reference_cat, production_cat)
print(f"Category drift: p={result['p_value']:.4f}")
print(f"  Reference: {result['reference_dist']}")
print(f"  Production: {result['production_dist']}")
Output
Category drift: p=0.0000
Reference: {'A': 0.5, 'B': 0.3, 'C': 0.2}
Production: {'A': 0.3, 'B': 0.4, 'C': 0.3}

Complete Data Drift Detector

Here’s a complete class that handles both numerical and categorical features:

from dataclasses import dataclass
from typing import Dict, List
import numpy as np
from scipy import stats

@dataclass
class DriftResult:
    """Result of a drift detection test."""
    feature: str
    test_name: str
    statistic: float
    p_value: float
    is_drift: bool
    severity: str

class DataDriftDetector:
    """Detect drift between reference and production data."""

    def __init__(self, p_threshold: float = 0.05):
        self.p_threshold = p_threshold
        self.reference_stats = {}

    def fit(self, reference_data: Dict[str, np.ndarray]):
        """Store reference data for comparison."""
        for feature, values in reference_data.items():
            is_categorical = values.dtype.kind in ['U', 'S', 'O']
            self.reference_stats[feature] = {
                'values': values,
                'is_categorical': is_categorical
            }
        print(f"Fitted on {len(reference_data)} features")

    def detect(self, production_data: Dict[str, np.ndarray]) -> List[DriftResult]:
        """Detect drift for all features."""
        results = []

        for feature, prod_values in production_data.items():
            if feature not in self.reference_stats:
                continue

            ref_stats = self.reference_stats[feature]
            ref_values = ref_stats['values']

            if ref_stats['is_categorical']:
                result = self._chi_squared_test(feature, ref_values, prod_values)
            else:
                result = self._ks_test(feature, ref_values, prod_values)

            results.append(result)

        return results

    def _ks_test(self, feature: str, ref: np.ndarray, prod: np.ndarray) -> DriftResult:
        stat, p = stats.ks_2samp(ref, prod)
        return DriftResult(
            feature=feature,
            test_name='Kolmogorov-Smirnov',
            statistic=stat,
            p_value=p,
            is_drift=p < self.p_threshold,
            severity=self._severity(p)
        )

    def _chi_squared_test(self, feature: str, ref: np.ndarray, prod: np.ndarray) -> DriftResult:
        categories = np.union1d(np.unique(ref), np.unique(prod))
        ref_freq = np.array([np.sum(ref == c) for c in categories]) / len(ref)
        prod_freq = np.array([np.sum(prod == c) for c in categories]) / len(prod)
        expected = ref_freq * len(prod) + 1e-10
        observed = prod_freq * len(prod)
        stat, p = stats.chisquare(observed, expected)
        return DriftResult(
            feature=feature,
            test_name='Chi-squared',
            statistic=stat,
            p_value=p,
            is_drift=p < self.p_threshold,
            severity=self._severity(p)
        )

    def _severity(self, p: float) -> str:
        if p < 0.001: return 'high'
        if p < 0.01: return 'medium'
        if p < 0.05: return 'low'
        return 'none'

Running Data Drift Detection

# Reference data (from training)
reference = {
    'age': np.random.normal(35, 10, 1000),
    'income': np.random.normal(50000, 15000, 1000),
    'category': np.random.choice(['A', 'B', 'C'], 1000, p=[0.5, 0.3, 0.2])
}

# Production data (current)
production = {
    'age': np.random.normal(40, 12, 500),      # Drift!
    'income': np.random.normal(50000, 15000, 500),  # No drift
    'category': np.random.choice(['A', 'B', 'C'], 500, p=[0.3, 0.4, 0.3])  # Drift!
}

detector = DataDriftDetector(p_threshold=0.05)
detector.fit(reference)
results = detector.detect(production)

for r in results:
    status = "DRIFT" if r.is_drift else "OK"
    print(f"{r.feature:15} p={r.p_value:.4f} -> {status} ({r.severity})")
Output
Fitted on 3 features
age             p=0.0000 -> DRIFT (high)
income          p=0.3475 -> OK (none)
category        p=0.0000 -> DRIFT (high)

Model Drift Detection

Model drift tracks changes in prediction patterns, not input features. Monitor these signals:

  1. Prediction distribution shift - Class proportions changing over time
  2. Confidence degradation - Model becoming less certain
  3. Low confidence ratio - Percentage of uncertain predictions
class ModelDriftDetector:
    """Detect drift in model predictions and confidence."""

    def __init__(self, distribution_threshold: float = 0.1,
                 confidence_threshold: float = 0.3):
        self.distribution_threshold = distribution_threshold
        self.confidence_threshold = confidence_threshold
        self.reference_dist = None
        self.reference_confidence = None

    def fit(self, predictions: np.ndarray, confidences: np.ndarray):
        """Store reference prediction statistics."""
        unique, counts = np.unique(predictions, return_counts=True)
        self.reference_dist = counts / len(predictions)
        self.reference_classes = unique
        self.reference_confidence = np.mean(confidences)

        print(f"Reference distribution: {dict(zip(unique, self.reference_dist.round(3)))}")
        print(f"Reference confidence: {self.reference_confidence:.3f}")

    def detect(self, predictions: np.ndarray, confidences: np.ndarray) -> dict:
        """Detect drift in production predictions."""
        # Current distribution
        unique, counts = np.unique(predictions, return_counts=True)
        current_dist = counts / len(predictions)

        # Jensen-Shannon divergence for distribution comparison
        js_div = self._js_divergence(self.reference_dist, current_dist)

        # Confidence comparison
        current_conf = np.mean(confidences)
        conf_drop = self.reference_confidence - current_conf

        # Low confidence ratio
        low_conf_ratio = np.mean(confidences < 0.6)

        return {
            'distribution': {
                'js_divergence': round(js_div, 4),
                'is_drift': js_div > self.distribution_threshold
            },
            'confidence': {
                'reference': round(self.reference_confidence, 3),
                'current': round(current_conf, 3),
                'drop': round(conf_drop, 3),
                'is_drift': conf_drop > self.confidence_threshold
            },
            'low_confidence_ratio': {
                'ratio': round(low_conf_ratio, 3),
                'is_concerning': low_conf_ratio > 0.2
            }
        }

    def _js_divergence(self, p: np.ndarray, q: np.ndarray) -> float:
        """Jensen-Shannon divergence between distributions."""
        max_len = max(len(p), len(q))
        p_pad = np.zeros(max_len)
        q_pad = np.zeros(max_len)
        p_pad[:len(p)] = p
        q_pad[:len(q)] = q

        p_pad = (p_pad + 1e-10) / (p_pad + 1e-10).sum()
        q_pad = (q_pad + 1e-10) / (q_pad + 1e-10).sum()

        m = 0.5 * (p_pad + q_pad)
        return 0.5 * stats.entropy(p_pad, m) + 0.5 * stats.entropy(q_pad, m)

Model Drift in Action

# Reference: high confidence, balanced predictions
ref_preds = np.random.choice([0, 1], 1000, p=[0.4, 0.6])
ref_confs = np.random.beta(5, 2, 1000)  # High confidence (mean ~0.71)

# Production: distribution shifted, lower confidence
prod_preds = np.random.choice([0, 1], 500, p=[0.6, 0.4])
prod_confs = np.random.beta(2, 2, 500)  # Lower confidence (mean ~0.5)

detector = ModelDriftDetector()
detector.fit(ref_preds, ref_confs)
results = detector.detect(prod_preds, prod_confs)

print(f"Distribution drift: JS={results['distribution']['js_divergence']:.4f}")
print(f"Confidence drop: {results['confidence']['drop']:.3f}")
print(f"Low confidence ratio: {results['low_confidence_ratio']['ratio']:.3f}")
Output
Reference distribution: {0: 0.421, 1: 0.579}
Reference confidence: 0.713
Distribution drift: JS=0.0150
Confidence drop: 0.192
Low confidence ratio: 0.596

Performance Monitoring

Track accuracy and latency with alerting when metrics degrade.

class PerformanceMonitor:
    """Monitor model performance with alerting."""

    def __init__(self, accuracy_threshold: float = 0.05,
                 latency_threshold_ms: float = 100):
        self.accuracy_threshold = accuracy_threshold
        self.latency_threshold_ms = latency_threshold_ms
        self.baseline_accuracy = None
        self.baseline_latency = None
        self.history = []

    def set_baseline(self, accuracy: float, latency_p95_ms: float):
        """Set performance baseline."""
        self.baseline_accuracy = accuracy
        self.baseline_latency = latency_p95_ms
        print(f"Baseline: accuracy={accuracy:.3f}, latency_p95={latency_p95_ms:.1f}ms")

    def log_batch(self, predictions: np.ndarray, labels: np.ndarray,
                  latencies_ms: np.ndarray) -> dict:
        """Log batch and check for alerts."""
        accuracy = np.mean(predictions == labels)
        latency_p95 = np.percentile(latencies_ms, 95)

        alerts = []

        # Check accuracy
        if self.baseline_accuracy:
            drop = self.baseline_accuracy - accuracy
            if drop > self.accuracy_threshold:
                alerts.append({
                    'type': 'accuracy_degradation',
                    'severity': 'high' if drop > 0.1 else 'medium',
                    'message': f"Accuracy dropped {drop:.2%} below baseline"
                })

        # Check latency
        if self.baseline_latency:
            increase = latency_p95 - self.baseline_latency
            if increase > self.latency_threshold_ms:
                alerts.append({
                    'type': 'latency_spike',
                    'severity': 'high' if increase > 200 else 'medium',
                    'message': f"P95 latency increased {increase:.0f}ms"
                })

        metrics = {
            'accuracy': round(accuracy, 4),
            'latency_p95_ms': round(latency_p95, 2),
            'sample_count': len(predictions),
            'alerts': alerts
        }

        self.history.append(metrics)
        return metrics

    def get_summary(self) -> dict:
        """Get monitoring summary."""
        if not self.history:
            return {}

        all_alerts = [a for m in self.history for a in m.get('alerts', [])]
        return {
            'total_batches': len(self.history),
            'total_samples': sum(m['sample_count'] for m in self.history),
            'avg_accuracy': round(np.mean([m['accuracy'] for m in self.history]), 4),
            'total_alerts': len(all_alerts)
        }

Simulating Performance Degradation

monitor = PerformanceMonitor(accuracy_threshold=0.05, latency_threshold_ms=50)
monitor.set_baseline(accuracy=0.92, latency_p95_ms=45.0)

# Simulate degrading performance over 5 batches
for i in range(5):
    # Performance degrades over time
    accuracy_factor = 1.0 - (i * 0.03)
    latency_factor = 1.0 + (i * 0.2)

    batch_size = 100
    preds = np.random.choice([0, 1], batch_size)
    labels = preds.copy()

    # Introduce errors
    n_errors = int(batch_size * (1 - 0.92 * accuracy_factor))
    error_idx = np.random.choice(batch_size, n_errors, replace=False)
    labels[error_idx] = 1 - labels[error_idx]

    latencies = np.random.gamma(4, 10 * latency_factor, batch_size)

    metrics = monitor.log_batch(preds, labels, latencies)
    alerts = f" ALERTS: {len(metrics['alerts'])}" if metrics['alerts'] else ""
    print(f"Batch {i+1}: accuracy={metrics['accuracy']:.3f}, "
          f"p95={metrics['latency_p95_ms']:.1f}ms{alerts}")

summary = monitor.get_summary()
print(f"\nTotal alerts: {summary['total_alerts']}")
Output
Baseline: accuracy=0.920, latency_p95=45.0ms
Batch 1: accuracy=0.930, p95=82.0ms
Batch 2: accuracy=0.900, p95=90.2ms
Batch 3: accuracy=0.870, p95=100.8ms ALERTS: 2
Batch 4: accuracy=0.840, p95=121.3ms ALERTS: 2
Batch 5: accuracy=0.810, p95=133.1ms ALERTS: 2

Total alerts: 6

Alerting Integration

CloudWatch Metrics (AWS)

import boto3

cloudwatch = boto3.client('cloudwatch')

def publish_metrics(metrics: dict, model_name: str):
    """Publish metrics to CloudWatch."""
    cloudwatch.put_metric_data(
        Namespace='MLModels',
        MetricData=[
            {
                'MetricName': 'Accuracy',
                'Value': metrics['accuracy'],
                'Unit': 'None',
                'Dimensions': [{'Name': 'ModelName', 'Value': model_name}]
            },
            {
                'MetricName': 'LatencyP95',
                'Value': metrics['latency_p95_ms'],
                'Unit': 'Milliseconds',
                'Dimensions': [{'Name': 'ModelName', 'Value': model_name}]
            },
            {
                'MetricName': 'DriftScore',
                'Value': metrics.get('drift_score', 0),
                'Unit': 'None',
                'Dimensions': [{'Name': 'ModelName', 'Value': model_name}]
            }
        ]
    )

Prometheus Metrics

from prometheus_client import Gauge, Counter, Histogram

# Define metrics
accuracy_gauge = Gauge('model_accuracy', 'Model accuracy', ['model_name'])
latency_histogram = Histogram('model_latency_seconds', 'Inference latency',
                              ['model_name'], buckets=[.01, .025, .05, .1, .25, .5, 1])
drift_counter = Counter('model_drift_detected', 'Drift events', ['model_name', 'feature'])

def record_metrics(model_name: str, accuracy: float, latency: float, drift_features: list):
    """Record Prometheus metrics."""
    accuracy_gauge.labels(model_name=model_name).set(accuracy)
    latency_histogram.labels(model_name=model_name).observe(latency)

    for feature in drift_features:
        drift_counter.labels(model_name=model_name, feature=feature).inc()

Retraining Triggers

Set up automatic retraining based on drift signals:

class RetrainingTrigger:
    """Determine when to retrain based on monitoring signals."""

    def __init__(self,
                 accuracy_threshold: float = 0.85,
                 drift_feature_threshold: int = 2,
                 consecutive_alerts_threshold: int = 3):
        self.accuracy_threshold = accuracy_threshold
        self.drift_feature_threshold = drift_feature_threshold
        self.consecutive_alerts_threshold = consecutive_alerts_threshold
        self.alert_count = 0

    def should_retrain(self, metrics: dict, drift_results: list) -> dict:
        """Determine if retraining is needed."""
        reasons = []

        # Check accuracy below threshold
        if metrics['accuracy'] < self.accuracy_threshold:
            reasons.append(f"Accuracy {metrics['accuracy']:.2%} below {self.accuracy_threshold:.2%}")

        # Check number of features with drift
        drifted_features = [r.feature for r in drift_results if r.is_drift]
        if len(drifted_features) >= self.drift_feature_threshold:
            reasons.append(f"Drift in {len(drifted_features)} features: {drifted_features}")

        # Track consecutive alerts
        if metrics.get('alerts'):
            self.alert_count += 1
        else:
            self.alert_count = 0

        if self.alert_count >= self.consecutive_alerts_threshold:
            reasons.append(f"{self.alert_count} consecutive batches with alerts")

        return {
            'should_retrain': len(reasons) > 0,
            'reasons': reasons,
            'urgency': 'high' if len(reasons) > 1 else 'medium' if reasons else 'none'
        }

Monitoring Dashboard Architecture

┌─────────────────────────────────────────────────────────────┐
│                    ML MONITORING STACK                       │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌──────────────┐   ┌──────────────┐   ┌──────────────┐    │
│  │   Model      │   │   Data       │   │  Performance │    │
│  │   Drift      │   │   Drift      │   │  Metrics     │    │
│  │   Detector   │   │   Detector   │   │  Collector   │    │
│  └──────┬───────┘   └──────┬───────┘   └──────┬───────┘    │
│         │                  │                  │             │
│         └─────────────────┬┴─────────────────┘             │
│                           │                                 │
│                    ┌──────▼──────┐                         │
│                    │  Metrics    │                         │
│                    │  Aggregator │                         │
│                    └──────┬──────┘                         │
│                           │                                 │
│         ┌─────────────────┼─────────────────┐              │
│         │                 │                 │              │
│   ┌─────▼─────┐    ┌─────▼─────┐    ┌─────▼─────┐        │
│   │CloudWatch │    │Prometheus │    │  Slack    │        │
│   │           │    │+ Grafana  │    │  Alerts   │        │
│   └───────────┘    └───────────┘    └───────────┘        │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Production Checklist

Before deploying monitoring:

Data Drift

  • Store reference distributions from training data
  • Choose appropriate statistical tests (KS for numerical, chi-squared for categorical)
  • Set alert thresholds based on business impact
  • Monitor all features or prioritize based on feature importance

Model Drift

  • Track prediction distribution over time
  • Monitor confidence/uncertainty metrics
  • Compare against baseline from validation set

Performance

  • Set accuracy baseline from validation
  • Set latency baseline from load testing
  • Configure alerting thresholds
  • Test alert delivery (don’t let alerts fail silently)

Retraining

  • Define retraining triggers
  • Automate data collection for retraining
  • Set up retraining pipeline
  • Implement gradual rollout for retrained models

What’s Next

You’re now monitoring your models in production. Complete the series:

  • ML Security - IAM roles, secrets management, VPC configuration

Full Code

Complete monitoring implementation: largo-tutorials/ml-monitoring

Found this helpful?
0

Comments

Loading comments...