Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit f198ab7

Browse files
PokhodenkoSAreazulhoque
authored andcommitted
Patch for support numba-dppy with device_context (numba#6899)
This modifications make jit() decorator use TargetDispatcher from dppl. Changes made in #57 by @AlexanderKalistratov and @1e-to. Patch to fix SDC integration testing (#203)
1 parent b4be10b commit f198ab7

File tree

8 files changed

+75
-23
lines changed

8 files changed

+75
-23
lines changed

numba/core/decorators.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def bar(x, y):
158158
target = options.pop('target')
159159
warnings.warn("The 'target' keyword argument is deprecated.", NumbaDeprecationWarning)
160160
else:
161-
target = options.pop('_target', 'cpu')
161+
target = options.pop('_target', None)
162162

163163
options['boundscheck'] = boundscheck
164164

@@ -195,28 +195,17 @@ def bar(x, y):
195195
jit_registry[hardware_registry['cpu']] = jit
196196

197197
def _jit(sigs, locals, target, cache, targetoptions, **dispatcher_args):
198-
199198
dispatcher = resolve_dispatcher_from_str(target)
200199

201-
def wrapper(func):
202-
if extending.is_jitted(func):
203-
raise TypeError(
204-
"A jit decorator was called on an already jitted function "
205-
f"{func}. If trying to access the original python "
206-
f"function, use the {func}.py_func attribute."
207-
)
208-
209-
if not inspect.isfunction(func):
210-
raise TypeError(
211-
"The decorated object is not a function (got type "
212-
f"{type(func)})."
213-
)
214-
200+
def wrapper(func, dispatcher):
215201
if config.ENABLE_CUDASIM and target == 'cuda':
216202
from numba import cuda
217203
return cuda.jit(func)
218204
if config.DISABLE_JIT and not target == 'npyufunc':
219205
return func
206+
if target == 'dppl':
207+
from . import dppl
208+
return dppl.jit(func)
220209
disp = dispatcher(py_func=func, locals=locals,
221210
targetoptions=targetoptions,
222211
**dispatcher_args)
@@ -232,7 +221,42 @@ def wrapper(func):
232221
disp.disable_compile()
233222
return disp
234223

235-
return wrapper
224+
def __wrapper(func):
225+
if extending.is_jitted(func):
226+
raise TypeError(
227+
"A jit decorator was called on an already jitted function "
228+
f"{func}. If trying to access the original python "
229+
f"function, use the {func}.py_func attribute."
230+
)
231+
232+
if not inspect.isfunction(func):
233+
raise TypeError(
234+
"The decorated object is not a function (got type "
235+
f"{type(func)})."
236+
)
237+
238+
is_numba_dppy_present = False
239+
try:
240+
import numba_dppy.config as dppy_config
241+
242+
is_numba_dppy_present = dppy_config.dppy_present
243+
except ImportError:
244+
pass
245+
246+
if (not is_numba_dppy_present
247+
or target == 'npyufunc' or targetoptions.get('no_cpython_wrapper')
248+
or sigs or config.DISABLE_JIT or not targetoptions.get('nopython')):
249+
target_ = target
250+
if target_ is None:
251+
target_ = 'cpu'
252+
disp = registry.dispatcher_registry[target_]
253+
return wrapper(func, disp)
254+
255+
from numba_dppy.target_dispatcher import TargetDispatcher
256+
disp = TargetDispatcher(func, wrapper, target, targetoptions.get('parallel'))
257+
return disp
258+
259+
return __wrapper
236260

237261

238262
def generated_jit(function=None, target='cpu', cache=False,

numba/core/dispatcher.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,14 @@ def _set_uuid(self, u):
742742
self._recent.append(self)
743743

744744

745-
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
745+
import abc
746+
747+
class DispatcherMeta(abc.ABCMeta):
748+
def __instancecheck__(self, other):
749+
return type(type(other)) == DispatcherMeta
750+
751+
752+
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase, metaclass=DispatcherMeta):
746753
"""
747754
Implementation of user-facing dispatcher objects (i.e. created using
748755
the @jit decorator).
@@ -995,6 +1002,9 @@ def get_function_type(self):
9951002
cres = tuple(self.overloads.values())[0]
9961003
return types.FunctionType(cres.signature)
9971004

1005+
def get_compiled(self):
1006+
return self
1007+
9981008

9991009
class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
10001010
"""

numba/core/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from numba.core.descriptors import TargetDescriptor
44
from numba.core import utils, typing, dispatcher, cpu
5+
from numba.core.compiler_lock import global_compiler_lock
56

67
# -----------------------------------------------------------------------------
78
# Default CPU target descriptors
@@ -26,16 +27,19 @@ class CPUTarget(TargetDescriptor):
2627
_nested = _NestedContext()
2728

2829
@utils.cached_property
30+
@global_compiler_lock
2931
def _toplevel_target_context(self):
3032
# Lazily-initialized top-level target context, for all threads
3133
return cpu.CPUContext(self.typing_context, self._target_name)
3234

3335
@utils.cached_property
36+
@global_compiler_lock
3437
def _toplevel_typing_context(self):
3538
# Lazily-initialized top-level typing context, for all threads
3639
return typing.Context()
3740

3841
@property
42+
@global_compiler_lock
3943
def target_context(self):
4044
"""
4145
The target context for CPU targets.
@@ -47,6 +51,7 @@ def target_context(self):
4751
return self._toplevel_target_context
4852

4953
@property
54+
@global_compiler_lock
5055
def typing_context(self):
5156
"""
5257
The typing context for CPU targets.
@@ -57,6 +62,7 @@ def typing_context(self):
5762
else:
5863
return self._toplevel_typing_context
5964

65+
@global_compiler_lock
6066
def nested_context(self, typing_context, target_context):
6167
"""
6268
A context manager temporarily replacing the contexts with the

numba/core/typing/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ def _resolve_user_function_type(self, func, args, kws, literals=None):
235235
if functy is not None:
236236
func = functy
237237

238+
from numba.core.registry import CPUDispatcher
239+
if isinstance(func, CPUDispatcher) and func is not CPUDispatcher:
240+
# if we are here it's numba-dppy case and we got TargetDispatcher, so get compiled version
241+
func = func.get_compiled()
242+
functy = self._lookup_global(func)
243+
if functy is not None:
244+
func = functy
245+
238246
if isinstance(func, types.Type):
239247
# If it's a type, it may support a __call__ method
240248
func_type = self.resolve_getattr(func, "__call__")

numba/tests/test_dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def test_serialization(self):
398398
def foo(x):
399399
return x + 1
400400

401+
foo = foo.get_compiled()
402+
401403
self.assertEqual(foo(1), 2)
402404

403405
# get serialization memo

numba/tests/test_nrt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def alloc_nrt_memory():
249249
"""
250250
return np.empty(N, dtype)
251251

252+
alloc_nrt_memory = alloc_nrt_memory.get_compiled()
253+
252254
def keep_memory():
253255
return alloc_nrt_memory()
254256

numba/tests/test_record_dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,8 @@ def test_record_arg_transform(self):
803803
self.assertIn('Array', transformed)
804804
self.assertNotIn('first', transformed)
805805
self.assertNotIn('second', transformed)
806-
# Length is usually 50 - 5 chars tolerance as above.
807-
self.assertLess(len(transformed), 50)
806+
# Length is usually 60 - 5 chars tolerance as above.
807+
self.assertLess(len(transformed), 60)
808808

809809
def test_record_two_arrays(self):
810810
"""

numba/tests/test_serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def test_reuse(self):
136136
137137
Note that "same function" is intentionally under-specified.
138138
"""
139-
func = closure(5)
139+
func = closure(5).get_compiled()
140140
pickled = pickle.dumps(func)
141-
func2 = closure(6)
141+
func2 = closure(6).get_compiled()
142142
pickled2 = pickle.dumps(func2)
143143

144144
f = pickle.loads(pickled)
@@ -153,7 +153,7 @@ def test_reuse(self):
153153
self.assertEqual(h(2, 3), 11)
154154

155155
# Now make sure the original object doesn't exist when deserializing
156-
func = closure(7)
156+
func = closure(7).get_compiled()
157157
func(42, 43)
158158
pickled = pickle.dumps(func)
159159
del func

0 commit comments

Comments
 (0)