Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions common/cuda_hip/base/index_set_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -32,10 +32,10 @@ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
template <typename IndexType>
void populate_subsets(std::shared_ptr<const DefaultExecutor> exec,
const IndexType index_space_size,
const array<IndexType>* indices,
array<IndexType>* subset_begin,
array<IndexType>* subset_end,
array<IndexType>* superset_indices,
const array<IndexType>& indices,
array<IndexType>& subset_begin,
array<IndexType>& subset_end,
array<IndexType>& superset_indices,
const bool is_sorted) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_INDEX_SET_POPULATE_KERNEL);
Expand Down
8 changes: 4 additions & 4 deletions common/cuda_hip/distributed/vector_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -28,7 +28,7 @@ void build_local(
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
partition,
comm_index_type local_part, matrix::Dense<ValueType>* local_mtx)
comm_index_type local_part, matrix::view::dense<ValueType> local_mtx)
{
const auto* range_bounds = partition->get_range_bounds();
const auto* range_starting_indices =
Expand Down Expand Up @@ -62,7 +62,7 @@ void build_local(
range_id.get_data())),
map_to_local_row);

auto stride = local_mtx->get_stride();
auto stride = local_mtx.stride;
auto map_to_flat_idx =
[stride] __host__ __device__(
const thrust::tuple<LocalIndexType, GlobalIndexType>& row_col) {
Expand All @@ -80,7 +80,7 @@ void build_local(
thrust::scatter_if(
thrust_policy(exec), input.get_const_values(),
input.get_const_values() + input.get_num_stored_elements(), flat_idx_it,
range_id.get_data(), local_mtx->get_values(), is_local_row);
range_id.get_data(), local_mtx.values, is_local_row);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
Expand Down
46 changes: 24 additions & 22 deletions common/cuda_hip/matrix/coo_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -232,7 +232,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_spmm(
template <typename ValueType, typename IndexType>
void spmv(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Coo<ValueType, IndexType>* a,
const matrix::Dense<ValueType>* b, matrix::Dense<ValueType>* c)
matrix::view::dense<const ValueType> b,
matrix::view::dense<ValueType> c)
{
dense::fill(exec, c, zero<ValueType>());
spmv2(exec, a, b, c);
Expand All @@ -243,11 +244,11 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL);

template <typename ValueType, typename IndexType>
void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Dense<ValueType>* alpha,
matrix::view::dense<const ValueType> alpha,
const matrix::Coo<ValueType, IndexType>* a,
const matrix::Dense<ValueType>* b,
const matrix::Dense<ValueType>* beta,
matrix::Dense<ValueType>* c)
matrix::view::dense<const ValueType> b,
matrix::view::dense<const ValueType> beta,
matrix::view::dense<ValueType> c)
{
dense::scale(exec, beta, c);
advanced_spmv2(exec, alpha, a, b, c);
Expand All @@ -260,10 +261,11 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
template <typename ValueType, typename IndexType>
void spmv2(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Coo<ValueType, IndexType>* a,
const matrix::Dense<ValueType>* b, matrix::Dense<ValueType>* c)
matrix::view::dense<const ValueType> b,
matrix::view::dense<ValueType> c)
{
const auto nnz = a->get_num_stored_elements();
const auto b_ncols = b->get_size()[1];
const auto b_ncols = b.size[1];
const dim3 coo_block(config::warp_size, warps_in_block, 1);
const auto nwarps = host_kernel::calculate_nwarps(exec, nnz);

Expand Down Expand Up @@ -296,8 +298,8 @@ void spmv2(std::shared_ptr<const DefaultExecutor> exec,
nnz, num_lines, as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
as_device_type(b.values), b.stride, as_device_type(c.values),
c.stride);
} else {
int num_elems =
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
Expand All @@ -308,8 +310,8 @@ void spmv2(std::shared_ptr<const DefaultExecutor> exec,
nnz, num_elems, as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
as_device_type(b.values), b.stride, as_device_type(c.values),
c.stride);
}
}
}
Expand All @@ -319,15 +321,15 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL);

template <typename ValueType, typename IndexType>
void advanced_spmv2(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Dense<ValueType>* alpha,
matrix::view::dense<const ValueType> alpha,
const matrix::Coo<ValueType, IndexType>* a,
const matrix::Dense<ValueType>* b,
matrix::Dense<ValueType>* c)
matrix::view::dense<const ValueType> b,
matrix::view::dense<ValueType> c)
{
const auto nnz = a->get_num_stored_elements();
const auto nwarps = host_kernel::calculate_nwarps(exec, nnz);
const dim3 coo_block(config::warp_size, warps_in_block, 1);
const auto b_ncols = b->get_size()[1];
const auto b_ncols = b.size[1];

if (nwarps <= 0 || b_ncols <= 0) {
return;
Expand Down Expand Up @@ -355,23 +357,23 @@ void advanced_spmv2(std::shared_ptr<const DefaultExecutor> exec,
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);

abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_lines, as_device_type(alpha->get_const_values()),
nnz, num_lines, as_device_type(alpha.values),
as_device_type(a->get_const_values()), a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
as_device_type(b.values), b.stride, as_device_type(c.values),
c.stride);
} else {
int num_elems =
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));

abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_elems, as_device_type(alpha->get_const_values()),
nnz, num_elems, as_device_type(alpha.values),
as_device_type(a->get_const_values()), a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
as_device_type(b.values), b.stride, as_device_type(c.values),
c.stride);
}
}
}
Expand Down
Loading
Loading