Skip to content

Commit 3bf934d

Browse files
s-noghabiThe tunix Authors
authored andcommitted
Instrument agentic loop with perf v2
PiperOrigin-RevId: 875514876
1 parent efb4913 commit 3bf934d

File tree

17 files changed

+1814
-44
lines changed

17 files changed

+1814
-44
lines changed

examples/deepscaler/train_deepscaler_nb.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
from tunix.utils import math_rewards
7171
from tunix.utils import compat
7272
from tunix.cli.utils import data as data_lib
73+
from tunix import PerfMetricsConfig
74+
from tunix.perf.experimental.export import PerfMetricsExport
7375

7476
try:
7577
import pathwaysutils
@@ -109,7 +111,7 @@
109111
# The number of times the policy generates multiple responses for a given prompt
110112
# within a single training step. This corresponds to `G` in Algorithm 1 in the
111113
# paper. The "group" in GRPO comes from here.
112-
NUM_GENERATIONS = 8
114+
NUM_GENERATIONS = 2
113115

114116
# === other GRPO configs ===
115117
# The number of iterations per batch (𝜇 in GRPO algo 1).
@@ -125,15 +127,15 @@
125127

126128
# ====== Training ======
127129
ENABLE_REMAT = True
128-
BATCH_SIZE = 128
129-
MINI_BATCH_SIZE = 64
130+
BATCH_SIZE = 4
131+
MINI_BATCH_SIZE = 2
130132
NUM_BATCHES = 100
131133
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
132134
# increased to a max. of 330 (if batch size is 4).
133135
NUM_TEST_BATCHES = 50
134136

135-
EVAL_EVERY_N_STEPS = 1000 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
136-
NUM_EPOCHS = 100 # can potentially train for more epochs
137+
EVAL_EVERY_N_STEPS = 50 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
138+
NUM_EPOCHS = 10 # can potentially train for more epochs
137139

138140
# Number of training steps.
139141
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
@@ -529,13 +531,20 @@ def get_lora_model(base_model, model_mesh):
529531
max_concurrency=MAX_CONCURRENCY,
530532
)
531533

534+
# Perf Metrics logging
535+
perf_metrics_config = PerfMetricsConfig()
536+
perf_metrics_config.custom_export_fn_v2 = PerfMetricsExport(
537+
"/tmp/agentic_perf"
538+
).export_metrics
539+
532540
# %%
533541
# RL cluster
534542
rl_cluster = rl_cluster_lib.RLCluster(
535543
actor=qwen2_actor,
536544
reference=qwen2_ref,
537545
tokenizer=tokenizer,
538546
cluster_config=cluster_config,
547+
perf_config=perf_metrics_config,
539548
)
540549

