@@ -28,7 +28,7 @@ struct TraceEventBuilder {
2828 halide_trace_event_code_t event;
2929 Expr parent_id, value_index;
3030
31- Expr build () {
31+ Expr build () const {
3232 Expr values = Call::make (type_of<void *>(), Call::make_struct,
3333 value, Call::Intrinsic);
3434 Expr coords = Call::make (type_of<int32_t *>(), Call::make_struct,
@@ -63,6 +63,8 @@ class InjectTracing : public IRMutator {
6363 map<string, vector<Type>> funcs_touched;
6464 map<string, vector<Type>> images_touched;
6565
66+ string call_stack = " pipeline" ;
67+
6668 InjectTracing (const map<string, Function> &e, const Target &t)
6769 : env(e),
6870 trace_all_loads (t.has_feature(Target::TraceLoads)),
@@ -101,22 +103,21 @@ class InjectTracing : public IRMutator {
101103 }
102104 }
103105
106+ protected:
104107 using IRMutator::visit;
105108
106109 Expr visit (const Call *op) override {
107110 Expr expr = IRMutator::visit (op);
108111 op = expr.as <Call>();
109112 internal_assert (op);
110113 bool trace_it = false ;
111- Expr trace_parent;
112114 if (op->call_type == Call::Halide) {
113115 auto it = env.find (op->name );
114116 internal_assert (it != env.end ()) << op->name << " not in environment\n " ;
115- Function f = it->second ;
117+ const Function & f = it->second ;
116118 internal_assert (!f.can_be_inlined () || !f.schedule ().compute_level ().is_inlined ());
117119
118120 trace_it = trace_all_loads || f.is_tracing_loads ();
119- trace_parent = Variable::make (Int (32 ), op->name + " .trace_id" );
120121 if (trace_it) {
121122 add_trace_tags (op->name , f.get_trace_tags ());
122123 touch (funcs_touched, op->name , op->value_index , op->type );
@@ -125,7 +126,6 @@ class InjectTracing : public IRMutator {
125126 // op->param is defined when we're loading from an ImageParam, and undefined
126127 // when we're loading from an inlined Buffer.
127128 trace_it = trace_all_loads || (op->param .defined () && op->param .is_tracing_loads ());
128- trace_parent = Variable::make (Int (32 ), " pipeline.trace_id" );
129129 if (trace_it) {
130130 if (op->param .defined ()) {
131131 add_trace_tags (op->name , op->param .get_trace_tags ());
@@ -144,7 +144,7 @@ class InjectTracing : public IRMutator {
144144 builder.coordinates = op->args ;
145145 builder.type = op->type ;
146146 builder.event = halide_trace_load;
147- builder.parent_id = trace_parent ;
147+ builder.parent_id = Variable::make ( Int ( 32 ), call_stack + " .trace_id " ) ;
148148 builder.value_index = op->value_index ;
149149 Expr trace = builder.build ();
150150
@@ -177,7 +177,7 @@ class InjectTracing : public IRMutator {
177177 builder.func = f.name ();
178178 builder.coordinates = op->args ;
179179 builder.event = halide_trace_store;
180- builder.parent_id = Variable::make (Int (32 ), op-> name + " .trace_id" );
180+ builder.parent_id = Variable::make (Int (32 ), call_stack + " .trace_id" );
181181 for (size_t i = 0 ; i < values.size (); i++) {
182182 Type t = values[i].type ();
183183 touch (funcs_touched, f.name (), (int )i, t);
@@ -220,111 +220,113 @@ class InjectTracing : public IRMutator {
220220 }
221221
222222 Stmt visit (const Realize *op) override {
223- Stmt stmt = IRMutator::visit (op);
224- op = stmt.as <Realize>();
225- internal_assert (op);
226-
227- map<string, Function>::const_iterator iter = env.find (op->name );
223+ auto iter = env.find (op->name );
228224 if (iter == env.end ()) {
229- return stmt ;
225+ return IRMutator::visit (op) ;
230226 }
227+
231228 Function f = iter->second ;
232- if (f.is_tracing_realizations () || trace_all_realizations) {
233- add_trace_tags (op->name , f.get_trace_tags ());
234- for (size_t i = 0 ; i < op->types .size (); i++) {
235- touch (funcs_touched, op->name , i, op->types [i]);
236- }
237229
238- // Throw a tracing call before and after the realize body
239- TraceEventBuilder builder;
240- builder.func = op->name ;
241- builder.parent_id = Variable::make (Int (32 ), " pipeline.trace_id" );
242- builder.event = halide_trace_begin_realization;
243- for (const auto &bound : op->bounds ) {
244- builder.coordinates .push_back (bound.min );
245- builder.coordinates .push_back (bound.extent );
230+ if (!(f.is_tracing_realizations () || trace_all_realizations)) {
231+ if (f.is_tracing_stores () || f.is_tracing_loads ()) {
232+ // We need a trace id defined to pass to the loads and stores
233+ ScopedValue new_func{call_stack, op->name };
234+ Stmt _keep_alive = IRMutator::visit (op);
235+ op = _keep_alive.as <Realize>();
236+ internal_assert (op);
237+ return Realize::make (op->name , op->types , op->memory_type , op->bounds , op->condition ,
238+ LetStmt::make (op->name + " .trace_id" , 0 ,
239+ op->body ));
240+ } else {
241+ return IRMutator::visit (op);
246242 }
243+ }
247244
248- // Begin realization returns a unique token to pass to further trace calls affecting this buffer.
249- Expr call_before = builder.build ();
250-
251- builder.event = halide_trace_end_realization;
252- builder.parent_id = Variable::make (Int (32 ), op->name + " .trace_id" );
253- Expr call_after = builder.build ();
254-
255- Stmt new_body = op->body ;
256- new_body = Block::make (new_body, Evaluate::make (call_after));
257- new_body = LetStmt::make (op->name + " .trace_id" , call_before, new_body);
258- stmt = Realize::make (op->name , op->types , op->memory_type , op->bounds , op->condition , new_body);
259- // Warning: 'op' may be invalid at this point
260- } else if (f.is_tracing_stores () || f.is_tracing_loads ()) {
261- // We need a trace id defined to pass to the loads and stores
262- Stmt new_body = op->body ;
263- new_body = LetStmt::make (op->name + " .trace_id" , 0 , new_body);
264- stmt = Realize::make (op->name , op->types , op->memory_type , op->bounds , op->condition , new_body);
245+ // ReSharper disable once CppTooWideScope CppJoinDeclarationAndAssignment
246+ Stmt _keep_alive;
247+ {
248+ ScopedValue new_func{call_stack, op->name };
249+ _keep_alive = IRMutator::visit (op);
250+ op = _keep_alive.as <Realize>();
265251 }
266- return stmt;
252+ internal_assert (op);
253+
254+ add_trace_tags (op->name , f.get_trace_tags ());
255+ for (size_t i = 0 ; i < op->types .size (); i++) {
256+ touch (funcs_touched, op->name , i, op->types [i]);
257+ }
258+
259+ // Throw a tracing call before and after the realize body
260+ TraceEventBuilder builder;
261+ builder.func = op->name ;
262+ builder.parent_id = Variable::make (Int (32 ), call_stack + " .trace_id" );
263+ builder.event = halide_trace_begin_realization;
264+ for (const auto &bound : op->bounds ) {
265+ builder.coordinates .push_back (bound.min );
266+ builder.coordinates .push_back (bound.extent );
267+ }
268+
269+ // Begin realization returns a unique token to pass to further trace calls affecting this buffer.
270+ Expr call_before = builder.build ();
271+
272+ builder.event = halide_trace_end_realization;
273+ builder.parent_id = Variable::make (Int (32 ), op->name + " .trace_id" );
274+ Expr call_after = builder.build ();
275+
276+ return Realize::make (op->name , op->types , op->memory_type , op->bounds , op->condition ,
277+ LetStmt::make (op->name + " .trace_id" , call_before,
278+ Block::make (op->body , Evaluate::make (call_after))));
267279 }
268280
269281 Stmt visit (const ProducerConsumer *op) override {
270- Stmt stmt = IRMutator::visit (op);
271- op = stmt.as <ProducerConsumer>();
272- internal_assert (op);
273- map<string, Function>::const_iterator iter = env.find (op->name );
282+ auto iter = env.find (op->name );
274283 if (iter == env.end ()) {
275- return stmt ;
284+ return IRMutator::visit (op) ;
276285 }
277- Function f = iter->second ;
278- if (f.is_tracing_realizations () || trace_all_realizations) {
279- // Throw a tracing call around each pipeline event
280- TraceEventBuilder builder;
281- builder.func = op->name ;
282- builder.parent_id = Variable::make (Int (32 ), op->name + " .trace_id" );
283-
284- // Use the size of the pure step
285- const vector<string> &f_args = f.args ();
286- for (int i = 0 ; i < f.dimensions (); i++) {
287- Expr min = Variable::make (Int (32 ), f.name () + " .s0." + f_args[i] + " .min" );
288- Expr max = Variable::make (Int (32 ), f.name () + " .s0." + f_args[i] + " .max" );
289- Expr extent = (max + 1 ) - min;
290- builder.coordinates .push_back (min);
291- builder.coordinates .push_back (extent);
292- }
293286
294- builder.event = (op->is_producer ? halide_trace_produce : halide_trace_consume);
295- Expr begin_op_call = builder.build ();
287+ const Function &f = iter->second ;
296288
297- builder.event = (op->is_producer ? halide_trace_end_produce : halide_trace_end_consume);
298- Expr end_op_call = builder.build ();
289+ if (!(f.is_tracing_realizations () || trace_all_realizations)) {
290+ return IRMutator::visit (op);
291+ }
299292
300- Stmt new_body = Block::make (op->body , Evaluate::make (end_op_call));
293+ // ReSharper disable once CppTooWideScope CppJoinDeclarationAndAssignment
294+ Stmt _keep_alive;
295+ {
296+ ScopedValue new_func{call_stack, op->name };
297+ _keep_alive = IRMutator::visit (op);
298+ op = _keep_alive.as <ProducerConsumer>();
299+ internal_assert (op);
300+ }
301301
302- stmt = LetStmt::make (f.name () + " .trace_id" , begin_op_call,
303- ProducerConsumer::make (op->name , op->is_producer , new_body));
302+ // Throw a tracing call around each pipeline event
303+ TraceEventBuilder builder;
304+ builder.func = op->name ;
305+
306+ // Use the size of the pure step
307+ for (const auto &arg : f.args ()) {
308+ Expr min = Variable::make (Int (32 ), op->name + " .s0." + arg + " .min" );
309+ Expr max = Variable::make (Int (32 ), op->name + " .s0." + arg + " .max" );
310+ Expr extent = (max + 1 ) - min;
311+ builder.coordinates .push_back (min);
312+ builder.coordinates .push_back (extent);
304313 }
305- return stmt;
306- }
307- };
308314
309- class RemoveRealizeOverOutput : public IRMutator {
310- using IRMutator::visit ;
311- const vector<Function> &outputs ;
315+ builder. parent_id = Variable::make ( Int ( 32 ), call_stack + " .trace_id " );
316+ builder. event = (op-> is_producer ? halide_trace_produce : halide_trace_consume) ;
317+ Expr begin_op_call = builder. build () ;
312318
313- Stmt visit (const Realize *op) override {
314- for (const Function &f : outputs) {
315- if (op->name == f.name ()) {
316- return mutate (op->body );
317- }
318- }
319- return IRMutator::visit (op);
320- }
319+ builder.parent_id = Variable::make (Int (32 ), op->name + " .trace_id" );
320+ builder.event = (op->is_producer ? halide_trace_end_produce : halide_trace_end_consume);
321+ Expr end_op_call = builder.build ();
321322
322- public:
323- RemoveRealizeOverOutput (const vector<Function> &o)
324- : outputs(o) {
323+ return LetStmt::make (op->name + " .trace_id" , begin_op_call,
324+ ProducerConsumer::make (op->name , op->is_producer ,
325+ Block::make (
326+ op->body ,
327+ Evaluate::make (end_op_call))));
325328 }
326329};
327-
328330} // namespace
329331
330332Stmt inject_tracing (Stmt s, const string &pipeline_name, bool trace_pipeline,
@@ -351,7 +353,14 @@ Stmt inject_tracing(Stmt s, const string &pipeline_name, bool trace_pipeline,
351353 s = tracing.mutate (s);
352354
353355 // Strip off the dummy realize blocks
354- s = RemoveRealizeOverOutput (outputs).mutate (s);
356+ s = mutate_with (s, [&](auto *self, const Realize *op) {
357+ for (const Function &f : outputs) {
358+ if (op->name == f.name ()) {
359+ return self->mutate (op->body );
360+ }
361+ }
362+ return self->visit_base (op);
363+ });
355364
356365 if (!s.same_as (original) || trace_pipeline || t.has_feature (Target::TracePipeline)) {
357366 // Add pipeline start and end events
0 commit comments