-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpipeline.py
More file actions
executable file
·217 lines (180 loc) · 7.37 KB
/
pipeline.py
File metadata and controls
executable file
·217 lines (180 loc) · 7.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python3
"""
Full ARC Diffusion Pipeline Runner
Runs the complete training pipeline in sequence:
1. Diffusion model training (with integrated size head)
2. Model evaluation
Usage:
uv run python pipeline.py --config configs/test_config.json
uv run python pipeline.py --config configs/test_config.json --skip-training
"""
import argparse
import json
import sys
import subprocess
from pathlib import Path
import time
import datetime
def run_command(command: list, description: str, cwd: str = None) -> bool:
"""Run a command and return True if successful."""
print(f"\n{'='*60}")
print(f"🚀 {description}")
print(f"Command: {' '.join(command)}")
print(f"{'='*60}")
start_time = time.time()
try:
result = subprocess.run(
command,
cwd=cwd,
check=True,
capture_output=False, # Show output in real-time
text=True
)
elapsed = time.time() - start_time
print(f"\n✅ {description} completed successfully in {elapsed:.1f}s")
return True
except subprocess.CalledProcessError as e:
elapsed = time.time() - start_time
print(f"\n❌ {description} failed after {elapsed:.1f}s")
print(f"Exit code: {e.returncode}")
return False
except KeyboardInterrupt:
print(f"\n⏹️ {description} interrupted by user")
return False
except Exception as e:
elapsed = time.time() - start_time
print(f"\n💥 {description} failed with error after {elapsed:.1f}s: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Run full ARC diffusion training pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Run full pipeline (training + evaluation)
uv run pipeline.py --config configs/my_config.json
# Skip training, only run evaluation
uv run pipeline.py --config configs/my_config.json --skip-training
# Only run training
uv run pipeline.py --config configs/my_config.json --skip-evaluation
"""
)
parser.add_argument("--config", required=True, help="Path to config JSON file")
parser.add_argument("--skip-training", action="store_true", help="Skip diffusion model training")
parser.add_argument("--skip-evaluation", action="store_true", help="Skip evaluation")
parser.add_argument("--eval-limit", type=int, default=0, help="Limit evaluation to N tasks (default: 0 for all)")
parser.add_argument("--prefer-best", action="store_true", help="Use best_model.pt for evaluation instead of final_model.pt")
args = parser.parse_args()
# Validate config file
config_path = Path(args.config)
if not config_path.exists():
print(f"❌ Config file not found: {config_path}")
sys.exit(1)
# Load and validate config
try:
with open(config_path, 'r') as f:
config = json.load(f)
except Exception as e:
print(f"❌ Failed to load config file: {e}")
sys.exit(1)
# Extract key paths
output_dir = Path(config.get('output', {}).get('output_dir', 'outputs/default'))
print(f"🎯 ARC Diffusion Full Pipeline")
print(f"📅 Started at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"📁 Config: {config_path}")
print(f"📂 Output dir: {output_dir}")
print(f"⚡ Evaluation limit: {args.eval_limit}")
# Track which steps to run
steps_to_run = []
if not args.skip_training:
steps_to_run.append("training (with integrated size head)")
if not args.skip_evaluation:
steps_to_run.append("evaluation")
print(f"📋 Pipeline steps: {' → '.join(steps_to_run)}")
# Get project root (assuming we're running from repo root)
project_root = Path.cwd()
# Step 1: Diffusion model training (with integrated size head)
if not args.skip_training:
if not run_command(
["uv", "run", "python", "train_diffusion_backbone.py", "--config", str(config_path)],
"Diffusion model training (with integrated size head)",
cwd=str(project_root)
):
print("❌ Training failed, stopping pipeline")
sys.exit(1)
else:
print("\n⏭️ Skipping diffusion model training")
# Step 2: Evaluation
if not args.skip_evaluation:
# Check if model exists
best_model_path = output_dir / "best_model.pt"
final_model_path = output_dir / "final_model.pt"
# Determine which model will be used based on --prefer-best flag
if args.prefer_best:
if best_model_path.exists():
print(f"✓ Using best model for evaluation: {best_model_path}")
elif final_model_path.exists():
print(f"⚠️ Using final model for evaluation (best model not found): {final_model_path}")
else:
print(f"❌ No model found: tried {best_model_path} and {final_model_path}")
print(" Cannot run evaluation without trained model")
sys.exit(1)
else:
if final_model_path.exists():
print(f"✓ Using final model for evaluation: {final_model_path}")
elif best_model_path.exists():
print(f"⚠️ Using best model for evaluation (final model not found): {best_model_path}")
else:
print(f"❌ No model found: tried {final_model_path} and {best_model_path}")
print(" Cannot run evaluation without trained model")
sys.exit(1)
# Run evaluation
eval_command = [
"uv", "run", "python", "evaluate.py",
"--config", str(config_path),
"--limit", str(args.eval_limit),
"--maj",
"--stats",
"--prefer-best"
]
# Add --prefer-best flag if specified
if args.prefer_best:
eval_command.append("--prefer-best")
if not run_command(
eval_command,
f"Model evaluation (limit: {args.eval_limit} tasks)",
cwd=str(project_root)
):
print("❌ Evaluation failed")
sys.exit(1)
else:
print("\n⏭️ Skipping evaluation")
# Step 3: Upload to Hugging Face Hub
if not run_command(
["uv", "run", "python", "hf.py", "--push", "--config", str(config_path)],
"Upload to Hugging Face Hub",
cwd=str(project_root)
):
print("⚠️ HF upload failed (continuing anyway)")
# Pipeline completed
total_time = time.time()
print(f"\n{'='*60}")
print(f"🎉 Pipeline completed successfully!")
print(f"📂 Results saved in: {output_dir}")
print(f"📅 Finished at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Show key output files
print(f"\n📋 Key output files:")
model_path = output_dir / "best_model.pt"
if model_path.exists():
print(f" 🧠 Model (with integrated size head): {model_path}")
# Find evaluation results
eval_files = list(output_dir.glob("evaluation_*.json"))
if eval_files:
latest_eval = max(eval_files, key=lambda x: x.stat().st_mtime)
print(f" 📊 Evaluation: {latest_eval}")
visualization_path = output_dir / "training_noise_visualization.png"
if visualization_path.exists():
print(f" 🎨 Visualization: {visualization_path}")
print(f"{'='*60}")
if __name__ == "__main__":
main()