diff --git a/src/coreclr/jit/ifconversion.cpp b/src/coreclr/jit/ifconversion.cpp index 0acda646e0d9e6..d17b321ee836ec 100644 --- a/src/coreclr/jit/ifconversion.cpp +++ b/src/coreclr/jit/ifconversion.cpp @@ -59,6 +59,12 @@ class OptIfConversionDsc GenTree* TryTransformSelectOperOrLocal(GenTree* oper, GenTree* lcl); GenTree* TryTransformSelectOperOrZero(GenTree* oper, GenTree* lcl); + bool CanProfitablyFactorCommonCommutativeOperands(GenTree* trueInput, + GenTree* falseInput, + int maxCost, + int originalCost, + int savingsMargin); + GenTree* TryFactorCommonCommutativeOperands(GenTree* trueInput, GenTree* falseInput); GenTree* TryTransformSelectToOrdinaryOps(GenTree* trueInput, GenTree* falseInput); #ifdef DEBUG void IfConvertDump(); @@ -627,8 +633,59 @@ bool OptIfConversionDsc::optIfConvert(int* pReachabilityBudget) } #endif + // Get the select node inputs. + var_types selectType; + GenTree* selectTrueInput; + GenTree* selectFalseInput; + if (m_mainOper == GT_STORE_LCL_VAR) + { + selectFalseInput = m_thenOperation.node->AsLclVar()->Data(); + selectTrueInput = m_doElseConversion ? m_elseOperation.node->AsLclVar()->Data() : nullptr; + + // Pick the type as the type of the local, which should always be compatible even for implicit coercions. + selectType = genActualType(m_thenOperation.node); + } + else + { + assert(m_mainOper == GT_RETURN); + assert(m_doElseConversion); + assert(m_thenOperation.node->TypeGet() == m_elseOperation.node->TypeGet()); + + selectTrueInput = m_elseOperation.node->AsOp()->GetReturnValue(); + selectFalseInput = m_thenOperation.node->AsOp()->GetReturnValue(); + selectType = genActualType(m_thenOperation.node); + } + + if (!m_compiler->compStressCompile(Compiler::STRESS_IF_CONVERSION_INNER_LOOPS, 25)) + { + // Don't optimise the block if it is inside a loop. Loop-carried + // dependencies can cause significant stalls if if-converted. + // Detect via the block weight as that will be high when inside a loop. + + if (m_startBlock->getBBWeight(m_compiler) > BB_UNITY_WEIGHT * 1.05) + { + JITDUMP("Skipping if-conversion inside loop (via weight)\n"); + return false; + } + + // We may be inside an unnatural loop, so do the expensive check. + Compiler::ReachabilityResult reachability = + m_compiler->optReachableWithBudget(m_finalBlock, m_startBlock, nullptr, pReachabilityBudget); + if (reachability == Compiler::ReachabilityResult::Reachable) + { + JITDUMP("Skipping if-conversion inside loop (via reachability)\n"); + return false; + } + else if (reachability == Compiler::ReachabilityResult::BudgetExceeded) + { + JITDUMP("Skipping if-conversion since we ran out of reachability budget\n"); + return false; + } + } + // Using SELECT nodes means that both Then and Else operations are fully evaluated. // Put a limit on the original source and destinations. + constexpr int ifConversionCostLimit = 7; if (!m_compiler->compStressCompile(Compiler::STRESS_IF_CONVERSION_COST, 25)) { int thenCost = 0; @@ -655,7 +712,11 @@ bool OptIfConversionDsc::optIfConvert(int* pReachabilityBudget) } // Cost to allow for "x = cond ? a + b : c + d". - if (thenCost > 7 || elseCost > 7) + // If the original expressions are expensive, keep if-conversion only + // when we have a factorized form that can avoid duplicated work. + if ((thenCost > ifConversionCostLimit || elseCost > ifConversionCostLimit) && + !CanProfitablyFactorCommonCommutativeOperands(selectTrueInput, selectFalseInput, ifConversionCostLimit, + thenCost + elseCost, 1)) { JITDUMP("Skipping if-conversion that will evaluate RHS unconditionally at costs %d,%d\n", thenCost, elseCost); @@ -663,57 +724,14 @@ bool OptIfConversionDsc::optIfConvert(int* pReachabilityBudget) } } - if (!m_compiler->compStressCompile(Compiler::STRESS_IF_CONVERSION_INNER_LOOPS, 25)) - { - // Don't optimise the block if it is inside a loop. Loop-carried - // dependencies can cause significant stalls if if-converted. - // Detect via the block weight as that will be high when inside a loop. - - if (m_startBlock->getBBWeight(m_compiler) > BB_UNITY_WEIGHT * 1.05) - { - JITDUMP("Skipping if-conversion inside loop (via weight)\n"); - return false; - } - - // We may be inside an unnatural loop, so do the expensive check. - Compiler::ReachabilityResult reachability = - m_compiler->optReachableWithBudget(m_finalBlock, m_startBlock, nullptr, pReachabilityBudget); - if (reachability == Compiler::ReachabilityResult::Reachable) - { - JITDUMP("Skipping if-conversion inside loop (via reachability)\n"); - return false; - } - else if (reachability == Compiler::ReachabilityResult::BudgetExceeded) - { - JITDUMP("Skipping if-conversion since we ran out of reachability budget\n"); - return false; - } - } - - // Get the select node inputs. - var_types selectType; - GenTree* selectTrueInput; - GenTree* selectFalseInput; - if (m_mainOper == GT_STORE_LCL_VAR) - { - selectFalseInput = m_thenOperation.node->AsLclVar()->Data(); - selectTrueInput = m_doElseConversion ? m_elseOperation.node->AsLclVar()->Data() : nullptr; + // Build a factorized candidate only after all checks that can reject this + // conversion while still reporting "no change". + GenTree* select = TryFactorCommonCommutativeOperands(selectTrueInput, selectFalseInput); - // Pick the type as the type of the local, which should always be compatible even for implicit coercions. - selectType = genActualType(m_thenOperation.node); - } - else + if (select == nullptr) { - assert(m_mainOper == GT_RETURN); - assert(m_doElseConversion); - assert(m_thenOperation.node->TypeGet() == m_elseOperation.node->TypeGet()); - - selectTrueInput = m_elseOperation.node->AsOp()->GetReturnValue(); - selectFalseInput = m_thenOperation.node->AsOp()->GetReturnValue(); - selectType = genActualType(m_thenOperation.node); + select = TryTransformSelectToOrdinaryOps(selectTrueInput, selectFalseInput); } - - GenTree* select = TryTransformSelectToOrdinaryOps(selectTrueInput, selectFalseInput); if (select == nullptr) { #ifdef TARGET_RISCV64 @@ -844,6 +862,309 @@ static IntConstSelectOper MatchIntConstSelectValues(int64_t trueVal, int64_t fal return {GT_NONE}; } +//----------------------------------------------------------------------------- +// IsSupportedCommutativeSelectOper: return true if tree is a commutative, +// associative operation we can safely factor around a SELECT. +// +static bool IsSupportedCommutativeSelectOper(GenTree* tree) +{ + if (!varTypeIsIntegralOrI(tree)) + { + return false; + } + + switch (tree->OperGet()) + { + case GT_ADD: + return !tree->gtOverflowEx(); + case GT_AND: + case GT_OR: + case GT_XOR: + return true; + default: + return false; + } +} + +//----------------------------------------------------------------------------- +// GetCommutativeFactorOperCost: execution cost contribution of a single +// supported commutative operator node (excluding child costs). +// +static int GetCommutativeFactorOperCost(genTreeOps oper) +{ + switch (oper) + { + case GT_ADD: + case GT_AND: + case GT_OR: + case GT_XOR: + return 1; + + default: + unreached(); + return 1; + } +} + +//----------------------------------------------------------------------------- +// SumOperandCostEx: sum execution costs of all operands in a list. +// +static int SumOperandCostEx(ArrayStack* operands) +{ + int cost = 0; + for (int i = 0; i < operands->Height(); i++) + { + cost += operands->Bottom(i)->GetCostEx(); + } + + return cost; +} + +//----------------------------------------------------------------------------- +// ShouldAvoidFactoringConstAddSelect: guard for cases where +// factoring +// select(x + c1, x + c2) -> x + select(c1, c2) +// tends to pessimize codegen by forcing constant materialization. +// +static bool ShouldAvoidFactoringConstAddSelect(genTreeOps oper, + ArrayStack* trueOnlyOperands, + ArrayStack* falseOnlyOperands) +{ + if ((oper == GT_ADD) && (trueOnlyOperands->Height() == 1) && (falseOnlyOperands->Height() == 1)) + { + return trueOnlyOperands->Bottom(0)->IsIntegralConst() && falseOnlyOperands->Bottom(0)->IsIntegralConst(); + } + + return false; +} + +static void CollectAssociativeOperands(GenTree* node, genTreeOps oper, ArrayStack* operands); + +//----------------------------------------------------------------------------- +// TryAnalyzeCommonCommutativeOperands: analyze two commutative/associative +// trees and split operands into common and branch-specific subsets. +// +static bool TryAnalyzeCommonCommutativeOperands(GenTree* trueInput, + GenTree* falseInput, + CompAllocator alloc, + genTreeOps* oper, + var_types* resultType, + ArrayStack* commonOperands, + ArrayStack* trueOnlyOperands, + ArrayStack* falseOnlyOperands) +{ + if ((trueInput == nullptr) || (trueInput->OperGet() != falseInput->OperGet()) || + !IsSupportedCommutativeSelectOper(trueInput) || !IsSupportedCommutativeSelectOper(falseInput)) + { + return false; + } + + *oper = trueInput->OperGet(); + *resultType = genActualType(trueInput); + + if (*resultType != genActualType(falseInput)) + { + return false; + } + + ArrayStack trueOperands(alloc); + ArrayStack falseOperands(alloc); + + CollectAssociativeOperands(trueInput, *oper, &trueOperands); + CollectAssociativeOperands(falseInput, *oper, &falseOperands); + + if ((trueOperands.Height() < 2) || (falseOperands.Height() < 2)) + { + return false; + } + + ArrayStack falseMatched(alloc, falseOperands.Height()); + for (int i = 0; i < falseOperands.Height(); i++) + { + falseMatched.Push(false); + } + + for (int i = 0; i < trueOperands.Height(); i++) + { + GenTree* trueOperand = trueOperands.Bottom(i); + bool matched = false; + + for (int j = 0; j < falseOperands.Height(); j++) + { + if (falseMatched.Bottom(j)) + { + continue; + } + + if (GenTree::Compare(trueOperand, falseOperands.Bottom(j), true)) + { + falseMatched.BottomRef(j) = true; + commonOperands->Push(trueOperand); + matched = true; + break; + } + } + + if (!matched) + { + trueOnlyOperands->Push(trueOperand); + } + } + + for (int i = 0; i < falseOperands.Height(); i++) + { + if (!falseMatched.Bottom(i)) + { + falseOnlyOperands->Push(falseOperands.Bottom(i)); + } + } + + return !commonOperands->Empty() && !trueOnlyOperands->Empty() && !falseOnlyOperands->Empty(); +} + +//----------------------------------------------------------------------------- +// CollectAssociativeOperands: flatten an associative/commutative operator tree +// into a linear operand list. +// +static void CollectAssociativeOperands(GenTree* node, genTreeOps oper, ArrayStack* operands) +{ + if (node->OperIs(oper) && IsSupportedCommutativeSelectOper(node)) + { + CollectAssociativeOperands(node->gtGetOp1(), oper, operands); + CollectAssociativeOperands(node->gtGetOp2(), oper, operands); + return; + } + + operands->Push(node); +} + +//----------------------------------------------------------------------------- +// BuildAssociativeTree: build a left-associated binary tree from operands. +// +static GenTree* BuildAssociativeTree(Compiler* compiler, + genTreeOps oper, + var_types type, + ArrayStack* operands) +{ + assert(operands->Height() > 0); + + GenTree* tree = operands->Bottom(0); + for (int i = 1; i < operands->Height(); i++) + { + tree = compiler->gtNewOperNode(oper, type, tree, operands->Bottom(i)); + } + + return tree; +} + +//----------------------------------------------------------------------------- +// CanProfitablyFactorCommonCommutativeOperands: estimate whether a factorized +// form is cheap enough for if-conversion profitability. +// +bool OptIfConversionDsc::CanProfitablyFactorCommonCommutativeOperands(GenTree* trueInput, + GenTree* falseInput, + int maxCost, + int originalCost, + int savingsMargin) +{ + CompAllocator alloc = m_compiler->getAllocator(CMK_Unknown); + + genTreeOps oper = GT_NONE; + var_types resultType = TYP_UNDEF; + ArrayStack commonOperands(alloc); + ArrayStack trueOnlyOperands(alloc); + ArrayStack falseOnlyOperands(alloc); + + if (!TryAnalyzeCommonCommutativeOperands(trueInput, falseInput, alloc, &oper, &resultType, &commonOperands, + &trueOnlyOperands, &falseOnlyOperands)) + { + return false; + } + + if (ShouldAvoidFactoringConstAddSelect(oper, &trueOnlyOperands, &falseOnlyOperands)) + { + return false; + } + + assert(varTypeIsIntegralOrI(resultType)); + + const int operCost = GetCommutativeFactorOperCost(oper); + const int trueOnlyCost = SumOperandCostEx(&trueOnlyOperands); + const int falseOnlyCost = SumOperandCostEx(&falseOnlyOperands); + const int commonCost = SumOperandCostEx(&commonOperands); + const int trueInnerOpCount = trueOnlyOperands.Height() - 1; + const int falseInnerOpCount = falseOnlyOperands.Height() - 1; + const int commonOpCount = commonOperands.Height(); + + // select(diffTrue, diffFalse) + common terms + // GT_SELECT cost model: cond + op1 + op2 + 1 + int estimatedCost = m_cond->GetCostEx() + trueOnlyCost + falseOnlyCost + commonCost + 1 + + ((trueInnerOpCount + falseInnerOpCount + commonOpCount) * operCost); + + // Fast path: keep existing cheap-form threshold. + if (estimatedCost <= maxCost) + { + return true; + } + + // Relative profitability fallback: allow factorization when it is materially + // cheaper than evaluating both original branches unconditionally. + return (estimatedCost + savingsMargin) < originalCost; +} + +//----------------------------------------------------------------------------- +// TryFactorCommonCommutativeOperands: try factoring common associative and +// commutative operands out of a select expression. +// +// Arguments: +// trueInput - expression for the true branch +// falseInput - expression for the false branch +// +// Return Value: +// Transformed expression if factorization succeeded, otherwise nullptr. +// +GenTree* OptIfConversionDsc::TryFactorCommonCommutativeOperands(GenTree* trueInput, GenTree* falseInput) +{ + CompAllocator alloc = m_compiler->getAllocator(CMK_Unknown); + genTreeOps oper = GT_NONE; + var_types resultType = TYP_UNDEF; + ArrayStack commonOperands(alloc); + ArrayStack trueOnlyOperands(alloc); + ArrayStack falseOnlyOperands(alloc); + + if (!TryAnalyzeCommonCommutativeOperands(trueInput, falseInput, alloc, &oper, &resultType, &commonOperands, + &trueOnlyOperands, &falseOnlyOperands)) + { + return nullptr; + } + + if (ShouldAvoidFactoringConstAddSelect(oper, &trueOnlyOperands, &falseOnlyOperands)) + { + return nullptr; + } + + GenTree* trueDiffExpr = BuildAssociativeTree(m_compiler, oper, resultType, &trueOnlyOperands); + GenTree* falseDiffExpr = BuildAssociativeTree(m_compiler, oper, resultType, &falseOnlyOperands); + + if (genActualType(trueDiffExpr) != genActualType(falseDiffExpr)) + { + return nullptr; + } + + // Re-associate: + // select(op(common, t), op(common, f)) => op(common, select(t, f)) + // where "op" is associative+commutative. + GenTree* result = m_compiler->gtNewConditionalNode(GT_SELECT, m_cond, trueDiffExpr, falseDiffExpr, + genActualType(trueDiffExpr)); + + for (int i = 0; i < commonOperands.Height(); i++) + { + result = m_compiler->gtNewOperNode(oper, resultType, result, commonOperands.Bottom(i)); + } + + return result; +} + //----------------------------------------------------------------------------- // TryTransformSelectOperOrLocal: Try to trasform "cond ? oper(lcl, (-)1) : lcl" into "oper(')(lcl, cond)" // @@ -930,8 +1251,9 @@ GenTree* OptIfConversionDsc::TryTransformSelectOperOrZero(GenTree* trueInput, Ge //----------------------------------------------------------------------------- // TryTransformSelectToOrdinaryOps: Try transforming the identified if-else expressions to a single expression // -// This is meant mostly for RISC-V where the condition (1 or 0) is stored in a regular general-purpose register -// which can be fed as an argument to standard operations, e.g. +// This handles select simplifications and target-specific patterns. +// It supports RISC-V oriented transforms where the condition (1 or 0) is fed as +// an argument to standard operations, e.g. // * (cond ? 6 : 5) becomes (5 + cond) // * (cond ? -25 : -13) becomes (-25 >> cond) // * if (cond) a++; becomes (a + cond)