Mamba for Predictive Maintenance: State Space Models vs Transformers

Deep Dive 25 min read December 29, 2025 |
0

Compare Mamba's selective state space architecture against LSTM and Transformer for hard drive failure prediction. Learn when SSMs beat attention.

State Space Models (SSMs) like Mamba are challenging Transformers as the go-to architecture for sequence modeling. But do they actually work better for real-world time series tasks like predictive maintenance?

In this tutorial, we’ll implement a Mamba-based failure prediction model using the Backblaze hard drive dataset and compare it head-to-head against LSTM and Transformer baselines. Spoiler: the results are surprising.

What We’re Building

A comparison study using Backblaze’s Q4 2023 production data (~1.9M drive-day records after preprocessing). We explore two problem framings:

Binary Classification (0.01% positive rate):

ModelTest AUCTest F2PrecisionRecall
XGBoost0.9200.0952.2%54%
Mamba0.9010.0170.36%69%

Time-to-Failure Regression (predicting days until failure):

ModelMAE (days)Within 5 days
Mamba8.740.47848.2%
LSTM10.120.32739.3%
Transformer9.520.39239.3%

Why State Space Models?

Transformers revolutionized NLP, but they have a problem: quadratic complexity. Self-attention computes pairwise relationships between all tokens, giving O(n²) time and memory complexity.

ArchitectureComplexityMemoryParallelizable
TransformerO(n²)HighYes
LSTMO(n)ModerateNo
MambaO(n)LowYes

Mamba gets the best of both worlds: linear complexity like RNNs, but parallelizable like Transformers.

The Selective State Space Mechanism

Traditional state space models use fixed transition matrices. Mamba’s key innovation is making these input-dependent:

# Traditional SSM (fixed parameters)
h_t = A @ h_{t-1} + B @ x_t
y_t = C @ h_t

# Mamba SSM (input-dependent selection)
delta_t, B_t, C_t = project(x_t)  # Parameters depend on input!
h_t = exp(delta_t * A) @ h_{t-1} + delta_t * B_t @ x_t
y_t = C_t @ h_t

This selectivity lets Mamba filter out irrelevant information from sequences—crucial for sensor data where most readings are noise.

The Backblaze Dataset

We’re using Backblaze’s Q4 2023 hard drive data. After loading and preprocessing, we work with 1.86 million daily SMART readings from 266,531 drives.

from data_pipeline import download_dataset, extract_dataset, load_daily_files

# Download and extract (1GB)
zip_path = download_dataset("data_Q4_2023.zip")
data_dir = extract_dataset(zip_path)

# Load daily files
df = load_daily_files(data_dir, max_files=7)
print(f"Loaded {len(df):,} drive-day records")
Output

Loaded 1,862,455 drive-day records Preprocessed: 1,862,455 records, 266,531 drives

The Class Imbalance Challenge

Hard drives rarely fail—only 0.12% of samples are positive. This extreme imbalance is why F2 score matters more than accuracy:

print(f"Failure rate: {100*df['failure'].mean():.3f}%")
print(f"Positive samples (will fail in 7 days): {y_train.sum():.0f}")
print(f"Negative samples: {len(y_train) - y_train.sum():.0f}")
Output

Failure rate: 0.12% Positive samples (will fail in 7 days): 22 Negative samples: 18,558

Implementing the Mamba Model

Architecture Overview

Our Conv-Mamba hybrid follows the pattern from the original tutorial’s Conv-Transformer:

Input (batch, seq_len=3, features=12)
  → Conv1D (extract local patterns across SMART attributes)
  → Mamba blocks (capture temporal dependencies with O(n) complexity)
  → Global average pooling
  → Classification head

The Conv1D Preprocessing

Just like with Transformers, we found that Conv1D preprocessing dramatically improves Mamba’s performance:

# Conv1D extracts cross-feature patterns
self.conv1 = nn.Conv1d(input_dim, d_model, kernel_size=3, padding=1)
self.conv_norm = nn.LayerNorm(d_model)

Using the Official Mamba Implementation

The mamba-ssm package provides CUDA-optimized Mamba layers:

from mamba_ssm import Mamba

