1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include < c10/core/Device.h>
17+ #include < glog/logging.h>
18+ #include < torch/torch.h>
19+ #include < torch_npu/csrc/libs/init_npu.h>
20+ #include < torch_npu/torch_npu.h>
21+
22+ #include < nlohmann/json.hpp>
23+ #ifdef TORCH_HIGHER_THAN_PTA6
24+ #include < torch_npu/csrc/framework/OpCommand.h>
25+ #else
26+ #include < torch_npu/csrc/aten/NPUNativeFunctions.h>
27+ #include < torch_npu/csrc/framework/utils/OpPreparation.h>
28+ #endif
29+
30+ #include " acl/acl.h"
31+ #include " aclnn_causal_conv1d.h"
32+ #include " core/common/macros.h"
33+ #include " core/kernels/npu/utils.h"
34+ #include " xllm_ops_api.h"
35+
36+ namespace xllm ::kernel::npu {
37+
38+ torch::Tensor causal_conv1d (const torch::Tensor& x,
39+ const torch::Tensor& weight,
40+ const torch::Tensor& conv_state,
41+ const std::optional<torch::Tensor>& bias_opt,
42+ const torch::IntArrayRef query_start_loc_opt,
43+ const torch::IntArrayRef cache_indices_opt,
44+ const torch::IntArrayRef initial_state_mode_opt,
45+ const torch::IntArrayRef num_accepted_tokens_opt,
46+ int64_t activation_mode,
47+ int64_t pad_slot_id,
48+ int64_t run_mode) {
49+ check_tensor (x, " x" , " causal_conv1d" );
50+ check_tensor (weight, " weight" , " causal_conv1d" );
51+ check_tensor (conv_state, " conv_state" , " causal_conv1d" );
52+
53+ aclTensor* x_ids = nullptr ;
54+ aclTensor* weight_ids = nullptr ;
55+ aclTensor* bias_ids = nullptr ;
56+ aclTensor* conv_state_ids = nullptr ;
57+ aclTensor* output_ids = nullptr ;
58+ aclIntArray* query_start_loc_ids = nullptr ;
59+ aclIntArray* cache_indices_ids = nullptr ;
60+ aclIntArray* initial_state_mode_ids = nullptr ;
61+ aclIntArray* num_accepted_tokens_ids = nullptr ;
62+
63+ int32_t device_id = x.device ().index ();
64+ aclrtStream stream = c10_npu::getCurrentNPUStream (device_id).stream ();
65+
66+ create_acltensor (&x_ids, x);
67+ create_acltensor (&weight_ids, weight);
68+ create_acltensor (&conv_state_ids, conv_state);
69+ if (bias_opt.has_value () && bias_opt.value ().defined ()) {
70+ create_acltensor (&bias_ids, bias_opt.value ());
71+ }
72+ query_start_loc_ids = aclCreateIntArray (query_start_loc_opt.data (),
73+ query_start_loc_opt.size ());
74+ cache_indices_ids = aclCreateIntArray (cache_indices_opt.data (),
75+ cache_indices_opt.size ());
76+ initial_state_mode_ids = aclCreateIntArray (initial_state_mode_opt.data (),
77+ initial_state_mode_opt.size ());
78+ num_accepted_tokens_ids = aclCreateIntArray (num_accepted_tokens_opt.data (),
79+ num_accepted_tokens_opt.size ());
80+
81+ torch::Tensor output = torch::empty (x.sizes (), x.options ());
82+ create_acltensor (&output_ids, output);
83+
84+ uint64_t workspace_size = 0 ;
85+ aclOpExecutor* executor = nullptr ;
86+
87+ CHECK_ACL_SUCCESS (
88+ aclnnCausalConv1dGetWorkspaceSize (x_ids,
89+ weight_ids,
90+ bias_ids,
91+ conv_state_ids,
92+ query_start_loc_ids,
93+ cache_indices_ids,
94+ initial_state_mode_ids,
95+ num_accepted_tokens_ids,
96+ activation_mode,
97+ pad_slot_id,
98+ run_mode,
99+ output_ids,
100+ &workspace_size,
101+ &executor),
102+ " causal_conv1d: failed to get workspace size" );
103+
104+ void * workspace_addr = nullptr ;
105+ if (workspace_size > 0 ) {
106+ CHECK_ACL_SUCCESS (
107+ aclrtMalloc (&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST),
108+ " causal_conv1d: failed to allocate workspace" );
109+ }
110+
111+ CHECK_ACL_SUCCESS (
112+ aclnnCausalConv1d (workspace_addr, workspace_size, executor, stream),
113+ " causal_conv1d: failed to perform causal conv1d" );
114+
115+ CHECK_ACL_SUCCESS (aclrtSynchronizeStream (stream),
116+ " causal_conv1d: failed to synchronize stream" );
117+
118+ aclDestroyTensor (x_ids);
119+ aclDestroyTensor (weight_ids);
120+ aclDestroyTensor (conv_state_ids);
121+ aclDestroyTensor (output_ids);
122+ if (bias_ids != nullptr ) {
123+ aclDestroyTensor (bias_ids);
124+ }
125+ aclDestroyIntArray (query_start_loc_ids);
126+ aclDestroyIntArray (cache_indices_ids);
127+ aclDestroyIntArray (initial_state_mode_ids);
128+ aclDestroyIntArray (num_accepted_tokens_ids);
129+
130+ if (workspace_size > 0 ) {
131+ CHECK_ACL_SUCCESS (aclrtFree (workspace_addr),
132+ " causal_conv1d: failed to free workspace" );
133+ }
134+
135+ return output;
136+ }
137+ } // namespace xllm::kernel::npu
0 commit comments