Skip to content

[SelectionDAG] Add expansion for llvm.convert.to.arbitrary.fp#193595

Open
MrSidims wants to merge 3 commits intollvm:mainfrom
MrSidims:expand-float-to-apfloat
Open

[SelectionDAG] Add expansion for llvm.convert.to.arbitrary.fp#193595
MrSidims wants to merge 3 commits intollvm:mainfrom
MrSidims:expand-float-to-apfloat

Conversation

@MrSidims
Copy link
Copy Markdown
Contributor

The expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm:

  1. Bitcast the source float to an integer and extract sign, exponent, and mantissa bit fields via masks and shifts.
  2. Classify the input (zero/denormal/normal/Inf/NaN).
  3. Normalize source denormals by finding the MSB position of the mantissa and adjusting the effective exponent.
  4. Normal path: adjust the exponent bias from source to destination format and truncate the mantissa with rounding (supports NearestTiesToEven, TowardZero, TowardPositive, TowardNegative, NearestTiesToAway).
  5. Denormal destination path: when the biased destination exponent is <= 0, shift the mantissa right to produce a denormalized result with rounding.
  6. Handle mantissa overflow from rounding and exponent overflow. Produce Inf or saturate to max finite, depending on format and saturation flag.
  7. Build special-value results (canonical qNaN, signed Inf, signed zero) adapted to the destination format's non-finite behavior (IEEE754, NanOnly, FiniteOnly).
  8. Final selection in priority order: NaN > Inf > Zero > Overflow >
    Normal/Denorm.

Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN.

OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

E2E testing on X86 done with an assistance of Claude Code Opus 4.6.

The expansion converts a native IEEE float to an arbitrary-precision
FP format, returning the result as an integer, following this algorithm:

1. Bitcast the source float to an integer and extract sign, exponent,
and mantissa bit fields via masks and shifts.
2. Classify the input (zero/denormal/normal/Inf/NaN).
3. Normalize source denormals by finding the MSB position of the
mantissa and adjusting the effective exponent.
4. Normal path: adjust the exponent bias from source to destination
format and truncate the mantissa with rounding (supports
NearestTiesToEven, TowardZero, TowardPositive, TowardNegative,
NearestTiesToAway).
5. Denormal destination path: when the biased destination exponent is
<= 0, shift the mantissa right to produce a denormalized result
with rounding.
6. Handle mantissa overflow from rounding and exponent overflow. Produce Inf
or saturate to max finite, depending on format and saturation flag.
7. Build special-value results (canonical qNaN, signed Inf, signed
zero) adapted to the destination format's non-finite behavior
(IEEE754, NanOnly, FiniteOnly).
8. Final selection in priority order: NaN > Inf > Zero > Overflow >
Normal/Denorm.

Currently only conversions to OCP floats are covered, in LLVM terms
these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN,
Float4E2M1FN.

OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

E2E testing on X86 done with an assistance of Claude Code Opus 4.6.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-backend-nvptx
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-backend-amdgpu

Author: Dmitry Sidorov (MrSidims)

Changes

The expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm:

  1. Bitcast the source float to an integer and extract sign, exponent, and mantissa bit fields via masks and shifts.
  2. Classify the input (zero/denormal/normal/Inf/NaN).
  3. Normalize source denormals by finding the MSB position of the mantissa and adjusting the effective exponent.
  4. Normal path: adjust the exponent bias from source to destination format and truncate the mantissa with rounding (supports NearestTiesToEven, TowardZero, TowardPositive, TowardNegative, NearestTiesToAway).
  5. Denormal destination path: when the biased destination exponent is <= 0, shift the mantissa right to produce a denormalized result with rounding.
  6. Handle mantissa overflow from rounding and exponent overflow. Produce Inf or saturate to max finite, depending on format and saturation flag.
  7. Build special-value results (canonical qNaN, signed Inf, signed zero) adapted to the destination format's non-finite behavior (IEEE754, NanOnly, FiniteOnly).
  8. Final selection in priority order: NaN > Inf > Zero > Overflow >
    Normal/Denorm.

Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN.

OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

E2E testing on X86 done with an assistance of Claude Code Opus 4.6.


Patch is 172.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/193595.diff

14 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+9)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+473)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+20)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+73-23)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+43)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+1)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+2-1)
  • (added) llvm/test/CodeGen/AMDGPU/float-to-arbitrary-fp.ll (+905)
  • (added) llvm/test/CodeGen/NVPTX/float-to-arbitrary-fp.ll (+1028)
  • (added) llvm/test/CodeGen/X86/float-to-arbitrary-fp-error.ll (+76)
  • (added) llvm/test/CodeGen/X86/float-to-arbitrary-fp.ll (+1323)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 8a8a9ee71ca02..de9835ee43836 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1020,6 +1020,15 @@ enum NodeType {
   /// The second operand is a constant indicating the source FP semantics.
   CONVERT_FROM_ARBITRARY_FP,
 
+  /// CONVERT_TO_ARBITRARY_FP - Converts a native FP value to an arbitrary
+  /// floating-point format, returning the result as an integer.
+  /// The first operand is the source value.
+  /// The second operand is a constant indicating the destination FP semantics.
+  /// The third operand is a constant indication the rounding mode.
+  /// The last operand is a boolean consant indicating whether the result has
+  /// to be saturated.
+  CONVERT_TO_ARBITRARY_FP,
+
   /// Perform various unary floating-point operations inspired by libm. For
   /// FPOWI, the result is undefined if the integer operand doesn't fit into
   /// sizeof(int).
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 54d86dfbfa303..7dcc3a1f1c753 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3782,6 +3782,479 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     Results.push_back(Result);
     break;
   }
