diff --git a/numba_dpex/core/passes/passes.py b/numba_dpex/core/passes/passes.py index 1d4dfaecf0..dad6c79b41 100644 --- a/numba_dpex/core/passes/passes.py +++ b/numba_dpex/core/passes/passes.py @@ -24,6 +24,7 @@ new_error_context, ) from numba.core.ir_utils import remove_dels +from numba.core.typed_passes import NativeLowering from numba.parfors.parfor import Parfor from numba.parfors.parfor import ParforPass as _parfor_ParforPass from numba.parfors.parfor import PreParforPass as _parfor_PreParforPass @@ -387,3 +388,30 @@ def run_pass(self, state): else: raise RuntimeError("Diagnostics failed.") return True + + +@register_pass(mutates_CFG=False, analysis_only=True) +class QualNameDisambiguationLowering(NativeLowering): + """Qualified name disambiguation lowering pass + + If there are multiple @func decorated functions exist inside + another @func decorated block, the numba compiler machinery + creates same qualified names for different compiled function. + Therefore, we utilize `unique_name` to resolve the ambiguity. + + Args: + NativeLowering (CompilerPass): Superclass from which this + class has been inherited. + + Returns: + bool: True if `run_pass()` of the superclass is successful. + """ + + _name = "qual-name-disambiguation-lowering" + + def run_pass(self, state): + qual_name = state.func_id.func_qualname + state.func_id.func_qualname = state.func_id.unique_name + ret = NativeLowering.run_pass(self, state) + state.func_id.func_qualname = qual_name + return ret diff --git a/numba_dpex/core/pipelines/kernel_compiler.py b/numba_dpex/core/pipelines/kernel_compiler.py index 021c0f6465..38c2a1edbd 100644 --- a/numba_dpex/core/pipelines/kernel_compiler.py +++ b/numba_dpex/core/pipelines/kernel_compiler.py @@ -7,7 +7,6 @@ from numba.core.typed_passes import ( AnnotateTypes, IRLegalization, - NativeLowering, NopythonRewrites, NoPythonSupportedFeatureValidation, NopythonTypeInference, @@ -34,6 +33,7 @@ from numba_dpex.core.passes.passes import ( ConstantSizeStaticLocalMemoryPass, NoPythonBackend, + QualNameDisambiguationLowering, ) @@ -139,7 +139,13 @@ def define_nopython_lowering_pipeline(state, name="dpex_kernel_lowering"): pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering") # lower - pm.add_pass(NativeLowering, "native lowering") + # NativeLowering has some issue with freevar ambiguity, + # therefore, we are using QualNameDisambiguationLowering instead + # numba-dpex github issue: https://github.com/IntelPython/numba-dpex/issues/898 + pm.add_pass( + QualNameDisambiguationLowering, + "numba_dpex qualified name disambiguation", + ) pm.add_pass(NoPythonBackend, "nopython mode backend") pm.finalize() diff --git a/numba_dpex/tests/kernel_tests/test_func_qualname_disambiguation.py b/numba_dpex/tests/kernel_tests/test_func_qualname_disambiguation.py new file mode 100644 index 0000000000..f77365eaa2 --- /dev/null +++ b/numba_dpex/tests/kernel_tests/test_func_qualname_disambiguation.py @@ -0,0 +1,119 @@ +import dpctl +import dpctl.tensor as dpt +import numpy as np +import pytest + +import numba_dpex as ndpx +from numba_dpex.tests._helper import filter_strings + + +def make_write_values_kernel(n_rows): + """Uppermost kernel to set 1s in a certain way. + The uppermost kernel function invokes two levels + of inner functions to set 1s in an empty matrix + in a certain way. + + Args: + n_rows (int): Number of rows to iterate. + + Returns: + numba_dpex.core.kernel_interface.dispatcher.JitKernel: + A JitKernel object that encapsulates a @kernel + decorated numba_dpex compiled kernel object. + """ + write_values = make_write_values_kernel_func() + + @ndpx.kernel + def write_values_kernel(array_in): + for row_idx in range(n_rows): + is_even = (row_idx % 2) == 0 + write_values(array_in, row_idx, is_even) + + return write_values_kernel[ndpx.NdRange(ndpx.Range(1), ndpx.Range(1))] + + +def make_write_values_kernel_func(): + """An upper function to set 1 or 3 ones. A function to set + one or three 1s. If the row index is even it will set three 1s, + otherwise one 1. It uses the inner function to do this. + + Returns: + numba_dpex.core.kernel_interface.func.DpexFunctionTemplate: + A DpexFunctionTemplate that encapsulates a @func decorated + numba_dpex compiled function object. + """ + write_when_odd = make_write_values_kernel_func_inner(1) + write_when_even = make_write_values_kernel_func_inner(3) + + @ndpx.func + def write_values(array_in, row_idx, is_even): + if is_even: + write_when_even(array_in, row_idx) + else: + write_when_odd(array_in, row_idx) + + return write_values + + +def make_write_values_kernel_func_inner(n_cols): + """Inner function to set 1s. An inner function to set 1s in + n_cols number of columns. + + Args: + n_cols (int): Number of columns to be set to 1. + + Returns: + numba_dpex.core.kernel_interface.func.DpexFunctionTemplate: + A DpexFunctionTemplate that encapsulates a @func decorated + numba_dpex compiled function object. + """ + + @ndpx.func + def write_values_inner(array_in, row_idx): + for idx in range(n_cols): + array_in[row_idx, idx] = 1 + + return write_values_inner + + +@pytest.mark.parametrize("offload_device", filter_strings) +def test_qualname_basic(offload_device): + """A basic test function to test + qualified name disambiguation. + """ + ans = np.zeros((10, 10), dtype=np.int64) + for i in range(ans.shape[0]): + if i % 2 == 0: + ans[i, 0:3] = 1 + else: + ans[i, 0] = 1 + + a = np.zeros((10, 10), dtype=dpt.int64) + + device = dpctl.SyclDevice(offload_device) + queue = dpctl.SyclQueue(device) + + da = dpt.usm_ndarray( + a.shape, + dtype=a.dtype, + buffer="device", + buffer_ctor_kwargs={"queue": queue}, + ) + da.usm_data.copy_from_host(a.reshape((-1)).view("|u1")) + + kernel = make_write_values_kernel(10) + kernel(da) + + result = np.zeros_like(a) + da.usm_data.copy_to_host(result.reshape((-1)).view("|u1")) + + print(ans) + print(result) + + assert np.array_equal(result, ans) + + +if __name__ == "__main__": + test_qualname_basic("level_zero:gpu:0") + test_qualname_basic("opencl:gpu:0") + test_qualname_basic("opencl:cpu:0") diff --git a/numba_dpex/tests/test_debuginfo.py b/numba_dpex/tests/test_debuginfo.py index 6b36e4f2f4..4df788e6ac 100644 --- a/numba_dpex/tests/test_debuginfo.py +++ b/numba_dpex/tests/test_debuginfo.py @@ -127,8 +127,8 @@ def data_parallel_sum(a, b, c): c[i] = func_sum(a[i], b[i]) ir_tags = [ - r'\!DISubprogram\(name: ".*func_sum"', - r'\!DISubprogram\(name: ".*data_parallel_sum"', + r'\!DISubprogram\(name: ".*func_sum\$?\d*"', + r'\!DISubprogram\(name: ".*data_parallel_sum\$?\d*"', ] sig = (f32arrty, f32arrty, f32arrty) @@ -154,8 +154,8 @@ def data_parallel_sum(a, b, c): c[i] = func_sum(a[i], b[i]) ir_tags = [ - r'\!DISubprogram\(name: ".*func_sum"', - r'\!DISubprogram\(name: ".*data_parallel_sum"', + r'\!DISubprogram\(name: ".*func_sum\$?\d*"', + r'\!DISubprogram\(name: ".*data_parallel_sum\$\d*"', ] sig = (f32arrty, f32arrty, f32arrty)