Skip to content

Commit 0d110d2

Browse files
committed
Make distinct strided load nodes in the IR distinct in memory too
1 parent 23944a0 commit 0d110d2

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

src/StageStridedLoads.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,8 @@ class FindStridedLoads : public IRVisitor {
158158
// Replace a bunch of load expressions in a stmt
159159
class ReplaceStridedLoads : public IRMutator {
160160
public:
161-
std::map<std::pair<const Allocate *, const Load *>, Expr> replacements;
161+
std::map<const Load *, Expr> replacements;
162162
std::map<const Allocate *, int> padding;
163-
Scope<const Allocate *> allocation_scope;
164163
std::map<const IRNode *, std::vector<std::pair<std::string, Expr>>> let_injections;
165164

166165
Stmt mutate(const Stmt &s) override {
@@ -187,11 +186,7 @@ class ReplaceStridedLoads : public IRMutator {
187186

188187
protected:
189188
Expr visit(const Load *op) override {
190-
const Allocate *alloc = nullptr;
191-
if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) {
192-
alloc = *a_ptr;
193-
}
194-
auto it = replacements.find({alloc, op});
189+
auto it = replacements.find(op);
195190
if (it != replacements.end()) {
196191
return mutate(it->second);
197192
} else {
@@ -200,7 +195,6 @@ class ReplaceStridedLoads : public IRMutator {
200195
}
201196

202197
Stmt visit(const Allocate *op) override {
203-
ScopedBinding bind(allocation_scope, op->name, op);
204198
auto it = padding.find(op);
205199
Stmt s = IRMutator::visit(op);
206200
if (it == padding.end()) {
@@ -281,10 +275,25 @@ bool can_hoist_shared_load(const IRNode *n, const std::string &buf, const Expr &
281275

282276
} // namespace
283277

284-
Stmt stage_strided_loads(const Stmt &s, const Target &target) {
278+
Stmt stage_strided_loads(const Stmt &stmt, const Target &target) {
285279
FindStridedLoads finder;
286280
ReplaceStridedLoads replacer;
287281

282+
// Make all strided loads distinct IR nodes so that we can uniquely identify
283+
// them by address. We may want to mutate the same load node in different
284+
// ways depending on the surrounding context.
285+
Stmt s = mutate_with(stmt, [&](auto *self, const Load *l) {
286+
const Ramp *r = l->index.as<Ramp>();
287+
if (l->type.is_scalar() || (r && is_const_one(r->stride))) {
288+
// Definitely not a strided load
289+
return self->visit_base(l);
290+
} else {
291+
// Might be a strided load after simplification
292+
return Load::make(l->type, l->name, self->mutate(l->index), l->image, l->param,
293+
self->mutate(l->predicate), l->alignment);
294+
}
295+
});
296+
288297
// Find related clusters of strided loads anywhere in the stmt. While this
289298
// appears to look globally, it requires expressions to match exactly, so
290299
// really it's only going to find things inside the same loops and let
@@ -293,7 +302,6 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
293302

294303
for (const auto &l : finder.found_loads) {
295304
const FindStridedLoads::Key &k = l.first;
296-
const Allocate *alloc = k.allocation;
297305
const std::map<int64_t, std::vector<const Load *>> &v = l.second;
298306

299307
// Find clusters of strided loads that can share the same dense load.
@@ -352,7 +360,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
352360
Shuffle::make_slice(var, row * k.lanes, 1, k.lanes) :
353361
Shuffle::make_slice(var, row, k.stride, k.lanes);
354362
for (const Load *l : load->second) {
355-
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
363+
replacer.replacements.emplace(l, shuf);
356364
}
357365
}
358366
if (transpose_shared_load) {
@@ -364,7 +372,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
364372
int row = load->first - first_offset;
365373
Expr shuf = Shuffle::make_slice(shared_load, row, k.stride, k.lanes);
366374
for (const Load *l : load->second) {
367-
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
375+
replacer.replacements.emplace(l, shuf);
368376
}
369377
}
370378
}
@@ -374,7 +382,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
374382
// picked up in a cluster, but for whom we know it's safe to do a
375383
// dense load before their start.
376384
for (const auto &[offset, loads] : reverse_view(v)) {
377-
if (replacer.replacements.count({alloc, loads[0]})) {
385+
if (replacer.replacements.count(loads[0])) {
378386
continue;
379387
}
380388
int64_t delta = k.stride - 1;
@@ -392,14 +400,14 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
392400
dense_load = common_subexpression_elimination(dense_load);
393401
Expr shuf = Shuffle::make_slice(dense_load, delta, k.stride, k.lanes);
394402
for (const Load *l : loads) {
395-
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
403+
replacer.replacements.emplace(l, shuf);
396404
}
397405
}
398406

399407
// Look for any loads we can densify because an overlapping load occurs
400408
// in any parent scope.
401409
for (const auto &[offset, loads] : reverse_view(v)) {
402-
if (replacer.replacements.count({alloc, loads[0]})) {
410+
if (replacer.replacements.count(loads[0])) {
403411
continue;
404412
}
405413
int64_t min_offset = offset;
@@ -430,7 +438,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
430438
dense_load = common_subexpression_elimination(dense_load);
431439
Expr shuf = Shuffle::make_slice(dense_load, offset - final_offset, k.stride, k.lanes);
432440
for (const Load *l : loads) {
433-
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
441+
replacer.replacements.emplace(l, shuf);
434442
}
435443
}
436444

@@ -439,7 +447,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
439447
// external allocations by doing a dense load at a trimmed size. We rely
440448
// on codegen to do a good job at loading vectors of a funny size.
441449
for (const auto &[offset, loads] : v) {
442-
if (replacer.replacements.count({alloc, loads[0]})) {
450+
if (replacer.replacements.count(loads[0])) {
443451
continue;
444452
}
445453

@@ -463,7 +471,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
463471
dense_load = common_subexpression_elimination(dense_load);
464472
Expr shuf = Shuffle::make_slice(dense_load, offset - first_offset, k.stride, k.lanes);
465473
for (const Load *l : loads) {
466-
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
474+
replacer.replacements.emplace(l, shuf);
467475
}
468476

469477
} else if (k.lanes % 2 == 0) {
@@ -486,7 +494,7 @@ Stmt stage_strided_loads(const Stmt &s, const Target &target) {
486494
Expr shuf2 = Shuffle::make_slice(dense_load2, delta, k.stride, k.lanes / 2);
487495
Expr shuf = Shuffle::make_concat({shuf1, shuf2});
488496
for (const Load *l : loads) {
489-
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
497+
replacer.replacements.emplace(l, shuf);
490498
}
491499
}
492500
}

0 commit comments

Comments
 (0)