class OfficialMambaClassifier(nn.Module):
    def __init__(self, input_dim, d_model=64, d_state=16,
                 d_conv=4, expand=2, num_layers=2, dropout=0.1):
        super().__init__()

        # Conv1D feature extractor
        self.conv1 = nn.Conv1d(input_dim, d_model, kernel_size=3, padding=1)
        self.conv_norm = nn.LayerNorm(d_model)

        # Official Mamba layers
        self.mamba_layers = nn.ModuleList([
            Mamba(
                d_model=d_model,
                d_state=d_state,  # SSM state dimension
                d_conv=d_conv,    # Local convolution width
                expand=expand,    # Expansion factor
            )
            for _ in range(num_layers)
        ])

        self.norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, 1)
        )

    def forward(self, x):
        # Conv1D preprocessing
        x = x.transpose(1, 2)
        x = torch.relu(self.conv1(x))
        x = x.transpose(1, 2)
        x = self.conv_norm(x)

        # Mamba layers with residual connections
        for mamba, norm in zip(self.mamba_layers, self.norms):
            residual = x
            x = norm(x)
            x = mamba(x)
            x = self.dropout(x) + residual

        # Pool and classify
        x = x.mean(dim=1)  # Global average pooling
        logits = self.classifier(x).squeeze(-1)
        return torch.sigmoid(logits)

With extreme class imbalance, the default 0.5 threshold produces zero positive predictions. We search for the optimal threshold:

def find_optimal_threshold(labels, probs):
    """Find threshold that maximizes F2 score."""
    best_f2 = 0
    best_threshold = 0.5
    for threshold in np.arange(0.05, 0.95, 0.05):
        preds = (probs >= threshold).astype(int)
        precision = precision_score(labels, preds, zero_division=0)
        recall = recall_score(labels, preds, zero_division=0)
        f2 = (1 + 4) * precision * recall / (4 * precision + recall + 1e-9)
        if f2 > best_f2:
            best_f2 = f2
            best_threshold = threshold
    return best_threshold, best_f2

Training Results

python model_mamba.py
Output

============================================================ Mamba Model - Hard Drive Failure Prediction

Data shapes: X_train: torch.Size([1647755, 3, 12]) X_val: torch.Size([1620903, 3, 12]) X_test: torch.Size([4280336, 3, 12]) Positive rate (train): 0.38%

Training on cuda Using official mamba-ssm implementation (faster)

Parameters: 70,273

Training Mamba… Epoch 1: Train Loss=0.3107, Val Loss=0.2389, Val F2=0.0572 (thr=0.90), Val AUC=0.8356 Epoch 2: Train Loss=0.1930, Val Loss=0.1969, Val F2=0.0474 (thr=0.90), Val AUC=0.8260 Epoch 3: Train Loss=0.1486, Val Loss=0.1558, Val F2=0.0525 (thr=0.90), Val AUC=0.8265 Epoch 4: Train Loss=0.1272, Val Loss=0.1533, Val F2=0.0451 (thr=0.90), Val AUC=0.8187 Epoch 5: Train Loss=0.1121, Val Loss=0.1334, Val F2=0.0485 (thr=0.90), Val AUC=0.8212 Epoch 6: Train Loss=0.0792, Val Loss=0.1020, Val F2=0.0532 (thr=0.90), Val AUC=0.8264

============================================================ Test Set Evaluation

Precision: 0.0036 Recall: 0.6900 F1: 0.0071 F2: 0.0174 AUC: 0.9014

Model Comparison

Head-to-Head Results

MetricXGBoostTransformerLSTMMamba
Parameters-70,017139,39370,273
Test AUC0.9200.9160.9070.901
Test F20.0950.0060.0060.017
Precision2.2%0.12%0.11%0.36%
Recall54%80%78%69%

Analysis

XGBoost dominates with engineered features. Traditional ML with handcrafted rolling statistics (7-day mean, std, delta) achieves 5-6x better F2 than neural approaches. With 0.01% positive rate, feature engineering beats end-to-end learning.

All neural models achieve similar AUC (~0.90). They can all rank failures well, but struggle to find a good precision-recall tradeoff at any threshold. High recall comes at the cost of massive false positives.

Mamba offers 3x better F2 than LSTM/Transformer. Among neural models, Mamba’s selective filtering provides slightly better discrimination, but the advantage is modest compared to the XGBoost gap.

Improving Neural Models with Focal Loss

Standard BCE loss treats all 4.3M negatives equally, drowning out the signal from rare failures. Focal Loss downweights easy examples and focuses learning on hard boundary cases:

class FocalLoss(nn.Module):
    """Focal Loss - focuses on hard examples by downweighting easy ones."""
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha  # Weight for positive class
        self.gamma = gamma  # Focusing parameter

    def forward(self, pred, target):
        bce = F.binary_cross_entropy(pred, target, reduction='none')
        pt = torch.where(target == 1, pred, 1 - pred)
        alpha_t = torch.where(target == 1, self.alpha, 1 - self.alpha)
        focal_weight = alpha_t * (1 - pt) ** self.gamma
        return (focal_weight * bce).mean()