+  case ISD::CONVERT_TO_ARBITRARY_FP: {
+    // Expand conversion from a native IEEE float type to an arbitrary FP
+    // format, returning the result as an integer using bit manipulation.
+    //
+    // TODO: currently only conversions to FP4, FP6 and FP8 formats from OCP
+    // specification are expanded. Remaining arbitrary FP types: Float8E4M3,
+    // Float8E3M4, Float8E5M2FNUZ, Float8E4M3FNUZ, Float8E4M3B11FNUZ,
+    // Float8E8M0FNU.
+    EVT ResVT = Node->getValueType(0);
+
+    SDValue FloatVal = Node->getOperand(0);
+    const uint64_t SemEnum = Node->getConstantOperandVal(1);
+    const auto Sem = static_cast<APFloatBase::Semantics>(SemEnum);
+    const auto RoundMode =
+        static_cast<RoundingMode>(Node->getConstantOperandVal(2));
+    const bool Saturate = Node->getConstantOperandVal(3) != 0;
+
+    // Supported destination formats.
+    switch (Sem) {
+    case APFloatBase::S_Float8E5M2:
+    case APFloatBase::S_Float8E4M3FN:
+    case APFloatBase::S_Float6E3M2FN:
+    case APFloatBase::S_Float6E2M3FN:
+    case APFloatBase::S_Float4E2M1FN:
+      break;
+    default:
+      DAG.getContext()->emitError("CONVERT_TO_ARBITRARY_FP: not implemented "
+                                  "destination format (semantics enum " +
+                                  Twine(SemEnum) + ")");
+      Results.push_back(DAG.getPOISON(ResVT));
+      break;
+    }
+    if (!Results.empty())
+      break;
+
+    // Destination format parameters.
+    const fltSemantics &DstSem = APFloatBase::EnumToSemantics(Sem);
+    const unsigned DstBits = APFloat::getSizeInBits(DstSem);
+    const unsigned DstPrecision = APFloat::semanticsPrecision(DstSem);
+    const unsigned DstMant = DstPrecision - 1;
+    const unsigned DstExpBits = DstBits - DstMant - 1;
+    const int DstBias = 1 - APFloat::semanticsMinExponent(DstSem);
+    const unsigned DstExpMax = (1U << DstExpBits) - 1;
+    const uint64_t DstMantMask = (DstMant > 0) ? ((1ULL << DstMant) - 1) : 0;
+    const fltNonfiniteBehavior DstNFBehavior = DstSem.nonFiniteBehavior;
+    const fltNanEncoding DstNanEnc = DstSem.nanEncoding;
+
+    // Compute the maximum normal exponent for the destination format.
+    unsigned DstExpMaxNormal;
+    if (DstNFBehavior == fltNonfiniteBehavior::IEEE754)
+      DstExpMaxNormal = DstExpMax - 1;
+    else
+      DstExpMaxNormal = DstExpMax;
+
+    // For NanOnly formats the max exponent field for finite values
+    // is DstExpMax, but the encoding with exp = DstExpMax and
+    // mant = all-ones is NaN. So DstExpMaxNormal = DstExpMax, but max
+    // mantissa at that exponent is DstMantMask - 1 (if NanEnc == AllOnes) to
+    // avoid the NaN encoding.
+    uint64_t DstMaxMantAtMaxExp = DstMantMask;
+    if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+        DstNanEnc == fltNanEncoding::AllOnes)
+      DstMaxMantAtMaxExp = DstMantMask - 1;
+
+    // Source format parameters.
+    EVT SrcVT = FloatVal.getValueType();
+    const fltSemantics &SrcSem = SrcVT.getFltSemantics();
+    const unsigned SrcBits = APFloat::getSizeInBits(SrcSem);
+    const unsigned SrcPrecision = APFloat::semanticsPrecision(SrcSem);
+    const unsigned SrcMant = SrcPrecision - 1;
+    const unsigned SrcExpBits = SrcBits - SrcMant - 1;
+    const int SrcBias = 1 - APFloat::semanticsMinExponent(SrcSem);
+    const uint64_t SrcMantMask = (1ULL << SrcMant) - 1;
+    const uint64_t SrcExpMask = (1ULL << SrcExpBits) - 1;
+
+    // Work in the source integer type.
+    EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcBits);
+    EVT SetCCVT = getSetCCResultType(IntVT);
+
+    SDValue Zero = DAG.getConstant(0, dl, IntVT);
+    SDValue One = DAG.getConstant(1, dl, IntVT);
+
+    // Bitcast source float to integer and extract bit fields.
+    SDValue Src = DAG.getNode(ISD::BITCAST, dl, IntVT, FloatVal);
+    SDValue SrcMantField = DAG.getNode(ISD::AND, dl, IntVT, Src,
+                                       DAG.getConstant(SrcMantMask, dl, IntVT));
+
+    SDValue SrcExpField = DAG.getNode(
+        ISD::AND, dl, IntVT,
+        DAG.getNode(ISD::SRL, dl, IntVT, Src,
+                    DAG.getShiftAmountConstant(SrcMant, IntVT, dl)),
+        DAG.getConstant(SrcExpMask, dl, IntVT));
+
+    SDValue SignBit =
+        DAG.getNode(ISD::SRL, dl, IntVT, Src,
+                    DAG.getShiftAmountConstant(SrcBits - 1, IntVT, dl));
+
+    // Classify the input value.
+    SDValue SrcExpAllOnes = DAG.getConstant(SrcExpMask, dl, IntVT);
+    SDValue IsExpAllOnes =
+        DAG.getSetCC(dl, SetCCVT, SrcExpField, SrcExpAllOnes, ISD::SETEQ);
+    SDValue IsExpZero =
+        DAG.getSetCC(dl, SetCCVT, SrcExpField, Zero, ISD::SETEQ);
+    SDValue IsMantZero =
+        DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETEQ);
+    SDValue IsMantNonZero =
+        DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETNE);
+
+    // If source is IEEE fp, tehn NaN = exp_all_ones && mant != 0.
+    SDValue IsNaN =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantNonZero);
+    // Inf = exp_all_ones && mant == 0.
+    SDValue IsInf =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantZero);
+    // Zero = exp == 0 && mant == 0.
+    SDValue IsZero =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantZero);
+    // Source denorm = exp == 0 && mant != 0.
+    SDValue IsSrcDenorm =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantNonZero);
+
+    // Source denormal normalization.
+    // For a source denormal, the true exponent is (1 - SrcBias) and the
+    // mantissa has no implicit leading 1. Normalize by finding the position
+    // of the leading 1 in the mantissa.
+    SDValue LeadingZeros =
+        DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, IntVT, SrcMantField);
+
+    // normShift = LeadingZeros - (SrcBits - 1 - SrcMant).
+    const unsigned LZOffset = SrcBits - 1 - SrcMant;
+    SDValue NormShift = DAG.getNode(ISD::SUB, dl, IntVT, LeadingZeros,
+                                    DAG.getConstant(LZOffset, dl, IntVT));
+
+    // Normalized mantissa.
+    SDValue NormMant =
+        DAG.getNode(ISD::AND, dl, IntVT,
+                    DAG.getNode(ISD::SHL, dl, IntVT, SrcMantField, NormShift),
+                    DAG.getConstant(SrcMantMask, dl, IntVT));
+
+    // effective_exp = 1 - NormShift.
+    SDValue DenormSrcExp =
+        DAG.getNode(ISD::SUB, dl, IntVT, One, NormShift);
+
+    // Select between normal and denorm source.
+    SDValue EffSrcExp =
+        DAG.getSelect(dl, IntVT, IsSrcDenorm, DenormSrcExp, SrcExpField);
+    SDValue EffSrcMant =
+        DAG.getSelect(dl, IntVT, IsSrcDenorm, NormMant, SrcMantField);
+
+    // Compute new biased exponent for destination.
+    // new_exp = src_exp - SrcBias + DstBias
+    const int BiasAdjust = DstBias - SrcBias;
+    SDValue NewExp = DAG.getNode(
+        ISD::ADD, dl, IntVT, EffSrcExp,
+        DAG.getConstant(APInt(SrcBits, BiasAdjust, true), dl, IntVT));
+
+    // Compute rounding increment given the round bit, sticky bits, and LSB
+    // of the truncated mantissa.
+    auto ComputeRoundUp = [&](SDValue RoundBit, SDValue StickyBits,
+                              SDValue LSB) -> SDValue {
+      if (RoundMode == RoundingMode::NearestTiesToEven) {
+        // Round up if round_bit && (sticky || lsb)
+        SDValue StickyOrLSB = DAG.getNode(ISD::OR, dl, IntVT, StickyBits, LSB);
+        return DAG.getNode(ISD::AND, dl, IntVT, RoundBit, StickyOrLSB);
+      }
+      if (RoundMode == RoundingMode::TowardZero)
+        return Zero;
+      if (RoundMode == RoundingMode::TowardPositive) {
+        // Round up if positive and any truncated bits are set.
+        SDValue AnyTruncBits =
+            DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+        SDValue HasTruncBits =
+            DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+        SDValue IsPositive =
+            DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETEQ);
+        SDValue DoRound =
+            DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsPositive);
+        return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+      }
+      if (RoundMode == RoundingMode::TowardNegative) {
+        // Round up if negative and any truncated bits are set (to -Inf).
+        SDValue AnyTruncBits =
+            DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+        SDValue HasTruncBits =
+            DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+        SDValue IsNegative =
+            DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETNE);
+        SDValue DoRound =
+            DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsNegative);
+        return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+      }
+      if (RoundMode == RoundingMode::NearestTiesToAway)
+        return RoundBit;
+      llvm_unreachable("Unsupported rounding mode");
+    };
+
+    // Round mantissa from SrcMant bits to DstMant bits.
+    SDValue TruncMant;
+    SDValue RoundUp;
+    if (SrcMant > DstMant) {
+      const unsigned Shift = SrcMant - DstMant;
+      SDValue ShiftConst = DAG.getShiftAmountConstant(Shift, IntVT, dl);
+      TruncMant =
+          DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant, ShiftConst);
+
+      // Check bit at position Shift - 1 aka the round bit.
+      SDValue RoundBit;
+      if (Shift >= 1) {
+        RoundBit = DAG.getNode(
+            ISD::AND, dl, IntVT,
+            DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant,
+                        DAG.getShiftAmountConstant(Shift - 1, IntVT, dl)),
+            One);
+      } else {
+        RoundBit = Zero;
+      }
+
+      // OR of all bits below the round bit to get sticky bits.
+      SDValue StickyBits;
+      if (Shift >= 2) {
+        uint64_t StickyMask = (1ULL << (Shift - 1)) - 1;
+        StickyBits = DAG.getNode(ISD::AND, dl, IntVT, EffSrcMant,
+                                 DAG.getConstant(StickyMask, dl, IntVT));
+        StickyBits = DAG.getSetCC(dl, SetCCVT, StickyBits, Zero, ISD::SETNE);
+        StickyBits =
+            DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, StickyBits);
+      } else {
+        StickyBits = Zero;
+      }
+
+      // LSB of truncated mantissa.
+      SDValue LSB = DAG.getNode(ISD::AND, dl, IntVT, TruncMant, One);
+
+      RoundUp = ComputeRoundUp(RoundBit, StickyBits, LSB);
+    } else {
+      // If DstMant >= SrcMant, then no rounding needed, just shift left.
+      SDValue MantShift =
+          DAG.getShiftAmountConstant(DstMant - SrcMant, IntVT, dl);
+      TruncMant = DAG.getNode(ISD::SHL, dl, IntVT, EffSrcMant, MantShift);
+      RoundUp = Zero;
+    }
+
+    // Apply rounding.
+    SDValue RoundedMant =
+        DAG.getNode(ISD::ADD, dl, IntVT, TruncMant, RoundUp);
+
+    // Handle mantissa overflow from rounding.
+    // If rounded_mant > DstMantMask, carry into exponent.
+    SDValue MantOverflow = DAG.getSetCC(
+        dl, SetCCVT, RoundedMant,
+        DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+    // On overflow: mant = 0, exp += 1.
+    SDValue AdjMant =
+        DAG.getSelect(dl, IntVT, MantOverflow, Zero, RoundedMant);
+    SDValue AdjExp = DAG.getNode(
+        ISD::ADD, dl, IntVT, NewExp,
+        DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, MantOverflow));
+
+    // Precompute sign shifted to MSB of destination.
+    SDValue SignShifted = DAG.getNode(
+        ISD::SHL, dl, IntVT, SignBit,
+        DAG.getShiftAmountConstant(DstBits - 1, IntVT, dl));
+
+    // Destination denormal conversion (when new_exp <= 0).
+    // Shift the mantissa right by 1 - new_exp additional bits and set the
+    // exponent field to 0.
+    SDValue ExpIsNeg =
+        DAG.getSetCC(dl, SetCCVT, AdjExp,
+                     DAG.getConstant(1, dl, IntVT), ISD::SETLT);
+
+    SDValue DenormResult;
+    {
+      // denorm_shift = 1 - NewExp.
+      SDValue DenormShift =
+          DAG.getNode(ISD::SUB, dl, IntVT, One, NewExp);
+
+      // full_src_mant = (1 << SrcMant) | EffSrcMant.
+      SDValue ImplicitOne = DAG.getNode(
+          ISD::SHL, dl, IntVT, One,
+          DAG.getShiftAmountConstant(SrcMant, IntVT, dl));
+      SDValue FullSrcMant =
+          DAG.getNode(ISD::OR, dl, IntVT, EffSrcMant, ImplicitOne);
+
+      // Total right shift = (SrcMant - DstMant) + DenormShift
+      SDValue TotalShift;
+      if (SrcMant >= DstMant) {
+        TotalShift =
+            DAG.getNode(ISD::ADD, dl, IntVT, DenormShift,
+                        DAG.getConstant(SrcMant - DstMant, dl, IntVT));
+      } else {
+        TotalShift =
+            DAG.getNode(ISD::SUB, dl, IntVT, DenormShift,
+                        DAG.getConstant(DstMant - SrcMant, dl, IntVT));
+      }
+
+      // Clamp total shift to avoid UB, then trancate denorm mantissa.
+      SDValue MaxShift = DAG.getConstant(SrcBits - 1, dl, IntVT);
+      SDValue ClampedShift = DAG.getNode(ISD::UMIN, dl, IntVT, TotalShift,
+                                         MaxShift);
+      SDValue DenormTruncMant =
+          DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, ClampedShift);
+
+      // Rounding for denorm path.
+      SDValue DenormRoundUp;
+      {
+        // Round bit is at position TotalShift - 1 of FullSrcMant.
+        // Clamp to at least 1 so the subtraction doesn't underflow and create
+        // shift nodes with invalid shift amounts.
+        SDValue SafeShift =
+            DAG.getNode(ISD::UMAX, dl, IntVT, ClampedShift, One);
+        SDValue RoundBitPos =
+            DAG.getNode(ISD::SUB, dl, IntVT, SafeShift, One);
+        SDValue DenormRoundBit = DAG.getNode(
+            ISD::AND, dl, IntVT,
+            DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, RoundBitPos), One);
+
+        // Sticky: all bits below round bit.
+        // sticky_mask = (1 << RoundBitPos) - 1
+        SDValue StickyMask = DAG.getNode(
+            ISD::SUB, dl, IntVT,
+            DAG.getNode(ISD::SHL, dl, IntVT, One, RoundBitPos), One);
+        SDValue DenormStickyBits =
+            DAG.getNode(ISD::AND, dl, IntVT, FullSrcMant, StickyMask);
+        SDValue HasSticky =
+            DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT,
+                        DAG.getSetCC(dl, SetCCVT, DenormStickyBits, Zero,
+                                     ISD::SETNE));
+
+        SDValue DenormLSB =
+            DAG.getNode(ISD::AND, dl, IntVT, DenormTruncMant, One);
+
+        DenormRoundUp =
+            ComputeRoundUp(DenormRoundBit, HasSticky, DenormLSB);
+
+        // Only apply rounding if TotalShift >= 1 (i.e., there are bits to
+        // round).
+        SDValue ShiftGEOne =
+            DAG.getSetCC(dl, SetCCVT, ClampedShift, One, ISD::SETUGE);
+        DenormRoundUp = DAG.getSelect(dl, IntVT, ShiftGEOne, DenormRoundUp,
+                                      Zero);
+      }
+
+      SDValue DenormRoundedMant =
+          DAG.getNode(ISD::ADD, dl, IntVT, DenormTruncMant, DenormRoundUp);
+
+      // If rounding caused overflow into the normal range, then we get the
+      // smallest normal number.
+      SDValue DenormMantOF = DAG.getSetCC(
+          dl, SetCCVT, DenormRoundedMant,
+          DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+      SDValue DenormFinalMant = DAG.getSelect(
+          dl, IntVT, DenormMantOF, Zero, DenormRoundedMant);
+      SDValue DenormFinalExp = DAG.getSelect(
+          dl, IntVT, DenormMantOF, One, Zero);
+
+      // Assemble: sign | (exp << DstMant) | mant
+      SDValue DenormExpShifted = DAG.getNode(
+          ISD::SHL, dl, IntVT, DenormFinalExp,
+          DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+      DenormResult = DAG.getNode(
+          ISD::OR, dl, IntVT,
+          DAG.getNode(ISD::OR, dl, IntVT, SignShifted, DenormExpShifted),
+          DenormFinalMant);
+
+      // If the value is to small for even a denorm (all mantissa bits
+      // shifted away), handle based on rounding mode.
+      // This is covered by the DenormRoundedMant = 0 case naturally.
+    }
+
+    // Exponent overflow detection.
+    SDValue ExpOF = DAG.getSetCC(
+        dl, SetCCVT, AdjExp,
+        DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETGT);
+
+    // Also check if AdjExp == DstExpMaxNormal and mantissa overflow into
+    // a value that exceeds the max allowed mantissa at that exponent.
+    SDValue ExpAtMax = DAG.getSetCC(
+        dl, SetCCVT, AdjExp,
+        DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETEQ);
+    SDValue MantExceedsMax = DAG.getSetCC(
+        dl, SetCCVT, AdjMant,
+        DAG.getConstant(DstMaxMantAtMaxExp, dl, IntVT), ISD::SETGT);
+    SDValue ExpMantOF =
+        DAG.getNode(ISD::AND, dl, SetCCVT, ExpAtMax, MantExceedsMax);
+    SDValue IsOverflow =
+        DAG.getNode(ISD::OR, dl, SetCCVT, ExpOF, ExpMantOF);
+
+    // Build overflow result.
+    SDValue OverflowResult;
+
+    if (Saturate) {
+      // Clamp to max finite value:
+      // sign | (DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp
+      uint64_t MaxFinite =
+          ((uint64_t)DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp;
+      OverflowResult =
+          DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+                      DAG.getConstant(MaxFinite, dl, IntVT));
+    } else if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+      // Produce infinity.
+      uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+      OverflowResult =
+          DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+                      DAG.getConstant(InfBits, dl, IntVT));
+    } else {
+      // Emit poison if no Inf in format and not saturating.
+      OverflowResult = DAG.getPOISON(IntVT);
+    }
+
+    // Assemble normal result: sign | (AdjExp << DstMant) | AdjMant
+    SDValue NormExpShifted = DAG.getNode(
+        ISD::SHL, dl, IntVT, AdjExp,
+        DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+    SDValue NormResult = DAG.getNode(
+        ISD::OR, dl, IntVT,
+        DAG.getNode(ISD::OR, dl, IntVT, SignShifted, NormExpShifted), AdjMant);
+
+    // Build special-value results.
+    SDValue NaNResult;
+    if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+      // Produce canonical NaN.
+      const uint64_t QNaNBit = (DstMant > 0) ? (1ULL << (DstMant - 1)) : 0;
+      NaNResult =
+          DAG.getConstant(((uint64_t)DstExpMax << DstMant) | QNaNBit, dl,
+                          IntVT);
+    } else if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+               DstNanEnc == fltNanEncoding::AllOnes) {
+      // E4M3FN-style: NaN is exp=all-ones, mant=all-ones.
+      NaNResult =
+          DAG.getConstant(((uint64_t)DstExpMax << DstMant) | DstMantMask, dl,
+                          IntVT);
+    } else {
+      // NaN -> poison for finite only values.
+      NaNResult = DAG.getPOISON(IntVT);
+    }
+
+    // Inf handling.
+    SDValue InfResult;
+    if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+      // Produce signed infinity.
+      uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+      InfResult = DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+                              DAG.getConstant(InfBits, dl, IntVT));
+    } else if (Saturate) {
+      // Inf saturate...
[truncated]

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-llvm-selectiondag

Author: Dmitry Sidorov (MrSidims)

Changes

The expansion converts a native IEEE float to an arbitrary-precision FP format, returning the result as an integer, following this algorithm:

  1. Bitcast the source float to an integer and extract sign, exponent, and mantissa bit fields via masks and shifts.
  2. Classify the input (zero/denormal/normal/Inf/NaN).
  3. Normalize source denormals by finding the MSB position of the mantissa and adjusting the effective exponent.
  4. Normal path: adjust the exponent bias from source to destination format and truncate the mantissa with rounding (supports NearestTiesToEven, TowardZero, TowardPositive, TowardNegative, NearestTiesToAway).
  5. Denormal destination path: when the biased destination exponent is <= 0, shift the mantissa right to produce a denormalized result with rounding.
  6. Handle mantissa overflow from rounding and exponent overflow. Produce Inf or saturate to max finite, depending on format and saturation flag.
  7. Build special-value results (canonical qNaN, signed Inf, signed zero) adapted to the destination format's non-finite behavior (IEEE754, NanOnly, FiniteOnly).
  8. Final selection in priority order: NaN > Inf > Zero > Overflow >
    Normal/Denorm.

Currently only conversions to OCP floats are covered, in LLVM terms these are: Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, Float4E2M1FN.

OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

E2E testing on X86 done with an assistance of Claude Code Opus 4.6.


Patch is 172.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/193595.diff

14 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+9)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+473)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+20)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+73-23)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+43)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+1)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+2-1)
  • (added) llvm/test/CodeGen/AMDGPU/float-to-arbitrary-fp.ll (+905)
  • (added) llvm/test/CodeGen/NVPTX/float-to-arbitrary-fp.ll (+1028)
  • (added) llvm/test/CodeGen/X86/float-to-arbitrary-fp-error.ll (+76)
  • (added) llvm/test/CodeGen/X86/float-to-arbitrary-fp.ll (+1323)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 8a8a9ee71ca02..de9835ee43836 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1020,6 +1020,15 @@ enum NodeType {
   /// The second operand is a constant indicating the source FP semantics.
   CONVERT_FROM_ARBITRARY_FP,
 
+  /// CONVERT_TO_ARBITRARY_FP - Converts a native FP value to an arbitrary
+  /// floating-point format, returning the result as an integer.
+  /// The first operand is the source value.
+  /// The second operand is a constant indicating the destination FP semantics.
+  /// The third operand is a constant indication the rounding mode.
+  /// The last operand is a boolean consant indicating whether the result has
+  /// to be saturated.
+  CONVERT_TO_ARBITRARY_FP,
+
   /// Perform various unary floating-point operations inspired by libm. For
   /// FPOWI, the result is undefined if the integer operand doesn't fit into
   /// sizeof(int).
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 54d86dfbfa303..7dcc3a1f1c753 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3782,6 +3782,479 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     Results.push_back(Result);
     break;
   }
