diff --git a/dpnp/backend/kernels/dpnp_krnl_fft.cpp b/dpnp/backend/kernels/dpnp_krnl_fft.cpp index 8a79a35e844b..89acd5f61136 100644 --- a/dpnp/backend/kernels/dpnp_krnl_fft.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_fft.cpp @@ -180,7 +180,9 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref, const size_t input_size, const size_t result_size, _Descriptor_type& desc, - const size_t norm) + size_t inverse, + double backward_scale, + double forward_scale) { if (!shape_size) { @@ -199,9 +201,6 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref, const size_t shift = input_shape[shape_size - 1]; - double forward_scale = 1.0; - double backward_scale = 1.0 / shift; - desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale); desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale); // enum value from math library C interface @@ -213,7 +212,11 @@ void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref, fft_events.reserve(n_iter); for (size_t i = 0; i < n_iter; ++i) { - fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift)); + if (inverse) { + fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * shift, result + i * shift)); + } else { + fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift)); + } } sycl::event::wait(fft_events); @@ -234,7 +237,9 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref, const size_t input_size, const size_t result_size, _Descriptor_type& desc, - const size_t norm, + size_t inverse, + double backward_scale, + double forward_scale, const size_t real) { if (!shape_size) @@ -255,11 +260,9 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref, const size_t input_shift = input_shape[shape_size - 1]; const size_t result_shift = result_shape[shape_size - 1];; - double forward_scale = 1.0; - double backward_scale = 1.0 / input_shift; - desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale); desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale); + desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); desc.commit(queue); @@ -267,7 +270,11 @@ void dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef q_ref, fft_events.reserve(n_iter); for (size_t i = 0; i < n_iter; ++i) { - fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2)); + if (inverse) { + fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * input_shift, result + i * result_shift * 2)); + } else { + fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2)); + } } sycl::event::wait(fft_events); @@ -330,6 +337,21 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref, size_t dim = input_shape[shape_size - 1]; + double backward_scale = 1; + double forward_scale = 1; + + if (norm == 0) { // norm = "backward" + backward_scale = 1. / dim; + } else if (norm == 1) { // norm = "forward" + forward_scale = 1. / dim; + } else { // norm = "ortho" + if (inverse) { + backward_scale = 1. / sqrt(dim); + } else { + forward_scale = 1. / sqrt(dim); + } + } + if constexpr (std::is_same<_DataType_output, std::complex>::value || std::is_same<_DataType_output, std::complex>::value) { @@ -338,7 +360,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref, { desc_dp_cmplx_t desc(dim); dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>( - q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm); + q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale); } /* complex-to-complex, single precision */ else if constexpr (std::is_same<_DataType_input, std::complex>::value && @@ -346,7 +368,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref, { desc_sp_cmplx_t desc(dim); dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>( - q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm); + q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale); } /* real-to-complex, double precision */ else if constexpr (std::is_same<_DataType_input, double>::value && @@ -355,7 +377,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref, desc_dp_real_t desc(dim); dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>( - q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0); + q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0); } /* real-to-complex, single precision */ else if constexpr (std::is_same<_DataType_input, float>::value && @@ -363,7 +385,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref, { desc_sp_real_t desc(dim); // try: 2 * result_size dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>( - q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0); + q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0); } else if constexpr (std::is_same<_DataType_input, int32_t>::value || std::is_same<_DataType_input, int64_t>::value) @@ -380,7 +402,7 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref, desc_dp_real_t desc(dim); dpnp_fft_fft_mathlib_real_to_cmplx_c( - q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 0); + q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 0); dpnp_memory_free_c(q_ref, array1_copy); dpnp_memory_free_c(q_ref, copy_strides); @@ -484,6 +506,20 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref, size_t dim = input_shape[shape_size - 1]; + double backward_scale = 1; + double forward_scale = 1; + if (norm == 0) { // norm = "backward" + backward_scale = 1. / dim; + } else if (norm == 1) { // norm = "forward" + forward_scale = 1. / dim; + } else { // norm = "ortho" + if (inverse) { + backward_scale = 1. / sqrt(dim); + } else { + forward_scale = 1. / sqrt(dim); + } + } + if constexpr (std::is_same<_DataType_output, std::complex>::value || std::is_same<_DataType_output, std::complex>::value) { @@ -493,7 +529,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref, desc_dp_real_t desc(dim); dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>( - q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1l); + q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1); } /* real-to-complex, single precision */ else if constexpr (std::is_same<_DataType_input, float>::value && @@ -501,7 +537,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref, { desc_sp_real_t desc(dim); // try: 2 * result_size dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>( - q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1); + q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1); } else if constexpr (std::is_same<_DataType_input, int32_t>::value || std::is_same<_DataType_input, int64_t>::value) @@ -518,7 +554,7 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref, desc_dp_real_t desc(dim); dpnp_fft_fft_mathlib_real_to_cmplx_c( - q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, norm, 1); + q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, backward_scale, forward_scale, 1); dpnp_memory_free_c(q_ref, array1_copy); dpnp_memory_free_c(q_ref, copy_strides); diff --git a/dpnp/fft/dpnp_iface_fft.py b/dpnp/fft/dpnp_iface_fft.py index d245de85054a..6f49be633a06 100644 --- a/dpnp/fft/dpnp_iface_fft.py +++ b/dpnp/fft/dpnp_iface_fft.py @@ -118,8 +118,6 @@ def fft(x1, n=None, axis=-1, norm=None): pass # let fallback to handle exception elif input_boundarie < 1: pass # let fallback to handle exception - elif norm is not None: - pass elif n is not None: pass elif axis != -1: @@ -308,7 +306,7 @@ def ifft(x1, n=None, axis=-1, norm=None): """ x1_desc = dpnp.get_dpnp_descriptor(x1) - if x1_desc and 0: + if x1_desc: norm_ = get_validated_norm(norm) if axis is None: @@ -325,13 +323,12 @@ def ifft(x1, n=None, axis=-1, norm=None): pass # let fallback to handle exception elif input_boundarie < 1: pass # let fallback to handle exception - elif norm is not None: - pass elif n is not None: pass + elif x1_desc.dtype not in (numpy.complex128, numpy.complex64): + pass else: output_boundarie = input_boundarie - return dpnp_fft(x1_desc, input_boundarie, output_boundarie, axis_param, True, norm_.value).get_pyobj() return call_origin(numpy.fft.ifft, x1, n, axis, norm) diff --git a/tests/test_fft.py b/tests/test_fft.py index 703495f3b585..66019defd1ac 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -6,7 +6,8 @@ @pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64']) -def test_fft(type): +@pytest.mark.parametrize("norm", [None, 'forward', 'ortho']) +def test_fft(type, norm): # 1 dim array data = numpy.arange(100, dtype=numpy.dtype(type)) # TODO: @@ -14,8 +15,8 @@ def test_fft(type): # dpnp_data = dpnp.arange(100, dtype=dpnp.dtype(type)) dpnp_data = dpnp.array(data) - np_res = numpy.fft.fft(data) - dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data)) + np_res = numpy.fft.fft(data, norm=norm) + dpnp_res = dpnp.asnumpy(dpnp.fft.fft(dpnp_data, norm=norm)) numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7) assert dpnp_res.dtype == np_res.dtype @@ -23,12 +24,27 @@ def test_fft(type): @pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize("shape", [(8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)]) -def test_fft_ndim(type, shape): +@pytest.mark.parametrize("norm", [None, 'forward', 'ortho']) +def test_fft_ndim(type, shape, norm): np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape) dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape) - np_res = numpy.fft.fft(np_data) - dpnp_res = dpnp.fft.fft(dpnp_data) + np_res = numpy.fft.fft(np_data, norm=norm) + dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm) + + numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7) + assert dpnp_res.dtype == np_res.dtype + + +@pytest.mark.parametrize("type", ['complex128', 'complex64', 'float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize("shape", [(64,), (8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)]) +@pytest.mark.parametrize("norm", [None, 'forward', 'ortho']) +def test_fft_ifft(type, shape, norm): + np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape) + dpnp_data = dpnp.arange(64, dtype=numpy.dtype(type)).reshape(shape) + + np_res = numpy.fft.ifft(np_data, norm=norm) + dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm) numpy.testing.assert_allclose(dpnp_res, np_res, rtol=1e-4, atol=1e-7) assert dpnp_res.dtype == np_res.dtype