[SelectionDAG] Add expansion for llvm.convert.to.arbitrary.fp#193595
[SelectionDAG] Add expansion for llvm.convert.to.arbitrary.fp#193595
Conversation
The expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm: 1. Bitcast the source float to an integer and extract sign, exponent, and mantissa bit fields via masks and shifts. 2. Classify the input (zero/denormal/normal/Inf/NaN). 3. Normalize source denormals by finding the MSB position of the mantissa and adjusting the effective exponent. 4. Normal path: adjust the exponent bias from source to destination format and truncate the mantissa with rounding (supports NearestTiesToEven, TowardZero, TowardPositive, TowardNegative, NearestTiesToAway). 5. Denormal destination path: when the biased destination exponent is <= 0, shift the mantissa right to produce a denormalized result with rounding. 6. Handle mantissa overflow from rounding and exponent overflow. Produce Inf or saturate to max finite, depending on format and saturation flag. 7. Build special-value results (canonical qNaN, signed Inf, signed zero) adapted to the destination format's non-finite behavior (IEEE754, NanOnly, FiniteOnly). 8. Final selection in priority order: NaN > Inf > Zero > Overflow > Normal/Denorm. Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN. OCP spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf E2E testing on X86 done with an assistance of Claude Code Opus 4.6.
|
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-backend-amdgpu Author: Dmitry Sidorov (MrSidims) ChangesThe expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm:
Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN. OCP spec: E2E testing on X86 done with an assistance of Claude Code Opus 4.6. Patch is 172.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/193595.diff 14 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 8a8a9ee71ca02..de9835ee43836 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1020,6 +1020,15 @@ enum NodeType {
/// The second operand is a constant indicating the source FP semantics.
CONVERT_FROM_ARBITRARY_FP,
+ /// CONVERT_TO_ARBITRARY_FP - Converts a native FP value to an arbitrary
+ /// floating-point format, returning the result as an integer.
+ /// The first operand is the source value.
+ /// The second operand is a constant indicating the destination FP semantics.
+ /// The third operand is a constant indication the rounding mode.
+ /// The last operand is a boolean consant indicating whether the result has
+ /// to be saturated.
+ CONVERT_TO_ARBITRARY_FP,
+
/// Perform various unary floating-point operations inspired by libm. For
/// FPOWI, the result is undefined if the integer operand doesn't fit into
/// sizeof(int).
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 54d86dfbfa303..7dcc3a1f1c753 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3782,6 +3782,479 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Result);
break;
}
+ case ISD::CONVERT_TO_ARBITRARY_FP: {
+ // Expand conversion from a native IEEE float type to an arbitrary FP
+ // format, returning the result as an integer using bit manipulation.
+ //
+ // TODO: currently only conversions to FP4, FP6 and FP8 formats from OCP
+ // specification are expanded. Remaining arbitrary FP types: Float8E4M3,
+ // Float8E3M4, Float8E5M2FNUZ, Float8E4M3FNUZ, Float8E4M3B11FNUZ,
+ // Float8E8M0FNU.
+ EVT ResVT = Node->getValueType(0);
+
+ SDValue FloatVal = Node->getOperand(0);
+ const uint64_t SemEnum = Node->getConstantOperandVal(1);
+ const auto Sem = static_cast<APFloatBase::Semantics>(SemEnum);
+ const auto RoundMode =
+ static_cast<RoundingMode>(Node->getConstantOperandVal(2));
+ const bool Saturate = Node->getConstantOperandVal(3) != 0;
+
+ // Supported destination formats.
+ switch (Sem) {
+ case APFloatBase::S_Float8E5M2:
+ case APFloatBase::S_Float8E4M3FN:
+ case APFloatBase::S_Float6E3M2FN:
+ case APFloatBase::S_Float6E2M3FN:
+ case APFloatBase::S_Float4E2M1FN:
+ break;
+ default:
+ DAG.getContext()->emitError("CONVERT_TO_ARBITRARY_FP: not implemented "
+ "destination format (semantics enum " +
+ Twine(SemEnum) + ")");
+ Results.push_back(DAG.getPOISON(ResVT));
+ break;
+ }
+ if (!Results.empty())
+ break;
+
+ // Destination format parameters.
+ const fltSemantics &DstSem = APFloatBase::EnumToSemantics(Sem);
+ const unsigned DstBits = APFloat::getSizeInBits(DstSem);
+ const unsigned DstPrecision = APFloat::semanticsPrecision(DstSem);
+ const unsigned DstMant = DstPrecision - 1;
+ const unsigned DstExpBits = DstBits - DstMant - 1;
+ const int DstBias = 1 - APFloat::semanticsMinExponent(DstSem);
+ const unsigned DstExpMax = (1U << DstExpBits) - 1;
+ const uint64_t DstMantMask = (DstMant > 0) ? ((1ULL << DstMant) - 1) : 0;
+ const fltNonfiniteBehavior DstNFBehavior = DstSem.nonFiniteBehavior;
+ const fltNanEncoding DstNanEnc = DstSem.nanEncoding;
+
+ // Compute the maximum normal exponent for the destination format.
+ unsigned DstExpMaxNormal;
+ if (DstNFBehavior == fltNonfiniteBehavior::IEEE754)
+ DstExpMaxNormal = DstExpMax - 1;
+ else
+ DstExpMaxNormal = DstExpMax;
+
+ // For NanOnly formats the max exponent field for finite values
+ // is DstExpMax, but the encoding with exp = DstExpMax and
+ // mant = all-ones is NaN. So DstExpMaxNormal = DstExpMax, but max
+ // mantissa at that exponent is DstMantMask - 1 (if NanEnc == AllOnes) to
+ // avoid the NaN encoding.
+ uint64_t DstMaxMantAtMaxExp = DstMantMask;
+ if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+ DstNanEnc == fltNanEncoding::AllOnes)
+ DstMaxMantAtMaxExp = DstMantMask - 1;
+
+ // Source format parameters.
+ EVT SrcVT = FloatVal.getValueType();
+ const fltSemantics &SrcSem = SrcVT.getFltSemantics();
+ const unsigned SrcBits = APFloat::getSizeInBits(SrcSem);
+ const unsigned SrcPrecision = APFloat::semanticsPrecision(SrcSem);
+ const unsigned SrcMant = SrcPrecision - 1;
+ const unsigned SrcExpBits = SrcBits - SrcMant - 1;
+ const int SrcBias = 1 - APFloat::semanticsMinExponent(SrcSem);
+ const uint64_t SrcMantMask = (1ULL << SrcMant) - 1;
+ const uint64_t SrcExpMask = (1ULL << SrcExpBits) - 1;
+
+ // Work in the source integer type.
+ EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcBits);
+ EVT SetCCVT = getSetCCResultType(IntVT);
+
+ SDValue Zero = DAG.getConstant(0, dl, IntVT);
+ SDValue One = DAG.getConstant(1, dl, IntVT);
+
+ // Bitcast source float to integer and extract bit fields.
+ SDValue Src = DAG.getNode(ISD::BITCAST, dl, IntVT, FloatVal);
+ SDValue SrcMantField = DAG.getNode(ISD::AND, dl, IntVT, Src,
+ DAG.getConstant(SrcMantMask, dl, IntVT));
+
+ SDValue SrcExpField = DAG.getNode(
+ ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, Src,
+ DAG.getShiftAmountConstant(SrcMant, IntVT, dl)),
+ DAG.getConstant(SrcExpMask, dl, IntVT));
+
+ SDValue SignBit =
+ DAG.getNode(ISD::SRL, dl, IntVT, Src,
+ DAG.getShiftAmountConstant(SrcBits - 1, IntVT, dl));
+
+ // Classify the input value.
+ SDValue SrcExpAllOnes = DAG.getConstant(SrcExpMask, dl, IntVT);
+ SDValue IsExpAllOnes =
+ DAG.getSetCC(dl, SetCCVT, SrcExpField, SrcExpAllOnes, ISD::SETEQ);
+ SDValue IsExpZero =
+ DAG.getSetCC(dl, SetCCVT, SrcExpField, Zero, ISD::SETEQ);
+ SDValue IsMantZero =
+ DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETEQ);
+ SDValue IsMantNonZero =
+ DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETNE);
+
+ // If source is IEEE fp, tehn NaN = exp_all_ones && mant != 0.
+ SDValue IsNaN =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantNonZero);
+ // Inf = exp_all_ones && mant == 0.
+ SDValue IsInf =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantZero);
+ // Zero = exp == 0 && mant == 0.
+ SDValue IsZero =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantZero);
+ // Source denorm = exp == 0 && mant != 0.
+ SDValue IsSrcDenorm =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantNonZero);
+
+ // Source denormal normalization.
+ // For a source denormal, the true exponent is (1 - SrcBias) and the
+ // mantissa has no implicit leading 1. Normalize by finding the position
+ // of the leading 1 in the mantissa.
+ SDValue LeadingZeros =
+ DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, IntVT, SrcMantField);
+
+ // normShift = LeadingZeros - (SrcBits - 1 - SrcMant).
+ const unsigned LZOffset = SrcBits - 1 - SrcMant;
+ SDValue NormShift = DAG.getNode(ISD::SUB, dl, IntVT, LeadingZeros,
+ DAG.getConstant(LZOffset, dl, IntVT));
+
+ // Normalized mantissa.
+ SDValue NormMant =
+ DAG.getNode(ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SHL, dl, IntVT, SrcMantField, NormShift),
+ DAG.getConstant(SrcMantMask, dl, IntVT));
+
+ // effective_exp = 1 - NormShift.
+ SDValue DenormSrcExp =
+ DAG.getNode(ISD::SUB, dl, IntVT, One, NormShift);
+
+ // Select between normal and denorm source.
+ SDValue EffSrcExp =
+ DAG.getSelect(dl, IntVT, IsSrcDenorm, DenormSrcExp, SrcExpField);
+ SDValue EffSrcMant =
+ DAG.getSelect(dl, IntVT, IsSrcDenorm, NormMant, SrcMantField);
+
+ // Compute new biased exponent for destination.
+ // new_exp = src_exp - SrcBias + DstBias
+ const int BiasAdjust = DstBias - SrcBias;
+ SDValue NewExp = DAG.getNode(
+ ISD::ADD, dl, IntVT, EffSrcExp,
+ DAG.getConstant(APInt(SrcBits, BiasAdjust, true), dl, IntVT));
+
+ // Compute rounding increment given the round bit, sticky bits, and LSB
+ // of the truncated mantissa.
+ auto ComputeRoundUp = [&](SDValue RoundBit, SDValue StickyBits,
+ SDValue LSB) -> SDValue {
+ if (RoundMode == RoundingMode::NearestTiesToEven) {
+ // Round up if round_bit && (sticky || lsb)
+ SDValue StickyOrLSB = DAG.getNode(ISD::OR, dl, IntVT, StickyBits, LSB);
+ return DAG.getNode(ISD::AND, dl, IntVT, RoundBit, StickyOrLSB);
+ }
+ if (RoundMode == RoundingMode::TowardZero)
+ return Zero;
+ if (RoundMode == RoundingMode::TowardPositive) {
+ // Round up if positive and any truncated bits are set.
+ SDValue AnyTruncBits =
+ DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+ SDValue HasTruncBits =
+ DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+ SDValue IsPositive =
+ DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETEQ);
+ SDValue DoRound =
+ DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsPositive);
+ return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+ }
+ if (RoundMode == RoundingMode::TowardNegative) {
+ // Round up if negative and any truncated bits are set (to -Inf).
+ SDValue AnyTruncBits =
+ DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+ SDValue HasTruncBits =
+ DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+ SDValue IsNegative =
+ DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETNE);
+ SDValue DoRound =
+ DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsNegative);
+ return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+ }
+ if (RoundMode == RoundingMode::NearestTiesToAway)
+ return RoundBit;
+ llvm_unreachable("Unsupported rounding mode");
+ };
+
+ // Round mantissa from SrcMant bits to DstMant bits.
+ SDValue TruncMant;
+ SDValue RoundUp;
+ if (SrcMant > DstMant) {
+ const unsigned Shift = SrcMant - DstMant;
+ SDValue ShiftConst = DAG.getShiftAmountConstant(Shift, IntVT, dl);
+ TruncMant =
+ DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant, ShiftConst);
+
+ // Check bit at position Shift - 1 aka the round bit.
+ SDValue RoundBit;
+ if (Shift >= 1) {
+ RoundBit = DAG.getNode(
+ ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant,
+ DAG.getShiftAmountConstant(Shift - 1, IntVT, dl)),
+ One);
+ } else {
+ RoundBit = Zero;
+ }
+
+ // OR of all bits below the round bit to get sticky bits.
+ SDValue StickyBits;
+ if (Shift >= 2) {
+ uint64_t StickyMask = (1ULL << (Shift - 1)) - 1;
+ StickyBits = DAG.getNode(ISD::AND, dl, IntVT, EffSrcMant,
+ DAG.getConstant(StickyMask, dl, IntVT));
+ StickyBits = DAG.getSetCC(dl, SetCCVT, StickyBits, Zero, ISD::SETNE);
+ StickyBits =
+ DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, StickyBits);
+ } else {
+ StickyBits = Zero;
+ }
+
+ // LSB of truncated mantissa.
+ SDValue LSB = DAG.getNode(ISD::AND, dl, IntVT, TruncMant, One);
+
+ RoundUp = ComputeRoundUp(RoundBit, StickyBits, LSB);
+ } else {
+ // If DstMant >= SrcMant, then no rounding needed, just shift left.
+ SDValue MantShift =
+ DAG.getShiftAmountConstant(DstMant - SrcMant, IntVT, dl);
+ TruncMant = DAG.getNode(ISD::SHL, dl, IntVT, EffSrcMant, MantShift);
+ RoundUp = Zero;
+ }
+
+ // Apply rounding.
+ SDValue RoundedMant =
+ DAG.getNode(ISD::ADD, dl, IntVT, TruncMant, RoundUp);
+
+ // Handle mantissa overflow from rounding.
+ // If rounded_mant > DstMantMask, carry into exponent.
+ SDValue MantOverflow = DAG.getSetCC(
+ dl, SetCCVT, RoundedMant,
+ DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+ // On overflow: mant = 0, exp += 1.
+ SDValue AdjMant =
+ DAG.getSelect(dl, IntVT, MantOverflow, Zero, RoundedMant);
+ SDValue AdjExp = DAG.getNode(
+ ISD::ADD, dl, IntVT, NewExp,
+ DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, MantOverflow));
+
+ // Precompute sign shifted to MSB of destination.
+ SDValue SignShifted = DAG.getNode(
+ ISD::SHL, dl, IntVT, SignBit,
+ DAG.getShiftAmountConstant(DstBits - 1, IntVT, dl));
+
+ // Destination denormal conversion (when new_exp <= 0).
+ // Shift the mantissa right by 1 - new_exp additional bits and set the
+ // exponent field to 0.
+ SDValue ExpIsNeg =
+ DAG.getSetCC(dl, SetCCVT, AdjExp,
+ DAG.getConstant(1, dl, IntVT), ISD::SETLT);
+
+ SDValue DenormResult;
+ {
+ // denorm_shift = 1 - NewExp.
+ SDValue DenormShift =
+ DAG.getNode(ISD::SUB, dl, IntVT, One, NewExp);
+
+ // full_src_mant = (1 << SrcMant) | EffSrcMant.
+ SDValue ImplicitOne = DAG.getNode(
+ ISD::SHL, dl, IntVT, One,
+ DAG.getShiftAmountConstant(SrcMant, IntVT, dl));
+ SDValue FullSrcMant =
+ DAG.getNode(ISD::OR, dl, IntVT, EffSrcMant, ImplicitOne);
+
+ // Total right shift = (SrcMant - DstMant) + DenormShift
+ SDValue TotalShift;
+ if (SrcMant >= DstMant) {
+ TotalShift =
+ DAG.getNode(ISD::ADD, dl, IntVT, DenormShift,
+ DAG.getConstant(SrcMant - DstMant, dl, IntVT));
+ } else {
+ TotalShift =
+ DAG.getNode(ISD::SUB, dl, IntVT, DenormShift,
+ DAG.getConstant(DstMant - SrcMant, dl, IntVT));
+ }
+
+ // Clamp total shift to avoid UB, then trancate denorm mantissa.
+ SDValue MaxShift = DAG.getConstant(SrcBits - 1, dl, IntVT);
+ SDValue ClampedShift = DAG.getNode(ISD::UMIN, dl, IntVT, TotalShift,
+ MaxShift);
+ SDValue DenormTruncMant =
+ DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, ClampedShift);
+
+ // Rounding for denorm path.
+ SDValue DenormRoundUp;
+ {
+ // Round bit is at position TotalShift - 1 of FullSrcMant.
+ // Clamp to at least 1 so the subtraction doesn't underflow and create
+ // shift nodes with invalid shift amounts.
+ SDValue SafeShift =
+ DAG.getNode(ISD::UMAX, dl, IntVT, ClampedShift, One);
+ SDValue RoundBitPos =
+ DAG.getNode(ISD::SUB, dl, IntVT, SafeShift, One);
+ SDValue DenormRoundBit = DAG.getNode(
+ ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, RoundBitPos), One);
+
+ // Sticky: all bits below round bit.
+ // sticky_mask = (1 << RoundBitPos) - 1
+ SDValue StickyMask = DAG.getNode(
+ ISD::SUB, dl, IntVT,
+ DAG.getNode(ISD::SHL, dl, IntVT, One, RoundBitPos), One);
+ SDValue DenormStickyBits =
+ DAG.getNode(ISD::AND, dl, IntVT, FullSrcMant, StickyMask);
+ SDValue HasSticky =
+ DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT,
+ DAG.getSetCC(dl, SetCCVT, DenormStickyBits, Zero,
+ ISD::SETNE));
+
+ SDValue DenormLSB =
+ DAG.getNode(ISD::AND, dl, IntVT, DenormTruncMant, One);
+
+ DenormRoundUp =
+ ComputeRoundUp(DenormRoundBit, HasSticky, DenormLSB);
+
+ // Only apply rounding if TotalShift >= 1 (i.e., there are bits to
+ // round).
+ SDValue ShiftGEOne =
+ DAG.getSetCC(dl, SetCCVT, ClampedShift, One, ISD::SETUGE);
+ DenormRoundUp = DAG.getSelect(dl, IntVT, ShiftGEOne, DenormRoundUp,
+ Zero);
+ }
+
+ SDValue DenormRoundedMant =
+ DAG.getNode(ISD::ADD, dl, IntVT, DenormTruncMant, DenormRoundUp);
+
+ // If rounding caused overflow into the normal range, then we get the
+ // smallest normal number.
+ SDValue DenormMantOF = DAG.getSetCC(
+ dl, SetCCVT, DenormRoundedMant,
+ DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+ SDValue DenormFinalMant = DAG.getSelect(
+ dl, IntVT, DenormMantOF, Zero, DenormRoundedMant);
+ SDValue DenormFinalExp = DAG.getSelect(
+ dl, IntVT, DenormMantOF, One, Zero);
+
+ // Assemble: sign | (exp << DstMant) | mant
+ SDValue DenormExpShifted = DAG.getNode(
+ ISD::SHL, dl, IntVT, DenormFinalExp,
+ DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+ DenormResult = DAG.getNode(
+ ISD::OR, dl, IntVT,
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted, DenormExpShifted),
+ DenormFinalMant);
+
+ // If the value is to small for even a denorm (all mantissa bits
+ // shifted away), handle based on rounding mode.
+ // This is covered by the DenormRoundedMant = 0 case naturally.
+ }
+
+ // Exponent overflow detection.
+ SDValue ExpOF = DAG.getSetCC(
+ dl, SetCCVT, AdjExp,
+ DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETGT);
+
+ // Also check if AdjExp == DstExpMaxNormal and mantissa overflow into
+ // a value that exceeds the max allowed mantissa at that exponent.
+ SDValue ExpAtMax = DAG.getSetCC(
+ dl, SetCCVT, AdjExp,
+ DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETEQ);
+ SDValue MantExceedsMax = DAG.getSetCC(
+ dl, SetCCVT, AdjMant,
+ DAG.getConstant(DstMaxMantAtMaxExp, dl, IntVT), ISD::SETGT);
+ SDValue ExpMantOF =
+ DAG.getNode(ISD::AND, dl, SetCCVT, ExpAtMax, MantExceedsMax);
+ SDValue IsOverflow =
+ DAG.getNode(ISD::OR, dl, SetCCVT, ExpOF, ExpMantOF);
+
+ // Build overflow result.
+ SDValue OverflowResult;
+
+ if (Saturate) {
+ // Clamp to max finite value:
+ // sign | (DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp
+ uint64_t MaxFinite =
+ ((uint64_t)DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp;
+ OverflowResult =
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(MaxFinite, dl, IntVT));
+ } else if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // Produce infinity.
+ uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+ OverflowResult =
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(InfBits, dl, IntVT));
+ } else {
+ // Emit poison if no Inf in format and not saturating.
+ OverflowResult = DAG.getPOISON(IntVT);
+ }
+
+ // Assemble normal result: sign | (AdjExp << DstMant) | AdjMant
+ SDValue NormExpShifted = DAG.getNode(
+ ISD::SHL, dl, IntVT, AdjExp,
+ DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+ SDValue NormResult = DAG.getNode(
+ ISD::OR, dl, IntVT,
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted, NormExpShifted), AdjMant);
+
+ // Build special-value results.
+ SDValue NaNResult;
+ if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // Produce canonical NaN.
+ const uint64_t QNaNBit = (DstMant > 0) ? (1ULL << (DstMant - 1)) : 0;
+ NaNResult =
+ DAG.getConstant(((uint64_t)DstExpMax << DstMant) | QNaNBit, dl,
+ IntVT);
+ } else if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+ DstNanEnc == fltNanEncoding::AllOnes) {
+ // E4M3FN-style: NaN is exp=all-ones, mant=all-ones.
+ NaNResult =
+ DAG.getConstant(((uint64_t)DstExpMax << DstMant) | DstMantMask, dl,
+ IntVT);
+ } else {
+ // NaN -> poison for finite only values.
+ NaNResult = DAG.getPOISON(IntVT);
+ }
+
+ // Inf handling.
+ SDValue InfResult;
+ if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // Produce signed infinity.
+ uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+ InfResult = DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(InfBits, dl, IntVT));
+ } else if (Saturate) {
+ // Inf saturate...
[truncated]
|
|
@llvm/pr-subscribers-llvm-selectiondag Author: Dmitry Sidorov (MrSidims) ChangesThe expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm:
Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN. OCP spec: E2E testing on X86 done with an assistance of Claude Code Opus 4.6. Patch is 172.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/193595.diff 14 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 8a8a9ee71ca02..de9835ee43836 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1020,6 +1020,15 @@ enum NodeType {
/// The second operand is a constant indicating the source FP semantics.
CONVERT_FROM_ARBITRARY_FP,
+ /// CONVERT_TO_ARBITRARY_FP - Converts a native FP value to an arbitrary
+ /// floating-point format, returning the result as an integer.
+ /// The first operand is the source value.
+ /// The second operand is a constant indicating the destination FP semantics.
+ /// The third operand is a constant indication the rounding mode.
+ /// The last operand is a boolean consant indicating whether the result has
+ /// to be saturated.
+ CONVERT_TO_ARBITRARY_FP,
+
/// Perform various unary floating-point operations inspired by libm. For
/// FPOWI, the result is undefined if the integer operand doesn't fit into
/// sizeof(int).
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 54d86dfbfa303..7dcc3a1f1c753 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3782,6 +3782,479 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Result);
break;
}
+ case ISD::CONVERT_TO_ARBITRARY_FP: {
+ // Expand conversion from a native IEEE float type to an arbitrary FP
+ // format, returning the result as an integer using bit manipulation.
+ //
+ // TODO: currently only conversions to FP4, FP6 and FP8 formats from OCP
+ // specification are expanded. Remaining arbitrary FP types: Float8E4M3,
+ // Float8E3M4, Float8E5M2FNUZ, Float8E4M3FNUZ, Float8E4M3B11FNUZ,
+ // Float8E8M0FNU.
+ EVT ResVT = Node->getValueType(0);
+
+ SDValue FloatVal = Node->getOperand(0);
+ const uint64_t SemEnum = Node->getConstantOperandVal(1);
+ const auto Sem = static_cast<APFloatBase::Semantics>(SemEnum);
+ const auto RoundMode =
+ static_cast<RoundingMode>(Node->getConstantOperandVal(2));
+ const bool Saturate = Node->getConstantOperandVal(3) != 0;
+
+ // Supported destination formats.
+ switch (Sem) {
+ case APFloatBase::S_Float8E5M2:
+ case APFloatBase::S_Float8E4M3FN:
+ case APFloatBase::S_Float6E3M2FN:
+ case APFloatBase::S_Float6E2M3FN:
+ case APFloatBase::S_Float4E2M1FN:
+ break;
+ default:
+ DAG.getContext()->emitError("CONVERT_TO_ARBITRARY_FP: not implemented "
+ "destination format (semantics enum " +
+ Twine(SemEnum) + ")");
+ Results.push_back(DAG.getPOISON(ResVT));
+ break;
+ }
+ if (!Results.empty())
+ break;
+
+ // Destination format parameters.
+ const fltSemantics &DstSem = APFloatBase::EnumToSemantics(Sem);
+ const unsigned DstBits = APFloat::getSizeInBits(DstSem);
+ const unsigned DstPrecision = APFloat::semanticsPrecision(DstSem);
+ const unsigned DstMant = DstPrecision - 1;
+ const unsigned DstExpBits = DstBits - DstMant - 1;
+ const int DstBias = 1 - APFloat::semanticsMinExponent(DstSem);
+ const unsigned DstExpMax = (1U << DstExpBits) - 1;
+ const uint64_t DstMantMask = (DstMant > 0) ? ((1ULL << DstMant) - 1) : 0;
+ const fltNonfiniteBehavior DstNFBehavior = DstSem.nonFiniteBehavior;
+ const fltNanEncoding DstNanEnc = DstSem.nanEncoding;
+
+ // Compute the maximum normal exponent for the destination format.
+ unsigned DstExpMaxNormal;
+ if (DstNFBehavior == fltNonfiniteBehavior::IEEE754)
+ DstExpMaxNormal = DstExpMax - 1;
+ else
+ DstExpMaxNormal = DstExpMax;
+
+ // For NanOnly formats the max exponent field for finite values
+ // is DstExpMax, but the encoding with exp = DstExpMax and
+ // mant = all-ones is NaN. So DstExpMaxNormal = DstExpMax, but max
+ // mantissa at that exponent is DstMantMask - 1 (if NanEnc == AllOnes) to
+ // avoid the NaN encoding.
+ uint64_t DstMaxMantAtMaxExp = DstMantMask;
+ if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+ DstNanEnc == fltNanEncoding::AllOnes)
+ DstMaxMantAtMaxExp = DstMantMask - 1;
+
+ // Source format parameters.
+ EVT SrcVT = FloatVal.getValueType();
+ const fltSemantics &SrcSem = SrcVT.getFltSemantics();
+ const unsigned SrcBits = APFloat::getSizeInBits(SrcSem);
+ const unsigned SrcPrecision = APFloat::semanticsPrecision(SrcSem);
+ const unsigned SrcMant = SrcPrecision - 1;
+ const unsigned SrcExpBits = SrcBits - SrcMant - 1;
+ const int SrcBias = 1 - APFloat::semanticsMinExponent(SrcSem);
+ const uint64_t SrcMantMask = (1ULL << SrcMant) - 1;
+ const uint64_t SrcExpMask = (1ULL << SrcExpBits) - 1;
+
+ // Work in the source integer type.
+ EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcBits);
+ EVT SetCCVT = getSetCCResultType(IntVT);
+
+ SDValue Zero = DAG.getConstant(0, dl, IntVT);
+ SDValue One = DAG.getConstant(1, dl, IntVT);
+
+ // Bitcast source float to integer and extract bit fields.
+ SDValue Src = DAG.getNode(ISD::BITCAST, dl, IntVT, FloatVal);
+ SDValue SrcMantField = DAG.getNode(ISD::AND, dl, IntVT, Src,
+ DAG.getConstant(SrcMantMask, dl, IntVT));
+
+ SDValue SrcExpField = DAG.getNode(
+ ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, Src,
+ DAG.getShiftAmountConstant(SrcMant, IntVT, dl)),
+ DAG.getConstant(SrcExpMask, dl, IntVT));
+
+ SDValue SignBit =
+ DAG.getNode(ISD::SRL, dl, IntVT, Src,
+ DAG.getShiftAmountConstant(SrcBits - 1, IntVT, dl));
+
+ // Classify the input value.
+ SDValue SrcExpAllOnes = DAG.getConstant(SrcExpMask, dl, IntVT);
+ SDValue IsExpAllOnes =
+ DAG.getSetCC(dl, SetCCVT, SrcExpField, SrcExpAllOnes, ISD::SETEQ);
+ SDValue IsExpZero =
+ DAG.getSetCC(dl, SetCCVT, SrcExpField, Zero, ISD::SETEQ);
+ SDValue IsMantZero =
+ DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETEQ);
+ SDValue IsMantNonZero =
+ DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETNE);
+
+ // If source is IEEE fp, tehn NaN = exp_all_ones && mant != 0.
+ SDValue IsNaN =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantNonZero);
+ // Inf = exp_all_ones && mant == 0.
+ SDValue IsInf =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantZero);
+ // Zero = exp == 0 && mant == 0.
+ SDValue IsZero =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantZero);
+ // Source denorm = exp == 0 && mant != 0.
+ SDValue IsSrcDenorm =
+ DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantNonZero);
+
+ // Source denormal normalization.
+ // For a source denormal, the true exponent is (1 - SrcBias) and the
+ // mantissa has no implicit leading 1. Normalize by finding the position
+ // of the leading 1 in the mantissa.
+ SDValue LeadingZeros =
+ DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, IntVT, SrcMantField);
+
+ // normShift = LeadingZeros - (SrcBits - 1 - SrcMant).
+ const unsigned LZOffset = SrcBits - 1 - SrcMant;
+ SDValue NormShift = DAG.getNode(ISD::SUB, dl, IntVT, LeadingZeros,
+ DAG.getConstant(LZOffset, dl, IntVT));
+
+ // Normalized mantissa.
+ SDValue NormMant =
+ DAG.getNode(ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SHL, dl, IntVT, SrcMantField, NormShift),
+ DAG.getConstant(SrcMantMask, dl, IntVT));
+
+ // effective_exp = 1 - NormShift.
+ SDValue DenormSrcExp =
+ DAG.getNode(ISD::SUB, dl, IntVT, One, NormShift);
+
+ // Select between normal and denorm source.
+ SDValue EffSrcExp =
+ DAG.getSelect(dl, IntVT, IsSrcDenorm, DenormSrcExp, SrcExpField);
+ SDValue EffSrcMant =
+ DAG.getSelect(dl, IntVT, IsSrcDenorm, NormMant, SrcMantField);
+
+ // Compute new biased exponent for destination.
+ // new_exp = src_exp - SrcBias + DstBias
+ const int BiasAdjust = DstBias - SrcBias;
+ SDValue NewExp = DAG.getNode(
+ ISD::ADD, dl, IntVT, EffSrcExp,
+ DAG.getConstant(APInt(SrcBits, BiasAdjust, true), dl, IntVT));
+
+ // Compute rounding increment given the round bit, sticky bits, and LSB
+ // of the truncated mantissa.
+ auto ComputeRoundUp = [&](SDValue RoundBit, SDValue StickyBits,
+ SDValue LSB) -> SDValue {
+ if (RoundMode == RoundingMode::NearestTiesToEven) {
+ // Round up if round_bit && (sticky || lsb)
+ SDValue StickyOrLSB = DAG.getNode(ISD::OR, dl, IntVT, StickyBits, LSB);
+ return DAG.getNode(ISD::AND, dl, IntVT, RoundBit, StickyOrLSB);
+ }
+ if (RoundMode == RoundingMode::TowardZero)
+ return Zero;
+ if (RoundMode == RoundingMode::TowardPositive) {
+ // Round up if positive and any truncated bits are set.
+ SDValue AnyTruncBits =
+ DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+ SDValue HasTruncBits =
+ DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+ SDValue IsPositive =
+ DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETEQ);
+ SDValue DoRound =
+ DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsPositive);
+ return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+ }
+ if (RoundMode == RoundingMode::TowardNegative) {
+ // Round up if negative and any truncated bits are set (to -Inf).
+ SDValue AnyTruncBits =
+ DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+ SDValue HasTruncBits =
+ DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+ SDValue IsNegative =
+ DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETNE);
+ SDValue DoRound =
+ DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsNegative);
+ return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+ }
+ if (RoundMode == RoundingMode::NearestTiesToAway)
+ return RoundBit;
+ llvm_unreachable("Unsupported rounding mode");
+ };
+
+ // Round mantissa from SrcMant bits to DstMant bits.
+ SDValue TruncMant;
+ SDValue RoundUp;
+ if (SrcMant > DstMant) {
+ const unsigned Shift = SrcMant - DstMant;
+ SDValue ShiftConst = DAG.getShiftAmountConstant(Shift, IntVT, dl);
+ TruncMant =
+ DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant, ShiftConst);
+
+ // Check bit at position Shift - 1 aka the round bit.
+ SDValue RoundBit;
+ if (Shift >= 1) {
+ RoundBit = DAG.getNode(
+ ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant,
+ DAG.getShiftAmountConstant(Shift - 1, IntVT, dl)),
+ One);
+ } else {
+ RoundBit = Zero;
+ }
+
+ // OR of all bits below the round bit to get sticky bits.
+ SDValue StickyBits;
+ if (Shift >= 2) {
+ uint64_t StickyMask = (1ULL << (Shift - 1)) - 1;
+ StickyBits = DAG.getNode(ISD::AND, dl, IntVT, EffSrcMant,
+ DAG.getConstant(StickyMask, dl, IntVT));
+ StickyBits = DAG.getSetCC(dl, SetCCVT, StickyBits, Zero, ISD::SETNE);
+ StickyBits =
+ DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, StickyBits);
+ } else {
+ StickyBits = Zero;
+ }
+
+ // LSB of truncated mantissa.
+ SDValue LSB = DAG.getNode(ISD::AND, dl, IntVT, TruncMant, One);
+
+ RoundUp = ComputeRoundUp(RoundBit, StickyBits, LSB);
+ } else {
+ // If DstMant >= SrcMant, then no rounding needed, just shift left.
+ SDValue MantShift =
+ DAG.getShiftAmountConstant(DstMant - SrcMant, IntVT, dl);
+ TruncMant = DAG.getNode(ISD::SHL, dl, IntVT, EffSrcMant, MantShift);
+ RoundUp = Zero;
+ }
+
+ // Apply rounding.
+ SDValue RoundedMant =
+ DAG.getNode(ISD::ADD, dl, IntVT, TruncMant, RoundUp);
+
+ // Handle mantissa overflow from rounding.
+ // If rounded_mant > DstMantMask, carry into exponent.
+ SDValue MantOverflow = DAG.getSetCC(
+ dl, SetCCVT, RoundedMant,
+ DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+ // On overflow: mant = 0, exp += 1.
+ SDValue AdjMant =
+ DAG.getSelect(dl, IntVT, MantOverflow, Zero, RoundedMant);
+ SDValue AdjExp = DAG.getNode(
+ ISD::ADD, dl, IntVT, NewExp,
+ DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, MantOverflow));
+
+ // Precompute sign shifted to MSB of destination.
+ SDValue SignShifted = DAG.getNode(
+ ISD::SHL, dl, IntVT, SignBit,
+ DAG.getShiftAmountConstant(DstBits - 1, IntVT, dl));
+
+ // Destination denormal conversion (when new_exp <= 0).
+ // Shift the mantissa right by 1 - new_exp additional bits and set the
+ // exponent field to 0.
+ SDValue ExpIsNeg =
+ DAG.getSetCC(dl, SetCCVT, AdjExp,
+ DAG.getConstant(1, dl, IntVT), ISD::SETLT);
+
+ SDValue DenormResult;
+ {
+ // denorm_shift = 1 - NewExp.
+ SDValue DenormShift =
+ DAG.getNode(ISD::SUB, dl, IntVT, One, NewExp);
+
+ // full_src_mant = (1 << SrcMant) | EffSrcMant.
+ SDValue ImplicitOne = DAG.getNode(
+ ISD::SHL, dl, IntVT, One,
+ DAG.getShiftAmountConstant(SrcMant, IntVT, dl));
+ SDValue FullSrcMant =
+ DAG.getNode(ISD::OR, dl, IntVT, EffSrcMant, ImplicitOne);
+
+ // Total right shift = (SrcMant - DstMant) + DenormShift
+ SDValue TotalShift;
+ if (SrcMant >= DstMant) {
+ TotalShift =
+ DAG.getNode(ISD::ADD, dl, IntVT, DenormShift,
+ DAG.getConstant(SrcMant - DstMant, dl, IntVT));
+ } else {
+ TotalShift =
+ DAG.getNode(ISD::SUB, dl, IntVT, DenormShift,
+ DAG.getConstant(DstMant - SrcMant, dl, IntVT));
+ }
+
+ // Clamp total shift to avoid UB, then trancate denorm mantissa.
+ SDValue MaxShift = DAG.getConstant(SrcBits - 1, dl, IntVT);
+ SDValue ClampedShift = DAG.getNode(ISD::UMIN, dl, IntVT, TotalShift,
+ MaxShift);
+ SDValue DenormTruncMant =
+ DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, ClampedShift);
+
+ // Rounding for denorm path.
+ SDValue DenormRoundUp;
+ {
+ // Round bit is at position TotalShift - 1 of FullSrcMant.
+ // Clamp to at least 1 so the subtraction doesn't underflow and create
+ // shift nodes with invalid shift amounts.
+ SDValue SafeShift =
+ DAG.getNode(ISD::UMAX, dl, IntVT, ClampedShift, One);
+ SDValue RoundBitPos =
+ DAG.getNode(ISD::SUB, dl, IntVT, SafeShift, One);
+ SDValue DenormRoundBit = DAG.getNode(
+ ISD::AND, dl, IntVT,
+ DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, RoundBitPos), One);
+
+ // Sticky: all bits below round bit.
+ // sticky_mask = (1 << RoundBitPos) - 1
+ SDValue StickyMask = DAG.getNode(
+ ISD::SUB, dl, IntVT,
+ DAG.getNode(ISD::SHL, dl, IntVT, One, RoundBitPos), One);
+ SDValue DenormStickyBits =
+ DAG.getNode(ISD::AND, dl, IntVT, FullSrcMant, StickyMask);
+ SDValue HasSticky =
+ DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT,
+ DAG.getSetCC(dl, SetCCVT, DenormStickyBits, Zero,
+ ISD::SETNE));
+
+ SDValue DenormLSB =
+ DAG.getNode(ISD::AND, dl, IntVT, DenormTruncMant, One);
+
+ DenormRoundUp =
+ ComputeRoundUp(DenormRoundBit, HasSticky, DenormLSB);
+
+ // Only apply rounding if TotalShift >= 1 (i.e., there are bits to
+ // round).
+ SDValue ShiftGEOne =
+ DAG.getSetCC(dl, SetCCVT, ClampedShift, One, ISD::SETUGE);
+ DenormRoundUp = DAG.getSelect(dl, IntVT, ShiftGEOne, DenormRoundUp,
+ Zero);
+ }
+
+ SDValue DenormRoundedMant =
+ DAG.getNode(ISD::ADD, dl, IntVT, DenormTruncMant, DenormRoundUp);
+
+ // If rounding caused overflow into the normal range, then we get the
+ // smallest normal number.
+ SDValue DenormMantOF = DAG.getSetCC(
+ dl, SetCCVT, DenormRoundedMant,
+ DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+ SDValue DenormFinalMant = DAG.getSelect(
+ dl, IntVT, DenormMantOF, Zero, DenormRoundedMant);
+ SDValue DenormFinalExp = DAG.getSelect(
+ dl, IntVT, DenormMantOF, One, Zero);
+
+ // Assemble: sign | (exp << DstMant) | mant
+ SDValue DenormExpShifted = DAG.getNode(
+ ISD::SHL, dl, IntVT, DenormFinalExp,
+ DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+ DenormResult = DAG.getNode(
+ ISD::OR, dl, IntVT,
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted, DenormExpShifted),
+ DenormFinalMant);
+
+ // If the value is to small for even a denorm (all mantissa bits
+ // shifted away), handle based on rounding mode.
+ // This is covered by the DenormRoundedMant = 0 case naturally.
+ }
+
+ // Exponent overflow detection.
+ SDValue ExpOF = DAG.getSetCC(
+ dl, SetCCVT, AdjExp,
+ DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETGT);
+
+ // Also check if AdjExp == DstExpMaxNormal and mantissa overflow into
+ // a value that exceeds the max allowed mantissa at that exponent.
+ SDValue ExpAtMax = DAG.getSetCC(
+ dl, SetCCVT, AdjExp,
+ DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETEQ);
+ SDValue MantExceedsMax = DAG.getSetCC(
+ dl, SetCCVT, AdjMant,
+ DAG.getConstant(DstMaxMantAtMaxExp, dl, IntVT), ISD::SETGT);
+ SDValue ExpMantOF =
+ DAG.getNode(ISD::AND, dl, SetCCVT, ExpAtMax, MantExceedsMax);
+ SDValue IsOverflow =
+ DAG.getNode(ISD::OR, dl, SetCCVT, ExpOF, ExpMantOF);
+
+ // Build overflow result.
+ SDValue OverflowResult;
+
+ if (Saturate) {
+ // Clamp to max finite value:
+ // sign | (DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp
+ uint64_t MaxFinite =
+ ((uint64_t)DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp;
+ OverflowResult =
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(MaxFinite, dl, IntVT));
+ } else if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // Produce infinity.
+ uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+ OverflowResult =
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(InfBits, dl, IntVT));
+ } else {
+ // Emit poison if no Inf in format and not saturating.
+ OverflowResult = DAG.getPOISON(IntVT);
+ }
+
+ // Assemble normal result: sign | (AdjExp << DstMant) | AdjMant
+ SDValue NormExpShifted = DAG.getNode(
+ ISD::SHL, dl, IntVT, AdjExp,
+ DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+ SDValue NormResult = DAG.getNode(
+ ISD::OR, dl, IntVT,
+ DAG.getNode(ISD::OR, dl, IntVT, SignShifted, NormExpShifted), AdjMant);
+
+ // Build special-value results.
+ SDValue NaNResult;
+ if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // Produce canonical NaN.
+ const uint64_t QNaNBit = (DstMant > 0) ? (1ULL << (DstMant - 1)) : 0;
+ NaNResult =
+ DAG.getConstant(((uint64_t)DstExpMax << DstMant) | QNaNBit, dl,
+ IntVT);
+ } else if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+ DstNanEnc == fltNanEncoding::AllOnes) {
+ // E4M3FN-style: NaN is exp=all-ones, mant=all-ones.
+ NaNResult =
+ DAG.getConstant(((uint64_t)DstExpMax << DstMant) | DstMantMask, dl,
+ IntVT);
+ } else {
+ // NaN -> poison for finite only values.
+ NaNResult = DAG.getPOISON(IntVT);
+ }
+
+ // Inf handling.
+ SDValue InfResult;
+ if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+ // Produce signed infinity.
+ uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+ InfResult = DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+ DAG.getConstant(InfBits, dl, IntVT));
+ } else if (Saturate) {
+ // Inf saturate...
[truncated]
|
See: https://github.com/MrSidims/llvm-project/tree/only-for-testing-ocl-fp (it doesn't cover all the cases yet just some PoC) |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| unsigned DstExpMaxNormal; | ||
| if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) | ||
| DstExpMaxNormal = DstExpMax - 1; | ||
| else | ||
| DstExpMaxNormal = DstExpMax; |
| const uint64_t SrcExpMask = (1ULL << SrcExpBits) - 1; | ||
|
|
||
| // Work in the source integer type. | ||
| EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcBits); |
| cast<MetadataAsValue>(I.getArgOperand(2))->getMetadata(); | ||
| StringRef RoundStr = cast<MDString>(RoundMD)->getString(); | ||
| std::optional<RoundingMode> RoundMode = convertStrToRoundingMode(RoundStr); | ||
| if (!RoundMode || *RoundMode == RoundingMode::Dynamic) { |
There was a problem hiding this comment.
I'd assume the IR verifier would reject dynamic rounding mode
| %r = call i8 @llvm.convert.to.arbitrary.fp.i8.f32( | ||
| float %v, metadata !"Float8E8M0FNU", metadata !"round.tonearest", i1 false) | ||
| ret i8 %r | ||
| } |
The expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm:
Normal/Denorm.
Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN.
OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
E2E testing on X86 done with an assistance of Claude Code Opus 4.6.