Skip to content
Merged
24 changes: 11 additions & 13 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotContigFactory>(
dot_dispatch_vector);

auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dot_dispatch_vector);
};

m.def("_dot", dot_pypi,
m.def("_dot", dot_pyapi,
"Call `dot` from OneMKL BLAS library to return "
"the dot product of two real-valued vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -82,13 +82,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotcContigFactory>(
dotc_dispatch_vector);

auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotc_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotc_dispatch_vector);
};

m.def("_dotc", dotc_pypi,
m.def("_dotc", dotc_pyapi,
"Call `dotc` from OneMKL BLAS library to return "
"the dot product of two complex vectors, "
"conjugating the first vector.",
Expand All @@ -101,13 +101,13 @@ PYBIND11_MODULE(_blas_impl, m)
blas_ext::DotuContigFactory>(
dotu_dispatch_vector);

auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
auto dotu_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
arrayT dst, const event_vecT &depends = {}) {
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
dotu_dispatch_vector);
};

m.def("_dotu", dotu_pypi,
m.def("_dotu", dotu_pyapi,
"Call `dotu` from OneMKL BLAS library to return "
"the dot product of two complex vectors.",
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
Expand All @@ -119,16 +119,14 @@ PYBIND11_MODULE(_blas_impl, m)
"Call `gemm` from OneMKL BLAS library to return "
"the matrix-matrix product with 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}

{
m.def("_gemm_batch", &blas_ext::gemm_batch,
"Call `gemm_batch` from OneMKL BLAS library to return "
"the matrix-matrix product for a batch of 2-D matrices.",
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),
py::arg("strideb"), py::arg("stridec"),
py::arg("depends") = py::list());
py::arg("resultC"), py::arg("depends") = py::list());
}
}
78 changes: 63 additions & 15 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
const std::int64_t,
char *,
const std::int64_t,
bool,
const std::vector<sycl::event> &);

static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
Expand All @@ -77,6 +78,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
const std::int64_t ldb,
char *resultC,
const std::int64_t ldc,
bool is_row_major,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<Tab>(exec_q);
Expand All @@ -91,7 +93,25 @@ static sycl::event gemm_impl(sycl::queue &exec_q,

sycl::event gemm_event;
try {
gemm_event = mkl_blas::row_major::gemm(
auto gemm_func =
[&](sycl::queue &q, oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB, std::int64_t m, std::int64_t n,
std::int64_t k, Tab alpha, const Tab *a, std::int64_t lda,
const Tab *b, std::int64_t ldb, Tab beta, Tc *c,
std::int64_t ldc,
const std::vector<sycl::event> &deps) -> sycl::event {
if (is_row_major) {
return mkl_blas::row_major::gemm(q, transA, transB, m, n, k,
alpha, a, lda, b, ldb, beta, c,
ldc, deps);
}
else {
return mkl_blas::column_major::gemm(q, transA, transB, m, n, k,
alpha, a, lda, b, ldb, beta,
c, ldc, deps);
}
};
gemm_event = gemm_func(
exec_q,
transA, // Defines the transpose operation for matrix A:
// 'N' indicates no transpose, 'T' for transpose,
Expand Down Expand Up @@ -130,7 +150,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
return gemm_event;
}

std::pair<sycl::event, sycl::event>
std::tuple<sycl::event, sycl::event, bool>
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
Expand Down Expand Up @@ -208,16 +228,44 @@ std::pair<sycl::event, sycl::event>
throw py::value_error(
"Result array is not c-contiguous nor f-contiguous.");
}
oneapi::mkl::transpose transA = is_matrixA_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
oneapi::mkl::transpose transB = is_matrixB_f_contig
? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
bool is_row_major = true;
if (is_matrixA_f_contig && is_matrixB_f_contig) {
is_row_major = false;
}
oneapi::mkl::transpose transA;
oneapi::mkl::transpose transB;
if (is_row_major) {
transA = is_matrixA_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
transB = is_matrixB_f_contig ? oneapi::mkl::transpose::T
: oneapi::mkl::transpose::N;
}
else {
transA = oneapi::mkl::transpose::N;
transB = oneapi::mkl::transpose::N;
}

const std::int64_t lda = (transA == oneapi::mkl::transpose::N) ? k : m;
const std::int64_t ldb = (transB == oneapi::mkl::transpose::N) ? n : k;
const std::int64_t ldc = n; // always n for row_major
std::int64_t lda;
std::int64_t ldb;
if (is_row_major) {
if (transA == oneapi::mkl::transpose::N) {
lda = k;
}
else {
lda = m;
}
if (transB == oneapi::mkl::transpose::N) {
ldb = n;
}
else {
ldb = k;
}
}
else {
lda = m;
ldb = k;
}
const std::int64_t ldc = is_row_major ? n : m;

int matrixA_typenum = matrixA.get_typenum();
int matrixB_typenum = matrixB.get_typenum();
Expand All @@ -242,14 +290,14 @@ std::pair<sycl::event, sycl::event>
char *b_typeless_ptr = matrixB.get_data();
char *r_typeless_ptr = resultC.get_data();

sycl::event gemm_ev =
gemm_fn(exec_q, transA, transB, m, n, k, a_typeless_ptr, lda,
b_typeless_ptr, ldb, r_typeless_ptr, ldc, depends);
sycl::event gemm_ev = gemm_fn(exec_q, transA, transB, m, n, k,
a_typeless_ptr, lda, b_typeless_ptr, ldb,
r_typeless_ptr, ldc, is_row_major, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_ev});

return std::make_pair(args_ev, gemm_ev);
return std::make_tuple(args_ev, gemm_ev, is_row_major);
}

template <typename fnT, typename Tab, typename Tc>
Expand Down
8 changes: 2 additions & 6 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,18 @@ namespace ext
{
namespace blas
{
extern std::pair<sycl::event, sycl::event>
extern std::tuple<sycl::event, sycl::event, bool>
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
extern std::tuple<sycl::event, sycl::event, bool>
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::int64_t batch_size,
size_t stridea,
size_t strideb,
size_t stridec,
const std::vector<sycl::event> &depends);

extern void init_gemm_dispatch_table(void);
Expand Down
Loading