Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit e31c726

Browse files
author
etotmeni
committed
Add semantics 'with context' for gpu and cpu
1 parent 02b504f commit e31c726

File tree

10 files changed

+263
-23
lines changed

10 files changed

+263
-23
lines changed

numba/core/decorators.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def bar(x, y):
147147
if 'target' in options:
148148
target = options.pop('target')
149149
warnings.warn("The 'target' keyword argument is deprecated.", NumbaDeprecationWarning)
150-
else:
150+
elif '_target' in options:
151151
target = options.pop('_target', 'cpu')
152+
else:
153+
target = None
152154

153155
parallel_option = options.get('parallel')
154156
if isinstance(parallel_option, dict) and parallel_option.get('offload') is True:
@@ -187,22 +189,8 @@ def bar(x, y):
187189

188190

189191
def _jit(sigs, locals, target, cache, targetoptions, **dispatcher_args):
190-
dispatcher = registry.dispatcher_registry[target]
191-
192-
def wrapper(func):
193-
if extending.is_jitted(func):
194-
raise TypeError(
195-
"A jit decorator was called on an already jitted function "
196-
f"{func}. If trying to access the original python "
197-
f"function, use the {func}.py_func attribute."
198-
)
199-
200-
if not inspect.isfunction(func):
201-
raise TypeError(
202-
"The decorated object is not a function (got type "
203-
f"{type(func)})."
204-
)
205192

193+
def wrapper(func, dispatcher):
206194
if config.ENABLE_CUDASIM and target == 'cuda':
207195
from numba import cuda
208196
return cuda.jit(func)
@@ -226,7 +214,33 @@ def wrapper(func):
226214
disp.disable_compile()
227215
return disp
228216

229-
return wrapper
217+
def __wrapper(func):
218+
if extending.is_jitted(func):
219+
raise TypeError(
220+
"A jit decorator was called on an already jitted function "
221+
f"{func}. If trying to access the original python "
222+
f"function, use the {func}.py_func attribute."
223+
)
224+
225+
if not inspect.isfunction(func):
226+
raise TypeError(
227+
"The decorated object is not a function (got type "
228+
f"{type(func)})."
229+
)
230+
231+
if (target == 'npyufunc' or targetoptions.get('no_cpython_wrapper')
232+
or sigs or config.DISABLE_JIT or not targetoptions.get('nopython')):
233+
target_ = target
234+
if target_ is None:
235+
target_ = 'cpu'
236+
disp = registry.dispatcher_registry[target_]
237+
return wrapper(func, disp)
238+
239+
from numba.dppl.target_dispatcher import TargetDispatcher
240+
disp = TargetDispatcher(func, wrapper, target)
241+
return disp
242+
243+
return __wrapper
230244

231245

232246
def generated_jit(function=None, target='cpu', cache=False,

numba/core/dispatcher.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,14 @@ def _set_uuid(self, u):
673673
self._recent.append(self)
674674

675675

676-
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
676+
import abc
677+
678+
class DispatcherMeta(abc.ABCMeta):
679+
def __instancecheck__(self, other):
680+
return type(type(other)) == DispatcherMeta
681+
682+
683+
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase, metaclass=DispatcherMeta):
677684
"""
678685
Implementation of user-facing dispatcher objects (i.e. created using
679686
the @jit decorator).

numba/core/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from numba.core.descriptors import TargetDescriptor
44
from numba.core import utils, typing, dispatcher, cpu
5+
from numba.core.compiler_lock import global_compiler_lock
56

67
# -----------------------------------------------------------------------------
78
# Default CPU target descriptors
@@ -26,16 +27,19 @@ class CPUTarget(TargetDescriptor):
2627
_nested = _NestedContext()
2728

2829
@utils.cached_property
30+
@global_compiler_lock
2931
def _toplevel_target_context(self):
3032
# Lazily-initialized top-level target context, for all threads
3133
return cpu.CPUContext(self.typing_context)
3234

3335
@utils.cached_property
36+
@global_compiler_lock
3437
def _toplevel_typing_context(self):
3538
# Lazily-initialized top-level typing context, for all threads
3639
return typing.Context()
3740

3841
@property
42+
@global_compiler_lock
3943
def target_context(self):
4044
"""
4145
The target context for CPU targets.
@@ -47,6 +51,7 @@ def target_context(self):
4751
return self._toplevel_target_context
4852

