Comprehensive guide to evaluating trained H-JEPA models.
# 1. Linear Probing (best metric for SSL)
python3.11 scripts/eval_linear_probe.py \
--checkpoint results/my_exp/checkpoints/best_model.pt \
--dataset cifar10 \
--device mps
# 2. k-NN Evaluation (fast, no training)
python3.11 scripts/eval_knn.py \
--checkpoint results/my_exp/checkpoints/best_model.pt \
--dataset cifar10 \
--k 1 5 10 20
# 3. Transfer Learning (comprehensive)
python3.11 scripts/eval_transfer.py \
--checkpoint results/my_exp/checkpoints/best_model.pt \
--datasets cifar10 cifar100 stl10 \
--linear-probe \
--knnWhat it measures: Quality of learned features for supervised tasks
How it works:
- Freeze pretrained encoder
- Train a linear classifier on top
- Measure classification accuracy
Interpretation:
- 70-80%: Good representations
- 80-85%: Very good representations
- 85%+: Excellent representations (on CIFAR-10)
Usage:
python3.11 scripts/eval_linear_probe.py \
--checkpoint PATH \
--dataset cifar10 \
--epochs 100 \
--lr 0.001 \
--hierarchy-level -1 # Use top hierarchyOutput:
Best validation accuracy: 82.45%
Final validation accuracy: 82.31%
Results saved to results/linear_probe/checkpoint_cifar10_results.json
What it measures: Clustering quality in feature space
How it works:
- Extract features from train/test sets
- For each test sample, find k nearest neighbors
- Predict via majority vote
Interpretation:
- No training required (fast!)
- Pure measure of representation quality
- Lower than linear probe (expected)
Usage:
python3.11 scripts/eval_knn.py \
--checkpoint PATH \
--dataset cifar10 \
--k 1 5 10 20 \
--temperature 0.07 # Softmax temperatureOutput:
k= 1: 75.32%
k= 5: 78.91%
k=10: 79.45%
k=20: 79.12%
Best k value: Typically 10-20 for CIFAR, 5-10 for ImageNet
What it measures: Generalization to different datasets
How it works:
- Pretrain on dataset A (e.g., CIFAR-10)
- Evaluate on dataset B (e.g., STL-10)
- Compare performance
Usage:
python3.11 scripts/eval_transfer.py \
--checkpoint PATH \
--datasets cifar10 cifar100 stl10 \
--linear-probe \
--knnOutput:
Dataset Linear Probe k-NN (k=20)
CIFAR-10 82.45% 79.12%
CIFAR-100 54.32% 48.91%
STL-10 85.67% 82.34%
Interactive exploration of trained models.
python3.11 scripts/explore_model.py \
--checkpoint PATH \
--device mps \
--output-dir results/exploration \
--sample-idx 0Generates:
attention_maps.png- Multi-head attention patternshierarchical_representations.png- Multi-scale featuresmasked_prediction.png- Prediction demoembedding_similarity.png- Feature space analysis
Understand what features detect.
python3.11 scripts/visualize_features.py \
--checkpoint PATH \
--dataset cifar10 \
--hierarchy 0 1 2 \
--num-samples 200Generates:
- Feature activation maps per channel
- Top activating patches
- Feature statistics and distributions
Visualize aggregated attention across layers.
python3.11 scripts/visualize_attention_rollout.py \
--checkpoint PATH \
--sample-idx 0 10 20 \
--discard-ratio 0.1Generates:
- Attention rollout heatmaps
- Layer-by-layer attention comparison
- Attention overlays on images
Note: Disables Flash Attention to extract weights
Jupyter notebook for hands-on exploration.
cd notebooks
jupyter notebook explore_hjepa.ipynbFeatures:
- Interactive sample browser
- Attention visualization
- Similarity search
- Quick k-NN evaluation
- Feature export
CIFAR-10:
- Linear Probe: 78-85%
- k-NN (k=20): 75-80%
CIFAR-100:
- Linear Probe: 50-60%
- k-NN (k=20): 45-55%
STL-10:
- Linear Probe: 80-88%
- k-NN (k=20): 78-85%
ImageNet-1K:
- Linear Probe: 65-75%
- k-NN (k=200): 55-65%
CIFAR-10 Linear Probe (ViT-Base):
Method Accuracy
SimCLR 76.5%
DINO 81.3%
MAE 78.2%
H-JEPA (ours) 82.5% ← Target
Training efficiency:
- H-JEPA: 100 epochs → 82%
- MAE: 200 epochs → 78%
- DINO: 300 epochs → 81%
Don't rely on a single metric:
# Comprehensive evaluation
python3.11 scripts/eval_transfer.py \
--checkpoint PATH \
--datasets cifar10 stl10 \
--linear-probe \ # Supervised metric
--knn # Unsupervised metricEvaluate generalization:
# Train on CIFAR-10, test on multiple
python3.11 scripts/eval_transfer.py \
--checkpoint pretrained_cifar10.pt \
--datasets cifar10 cifar100 stl10Good model: Performance holds across datasets Overfit model: Drops significantly on new data
Test each hierarchy level:
for level in -3 -2 -1; do
python3.11 scripts/eval_linear_probe.py \
--checkpoint PATH \
--hierarchy-level $level \
--output-dir results/hierarchy_$level
doneExpected:
- Level 1 (fine): Good for texture tasks
- Level 2 (mid): Good for parts
- Level 3 (coarse): Good for objects/scenes
Evaluate checkpoints during training:
# Automatic evaluation every epoch
./scripts/watch_and_explore.shOr manually:
for epoch in 10 20 30 40 50; do
python3.11 scripts/eval_knn.py \
--checkpoint results/checkpoints/checkpoint_epoch_$epoch.pt \
--dataset cifar10
donePlot results:
- Accuracy vs. Epoch
- Find optimal checkpoint
Create a results file:
# results/my_experiment/results.yaml
experiment:
name: my_experiment
date: 2025-01-17
checkpoint: checkpoint_epoch_100.pt
config:
model: vit_base_patch16_224
dataset: cifar10
epochs: 100
batch_size: 64
lr: 0.001
evaluation:
cifar10:
linear_probe: 82.45%
knn_k20: 79.12%
stl10:
linear_probe: 85.67%
knn_k20: 82.34%
notes: |
- Flash Attention enabled
- RoPE position embeddings
- 3 hierarchies
- Training time: 4.5 hours on M1 MaxPossible causes:
- Undertrained - train longer
- Bad hyperparameters - tune LR, weight decay
- Too small model - use larger ViT
- Wrong hierarchy - try different levels
Solutions:
# 1. Train longer
python3.11 scripts/train.py --config CONFIG --epochs 200
# 2. Tune hyperparameters
python3.11 scripts/train.py --config CONFIG --lr 0.003 --weight-decay 0.1
# 3. Larger model
# Edit config.yaml: encoder_type: vit_large_patch16_224
# 4. Try all hierarchies
for h in -3 -2 -1; do
python3.11 scripts/eval_linear_probe.py --hierarchy-level $h
doneExpected: k-NN is typically 3-5% lower
Too large gap (>10%):
- Features not well normalized
- Need more neighbors (try k=50)
- Temperature tuning needed
Solutions:
# Try more neighbors
python3.11 scripts/eval_knn.py --k 10 20 50 100
# Tune temperature
python3.11 scripts/eval_knn.py --temperature 0.05 # or 0.1, 0.2Solutions:
# Reduce batch size
python3.11 scripts/eval_linear_probe.py --batch-size 64 # default: 256
# Use CPU for linear probe (still fast)
python3.11 scripts/eval_linear_probe.py --device cpu
# Extract features once, train classifier separately
# (Features are cached in eval scripts)#!/bin/bash
# evaluate_all.sh
CHECKPOINT="results/final_model/checkpoints/best_model.pt"
# Linear probing on all datasets
for dataset in cifar10 cifar100 stl10; do
python3.11 scripts/eval_linear_probe.py \
--checkpoint $CHECKPOINT \
--dataset $dataset \
--epochs 100 \
--output-dir results/paper/linear_probe
done
# k-NN evaluation
python3.11 scripts/eval_knn.py \
--checkpoint $CHECKPOINT \
--dataset cifar10 \
--k 1 5 10 20 50 100 \
--output-dir results/paper/knn
# Transfer learning
python3.11 scripts/eval_transfer.py \
--checkpoint $CHECKPOINT \
--datasets cifar10 cifar100 stl10 \
--linear-probe \
--knn \
--output-dir results/paper/transfer
# Visualizations
python3.11 scripts/visualize_features.py \
--checkpoint $CHECKPOINT \
--output-dir results/paper/visualizations
python3.11 scripts/visualize_attention_rollout.py \
--checkpoint $CHECKPOINT \
--sample-idx 0 10 20 30 40 \
--output-dir results/paper/attention# Test each component
for config in baseline +rope +flash +hierarchies; do
python3.11 scripts/train.py --config configs/ablation_$config.yaml
python3.11 scripts/eval_linear_probe.py \
--checkpoint results/ablation_$config/checkpoints/best.pt \
--dataset cifar10
done# scripts/generate_tables.py
import pandas as pd
import json
results = {
'Method': ['Baseline', '+RoPE', '+Flash', '+Hierarchies'],
'CIFAR-10': [78.2, 79.5, 79.4, 82.5],
'CIFAR-100': [52.1, 53.8, 53.9, 56.2],
'STL-10': [82.3, 83.9, 84.1, 85.7],
}
df = pd.DataFrame(results)
print(df.to_latex(index=False)) # For paper
print(df.to_markdown(index=False)) # For READMEAfter evaluation:
- Document results - Create results summary
- Compare to baselines - How does it stack up?
- Identify improvements - What could be better?
- Share checkpoints - Make models available
- Publish findings - Blog post, paper, etc.
Further Reading:
docs/TRAINING_GUIDE.md- Training best practicesdocs/ARCHITECTURE.md- Understanding the modelnotebooks/explore_hjepa.ipynb- Interactive exploration