Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _intrinsic_spirv_global_index_const(
sig = types.int64(types.int32)

def _intrinsic_spirv_global_index_const_gen(
context: SPIRVTargetContext,
context: SPIRVTargetContext, # pylint: disable=unused-argument
builder: llvmir.IRBuilder,
sig, # pylint: disable=unused-argument
args,
Expand All @@ -79,7 +79,16 @@ def _intrinsic_spirv_global_index_const_gen(
dim,
)

return context.cast(builder, res, types.uintp, types.intp)
# Generating same check as sycl does. Did they add it to avoid pointer
# bitcast on special constant?
max_int32 = llvmir.Constant(res.type, 2147483648)
cmp = builder.icmp_unsigned("<", res, max_int32)

inst = builder.assume(cmp)
# TODO: tail does not always work
inst.tail = "tail"

return res

return sig, _intrinsic_spirv_global_index_const_gen

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import llvmlite.ir as llvmir
from llvmlite.ir.builder import IRBuilder
from numba.core import cgutils, types
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
from numba.core.typing.templates import Signature
from numba.extending import intrinsic, overload
from numba.extending import type_callable

from numba_dpex.core.types import USMNdArray
from numba_dpex.experimental.target import DpexExpKernelTypingContext
Expand All @@ -23,55 +24,12 @@
)
from numba_dpex.utils import address_space as AddressSpace

from ..target import DPEX_KERNEL_EXP_TARGET_NAME
from ._registry import lower


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_private_array_ctor(
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
):
require_literal(ty_shape)

ty_array = USMNdArray(
dtype=_ty_parse_dtype(ty_dtype),
ndim=_ty_parse_shape(ty_shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

sig = ty_array(ty_shape, ty_dtype)

def codegen(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
shape = args[0]
ty_shape = sig.args[0]
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)
return ary._getvalue() # pylint: disable=protected-access

return (
sig,
codegen,
)


@overload(
PrivateArray,
prefer_literal=True,
target=DPEX_KERNEL_EXP_TARGET_NAME,
)
def ol_private_array_ctor(
shape,
dtype,
):
"""Overload of the constructor for the class
@type_callable(PrivateArray)
def type_interval(context): # pylint: disable=unused-argument
"""Sets type of the constructor for the class
class:`numba_dpex.kernel_api.PrivateArray`.

Raises:
Expand All @@ -81,11 +39,48 @@ def ol_private_array_ctor(
type.
"""

def ol_private_array_ctor_impl(
shape,
dtype,
):
# pylint: disable=no-value-for-parameter
return _intrinsic_private_array_ctor(shape, dtype)
def typer(shape, dtype, fill_zeros=types.BooleanLiteral(False)):
require_literal(shape)
require_literal(fill_zeros)

return USMNdArray(
dtype=_ty_parse_dtype(dtype),
ndim=_ty_parse_shape(shape),
layout="C",
addrspace=AddressSpace.PRIVATE,
)

return typer


@lower(PrivateArray, types.IntegerLiteral, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.Tuple, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.UniTuple, types.Any, types.BooleanLiteral)
@lower(PrivateArray, types.IntegerLiteral, types.Any)
@lower(PrivateArray, types.Tuple, types.Any)
@lower(PrivateArray, types.UniTuple, types.Any)
def dpex_private_array_lower(
context: DpexExpKernelTypingContext,
builder: IRBuilder,
sig: Signature,
args: list[llvmir.Value],
):
"""Implements lower for the class:`numba_dpex.kernel_api.PrivateArray`"""
shape = args[0]
ty_shape = sig.args[0]
if len(sig.args) == 3:
fill_zeros = sig.args[-1].literal_value
else:
fill_zeros = False
ty_array = sig.return_type

ary = make_spirv_generic_array_on_stack(
context, builder, ty_array, ty_shape, shape
)

if fill_zeros:
cgutils.memset(
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
)

return ol_private_array_ctor_impl
return ary._getvalue() # pylint: disable=protected-access
12 changes: 12 additions & 0 deletions numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

"""
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
"""

from numba.core.imputils import Registry

registry = Registry()
lower = registry.lower
9 changes: 6 additions & 3 deletions numba_dpex/kernel_api/private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
kernel function.
"""

from numpy import ndarray
import numpy as np


class PrivateArray:
Expand All @@ -16,10 +16,13 @@ class PrivateArray:
inside kernel work item.
"""

def __init__(self, shape, dtype) -> None:
def __init__(self, shape, dtype, fill_zeros=False) -> None:
"""Creates a new PrivateArray instance of the given shape and dtype."""

self._data = ndarray(shape=shape, dtype=dtype)
if fill_zeros:
self._data = np.zeros(shape=shape, dtype=dtype)
else:
self._data = np.empty(shape=shape, dtype=dtype)

def __getitem__(self, idx_obj):
"""Returns the value stored at the position represented by idx_obj in
Expand Down
4 changes: 3 additions & 1 deletion numba_dpex/kernel_api_impl/spirv/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def require_literal(literal_type: types.Type):

for i, _ in enumerate(literal_type):
if not isinstance(literal_type[i], types.Literal):
raise errors.TypingError("requires literal type")
raise errors.TypingError(
"requires each element of tuple literal type"
)


def make_spirv_array( # pylint: disable=too-many-arguments
Expand Down
26 changes: 18 additions & 8 deletions numba_dpex/kernel_api_impl/spirv/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""Implements a new numba dispatcher class and a compiler class to compile and
call numba_dpex.kernel decorated function.
"""
import hashlib
from collections import namedtuple
from contextlib import ExitStack
from typing import Tuple
Expand Down Expand Up @@ -181,6 +182,9 @@ def _compile_to_spirv(
# all linking libraries getting linked together and final optimization
# including inlining of functions if an inlining level is specified.
kernel_library.finalize()

if config.DUMP_KERNEL_LLVM:
self._dump_kernel(kernel_fndesc, kernel_library)
# Compiled the LLVM IR to SPIR-V
kernel_spirv_module = spirv_generator.llvm_to_spirv(
kernel_targetctx,
Expand Down Expand Up @@ -268,20 +272,26 @@ def _compile_cached(

kcres_attrs.append(kernel_device_ir_module)

if config.DUMP_KERNEL_LLVM:
with open(
cres.fndesc.llvm_func_name + ".ll",
"w",
encoding="UTF-8",
) as fptr:
fptr.write(str(cres.library.final_module))

except errors.TypingError as err:
self._failed_cache[key] = err
return False, err

return True, _SPIRVKernelCompileResult(*kcres_attrs)

def _dump_kernel(self, fndesc, library):
"""Dump kernel into file."""
name = fndesc.llvm_func_name
if len(name) > 200:
sha256 = hashlib.sha256(name.encode("utf-8")).hexdigest()
name = name[:150] + "_" + sha256

with open(
name + ".ll",
"w",
encoding="UTF-8",
) as fptr:
fptr.write(str(library.final_module))


class SPIRVKernelDispatcher(Dispatcher):
"""Dispatcher class designed to compile kernel decorated functions. The
Expand Down
1 change: 1 addition & 0 deletions numba_dpex/kernel_api_impl/spirv/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def finalize(self):
llvm_spirv_args = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
"--spirv-ext=+SPV_INTEL_arbitrary_precision_integers",
]
for key in list(self.context.extra_compile_options.keys()):
if key == LLVM_SPIRV_ARGS:
Expand Down
4 changes: 4 additions & 0 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,16 @@ def load_additional_registries(self):
# pylint: disable=import-outside-toplevel
from numba_dpex import printimpl
from numba_dpex.dpnp_iface import dpnpimpl
from numba_dpex.experimental._kernel_dpcpp_spirv_overloads._registry import (
registry as spirv_registry,
)
from numba_dpex.ocl import mathimpl, oclimpl

self.insert_func_defn(oclimpl.registry.functions)
self.insert_func_defn(mathimpl.registry.functions)
self.insert_func_defn(dpnpimpl.registry.functions)
self.install_registry(printimpl.registry)
self.install_registry(spirv_registry)
# Replace dpnp math functions with their OpenCL versions.
self.replace_dpnp_ufunc_with_ocl_intrinsics()

Expand Down
32 changes: 31 additions & 1 deletion numba_dpex/tests/experimental/test_private_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ def private_array_kernel(item: Item, a):
a[i] += p[j]


def private_array_kernel_fill_true(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype, fill_zeros=True)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_array_kernel_fill_false(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(10, a.dtype, fill_zeros=False)

for j in range(10):
p[j] = j * j

a[i] = 0
for j in range(10):
a[i] += p[j]


def private_2d_array_kernel(item: Item, a):
i = item.get_linear_id()
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
Expand All @@ -36,7 +60,13 @@ def private_2d_array_kernel(item: Item, a):


@pytest.mark.parametrize(
"kernel", [private_array_kernel, private_2d_array_kernel]
"kernel",
[
private_array_kernel,
private_array_kernel_fill_true,
private_array_kernel_fill_false,
private_2d_array_kernel,
],
)
@pytest.mark.parametrize(
"call_kernel, decorator",
Expand Down