Skip to content

Commit 7ce6e47

Browse files
committed
add conv1d _fn op interface
1 parent 7eff2db commit 7ce6e47

4 files changed

Lines changed: 151 additions & 0 deletions

File tree

xllm/core/kernels/npu/xllm_ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cc_library(
1313
beam_search.cpp
1414
select_unshared_kv.cpp
1515
beam_search_rec.cpp
16+
causal_conv1d.cpp
1617
DEPS
1718
atb
1819
torch_npu
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <c10/core/Device.h>
17+
#include <glog/logging.h>
18+
#include <torch/torch.h>
19+
#include <torch_npu/csrc/libs/init_npu.h>
20+
#include <torch_npu/torch_npu.h>
21+
22+
#include <nlohmann/json.hpp>
23+
#ifdef TORCH_HIGHER_THAN_PTA6
24+
#include <torch_npu/csrc/framework/OpCommand.h>
25+
#else
26+
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
27+
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
28+
#endif
29+
30+
#include "acl/acl.h"
31+
#include "aclnn_causal_conv1d.h"
32+
#include "core/common/macros.h"
33+
#include "core/kernels/npu/utils.h"
34+
#include "xllm_ops_api.h"
35+
36+
namespace xllm::kernel::npu {
37+
38+
torch::Tensor causal_conv1d(const torch::Tensor& x,
39+
const torch::Tensor& weight,
40+
const torch::Tensor& conv_state,
41+
const std::optional<torch::Tensor>& bias_opt,
42+
const torch::IntArrayRef query_start_loc_opt,
43+
const torch::IntArrayRef cache_indices_opt,
44+
const torch::IntArrayRef initial_state_mode_opt,
45+
const torch::IntArrayRef num_accepted_tokens_opt,
46+
int64_t activation_mode,
47+
int64_t pad_slot_id,
48+
int64_t run_mode) {
49+
check_tensor(x, "x", "causal_conv1d");
50+
check_tensor(weight, "weight", "causal_conv1d");
51+
check_tensor(conv_state, "conv_state", "causal_conv1d");
52+
53+
aclTensor* x_ids = nullptr;
54+
aclTensor* weight_ids = nullptr;
55+
aclTensor* bias_ids = nullptr;
56+
aclTensor* conv_state_ids = nullptr;
57+
aclTensor* output_ids = nullptr;
58+
aclIntArray* query_start_loc_ids = nullptr;
59+
aclIntArray* cache_indices_ids = nullptr;
60+
aclIntArray* initial_state_mode_ids = nullptr;
61+
aclIntArray* num_accepted_tokens_ids = nullptr;
62+
63+
int32_t device_id = x.device().index();
64+
aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream();
65+
66+
create_acltensor(&x_ids, x);
67+
create_acltensor(&weight_ids, weight);
68+
create_acltensor(&conv_state_ids, conv_state);
69+
if (bias_opt.has_value() && bias_opt.value().defined()) {
70+
create_acltensor(&bias_ids, bias_opt.value());
71+
}
72+
query_start_loc_ids = aclCreateIntArray(query_start_loc_opt.data(),
73+
query_start_loc_opt.size());
74+
cache_indices_ids = aclCreateIntArray(cache_indices_opt.data(),
75+
cache_indices_opt.size());
76+
initial_state_mode_ids = aclCreateIntArray(initial_state_mode_opt.data(),
77+
initial_state_mode_opt.size());
78+
num_accepted_tokens_ids = aclCreateIntArray(num_accepted_tokens_opt.data(),
79+
num_accepted_tokens_opt.size());
80+
81+
torch::Tensor output = torch::empty(x.sizes(), x.options());
82+
create_acltensor(&output_ids, output);
83+
84+
uint64_t workspace_size = 0;
85+
aclOpExecutor* executor = nullptr;
86+
87+
CHECK_ACL_SUCCESS(
88+
aclnnCausalConv1dGetWorkspaceSize(x_ids,
89+
weight_ids,
90+
bias_ids,
91+
conv_state_ids,
92+
query_start_loc_ids,
93+
cache_indices_ids,
94+
initial_state_mode_ids,
95+
num_accepted_tokens_ids,
96+
activation_mode,
97+
pad_slot_id,
98+
run_mode,
99+
output_ids,
100+
&workspace_size,
101+
&executor),
102+
"causal_conv1d: failed to get workspace size");
103+
104+
void* workspace_addr = nullptr;
105+
if (workspace_size > 0) {
106+
CHECK_ACL_SUCCESS(
107+
aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST),
108+
"causal_conv1d: failed to allocate workspace");
109+
}
110+
111+
CHECK_ACL_SUCCESS(
112+
aclnnCausalConv1d(workspace_addr, workspace_size, executor, stream),
113+
"causal_conv1d: failed to perform causal conv1d");
114+
115+
CHECK_ACL_SUCCESS(aclrtSynchronizeStream(stream),
116+
"causal_conv1d: failed to synchronize stream");
117+
118+
aclDestroyTensor(x_ids);
119+
aclDestroyTensor(weight_ids);
120+
aclDestroyTensor(conv_state_ids);
121+
aclDestroyTensor(output_ids);
122+
if (bias_ids != nullptr) {
123+
aclDestroyTensor(bias_ids);
124+
}
125+
aclDestroyIntArray(query_start_loc_ids);
126+
aclDestroyIntArray(cache_indices_ids);
127+
aclDestroyIntArray(initial_state_mode_ids);
128+
aclDestroyIntArray(num_accepted_tokens_ids);
129+
130+
if (workspace_size > 0) {
131+
CHECK_ACL_SUCCESS(aclrtFree(workspace_addr),
132+
"causal_conv1d: failed to free workspace");
133+
}
134+
135+
return output;
136+
}
137+
} // namespace xllm::kernel::npu

xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,16 @@ void select_unshared_kv(const torch::Tensor& beam_index,
5353
int64_t decode_step,
5454
int64_t beam_size,
5555
int64_t layer_num);
56+
57+
torch::Tensor causal_conv1d(const torch::Tensor& x,
58+
const torch::Tensor& weight,
59+
const torch::Tensor& conv_state,
60+
const std::optional<torch::Tensor>& bias_opt,
61+
const torch::IntArrayRef query_start_loc_opt,
62+
const torch::IntArrayRef cache_indices_opt,
63+
const torch::IntArrayRef initial_state_mode_opt,
64+
const torch::IntArrayRef num_accepted_tokens_opt,
65+
int64_t activation_mode,
66+
int64_t pad_slot_id,
67+
int64_t run_mode);
5668
} // namespace xllm::kernel::npu

xllm/core/layers/npu_torch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ cc_library(
2525
qwen3_5_decoder_layer_impl.cpp
2626
DEPS
2727
:common_layers
28+
$<$<BOOL:${USE_NPU}>:xllm_ops>
2829
)

0 commit comments

Comments
 (0)