541550
show_hbm_usage("after RLCluster creation")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Tests for export."""
2+
3+
import os
4+
import pathlib
5+
import time
6+
from absl.testing import absltest
7+
from tunix.perf.experimental import export
8+
from tunix.perf.experimental import tracer
9+
10+
11+
class ExportTest(absltest.TestCase):
12+
13+
def test_perf_metrics_export(self):
14+
# Backward compatibility check
15+
tmp_dir = pathlib.Path(self.create_tempdir().full_path)
16+
exporter = export.PerfMetricsExport(trace_dir=tmp_dir)
17+
18+
# Create dummy timeline
19+
t = tracer.PerfTracer(export_fn=exporter.export_metrics)
20+
with t.span("test_span"):
21+
time.sleep(0.001)
22+
t.export()
23+
24+
files = os.listdir(tmp_dir)
25+
self.assertLen(files, 1)
26+
self.assertTrue(files[0].startswith("perfetto_trace_v2_"))
27+
28+
29+
if __name__ == "__main__":
30+
absltest.main()
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for perfetto."""
16+
17+
import os
18+
import tempfile
19+
import time
20+
21+
from absl.testing import absltest
22+
from tunix.perf.experimental import perfetto
23+
from tunix.perf.experimental import tracer
24+
25+
26+
class PerfettoTest(absltest.TestCase):
27+
28+
def test_create_span_name(self):
29+
# Test basic span name with global_step
30+
name = perfetto._create_span_name("my_span", {"global_step": 10})
31+
self.assertEqual(name, "my_span (step=10)")
32+
33+
# Test peft_train_step with role
34+
name = perfetto._create_span_name(
35+
"peft_train_step", {"global_step": 20, "role": "actor"}
36+
)
37+
self.assertEqual(name, "peft_train_step (step=20, role=actor)")
38+
39+
# Test rollout with group_id and pair_index
40+
name = perfetto._create_span_name(
41+
"rollout", {"group_id": 5, "pair_index": 3, "global_step": 100}
42+
)
43+
self.assertEqual(name, "rollout (step=100, group_id=5, pair_index=3)")
44+
45+
# Test rollout with missing pair_index
46+
name = perfetto._create_span_name("rollout", {"group_id": 5})
47+
self.assertEqual(name, "rollout (group_id=5)")
48+
49+
# Test unknown name with extra tags (should ignore specific logic but keep step)
50+
name = perfetto._create_span_name(
51+
"unknown_span", {"role": "actor", "global_step": 50}
52+
)
53+
self.assertEqual(name, "unknown_span (step=50)")
54+
55+
# Test no tags
56+
name = perfetto._create_span_name("simple_span", {})
57+
self.assertEqual(name, "simple_span")
58+
59+
# TODO(noghabi): Add more tests for PerfettoTraceWriter.
60+
def test_perfetto_trace_writer(self):
61+
with tempfile.TemporaryDirectory() as tmp_dir:
62+
writer = perfetto.PerfettoTraceWriter(trace_dir=tmp_dir)
63+
64+
# Create some dummy timelines
65+
t = tracer.Timeline("test_timeline", time.perf_counter())
66+
s = t.start_span("test_span", time.perf_counter())
67+
time.sleep(0.001)
68+
t.stop_span(time.perf_counter())
69+
70+
timelines = {"test_timeline": t}
71+
72+
writer.write_timelines(timelines)
73+
74+
# Check if file was created
75+
files = os.listdir(tmp_dir)
76+
self.assertLen(files, 1)
77+
self.assertTrue(files[0].startswith("perfetto_trace_v2_"))
78+
self.assertTrue(files[0].endswith(".pb"))
79+
80+
# We could parse the proto back to verify content, but just existence is good for now.
81+
82+
83+
if __name__ == "__main__":
84+
absltest.main()
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import threading
16+
from unittest import mock
17+
from absl.testing import absltest
18+
from tunix.perf.experimental import constants
19+
from tunix.perf.experimental import timeline
20+
21+
22+
class SpanTest(absltest.TestCase):
23+
24+
def test_span(self):
25+
s = timeline.Span(name="test", begin=1.0, id=0)
26+
self.assertEqual(s.name, "test")
27+
self.assertEqual(s.begin, 1.0)
28+
self.assertEqual(s.end, float("inf"))
29+
self.assertEqual(s.ended, False)
30+
self.assertEqual(s.duration, float("inf"))
31+
32+
def test_span_with_tags(self):
33+
tags_dict = {constants.GLOBAL_STEP: 1, "custom_tag": "value"}
34+
s = timeline.Span(name="test_tags", begin=1.0, id=0, tags=tags_dict)
35+
self.assertEqual(s.tags, tags_dict)
36+
self.assertIn("tags=", repr(s))
37+
self.assertIn("global_step", repr(s))
38+
39+
def test_add_tag(self):
40+
s = timeline.Span(name="test_add_tag", begin=1.0, id=0)
41+
s.add_tag("foo", "bar")
42+
self.assertEqual(s.tags, {"foo": "bar"})
43+
s.add_tag(constants.GLOBAL_STEP, 100)
44+
self.assertEqual(s.tags, {"foo": "bar", "global_step": 100})
45+
46+
def test_add_tag_overwrite_warning(self):
47+
s = timeline.Span(name="test_add_tag_overwrite", begin=1.0, id=0)
48+
s.add_tag("foo", "bar")
49+
with self.assertLogs(level="WARNING") as cm:
50+
s.add_tag("foo", "baz")
51+
self.assertEqual(s.tags, {"foo": "baz"})
52+
self.assertTrue(
53+
any(
54+
"Tag 'foo' already exists with value 'bar'. Overwriting with 'baz'."
55+
in o
56+
for o in cm.output
57+
)
58+
)
59+
60+
def test_repr_with_born_at(self):
61+
born_at = 100.0
62+
s = timeline.Span(name="test_born_at", begin=101.0, id=0)
63+
s.end = 105.0
64+
65+
# Check default repr (born_at=0.0)
66+
expected_default = "[0] test_born_at: 101.000000, 105.000000"
67+
self.assertEqual(repr(s), expected_default)
68+
69+
# Check repr with explicit born_at
70+
expected_adjusted = "[0] test_born_at: 1.000000, 5.000000"
71+
self.assertEqual(s.__repr__(born_at=born_at), expected_adjusted)
72+
73+
74+
class TimelineTest(absltest.TestCase):
75+
76+
def test_basic_span_lifecycle(self):
77+
t = timeline.Timeline("test_tl", 100.0)
78+
s = t.start_span("span1", 101.0)
79+
self.assertEqual(s.name, "span1")
80+
self.assertEqual(s.begin, 101.0)
81+
self.assertEqual(s.id, 0)
82+
self.assertIsNone(s.parent_id)
83+
self.assertFalse(s.ended)
84+
85+
t.stop_span(102.0)
86+
self.assertTrue(s.ended)
87+
self.assertEqual(s.end, 102.0)
88+
89+
def test_nested_spans(self):
90+
t = timeline.Timeline("test_tl", 0.0)
91+
s1 = t.start_span("root", 1.0)
92+
s2 = t.start_span("child", 2.0)
93+
94+
self.assertEqual(s2.parent_id, s1.id)
95+
96+
t.stop_span(3.0) # stops s2
97+
self.assertEqual(s2.end, 3.0)
98+
99+
t.stop_span(4.0) # stops s1
100+
self.assertEqual(s1.end, 4.0)
101+
102+
def test_stop_span_error_cases(self):
103+
t = timeline.Timeline("test_tl", 0.0)
104+
with self.assertRaisesRegex(ValueError, "no more spans to end"):
105+
t.stop_span(1.0)
106+
107+
s = t.start_span("s1", 2.0)
108+
# End before begin
109+
with mock.patch("absl.logging.error") as mock_log:
110+
t.stop_span(1.0)
111+
mock_log.assert_called_once()
112+
self.assertEqual(s.end, 1.0)
113+
114+
def test_nested_timeline_with_tags_repr(self):
115+
born = 1000.0
116+
t = timeline.Timeline("test_tl", born)
117+
118+
# Start root
119+
s_root = t.start_span("root", born + 1.0)
120+
s_root.add_tag("type", "root_span")
121+
122+
# Start nested
123+
s_child = t.start_span("child", born + 2.0)
124+
s_child.add_tag("iter", 1)
125+
126+
# Stop nested
127+
t.stop_span(born + 3.0)
128+
129+
# Stop root
130+
t.stop_span(born + 4.0)
131+
132+
# Check tags are stored correctly
133+
self.assertEqual(s_root.tags, {"type": "root_span"})
134+
self.assertEqual(s_child.tags, {"iter": 1})
135+
136+
# Check full repr string
137+
expected_repr = (
138+
f"Timeline(test_tl, {born:.6f})\n"
139+
"[0] root: 1.000000, 4.000000, tags={'type': 'root_span'}\n"
140+
"[1] child: 2.000000, 3.000000 (parent=0), tags={'iter': 1}\n"
141+
)
142+
self.assertEqual(repr(t), expected_repr)
143+
144+
145+
class AsyncTimelineTest(absltest.TestCase):
146+
147+
def setUp(self):
148+
self.patcher = mock.patch("tunix.perf.experimental.timeline._async_wait")
149+
self.mock_async_wait = self.patcher.start()
150+
151+
# Setup mock behavior for _async_wait to immediately succeed by default
152+
def default_wait(waitlist, success, failure):
153+
success()
154+
return mock.Mock(spec=threading.Thread)
155+
156+
self.mock_async_wait.side_effect = default_wait
157+
158+
def tearDown(self):
159+
self.patcher.stop()
160+
161+
def test_span_success(self):
162+
t = timeline.AsyncTimeline("dev", 0.0)
163+
waitlist = ["thing"]
164+
165+
t.span("async_op", 1.0, waitlist)
166+
167+
self.mock_async_wait.assert_called_once()
168+
self.assertEqual(len(t.spans), 1)
169+
s = t.spans[0]
170+
self.assertEqual(s.name, "async_op")
171+
self.assertEqual(s.begin, 1.0)
172+
self.assertTrue(s.ended) # Ended because mock calls success immediately
173+
174+
def test_span_with_no_waitlist(self):
175+
t = timeline.AsyncTimeline("dev", 0.0)
176+
t.span("immediate", 1.0, [])
177+
self.mock_async_wait.assert_not_called()
178+
self.assertEqual(len(t.spans), 1)
179+
self.assertTrue(t.spans[0].ended)
180+
181+
def test_delayed_completion(self):
182+
t = timeline.AsyncTimeline("dev", 0.0)
183+
184+
# Capture callbacks
185+
callbacks = {}
186+
187+
def capture_wait(waitlist, success, failure):
188+
callbacks["success"] = success
189+
callbacks["failure"] = failure
190+
return mock.Mock(spec=threading.Thread)
191+
192+
self.mock_async_wait.side_effect = capture_wait
193+
194+
t.span("delayed", 1.0, ["wait"])
195+
196+
self.assertEqual(len(t.spans), 0) # Not yet recorded
197+
198+
# Simulate completion
199+
with mock.patch("time.perf_counter", return_value=5.0):
200+
callbacks["success"]()
201+
202+
self.assertEqual(len(t.spans), 1)
203+
s = t.spans[0]
204+
self.assertEqual(s.end, 5.0)
205+
206+
def test_failure(self):
207+
t = timeline.AsyncTimeline("dev", 0.0)
208+
209+
def fail_wait(waitlist, success, failure):
210+
failure(RuntimeError("failed"))
211+
return mock.Mock(spec=threading.Thread)
212+
213+
self.mock_async_wait.side_effect = fail_wait
214+
215+
with self.assertRaisesRegex(RuntimeError, "failed"):
216+
t.span("failed", 1.0, ["wait"])
217+
218+
219+
if __name__ == "__main__":
220+
absltest.main()

0 commit comments

Comments
 (0)