@@ -158,9 +158,8 @@ class FindStridedLoads : public IRVisitor {
158158// Replace a bunch of load expressions in a stmt
159159class ReplaceStridedLoads : public IRMutator {
160160public:
161- std::map<std::pair< const Allocate *, const Load *> , Expr> replacements;
161+ std::map<const Load *, Expr> replacements;
162162 std::map<const Allocate *, int > padding;
163- Scope<const Allocate *> allocation_scope;
164163 std::map<const IRNode *, std::vector<std::pair<std::string, Expr>>> let_injections;
165164
166165 Stmt mutate (const Stmt &s) override {
@@ -187,11 +186,7 @@ class ReplaceStridedLoads : public IRMutator {
187186
188187protected:
189188 Expr visit (const Load *op) override {
190- const Allocate *alloc = nullptr ;
191- if (const Allocate *const *a_ptr = allocation_scope.find (op->name )) {
192- alloc = *a_ptr;
193- }
194- auto it = replacements.find ({alloc, op});
189+ auto it = replacements.find (op);
195190 if (it != replacements.end ()) {
196191 return mutate (it->second );
197192 } else {
@@ -200,7 +195,6 @@ class ReplaceStridedLoads : public IRMutator {
200195 }
201196
202197 Stmt visit (const Allocate *op) override {
203- ScopedBinding bind (allocation_scope, op->name , op);
204198 auto it = padding.find (op);
205199 Stmt s = IRMutator::visit (op);
206200 if (it == padding.end ()) {
@@ -281,10 +275,25 @@ bool can_hoist_shared_load(const IRNode *n, const std::string &buf, const Expr &
281275
282276} // namespace
283277
284- Stmt stage_strided_loads (const Stmt &s , const Target &target) {
278+ Stmt stage_strided_loads (const Stmt &stmt , const Target &target) {
285279 FindStridedLoads finder;
286280 ReplaceStridedLoads replacer;
287281
282+ // Make all strided loads distinct IR nodes so that we can uniquely identify
283+ // them by address. We may want to mutate the same load node in different
284+ // ways depending on the surrounding context.
285+ Stmt s = mutate_with (stmt, [&](auto *self, const Load *l) {
286+ const Ramp *r = l->index .as <Ramp>();
287+ if (l->type .is_scalar () || (r && is_const_one (r->stride ))) {
288+ // Definitely not a strided load
289+ return self->visit_base (l);
290+ } else {
291+ // Might be a strided load after simplification
292+ return Load::make (l->type , l->name , self->mutate (l->index ), l->image , l->param ,
293+ self->mutate (l->predicate ), l->alignment );
294+ }
295+ });
296+
288297 // Find related clusters of strided loads anywhere in the stmt. While this
289298 // appears to look globally, it requires expressions to match exactly, so
290299 // really it's only going to find things inside the same loops and let
@@ -293,7 +302,6 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
293302
294303 for (const auto &l : finder.found_loads ) {
295304 const FindStridedLoads::Key &k = l.first ;
296- const Allocate *alloc = k.allocation ;
297305 const std::map<int64_t , std::vector<const Load *>> &v = l.second ;
298306
299307 // Find clusters of strided loads that can share the same dense load.
@@ -352,7 +360,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
352360 Shuffle::make_slice (var, row * k.lanes , 1 , k.lanes ) :
353361 Shuffle::make_slice (var, row, k.stride , k.lanes );
354362 for (const Load *l : load->second ) {
355- replacer.replacements .emplace (std::make_pair (alloc, l) , shuf);
363+ replacer.replacements .emplace (l , shuf);
356364 }
357365 }
358366 if (transpose_shared_load) {
@@ -364,7 +372,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
364372 int row = load->first - first_offset;
365373 Expr shuf = Shuffle::make_slice (shared_load, row, k.stride , k.lanes );
366374 for (const Load *l : load->second ) {
367- replacer.replacements .emplace (std::make_pair (alloc, l) , shuf);
375+ replacer.replacements .emplace (l , shuf);
368376 }
369377 }
370378 }
@@ -374,7 +382,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
374382 // picked up in a cluster, but for whom we know it's safe to do a
375383 // dense load before their start.
376384 for (const auto &[offset, loads] : reverse_view (v)) {
377- if (replacer.replacements .count ({alloc, loads[0 ]} )) {
385+ if (replacer.replacements .count (loads[0 ])) {
378386 continue ;
379387 }
380388 int64_t delta = k.stride - 1 ;
@@ -392,14 +400,14 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
392400 dense_load = common_subexpression_elimination (dense_load);
393401 Expr shuf = Shuffle::make_slice (dense_load, delta, k.stride , k.lanes );
394402 for (const Load *l : loads) {
395- replacer.replacements .emplace (std::make_pair (alloc, l) , shuf);
403+ replacer.replacements .emplace (l , shuf);
396404 }
397405 }
398406
399407 // Look for any loads we can densify because an overlapping load occurs
400408 // in any parent scope.
401409 for (const auto &[offset, loads] : reverse_view (v)) {
402- if (replacer.replacements .count ({alloc, loads[0 ]} )) {
410+ if (replacer.replacements .count (loads[0 ])) {
403411 continue ;
404412 }
405413 int64_t min_offset = offset;
@@ -430,7 +438,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
430438 dense_load = common_subexpression_elimination (dense_load);
431439 Expr shuf = Shuffle::make_slice (dense_load, offset - final_offset, k.stride , k.lanes );
432440 for (const Load *l : loads) {
433- replacer.replacements .emplace (std::make_pair (alloc, l) , shuf);
441+ replacer.replacements .emplace (l , shuf);
434442 }
435443 }
436444
@@ -439,7 +447,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
439447 // external allocations by doing a dense load at a trimmed size. We rely
440448 // on codegen to do a good job at loading vectors of a funny size.
441449 for (const auto &[offset, loads] : v) {
442- if (replacer.replacements .count ({alloc, loads[0 ]} )) {
450+ if (replacer.replacements .count (loads[0 ])) {
443451 continue ;
444452 }
445453
@@ -463,7 +471,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
463471 dense_load = common_subexpression_elimination (dense_load);
464472 Expr shuf = Shuffle::make_slice (dense_load, offset - first_offset, k.stride , k.lanes );
465473 for (const Load *l : loads) {
466- replacer.replacements .emplace (std::make_pair (alloc, l) , shuf);
474+ replacer.replacements .emplace (l , shuf);
467475 }
468476
469477 } else if (k.lanes % 2 == 0 ) {
@@ -486,7 +494,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
486494 Expr shuf2 = Shuffle::make_slice (dense_load2, delta, k.stride , k.lanes / 2 );
487495 Expr shuf = Shuffle::make_concat ({shuf1, shuf2});
488496 for (const Load *l : loads) {
489- replacer.replacements .emplace (std::make_pair (alloc, l) , shuf);
497+ replacer.replacements .emplace (l , shuf);
490498 }
491499 }
492500 }
0 commit comments