Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2f4f8bc
Add failing rfactor tests
alexreinking Nov 5, 2024
47cdbcc
Rewrite rfactor()
alexreinking Nov 21, 2024
530bd3f
Remove unused operators
alexreinking Nov 24, 2024
6483c12
Clean up dim/var matching helpers
alexreinking Nov 24, 2024
7ffdb66
Fix definition in PyStage.cpp
alexreinking Nov 24, 2024
3fb172f
Use and_condition_over_domain to predicate the reducing definition in…
alexreinking Nov 24, 2024
2792200
Disallow rfactor() on funcs with RVar+Var fused schedules
alexreinking Nov 24, 2024
9ccacd2
Use size_t in place of int
alexreinking Nov 24, 2024
e154fc1
Clean up uses of split_predicate()
alexreinking Nov 25, 2024
ed3c405
Clean out some excess std:: qualifications
alexreinking Nov 25, 2024
22add48
Use dim_match instead of var_name_match
alexreinking Nov 25, 2024
1355d3e
Use unordered_set directly instead of DimSet
alexreinking Nov 25, 2024
a201afe
Compute preserved rdims set earlier to drop find_rvar
alexreinking Nov 25, 2024
7a8c686
Hoist projection code into common block
alexreinking Nov 25, 2024
7819855
Use structured bindings
alexreinking Nov 25, 2024
9b7a4a2
Drop rebind() as add_let() was equivalent
alexreinking Nov 25, 2024
ce9af71
More cleaning
alexreinking Nov 25, 2024
576654d
Remove not-helpful-enough helpers
alexreinking Nov 25, 2024
fd04843
Reorganize update definitions to mirror each other
alexreinking Nov 25, 2024
a57d670
Update Func.cpp
steven-johnson Dec 3, 2024
01f4415
Update Substitute.h
steven-johnson Dec 3, 2024
4cf1d3c
trigger buildbots
steven-johnson Dec 10, 2024
4bce0a7
Simplify the preserved predicate to eliminate outermost NOT.
alexreinking Dec 10, 2024
d757298
Assert that add_let would not shadow a binding
alexreinking Dec 10, 2024
e397d7a
Improve "rvar not found" error message.
alexreinking Dec 10, 2024
70798dd
Further cleanup using the no-shadow invariant
alexreinking Dec 10, 2024
b0e9b29
MAke rfactor_validate_args a private member
alexreinking Dec 11, 2024
87e5242
Partition splits lists while validating to drop clean-up step.
alexreinking Dec 11, 2024
c4fef0c
Remove PurifyRVar split type.
alexreinking Dec 11, 2024
20e3bb7
Only process RVar splits to eliminate intermediate map cleanup
alexreinking Dec 11, 2024
724f0f7
Fix comment in ApplySplit.cpp
alexreinking Dec 11, 2024
06548a9
Factor out weaken_condition_under_domain and document logic
alexreinking Dec 11, 2024
f6a3fac
trigger buildbots
steven-johnson Dec 11, 2024
22a28d1
Remove unused var
steven-johnson Dec 12, 2024
8b87a09
Rename weaken_condition_under_domain ~> or_condition_over_domain
alexreinking Dec 13, 2024
faeede6
Adding comment on missing pure vars
alexreinking Dec 13, 2024
2a3b0df
Add more comments inside project_rdom
alexreinking Dec 13, 2024
57476b1
Add more detailed documentation to add_let.
alexreinking Dec 13, 2024
15a4dbf
Remove references to purify in comments.
alexreinking Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python_bindings/src/halide/halide_/PyStage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void define_stage(py::module &m) {
.def("dump_argument_list", &Stage::dump_argument_list)
.def("name", &Stage::name)

.def("rfactor", (Func(Stage::*)(std::vector<std::pair<RVar, Var>>)) & Stage::rfactor,
.def("rfactor", (Func(Stage::*)(const std::vector<std::pair<RVar, Var>> &)) & Stage::rfactor,
py::arg("preserved"))
.def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor,
py::arg("r"), py::arg("v"))
Expand Down
9 changes: 1 addition & 8 deletions src/ApplySplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
}
} break;
case Split::RenameVar:
case Split::PurifyRVar:
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution);
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt);
break;
Expand All @@ -167,10 +166,7 @@ vector<ApplySplitResult> apply_split(const Split &split, const string &prefix,
}

vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &split, const string &prefix) {
// Define the bounds on the split dimensions using the bounds
// on the function args. If it is a purify, we should use the bounds
// from the dims instead.

// Define the bounds on the split dimensions using the bounds on the function args.
vector<std::pair<string, Expr>> let_stmts;

Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
Expand Down Expand Up @@ -201,9 +197,6 @@ vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &spl
let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent);
break;
case Split::PurifyRVar:
// Do nothing for purify
break;
}

return let_stmts;
Expand Down
24 changes: 8 additions & 16 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <algorithm>
#include <iterator>
#include <numeric>

namespace Halide {
namespace Internal {
Expand Down Expand Up @@ -297,7 +298,6 @@ class BoundsInference : public IRMutator {
}

// Default case (no specialization)
vector<Expr> predicates = def.split_predicate();
for (const ReductionVariable &rv : def.schedule().rvars()) {
rvars.insert(rv);
}
Expand All @@ -308,23 +308,15 @@ class BoundsInference : public IRMutator {
}
vecs[1] = def.values();

vector<Expr> predicates = def.split_predicate();
for (size_t i = 0; i < result.size(); ++i) {
for (const Expr &val : vecs[i]) {
if (!predicates.empty()) {
Expr cond_val = Call::make(val.type(),
Internal::Call::if_then_else,
{likely(predicates[0]), val},
Internal::Call::PureIntrinsic);
for (size_t i = 1; i < predicates.size(); ++i) {
cond_val = Call::make(cond_val.type(),
Internal::Call::if_then_else,
{likely(predicates[i]), cond_val},
Internal::Call::PureIntrinsic);
}
result[i].emplace_back(const_true(), cond_val);
} else {
result[i].emplace_back(const_true(), val);
}
Expr cond_val = std::accumulate(
predicates.begin(), predicates.end(), val,
[](const auto &acc, const auto &pred) {
return Call::make(acc.type(), Call::if_then_else, {likely(pred), acc}, Call::PureIntrinsic);
});
result[i].emplace_back(const_true(), cond_val);
}
}

Expand Down
1 change: 0 additions & 1 deletion src/ConstantBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ ConstantInterval bounds_helper(const Expr &e,
ScopedBinding bind(scope, op->name, recurse(op->value));
return recurse(op->body);
} else if (const Call *op = e.as<Call>()) {
ConstantInterval result;
if (op->is_intrinsic(Call::abs)) {
return abs(recurse(op->args[0]));
} else if (op->is_intrinsic(Call::absd)) {
Expand Down
2 changes: 1 addition & 1 deletion src/Derivative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ void ReverseAccumulationVisitor::propagate_halide_function_call(
// f(r.x) = ... && r is associative
// => f(x) = ...
if (var != nullptr && var->reduction_domain.defined() &&
var->reduction_domain.split_predicate().empty()) {
is_const_one(var->reduction_domain.predicate())) {
ReductionDomain rdom = var->reduction_domain;
int rvar_id = -1;
for (int rid = 0; rid < (int)rdom.domain().size(); rid++) {
Expand Down
2 changes: 0 additions & 2 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ Split::SplitType Deserializer::deserialize_split_type(Serialize::SplitType split
return Split::SplitType::RenameVar;
case Serialize::SplitType::FuseVars:
return Split::SplitType::FuseVars;
case Serialize::SplitType::PurifyRVar:
return Split::SplitType::PurifyRVar;
default:
user_error << "unknown split type " << (int)split_type << "\n";
return Split::SplitType::SplitVar;
Expand Down
Loading