BCE vs Focal Loss Results

Loss FunctionAUCPrecisionRecallF2Predictions
BCE0.8870.7%43%0.03424,077
Focal (γ=2)0.8672.7%21%0.0893,044

Focal Loss improves F2 by 2.6x and precision by 3.8x. The model makes 8x fewer predictions but with much higher confidence. The tradeoff is lower recall (43% → 21%), but for F2 this is a better operating point.

Why Our Results Differ from Literature

Some papers report 95% precision with 67% recall on Backblaze data. Why can’t we achieve this?

Key Differences

FactorLiterature (2017)Our Experiment (2024)
Failures in training2,58693
Drive modelsSingle (or few)81 mixed
Data year2014Q4 2023
Failure rateHigher0.01%

1. We have 28x fewer failures. The 2017 paper had 2,586 failures to learn from. With only 93 failures in Q4 2023, models can’t learn robust patterns.

2. Mixed drive models hurt performance. Research consistently shows that SMART attributes mean different things for different drive models. Training on ST4000DM000 alone improves precision from 0.7% to 3-6%.

3. Modern drives fail differently. 2023 drives may have different failure modes than 2014 drives, making historical patterns less predictive.

What Actually Works

Based on recent literature and our experiments:

  1. Single drive model training - Don’t mix ST4000DM000 with TOSHIBA MG07ACA14TA
  2. Modified Focal Loss + Weighted CE - Combines focal loss with class weights
  3. K-means undersampling - Cluster negatives, sample from each cluster
  4. Survival analysis framing - Predict time-to-failure instead of binary classification
  5. Key SMART attributes - Focus on 5, 187, 188, 197, 198 which Backblaze identified as most predictive

Better Framing: Time-to-Failure Regression

Binary classification with 0.01% positive rate doesn’t play to Mamba’s strengths. What if we reframe the problem as regression: predicting how many days until a drive fails?

This approach:

  1. Uses continuous targets (days) instead of binary labels
  2. Trains only on failed drives (where we know the actual TTF)
  3. Leverages Mamba’s ability to model temporal degradation patterns

Time-to-Failure Results

Training on ST4000DM000 drives that actually failed, predicting days until failure:

ModelMAE (days)RMSEWithin 5dWithin 10d
Mamba8.7412.640.47848.2%69.3%
Transformer9.5213.640.39239.3%68.0%
LSTM10.1214.350.32739.3%65.9%

Why Regression Works Better

The key insight: survival analysis framing lets Mamba learn from temporal patterns rather than trying to discriminate rare events from noise.

# Instead of: "Will this drive fail in the next 7 days?" (binary, 0.01% positive)
# We ask: "How many days until this drive fails?" (continuous, all training samples useful)

class MambaRegressor(nn.Module):
    def __init__(self, input_dim, d_model=64):
        super().__init__()
        self.input_norm = nn.LayerNorm(input_dim)
        self.proj = nn.Linear(input_dim, d_model)
        self.mamba = Mamba(d_model=d_model, d_state=16, d_conv=4, expand=2)
        self.fc = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.input_norm(x)
        x = self.proj(x)
        x = self.mamba(x)
        return torch.clamp(self.fc(x[:, -1, :]).squeeze(-1), 0, 100)

With this approach, Mamba’s O(n) selective filtering effectively tracks the degradation trajectory in SMART readings—exactly what it was designed for.

When to Use Mamba

Based on our experiments and recent literature:

Use CaseBest Architecture
Tabular data with extreme imbalanceXGBoost (feature engineering wins)
Time-to-failure / survival analysisMamba (temporal degradation modeling)
Long sequences (1000+ steps)Mamba (linear complexity wins)
Very long sequences (8k+ tokens)Mamba-Transformer hybrid
Balanced classification tasksTransformer or Mamba
Need attention interpretabilityTransformer
Production inference speedMamba (no KV cache)

Full Code

The complete implementation is available on GitHub:

git clone https://github.com/StoliRocks/largo-tutorials
cd largo-tutorials/predictive-maintenance/hard-drive-failure

# Install dependencies
pip install torch mamba-ssm pandas scikit-learn tqdm requests

# Run the pipeline
python data_pipeline.py  # Download and preprocess
python model_mamba.py    # Train Mamba (binary classification)
python train_ttf_regression.py  # Train all models (time-to-failure regression)

What’s Next

References

Mamba & State Space Models:

Hard Drive Failure Prediction:

Data:

Found this helpful?
0

Comments

Loading comments...