diff --git a/src/implementations/BigFloat.jl b/src/implementations/BigFloat.jl index 1c3965ba..9353784a 100644 --- a/src/implementations/BigFloat.jl +++ b/src/implementations/BigFloat.jl @@ -121,6 +121,66 @@ function operate_to!(output::BigFloat, ::typeof(*), a::BigFloat, b::BigFloat) return output end +# Base.fma + +function promote_operation( + ::typeof(Base.fma), + ::Type{F}, + ::Type{F}, + ::Type{F}, +) where {F<:BigFloat} + return F +end + +function operate_to!( + output::F, + ::typeof(Base.fma), + x::F, + y::F, + z::F, +) where {F<:BigFloat} + ccall( + (:mpfr_fma, :libmpfr), + Int32, + (Ref{F}, Ref{F}, Ref{F}, Ref{F}, _MPFRRoundingMode), + output, + x, + y, + z, + Base.MPFR.ROUNDING_MODE[], + ) + return output +end + +function operate!(::typeof(Base.fma), x::F, y::F, z::F) where {F<:BigFloat} + return operate_to!(x, Base.fma, x, y, z) +end + +# Base.muladd + +function promote_operation( + ::typeof(Base.muladd), + ::Type{F}, + ::Type{F}, + ::Type{F}, +) where {F<:BigFloat} + return F +end + +function operate_to!( + output::F, + ::typeof(Base.muladd), + x::F, + y::F, + z::F, +) where {F<:BigFloat} + return operate_to!(output, Base.fma, x, y, z) +end + +function operate!(::typeof(Base.muladd), x::F, y::F, z::F) where {F<:BigFloat} + return operate!(Base.fma, x, y, z) +end + function operate_to!( output::BigFloat, op::Union{typeof(+),typeof(-),typeof(*)}, diff --git a/test/bigfloat_fma.jl b/test/bigfloat_fma.jl new file mode 100644 index 00000000..ea484ce0 --- /dev/null +++ b/test/bigfloat_fma.jl @@ -0,0 +1,92 @@ +# Copyright (c) 2023 MutableArithmetics.jl contributors +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain +# one at http://mozilla.org/MPL/2.0/. + +function test_fma_output_values(x::F, y::F, z::F) where {F<:BigFloat} + two_roundings_reference = x * y + z + one_rounding_reference = fma(x, y, z) + @test one_rounding_reference != two_roundings_reference + + @testset "fma $op output values" for op in (MA.operate!, MA.operate!!) + (a, b, c) = map(t -> t + zero(F), (x, y, z)) # copy + @inferred op(fma, a, b, c) + @test one_rounding_reference == a + @test y == b + @test z == c + end + + @testset "fma $op output values" for op in (MA.operate_to!, MA.operate_to!!) + (a, b, c) = map(t -> t + zero(F), (x, y, z)) # copy + out = F() + @inferred op(out, fma, a, b, c) + @test one_rounding_reference == out + @test x == a + @test y == b + @test z == c + end + + return nothing +end + +function test_fma_output_values(x::F, y::F, z::F) where {F<:Float64} + return test_fma_output_values(map(BigFloat, (x, y, z))...) +end + +function test_fma_output_values_func(x::F, y::F, z::F) where {F<:Float64} + return let x = x, y = y, z = z + () -> test_fma_output_values(x, y, z) + end +end + +@testset "fma output values: $exp_x $exp_y $sign_x $sign_y" for exp_x in (-3):3, + exp_y in (-3):3, + sign_x in (-1, 1), + sign_y in (-1, 1) + + # Assuming a two-bit mantissa + significand_length = 2 + + sign_z = -sign_x * sign_y + exp_z = exp_x + exp_y + + base = 2.0 + bit_0 = 2.0^0 + bit_1 = 2.0^-1 + + x = sign_x * base^exp_x * (bit_0 + bit_1) + y = sign_y * base^exp_y * (bit_0 + bit_1) + z = sign_z * base^exp_z * (bit_0 + bit_1) + + setprecision( + test_fma_output_values_func(x, y, z), + BigFloat, + significand_length, + ) +end + +@testset "muladd operate_to!! type inferred" begin + m1 = BigFloat(-1.0) + out = BigFloat() + @test iszero(@inferred MA.operate_to!!(out, Base.muladd, m1, m1, m1)) +end + +@testset "muladd operate!! type inferred" begin + x = BigFloat(-1.0) + y = BigFloat(-1.0) + z = BigFloat(-1.0) + @test iszero(@inferred MA.operate!!(Base.muladd, x, y, z)) +end + +@testset "fma $op doesn't allocate" for op in (MA.operate_to!, MA.operate_to!!) + alloc_test(let op = op, o = big"1.3", x = big"1.3" + () -> op(o, Base.fma, x, x, x) + end, 0) +end + +@testset "fma $op doesn't allocate" for op in (MA.operate!, MA.operate!!) + alloc_test(let op = op, x = big"1.3", y = big"1.3", z = big"1.3" + () -> op(Base.fma, x, y, z) + end, 0) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3d681e4f..90af628b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,9 @@ end @testset "BigFloat negation and absolute value" begin include("bigfloat_neg_abs.jl") end +@testset "BigFloat fma and muladd" begin + include("bigfloat_fma.jl") +end @testset "BigFloat dot" begin include("bigfloat_dot.jl") end