4953
@property
54+
@global_compiler_lock
5055
def typing_context(self):
5156
"""
5257
The typing context for CPU targets.
@@ -57,6 +62,7 @@ def typing_context(self):
5762
else:
5863
return self._toplevel_typing_context
5964

65+
@global_compiler_lock
6066
def nested_context(self, typing_context, target_context):
6167
"""
6268
A context manager temporarily replacing the contexts with the
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
from numba import dppl, njit, prange
3+
import dpctl
4+
import dpctl.ocldrv as ocldrv
5+
6+
7+
@njit
8+
def g(a):
9+
return a + 1
10+
11+
12+
@njit
13+
def f(a, b, c, N):
14+
for i in prange(N):
15+
a[i] = b[i] + g(c[i])
16+
17+
18+
def main():
19+
N = 10
20+
a = np.ones(N)
21+
b = np.ones(N)
22+
c = np.ones(N)
23+
24+
if ocldrv.has_gpu_device:
25+
with dpctl.device_context(dpctl.device_type.gpu):
26+
f(a, b, c, N)
27+
elif ocldrv.has_cpu_device:
28+
with dpctl.device_context(dpctl.device_type.cpu):
29+
f(a, b, c, N)
30+
else:
31+
print("No device found")
32+
33+
34+
if __name__ == '__main__':
35+
main()

numba/dppl/target_dispatcher.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from numba.core import registry, serialize, dispatcher
2+
from numba import types
3+
from numba.core.errors import UnsupportedError
4+
import dpctl
5+
import dpctl.ocldrv as ocldrv
6+
from numba.core.compiler_lock import global_compiler_lock
7+
8+
9+
class TargetDispatcher(serialize.ReduceMixin, metaclass=dispatcher.DispatcherMeta):
10+
__numba__ = 'py_func'
11+
12+
def __init__(self, py_func, wrapper, target, compiled=None):
13+
14+
self.__py_func = py_func
15+
self.__target = target
16+
self.__wrapper = wrapper
17+
self.__compiled = compiled if compiled is not None else {}
18+
self.__doc__ = py_func.__doc__
19+
self.__name__ = py_func.__name__
20+
self.__module__ = py_func.__module__
21+
22+
def __call__(self, *args, **kwargs):
23+
return self.get_compiled()(*args, **kwargs)
24+
25+
def __getattr__(self, name):
26+
return getattr(self.get_compiled(), name)
27+
28+
def __get__(self, obj, objtype=None):
29+
return self.get_compiled().__get__(obj, objtype)
30+
31+
def __repr__(self):
32+
return self.get_compiled().__repr__()
33+
34+
@classmethod
35+
def _rebuild(cls, py_func, wrapper, target, compiled):
36+
self = cls(py_func, wrapper, target, compiled)
37+
return self
38+
39+
def get_compiled(self, target=None):
40+
if target is None:
41+
target = self.__target
42+
43+
disp = self.get_current_disp()
44+
if not disp in self.__compiled.keys():
45+
with global_compiler_lock:
46+
if not disp in self.__compiled.keys():
47+
self.__compiled[disp] = self.__wrapper(self.__py_func, disp)
48+
49+
return self.__compiled[disp]
50+
51+
def get_current_disp(self):
52+
target = self.__target
53+
54+
if dpctl.is_in_device_context():
55+
if self.__target is not None:
56+
raise UnsupportedError("Unsupported defined 'target' with using context device")
57+
if dpctl.get_current_device_type() == dpctl.device_type.gpu:
58+
from numba.dppl import dppl_offload_dispatcher
59+
return registry.dispatcher_registry['__dppl_offload_gpu__']
60+
61+
if target is None:
62+
target = 'cpu'
63+
64+
return registry.dispatcher_registry[target]
65+
66+
def _reduce_states(self):
67+
return dict(
68+
py_func=self.__py_func,
69+
wrapper=self.__wrapper,
70+
target=self.__target,
71+
compiled=self.__compiled
72+
)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numba
2+
import numpy as np
3+
from numba import dppl, njit
4+
from numba.core import errors
5+
from numba.tests.support import captured_stdout
6+
from numba.dppl.testing import DPPLTestCase, unittest
7+
import dpctl
8+
import dpctl.ocldrv as ocldrv
9+
10+
11+
@unittest.skipIf(not dpctl.has_gpu_queues(), "No GPU platforms available")
12+
@unittest.skipIf(not dpctl.has_cpu_queues(), "No CPU platforms available")
13+
class TestWithDPPLContext(DPPLTestCase):
14+
def test_with_dppl_context_gpu(self):
15+
16+
@njit
17+
def nested_func(a, b):
18+
np.sin(a, b)
19+
20+
@njit
21+
def func(b):
22+
a = np.ones((64), dtype=np.float64)
23+
nested_func(a, b)
24+
25+
numba.dppl.compiler.DEBUG = 1
26+
expected = np.ones((64), dtype=np.float64)
27+
got_gpu = np.ones((64), dtype=np.float64)
28+
29+
with captured_stdout() as got_gpu_message:
30+
with dpctl.device_context(dpctl.device_type.gpu):
31+
func(got_gpu)
32+
33+
func(expected)
34+
35+
np.testing.assert_array_equal(expected, got_gpu)
36+
self.assertTrue('Parfor lowered on DPPL-device' in got_gpu_message.getvalue())
37+
38+
39+
def test_with_dppl_context_cpu(self):
40+
41+
@njit
42+
def nested_func(a, b):
43+
np.sin(a, b)
44+
45+
@njit
46+
def func(b):
47+
a = np.ones((64), dtype=np.float64)
48+
nested_func(a, b)
49+
50+
numba.dppl.compiler.DEBUG = 1
51+
expected = np.ones((64), dtype=np.float64)
52+
got_cpu = np.ones((64), dtype=np.float64)
53+
54+
with captured_stdout() as got_cpu_message:
55+
with dpctl.device_context(dpctl.device_type.cpu):
56+
func(got_cpu)
57+
58+
func(expected)
59+
60+
np.testing.assert_array_equal(expected, got_cpu)
61+
self.assertTrue('Parfor lowered on DPPL-device' not in got_cpu_message.getvalue())
62+
63+
64+
def test_with_dppl_context_target(self):
65+
66+
@njit(target='cpu')
67+
def nested_func_target(a, b):
68+
np.sin(a, b)
69+
70+
@njit(target='gpu')
71+
def func_target(b):
72+
a = np.ones((64), dtype=np.float64)
73+
nested_func_target(a, b)
74+
75+
@njit
76+
def func_no_target(b):
77+
a = np.ones((64), dtype=np.float64)
78+
nested_func_target(a, b)
79+
80+
a = np.ones((64), dtype=np.float64)
81+
b = np.ones((64), dtype=np.float64)
82+
83+
with self.assertRaises(errors.UnsupportedError) as raises_1:
84+
with dpctl.device_context(dpctl.device_type.gpu):
85+
nested_func_target(a, b)
86+
87+
with self.assertRaises(errors.UnsupportedError) as raises_2:
88+
with dpctl.device_context(dpctl.device_type.gpu):
89+
func_target(a)
90+
91+
with self.assertRaises(errors.UnsupportedError) as raises_3:
92+
with dpctl.device_context(dpctl.device_type.gpu):
93+
func_no_target(a)
94+
95+
msg = "Unsupported defined 'target' with using context device"
96+
self.assertTrue(msg in str(raises_1.exception))
97+
self.assertTrue(msg in str(raises_2.exception))
98+
self.assertTrue(msg in str(raises_3.exception))
99+
100+
101+
if __name__ == '__main__':
102+
unittest.main()

numba/tests/test_dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def test_serialization(self):
398398
def foo(x):
399399
return x + 1
400400

401+
foo = foo.get_compiled()
402+
401403
self.assertEqual(foo(1), 2)
402404

403405
# get serialization memo

numba/tests/test_nrt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def alloc_nrt_memory():
249249
"""
250250
return np.empty(N, dtype)
251251

252+
alloc_nrt_memory = alloc_nrt_memory.get_compiled()
253+
252254
def keep_memory():
253255
return alloc_nrt_memory()
254256

numba/tests/test_record_dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,8 @@ def test_record_arg_transform(self):
803803
self.assertIn('Array', transformed)
804804
self.assertNotIn('first', transformed)
805805
self.assertNotIn('second', transformed)
806-
# Length is usually 50 - 5 chars tolerance as above.
807-
self.assertLess(len(transformed), 50)
806+
# Length is usually 60 - 5 chars tolerance as above.
807+
self.assertLess(len(transformed), 60)
808808

809809
def test_record_two_arrays(self):
810810
"""

numba/tests/test_serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def test_reuse(self):
135135
136136
Note that "same function" is intentionally under-specified.
137137
"""
138-
func = closure(5)
138+
func = closure(5).get_compiled()
139139
pickled = pickle.dumps(func)
140-
func2 = closure(6)
140+
func2 = closure(6).get_compiled()
141141
pickled2 = pickle.dumps(func2)
142142

143143
f = pickle.loads(pickled)
@@ -152,7 +152,7 @@ def test_reuse(self):
152152
self.assertEqual(h(2, 3), 11)
153153

154154
# Now make sure the original object doesn't exist when deserializing
155-
func = closure(7)
155+
func = closure(7).get_compiled()
156156
func(42, 43)
157157
pickled = pickle.dumps(func)
158158
del func

0 commit comments

Comments
 (0)