+  case ISD::CONVERT_TO_ARBITRARY_FP: {
+    // Expand conversion from a native IEEE float type to an arbitrary FP
+    // format, returning the result as an integer using bit manipulation.
+    //
+    // TODO: currently only conversions to FP4, FP6 and FP8 formats from OCP
+    // specification are expanded. Remaining arbitrary FP types: Float8E4M3,
+    // Float8E3M4, Float8E5M2FNUZ, Float8E4M3FNUZ, Float8E4M3B11FNUZ,
+    // Float8E8M0FNU.
+    EVT ResVT = Node->getValueType(0);
+
+    SDValue FloatVal = Node->getOperand(0);
+    const uint64_t SemEnum = Node->getConstantOperandVal(1);
+    const auto Sem = static_cast<APFloatBase::Semantics>(SemEnum);
+    const auto RoundMode =
+        static_cast<RoundingMode>(Node->getConstantOperandVal(2));
+    const bool Saturate = Node->getConstantOperandVal(3) != 0;
+
+    // Supported destination formats.
+    switch (Sem) {
+    case APFloatBase::S_Float8E5M2:
+    case APFloatBase::S_Float8E4M3FN:
+    case APFloatBase::S_Float6E3M2FN:
+    case APFloatBase::S_Float6E2M3FN:
+    case APFloatBase::S_Float4E2M1FN:
+      break;
+    default:
+      DAG.getContext()->emitError("CONVERT_TO_ARBITRARY_FP: not implemented "
+                                  "destination format (semantics enum " +
+                                  Twine(SemEnum) + ")");
+      Results.push_back(DAG.getPOISON(ResVT));
+      break;
+    }
+    if (!Results.empty())
+      break;
+
+    // Destination format parameters.
+    const fltSemantics &DstSem = APFloatBase::EnumToSemantics(Sem);
+    const unsigned DstBits = APFloat::getSizeInBits(DstSem);
+    const unsigned DstPrecision = APFloat::semanticsPrecision(DstSem);
+    const unsigned DstMant = DstPrecision - 1;
+    const unsigned DstExpBits = DstBits - DstMant - 1;
+    const int DstBias = 1 - APFloat::semanticsMinExponent(DstSem);
+    const unsigned DstExpMax = (1U << DstExpBits) - 1;
+    const uint64_t DstMantMask = (DstMant > 0) ? ((1ULL << DstMant) - 1) : 0;
+    const fltNonfiniteBehavior DstNFBehavior = DstSem.nonFiniteBehavior;
+    const fltNanEncoding DstNanEnc = DstSem.nanEncoding;
+
+    // Compute the maximum normal exponent for the destination format.
+    unsigned DstExpMaxNormal;
+    if (DstNFBehavior == fltNonfiniteBehavior::IEEE754)
+      DstExpMaxNormal = DstExpMax - 1;
+    else
+      DstExpMaxNormal = DstExpMax;
+
+    // For NanOnly formats the max exponent field for finite values
+    // is DstExpMax, but the encoding with exp = DstExpMax and
+    // mant = all-ones is NaN. So DstExpMaxNormal = DstExpMax, but max
+    // mantissa at that exponent is DstMantMask - 1 (if NanEnc == AllOnes) to
+    // avoid the NaN encoding.
+    uint64_t DstMaxMantAtMaxExp = DstMantMask;
+    if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+        DstNanEnc == fltNanEncoding::AllOnes)
+      DstMaxMantAtMaxExp = DstMantMask - 1;
+
+    // Source format parameters.
+    EVT SrcVT = FloatVal.getValueType();
+    const fltSemantics &SrcSem = SrcVT.getFltSemantics();
+    const unsigned SrcBits = APFloat::getSizeInBits(SrcSem);
+    const unsigned SrcPrecision = APFloat::semanticsPrecision(SrcSem);
+    const unsigned SrcMant = SrcPrecision - 1;
+    const unsigned SrcExpBits = SrcBits - SrcMant - 1;
+    const int SrcBias = 1 - APFloat::semanticsMinExponent(SrcSem);
+    const uint64_t SrcMantMask = (1ULL << SrcMant) - 1;
+    const uint64_t SrcExpMask = (1ULL << SrcExpBits) - 1;
+
+    // Work in the source integer type.
+    EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcBits);
+    EVT SetCCVT = getSetCCResultType(IntVT);
+
+    SDValue Zero = DAG.getConstant(0, dl, IntVT);
+    SDValue One = DAG.getConstant(1, dl, IntVT);
+
+    // Bitcast source float to integer and extract bit fields.
+    SDValue Src = DAG.getNode(ISD::BITCAST, dl, IntVT, FloatVal);
+    SDValue SrcMantField = DAG.getNode(ISD::AND, dl, IntVT, Src,
+                                       DAG.getConstant(SrcMantMask, dl, IntVT));
+
+    SDValue SrcExpField = DAG.getNode(
+        ISD::AND, dl, IntVT,
+        DAG.getNode(ISD::SRL, dl, IntVT, Src,
+                    DAG.getShiftAmountConstant(SrcMant, IntVT, dl)),
+        DAG.getConstant(SrcExpMask, dl, IntVT));
+
+    SDValue SignBit =
+        DAG.getNode(ISD::SRL, dl, IntVT, Src,
+                    DAG.getShiftAmountConstant(SrcBits - 1, IntVT, dl));
+
+    // Classify the input value.
+    SDValue SrcExpAllOnes = DAG.getConstant(SrcExpMask, dl, IntVT);
+    SDValue IsExpAllOnes =
+        DAG.getSetCC(dl, SetCCVT, SrcExpField, SrcExpAllOnes, ISD::SETEQ);
+    SDValue IsExpZero =
+        DAG.getSetCC(dl, SetCCVT, SrcExpField, Zero, ISD::SETEQ);
+    SDValue IsMantZero =
+        DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETEQ);
+    SDValue IsMantNonZero =
+        DAG.getSetCC(dl, SetCCVT, SrcMantField, Zero, ISD::SETNE);
+
+    // If source is IEEE fp, tehn NaN = exp_all_ones && mant != 0.
+    SDValue IsNaN =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantNonZero);
+    // Inf = exp_all_ones && mant == 0.
+    SDValue IsInf =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpAllOnes, IsMantZero);
+    // Zero = exp == 0 && mant == 0.
+    SDValue IsZero =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantZero);
+    // Source denorm = exp == 0 && mant != 0.
+    SDValue IsSrcDenorm =
+        DAG.getNode(ISD::AND, dl, SetCCVT, IsExpZero, IsMantNonZero);
+
+    // Source denormal normalization.
+    // For a source denormal, the true exponent is (1 - SrcBias) and the
+    // mantissa has no implicit leading 1. Normalize by finding the position
+    // of the leading 1 in the mantissa.
+    SDValue LeadingZeros =
+        DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, IntVT, SrcMantField);
+
+    // normShift = LeadingZeros - (SrcBits - 1 - SrcMant).
+    const unsigned LZOffset = SrcBits - 1 - SrcMant;
+    SDValue NormShift = DAG.getNode(ISD::SUB, dl, IntVT, LeadingZeros,
+                                    DAG.getConstant(LZOffset, dl, IntVT));
+
+    // Normalized mantissa.
+    SDValue NormMant =
+        DAG.getNode(ISD::AND, dl, IntVT,
+                    DAG.getNode(ISD::SHL, dl, IntVT, SrcMantField, NormShift),
+                    DAG.getConstant(SrcMantMask, dl, IntVT));
+
+    // effective_exp = 1 - NormShift.
+    SDValue DenormSrcExp =
+        DAG.getNode(ISD::SUB, dl, IntVT, One, NormShift);
+
+    // Select between normal and denorm source.
+    SDValue EffSrcExp =
+        DAG.getSelect(dl, IntVT, IsSrcDenorm, DenormSrcExp, SrcExpField);
+    SDValue EffSrcMant =
+        DAG.getSelect(dl, IntVT, IsSrcDenorm, NormMant, SrcMantField);
+
+    // Compute new biased exponent for destination.
+    // new_exp = src_exp - SrcBias + DstBias
+    const int BiasAdjust = DstBias - SrcBias;
+    SDValue NewExp = DAG.getNode(
+        ISD::ADD, dl, IntVT, EffSrcExp,
+        DAG.getConstant(APInt(SrcBits, BiasAdjust, true), dl, IntVT));
+
+    // Compute rounding increment given the round bit, sticky bits, and LSB
+    // of the truncated mantissa.
+    auto ComputeRoundUp = [&](SDValue RoundBit, SDValue StickyBits,
+                              SDValue LSB) -> SDValue {
+      if (RoundMode == RoundingMode::NearestTiesToEven) {
+        // Round up if round_bit && (sticky || lsb)
+        SDValue StickyOrLSB = DAG.getNode(ISD::OR, dl, IntVT, StickyBits, LSB);
+        return DAG.getNode(ISD::AND, dl, IntVT, RoundBit, StickyOrLSB);
+      }
+      if (RoundMode == RoundingMode::TowardZero)
+        return Zero;
+      if (RoundMode == RoundingMode::TowardPositive) {
+        // Round up if positive and any truncated bits are set.
+        SDValue AnyTruncBits =
+            DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+        SDValue HasTruncBits =
+            DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+        SDValue IsPositive =
+            DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETEQ);
+        SDValue DoRound =
+            DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsPositive);
+        return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+      }
+      if (RoundMode == RoundingMode::TowardNegative) {
+        // Round up if negative and any truncated bits are set (to -Inf).
+        SDValue AnyTruncBits =
+            DAG.getNode(ISD::OR, dl, IntVT, RoundBit, StickyBits);
+        SDValue HasTruncBits =
+            DAG.getSetCC(dl, SetCCVT, AnyTruncBits, Zero, ISD::SETNE);
+        SDValue IsNegative =
+            DAG.getSetCC(dl, SetCCVT, SignBit, Zero, ISD::SETNE);
+        SDValue DoRound =
+            DAG.getNode(ISD::AND, dl, SetCCVT, HasTruncBits, IsNegative);
+        return DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, DoRound);
+      }
+      if (RoundMode == RoundingMode::NearestTiesToAway)
+        return RoundBit;
+      llvm_unreachable("Unsupported rounding mode");
+    };
+
+    // Round mantissa from SrcMant bits to DstMant bits.
+    SDValue TruncMant;
+    SDValue RoundUp;
+    if (SrcMant > DstMant) {
+      const unsigned Shift = SrcMant - DstMant;
+      SDValue ShiftConst = DAG.getShiftAmountConstant(Shift, IntVT, dl);
+      TruncMant =
+          DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant, ShiftConst);
+
+      // Check bit at position Shift - 1 aka the round bit.
+      SDValue RoundBit;
+      if (Shift >= 1) {
+        RoundBit = DAG.getNode(
+            ISD::AND, dl, IntVT,
+            DAG.getNode(ISD::SRL, dl, IntVT, EffSrcMant,
+                        DAG.getShiftAmountConstant(Shift - 1, IntVT, dl)),
+            One);
+      } else {
+        RoundBit = Zero;
+      }
+
+      // OR of all bits below the round bit to get sticky bits.
+      SDValue StickyBits;
+      if (Shift >= 2) {
+        uint64_t StickyMask = (1ULL << (Shift - 1)) - 1;
+        StickyBits = DAG.getNode(ISD::AND, dl, IntVT, EffSrcMant,
+                                 DAG.getConstant(StickyMask, dl, IntVT));
+        StickyBits = DAG.getSetCC(dl, SetCCVT, StickyBits, Zero, ISD::SETNE);
+        StickyBits =
+            DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, StickyBits);
+      } else {
+        StickyBits = Zero;
+      }
+
+      // LSB of truncated mantissa.
+      SDValue LSB = DAG.getNode(ISD::AND, dl, IntVT, TruncMant, One);
+
+      RoundUp = ComputeRoundUp(RoundBit, StickyBits, LSB);
+    } else {
+      // If DstMant >= SrcMant, then no rounding needed, just shift left.
+      SDValue MantShift =
+          DAG.getShiftAmountConstant(DstMant - SrcMant, IntVT, dl);
+      TruncMant = DAG.getNode(ISD::SHL, dl, IntVT, EffSrcMant, MantShift);
+      RoundUp = Zero;
+    }
+
+    // Apply rounding.
+    SDValue RoundedMant =
+        DAG.getNode(ISD::ADD, dl, IntVT, TruncMant, RoundUp);
+
+    // Handle mantissa overflow from rounding.
+    // If rounded_mant > DstMantMask, carry into exponent.
+    SDValue MantOverflow = DAG.getSetCC(
+        dl, SetCCVT, RoundedMant,
+        DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+    // On overflow: mant = 0, exp += 1.
+    SDValue AdjMant =
+        DAG.getSelect(dl, IntVT, MantOverflow, Zero, RoundedMant);
+    SDValue AdjExp = DAG.getNode(
+        ISD::ADD, dl, IntVT, NewExp,
+        DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, MantOverflow));
+
+    // Precompute sign shifted to MSB of destination.
+    SDValue SignShifted = DAG.getNode(
+        ISD::SHL, dl, IntVT, SignBit,
+        DAG.getShiftAmountConstant(DstBits - 1, IntVT, dl));
+
+    // Destination denormal conversion (when new_exp <= 0).
+    // Shift the mantissa right by 1 - new_exp additional bits and set the
+    // exponent field to 0.
+    SDValue ExpIsNeg =
+        DAG.getSetCC(dl, SetCCVT, AdjExp,
+                     DAG.getConstant(1, dl, IntVT), ISD::SETLT);
+
+    SDValue DenormResult;
+    {
+      // denorm_shift = 1 - NewExp.
+      SDValue DenormShift =
+          DAG.getNode(ISD::SUB, dl, IntVT, One, NewExp);
+
+      // full_src_mant = (1 << SrcMant) | EffSrcMant.
+      SDValue ImplicitOne = DAG.getNode(
+          ISD::SHL, dl, IntVT, One,
+          DAG.getShiftAmountConstant(SrcMant, IntVT, dl));
+      SDValue FullSrcMant =
+          DAG.getNode(ISD::OR, dl, IntVT, EffSrcMant, ImplicitOne);
+
+      // Total right shift = (SrcMant - DstMant) + DenormShift
+      SDValue TotalShift;
+      if (SrcMant >= DstMant) {
+        TotalShift =
+            DAG.getNode(ISD::ADD, dl, IntVT, DenormShift,
+                        DAG.getConstant(SrcMant - DstMant, dl, IntVT));
+      } else {
+        TotalShift =
+            DAG.getNode(ISD::SUB, dl, IntVT, DenormShift,
+                        DAG.getConstant(DstMant - SrcMant, dl, IntVT));
+      }
+
+      // Clamp total shift to avoid UB, then trancate denorm mantissa.
+      SDValue MaxShift = DAG.getConstant(SrcBits - 1, dl, IntVT);
+      SDValue ClampedShift = DAG.getNode(ISD::UMIN, dl, IntVT, TotalShift,
+                                         MaxShift);
+      SDValue DenormTruncMant =
+          DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, ClampedShift);
+
+      // Rounding for denorm path.
+      SDValue DenormRoundUp;
+      {
+        // Round bit is at position TotalShift - 1 of FullSrcMant.
+        // Clamp to at least 1 so the subtraction doesn't underflow and create
+        // shift nodes with invalid shift amounts.
+        SDValue SafeShift =
+            DAG.getNode(ISD::UMAX, dl, IntVT, ClampedShift, One);
+        SDValue RoundBitPos =
+            DAG.getNode(ISD::SUB, dl, IntVT, SafeShift, One);
+        SDValue DenormRoundBit = DAG.getNode(
+            ISD::AND, dl, IntVT,
+            DAG.getNode(ISD::SRL, dl, IntVT, FullSrcMant, RoundBitPos), One);
+
+        // Sticky: all bits below round bit.
+        // sticky_mask = (1 << RoundBitPos) - 1
+        SDValue StickyMask = DAG.getNode(
+            ISD::SUB, dl, IntVT,
+            DAG.getNode(ISD::SHL, dl, IntVT, One, RoundBitPos), One);
+        SDValue DenormStickyBits =
+            DAG.getNode(ISD::AND, dl, IntVT, FullSrcMant, StickyMask);
+        SDValue HasSticky =
+            DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT,
+                        DAG.getSetCC(dl, SetCCVT, DenormStickyBits, Zero,
+                                     ISD::SETNE));
+
+        SDValue DenormLSB =
+            DAG.getNode(ISD::AND, dl, IntVT, DenormTruncMant, One);
+
+        DenormRoundUp =
+            ComputeRoundUp(DenormRoundBit, HasSticky, DenormLSB);
+
+        // Only apply rounding if TotalShift >= 1 (i.e., there are bits to
+        // round).
+        SDValue ShiftGEOne =
+            DAG.getSetCC(dl, SetCCVT, ClampedShift, One, ISD::SETUGE);
+        DenormRoundUp = DAG.getSelect(dl, IntVT, ShiftGEOne, DenormRoundUp,
+                                      Zero);
+      }
+
+      SDValue DenormRoundedMant =
+          DAG.getNode(ISD::ADD, dl, IntVT, DenormTruncMant, DenormRoundUp);
+
+      // If rounding caused overflow into the normal range, then we get the
+      // smallest normal number.
+      SDValue DenormMantOF = DAG.getSetCC(
+          dl, SetCCVT, DenormRoundedMant,
+          DAG.getConstant(DstMantMask, dl, IntVT), ISD::SETGT);
+      SDValue DenormFinalMant = DAG.getSelect(
+          dl, IntVT, DenormMantOF, Zero, DenormRoundedMant);
+      SDValue DenormFinalExp = DAG.getSelect(
+          dl, IntVT, DenormMantOF, One, Zero);
+
+      // Assemble: sign | (exp << DstMant) | mant
+      SDValue DenormExpShifted = DAG.getNode(
+          ISD::SHL, dl, IntVT, DenormFinalExp,
+          DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+      DenormResult = DAG.getNode(
+          ISD::OR, dl, IntVT,
+          DAG.getNode(ISD::OR, dl, IntVT, SignShifted, DenormExpShifted),
+          DenormFinalMant);
+
+      // If the value is to small for even a denorm (all mantissa bits
+      // shifted away), handle based on rounding mode.
+      // This is covered by the DenormRoundedMant = 0 case naturally.
+    }
+
+    // Exponent overflow detection.
+    SDValue ExpOF = DAG.getSetCC(
+        dl, SetCCVT, AdjExp,
+        DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETGT);
+
+    // Also check if AdjExp == DstExpMaxNormal and mantissa overflow into
+    // a value that exceeds the max allowed mantissa at that exponent.
+    SDValue ExpAtMax = DAG.getSetCC(
+        dl, SetCCVT, AdjExp,
+        DAG.getConstant(DstExpMaxNormal, dl, IntVT), ISD::SETEQ);
+    SDValue MantExceedsMax = DAG.getSetCC(
+        dl, SetCCVT, AdjMant,
+        DAG.getConstant(DstMaxMantAtMaxExp, dl, IntVT), ISD::SETGT);
+    SDValue ExpMantOF =
+        DAG.getNode(ISD::AND, dl, SetCCVT, ExpAtMax, MantExceedsMax);
+    SDValue IsOverflow =
+        DAG.getNode(ISD::OR, dl, SetCCVT, ExpOF, ExpMantOF);
+
+    // Build overflow result.
+    SDValue OverflowResult;
+
+    if (Saturate) {
+      // Clamp to max finite value:
+      // sign | (DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp
+      uint64_t MaxFinite =
+          ((uint64_t)DstExpMaxNormal << DstMant) | DstMaxMantAtMaxExp;
+      OverflowResult =
+          DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+                      DAG.getConstant(MaxFinite, dl, IntVT));
+    } else if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+      // Produce infinity.
+      uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+      OverflowResult =
+          DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+                      DAG.getConstant(InfBits, dl, IntVT));
+    } else {
+      // Emit poison if no Inf in format and not saturating.
+      OverflowResult = DAG.getPOISON(IntVT);
+    }
+
+    // Assemble normal result: sign | (AdjExp << DstMant) | AdjMant
+    SDValue NormExpShifted = DAG.getNode(
+        ISD::SHL, dl, IntVT, AdjExp,
+        DAG.getShiftAmountConstant(DstMant, IntVT, dl));
+    SDValue NormResult = DAG.getNode(
+        ISD::OR, dl, IntVT,
+        DAG.getNode(ISD::OR, dl, IntVT, SignShifted, NormExpShifted), AdjMant);
+
+    // Build special-value results.
+    SDValue NaNResult;
+    if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+      // Produce canonical NaN.
+      const uint64_t QNaNBit = (DstMant > 0) ? (1ULL << (DstMant - 1)) : 0;
+      NaNResult =
+          DAG.getConstant(((uint64_t)DstExpMax << DstMant) | QNaNBit, dl,
+                          IntVT);
+    } else if (DstNFBehavior == fltNonfiniteBehavior::NanOnly &&
+               DstNanEnc == fltNanEncoding::AllOnes) {
+      // E4M3FN-style: NaN is exp=all-ones, mant=all-ones.
+      NaNResult =
+          DAG.getConstant(((uint64_t)DstExpMax << DstMant) | DstMantMask, dl,
+                          IntVT);
+    } else {
+      // NaN -> poison for finite only values.
+      NaNResult = DAG.getPOISON(IntVT);
+    }
+
+    // Inf handling.
+    SDValue InfResult;
+    if (DstNFBehavior == fltNonfiniteBehavior::IEEE754) {
+      // Produce signed infinity.
+      uint64_t InfBits = (uint64_t)DstExpMax << DstMant;
+      InfResult = DAG.getNode(ISD::OR, dl, IntVT, SignShifted,
+                              DAG.getConstant(InfBits, dl, IntVT));
+    } else if (Saturate) {
+      // Inf saturate...
[truncated]

@MrSidims
Copy link
Copy Markdown
Contributor Author

E2E testing on X86 done with an assistance of Claude Code Opus 4.6.

See: https://github.com/MrSidims/llvm-project/tree/only-for-testing-ocl-fp (it doesn't cover all the cases yet just some PoC)

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

Comment thread llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp Outdated
@arsenm arsenm added the floating-point Floating-point math label Apr 22, 2026
Comment on lines +3851 to +3855
unsigned DstExpMaxNormal;
if (DstNFBehavior == fltNonfiniteBehavior::IEEE754)
DstExpMaxNormal = DstExpMax - 1;
else
DstExpMaxNormal = DstExpMax;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assign from ?:

const uint64_t SrcExpMask = (1ULL << SrcExpBits) - 1;

// Work in the source integer type.
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcBits);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not handling vectors?

cast<MetadataAsValue>(I.getArgOperand(2))->getMetadata();
StringRef RoundStr = cast<MDString>(RoundMD)->getString();
std::optional<RoundingMode> RoundMode = convertStrToRoundingMode(RoundStr);
if (!RoundMode || *RoundMode == RoundingMode::Dynamic) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd assume the IR verifier would reject dynamic rounding mode

%r = call i8 @llvm.convert.to.arbitrary.fp.i8.f32(
float %v, metadata !"Float8E8M0FNU", metadata !"round.tonearest", i1 false)
ret i8 %r
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failing vector case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants