Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ website/_deploy.sh
# Cython / C extensions
cythonize.json
*.cpp
!thinc/backends/cblas_impl.cpp
*.so
*.so.1

Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ include thinc/tests/mypy/configs/*.ini
include thinc/tests/mypy/outputs/*.txt
include thinc/py.typed
recursive-exclude thinc *.cpp
include thinc/backends/cblas_impl.cpp
9 changes: 7 additions & 2 deletions examples/transformers_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def forward(
return TokensPlus(**token_data), lambda d_tokens: []

return Model(
"tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
"tokenizer",
forward,
attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
)


Expand Down Expand Up @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train):

def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
# Restore entries for bos and eos markers.
shim = model.shims[0]
row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
return ArgsKwargs(
args=(torch_tokvecs,),
kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
kwargs={
"grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device))
},
)

return tokvecs, backprop
Expand Down
2 changes: 1 addition & 1 deletion thinc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .util import torch2xp, xp2torch, tensorflow2xp, xp2tensorflow, mxnet2xp, xp2mxnet
from .compat import has_cupy
from .backends import get_ops, set_current_ops, get_current_ops, use_ops
from .backends import Ops, CupyOps, NumpyOps, set_gpu_allocator
from .backends import Ops, CupyOps, MPSOps, NumpyOps, set_gpu_allocator
from .backends import use_pytorch_for_gpu_memory, use_tensorflow_for_gpu_memory

from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear
Expand Down
9 changes: 7 additions & 2 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .ops import Ops
from .cupy_ops import CupyOps
from .numpy_ops import NumpyOps
from .mps_ops import MPSOps
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from ..util import get_torch_default_device, is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -48,6 +49,10 @@ def use_pytorch_for_gpu_memory() -> None: # pragma: no cover
(or vice versa), but do not currently have an implementation for it.
"""
assert_pytorch_installed()

if get_torch_default_device().type != "cuda":
return

pools = context_pools.get()
if "pytorch" not in pools:
pools["pytorch"] = cupy.cuda.MemoryPool(allocator=cupy_pytorch_allocator)
Expand Down Expand Up @@ -134,7 +139,6 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down Expand Up @@ -170,6 +174,7 @@ def _create_thread_local(
"ParamServer",
"Ops",
"CupyOps",
"MPSOps",
"NumpyOps",
"has_cupy",
]
7 changes: 5 additions & 2 deletions thinc/backends/_cupy_allocators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast

from ..types import ArrayXd
from ..util import tensorflow2xp
from ..util import get_torch_default_device, tensorflow2xp
from ..compat import torch, cupy, tensorflow


Expand All @@ -23,6 +23,7 @@ def cupy_tensorflow_allocator(size_in_bytes: int):


def cupy_pytorch_allocator(size_in_bytes: int):
device = get_torch_default_device()
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
allocate memory via PyTorch. This is important when using the two libraries
together, as otherwise OOM errors can occur when there's available memory
Expand All @@ -34,7 +35,9 @@ def cupy_pytorch_allocator(size_in_bytes: int):
# creating a whole Tensor.
# This turns out to be way faster than making FloatStorage? Maybe
# a Python vs C++ thing I guess?
torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False)
torch_tensor = torch.zeros(
(size_in_bytes // 4,), requires_grad=False, device=device
)
# cupy has a neat class to help us here. Otherwise it will try to free.
# I think this is a private API? It's not in the types.
address = torch_tensor.data_ptr() # type: ignore
Expand Down
23 changes: 18 additions & 5 deletions thinc/backends/cblas.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ ctypedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K,
int ldb, float beta, float* C, int ldc) nogil


ctypedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX,
double *Y, int incY) nogil


ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil

Expand All @@ -16,9 +20,18 @@ ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
cdef struct BlasFuncs



cdef extern from "cblas_impl.hh":
cdef cppclass CBlasImpl:
CBlas() nogil
daxpy_ptr daxpy() nogil
saxpy_ptr saxpy() nogil
sgemm_ptr sgemm() nogil
void set_daxpy(daxpy_ptr daxpy) nogil
void set_saxpy(saxpy_ptr saxpy) nogil
void set_sgemm(sgemm_ptr sgemm) nogil


cdef class CBlas:
cdef shared_ptr[BlasFuncs] ptr
cdef saxpy_ptr saxpy(self) nogil
cdef sgemm_ptr sgemm(self) nogil
cdef void set_saxpy(self, saxpy_ptr saxpy) nogil
cdef void set_sgemm(self, sgemm_ptr sgemm) nogil
cdef CBlasImpl c_impl
cdef CBlasImpl c(self) nogil
31 changes: 6 additions & 25 deletions thinc/backends/cblas.pyx
Original file line number Diff line number Diff line change
@@ -1,32 +1,13 @@
cimport blis.cy
from cython.operator cimport dereference as deref
from libcpp.memory cimport make_shared


cdef struct BlasFuncs:
saxpy_ptr saxpy
sgemm_ptr sgemm
# distutils: sources = thinc/backends/cblas_impl.cpp

cimport blis.cy

cdef class CBlas:
__slots__ = []

def __init__(self):
"""Construct a CBlas instance set to use BLIS implementations of the
supported BLAS functions."""
cdef BlasFuncs funcs
funcs.saxpy = blis.cy.saxpy
funcs.sgemm = blis.cy.sgemm
self.ptr = make_shared[BlasFuncs](funcs)

cdef saxpy_ptr saxpy(self) nogil:
return deref(self.ptr).saxpy

cdef sgemm_ptr sgemm(self) nogil:
return deref(self.ptr).sgemm

cdef void set_saxpy(self, saxpy_ptr saxpy) nogil:
deref(self.ptr).saxpy = saxpy
self.c_impl.set_saxpy(blis.cy.saxpy)
self.c_impl.set_sgemm(blis.cy.sgemm)

cdef void set_sgemm(self, sgemm_ptr sgemm) nogil:
deref(self.ptr).sgemm = sgemm
cdef CBlasImpl c(self) nogil:
return self.c_impl
35 changes: 35 additions & 0 deletions thinc/backends/cblas_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "cblas_impl.hh"

struct BlasFuncs {
daxpy_ptr daxpy;
saxpy_ptr saxpy;
sgemm_ptr sgemm;
};

CBlasImpl::CBlasImpl() {
blas_funcs.reset(new BlasFuncs);
}

daxpy_ptr CBlasImpl::daxpy() {
return blas_funcs->daxpy;
}

saxpy_ptr CBlasImpl::saxpy() {
return blas_funcs->saxpy;
}

sgemm_ptr CBlasImpl::sgemm() {
return blas_funcs->sgemm;
}

void CBlasImpl::set_daxpy(daxpy_ptr daxpy) {
blas_funcs->daxpy = daxpy;
}

void CBlasImpl::set_saxpy(saxpy_ptr saxpy) {
blas_funcs->saxpy = saxpy;
}

void CBlasImpl::set_sgemm(sgemm_ptr sgemm) {
blas_funcs->sgemm = sgemm;
}
41 changes: 41 additions & 0 deletions thinc/backends/cblas_impl.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef CBLAS_HH
#define CBLAS_HH

#include <memory>

typedef int bint;

struct BlasFuncs;

typedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K,
float alpha, const float* A, int lda, const float *B,
int ldb, float beta, float* C, int ldc);


typedef void (*daxpy_ptr)(int N, double alpha, const double* X, int incX,
double *Y, int incY);


typedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY);


class CBlasImpl {
public:
CBlasImpl();
virtual ~CBlasImpl() {}

daxpy_ptr daxpy();
saxpy_ptr saxpy();
sgemm_ptr sgemm();
void set_daxpy(daxpy_ptr daxpy);
void set_saxpy(saxpy_ptr saxpy);
void set_sgemm(sgemm_ptr sgemm);

private:
std::shared_ptr<BlasFuncs> blas_funcs;
};



#endif // CBLAS_HH
4 changes: 2 additions & 2 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_cupy_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_torch_cuda_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..compat import cupy, cupyx


Expand Down Expand Up @@ -62,7 +62,7 @@ def asarray(self, data, dtype=None):
# We'll try to perform a zero-copy conversion if possible.
if is_cupy_array(data):
array = data
elif is_torch_gpu_array(data):
elif is_torch_cuda_array(data):
array = torch2xp(data)
elif is_tensorflow_gpu_array(data):
array = tensorflow2xp(data)
Expand Down
26 changes: 26 additions & 0 deletions thinc/backends/mps_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import TYPE_CHECKING
import numpy

from .. import registry
from . import NumpyOps, Ops

if TYPE_CHECKING:
# Type checking does not work with dynamic base classes, since MyPy cannot
# determine against which base class to check. So, always derive from Ops
# during type checking.
_Ops = Ops
else:
try:
from thinc_apple_ops import AppleOps

_Ops = AppleOps
except ImportError:
_Ops = NumpyOps


@registry.ops("MPSOps")
class MPSOps(_Ops):
"""Ops class for Metal Performance shaders."""

name = "mps"
xp = numpy
5 changes: 1 addition & 4 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ except ImportError:
has_blis = False


cblas = CBlas()


ctypedef float weight_t


Expand Down Expand Up @@ -88,7 +85,7 @@ class NumpyOps(Ops):
return self.xp.empty(shape, dtype=dtype)

def cblas(self) -> CBlas:
return cblas
return CBlas()

def gemm(self, np.ndarray x, np.ndarray y, *, np.ndarray out=None, trans1=False, trans2=False):
if x.ndim != 2:
Expand Down
13 changes: 12 additions & 1 deletion thinc/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
import torch

has_torch = True
has_torch_gpu = torch.cuda.device_count() != 0
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps_gpu = (
hasattr(torch, "has_mps")
and torch.has_mps
and torch.backends.mps.is_available()
)
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
Expand All @@ -40,7 +46,9 @@
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")

Expand Down Expand Up @@ -68,3 +76,6 @@
import h5py
except ImportError: # pragma: no cover
h5py = None


has_gpu = has_cupy_gpu or has_torch_mps_gpu
Loading