Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ FILE(GLOB OP_SRCS
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/lightning_indexer.cpp
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_host/tiling/lightning_indexer_tiling.cpp
${PROJECT_OP_SRC_BASE}/tri_inv/op_host/tri_inv.cpp
${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp
${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_host/tiling/apply_top_k_top_p_min_p_tiling.cpp
)
if(BUILD_CATLASS_MODULE)
list(APPEND OP_SRCS
Expand Down Expand Up @@ -53,6 +55,7 @@ set(WORKSPACE_KERNEL_SRCS
${PROJECT_OP_SRC_BASE}/alloc_extend/op_kernel/alloc_extend_kernel.cpp
${PROJECT_OP_SRC_BASE}/build_tree/op_kernel/build_tree_kernel.cpp
${PROJECT_OP_SRC_BASE}/lightning_indexer/op_kernel/lightning_indexer_kernel.cpp
${PROJECT_OP_SRC_BASE}/apply_top_k_top_p_min_p/op_kernel/apply_top_k_top_p_min_p_kernel.cpp
)
if(BUILD_CATLASS_MODULE)
list(APPEND WORKSPACE_KERNEL_SRCS
Expand Down
65 changes: 65 additions & 0 deletions csrc/apply_top_k_top_p_min_p/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
## Introduction
A top-k, top-p and min-p sampling implementation for ascend.

## Sheet 1: Parameters
| Parameter | Dimension | Data Type | Format | Description |
|--------------|--------------------------|----------------------|--------|--------------------------------------------------|
| probs | [batch_size, vocab_size] | float32/float16/bf16 | ND | Probabilities for sampling.<br>The probabilities should be sorted in descending order. |
| k | [batch_size] | int32 | ND | Representing the threshold for top-k sampling. |
| p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for top-p sampling. |
| min_p | [batch_size] | float32/float16/bf16 | ND | Representing the threshold for min-p sampling.<br>When min_p is nullptr, the min-p sampling will be skipped. |
| sampled_res | [batch_size, vocab_size] | float32/float16/bf16 | ND | The result after sampling.<br>The DataType of sampled_res should be same as probs. |

## Calculation Formula
$$
sampled\_res[b][v] =
\begin{cases}
0 & \text{v >= k[b]} \\
probs[b][v] & \text{v < k[b]}
\end{cases}
$$
$$probs\_sum = cumsum(sampled\_res, dim=-1)$$
$$top\_p\_mask[b][v] = probs\_sum[b][v] - sampled\_res[b][v] > p[b]$$
$$
sampled\_res[b][v] =
\begin{cases}
0 & \text{top\_p\_mask = True} \\
sampled\_res[b][v] & \text{top\_p\_mask = False}
\end{cases}
$$
$$min\_p\_mask[b][v] = sampled\_res[b][v] < sampled\_res[b][0] * min\_p[b]$$
$$
sampled\_res[b][v] =
\begin{cases}
0 & \text{min\_p\_mask = True} \\
sampled\_res[b][v] & \text{min\_p\_mask = False}
\end{cases}
$$
Where $0 \le b \lt batch\_size$, and $0 \le v \lt vocab\_size$.

## Restrictions
1. Only support Ascend A2/A3.
2. $0 \lt k[b] \le vocab\_size$, where $0 \le b \lt batch\_size$, if $k[b] \lt 0$ or $k[b] \gt vocab\_size$, the $k[b]$ will regarded as vocab\_size.
2. $0 \le p[b] \le 1$, where $0 \le b \lt batch\_size$.

## Sample Code
```python
import numpy as np
import torch
import torch_npu
import sgl_kernel_npu

dtype = torch.float16
batch_size = 4
vocab_size = 128

logits = torch.tensor(np.random.uniform(-10, 10, (batch_size, vocab_size))).to(dtype).npu()
k = torch.tensor(np.random.randint(1, vocab_size, (batch_size))).to(torch.int32).npu()
p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu()
min_p = torch.tensor(np.random.uniform(0, 1, (batch_size))).to(dtype).npu()

probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = probs.sort(dim=-1, descending=True, stable=True)

torch.ops.npu.apply_top_k_top_p_min_p(probs_sort, k, p, min_p=min_p)
```
79 changes: 79 additions & 0 deletions csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include <cstdio>
#include <string>
#include "acl/acl.h"
#include "kernel_tiling/kernel_tiling.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/apply_top_k_top_p_min_p_tiling.h"
#include "defines.h"
#include "torch_helper.h"
#include "ge_helper.h"
#include "common_tiling.h"
#include "apply_top_k_top_p_min_p_def.h"
#include "common.h"
#include "aclrtlaunch_apply_top_k_top_p_min_p.h"

namespace sglang::ATKTPMPHost {

using namespace ge_helper;
constexpr uint32_t PADDING_BYTE = 32U;

inline at::Tensor ConstructApplyTopKTopPMinPOutputTensor(const at::Tensor &probs)
{
for (size_t i = 0; i < probs.sizes().size(); i++) {
TORCH_CHECK(probs.size(i) > 0,
"All values within probs's shape should be greater "
"than 0, but shape[",
i, "] is ", probs.size(i));
}
at::Tensor output = at::empty_like(probs);
return output;
}
} // namespace sglang::ATKTPMPHost

namespace sglang {
namespace npu_kernel {
HOST_API at::Tensor apply_top_k_top_p_min_p(const at::Tensor &probs, const at::Tensor &k, const at::Tensor &p,
const c10::optional<at::Tensor> &min_p)
{
using namespace ATKTPMPHost;
at::Tensor sampledRes = ConstructApplyTopKTopPMinPOutputTensor(probs);

auto probsType = probs.scalar_type();

at::Tensor minP = min_p.has_value()
? min_p.value()
: at::empty({1}, at::TensorOptions().dtype(probsType).device(probs.options().device()));

ApplyTopKTopPMinPTilingInfo applyTopKTopPMinPInfo;
applyTopKTopPMinPInfo.opParamInfo.probs.dtype = SCALAR_TYPE_TO_GE_DATATYPE(probsType);
applyTopKTopPMinPInfo.opParamInfo.probs.shape = probs.sizes();
applyTopKTopPMinPInfo.opParamInfo.k.dtype = SCALAR_TYPE_TO_GE_DATATYPE(k.scalar_type());
applyTopKTopPMinPInfo.opParamInfo.k.shape = k.sizes();
applyTopKTopPMinPInfo.opParamInfo.p.dtype = SCALAR_TYPE_TO_GE_DATATYPE(p.scalar_type());
applyTopKTopPMinPInfo.opParamInfo.p.shape = p.sizes();
if (min_p.has_value()) {
applyTopKTopPMinPInfo.opParamInfo.minP.dtype = SCALAR_TYPE_TO_GE_DATATYPE(minP.scalar_type());
applyTopKTopPMinPInfo.opParamInfo.minP.shape = minP.sizes();
}
applyTopKTopPMinPInfo.opParamInfo.sampledRes.dtype = SCALAR_TYPE_TO_GE_DATATYPE(sampledRes.scalar_type());
applyTopKTopPMinPInfo.opParamInfo.sampledRes.shape = sampledRes.sizes();

ApplyTopKTopPMinPTiling applyTopKTopPMinPTiling(&applyTopKTopPMinPInfo);
TORCH_CHECK(applyTopKTopPMinPTiling.DoTiling() == ge::GRAPH_SUCCESS, "apply_top_k_top_p_min_p DoTiling failed");

const auto &tilingData = applyTopKTopPMinPTiling.GetTilingData();

uint32_t tilingSize = (sizeof(ApplyTopKTopPMinPTiling) + PADDING_BYTE - 1) / PADDING_BYTE * PADDING_BYTE;
auto blockDim = tilingData.coreNum;
static auto tilingBuffer =
at::empty({tilingSize}, at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
aclrtMemcpy(tilingBuffer.data_ptr<uint8_t>(), tilingSize, &tilingData, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
at::Tensor tilingTensor = at::from_blob(tilingBuffer.data_ptr<uint8_t>(), tilingSize, at::kByte);

auto workspace = at::empty({applyTopKTopPMinPInfo.workspaceSize},
at::TensorOptions().dtype(at::kByte).device(probs.options().device()));
EXEC_KERNEL_CMD(apply_top_k_top_p_min_p, blockDim, probs, k, p, minP, sampledRes, workspace, tilingTensor);
return sampledRes;
}
} // namespace npu_kernel
} // namespace sglang
50 changes: 50 additions & 0 deletions csrc/apply_top_k_top_p_min_p/op_host/apply_top_k_top_p_min_p_def.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/

/*!
* \file apply_top_k_top_p_min_p_def.cpp
* \brief
*/
#include <cstdint>
#include "ge_helper.h"

namespace sglang {
namespace ATKTPMPHost {
using namespace ge_helper;
class ApplyTopKTopPMinP : public OpDef
{
public:
explicit ApplyTopKTopPMinP(const char *name) : OpDef(name)
{
this->Input("probs")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("k").ParamType(REQUIRED).DataTypeList({ge::DT_INT32}).FormatList({ge::FORMAT_ND}).AutoContiguous();
this->Input("p")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Input("min_p")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND})
.AutoContiguous();
this->Output("sampled_res")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16})
.FormatList({ge::FORMAT_ND});
}
};
} // namespace ATKTPMPHost
} // namespace sglang
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* This program is free software, you can redistribute it and/or modify it.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/

/*!
* \file apply_top_k_top_p_min_p_tiling.cpp
* \brief
*/

#include "apply_top_k_top_p_min_p_tiling.h"

using namespace ge;
using namespace AscendC;
using std::map;
using std::string;
namespace sglang::ATKTPMPHost {

// --------------------------ApplyTopKTopPMinPTiling类成员函数定义-----------------------
ge::graphStatus ApplyTopKTopPMinPTiling::CheckDtype()
{
TORCH_CHECK((tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT16) ||
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_BF16) ||
(tilingInfo_->opParamInfo.probs.dtype == ge::DT_FLOAT),
"The data types of probs, p and sampled_res must be float16, bfloat16 or float.");

TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.p.dtype,
"The data types of probs and p must be the same.");
TORCH_CHECK(tilingInfo_->opParamInfo.probs.dtype == tilingInfo_->opParamInfo.sampledRes.dtype,
"The data types of probs and sampled_res must be the same.");

TORCH_CHECK(tilingInfo_->opParamInfo.k.dtype == ge::DT_INT32, "The data types of the input k must be int32.");

return ge::GRAPH_SUCCESS;
}

ge::graphStatus ApplyTopKTopPMinPTiling::CheckShape()
{
TORCH_CHECK(tilingInfo_->opParamInfo.probs.shape.size() == DIM_NUM_TWO,
"ApplyTopKTopPMinP: the dimNum of probs should be ", DIM_NUM_TWO, ", but now is ",
tilingInfo_->opParamInfo.probs.shape.size(), ".");
tilingData_.batchSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ZERO];
tilingData_.vocabSize = tilingInfo_->opParamInfo.probs.shape[DIM_IDX_ONE];

TORCH_CHECK(tilingInfo_->opParamInfo.k.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of k should be ",
DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.k.shape.size(), ".");
int64_t kSize = tilingInfo_->opParamInfo.k.shape[DIM_IDX_ZERO];
TORCH_CHECK(kSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of k should be [", tilingData_.batchSize,
"], but now is [", kSize, "].");

TORCH_CHECK(tilingInfo_->opParamInfo.p.shape.size() == DIM_NUM_ONE, "ApplyTopKTopPMinP: the dimNum of p should be ",
DIM_NUM_ONE, ", but now is ", tilingInfo_->opParamInfo.p.shape.size(), ".");
int64_t pSize = tilingInfo_->opParamInfo.p.shape[DIM_IDX_ZERO];
TORCH_CHECK(pSize == tilingData_.batchSize, "ApplyTopKTopPMinP: the shape of p should be [", tilingData_.batchSize,
"], but now is [", pSize, "].");

if (tilingInfo_->opParamInfo.minP.shape.size() != DIM_NUM_ZERO) {
int64_t minPSize = tilingInfo_->opParamInfo.minP.shape[DIM_IDX_ZERO];
TORCH_CHECK(minPSize == tilingData_.batchSize, ": the shape of p should be [", tilingData_.batchSize,
"], but now is [", minPSize, "].");
tilingInfo_->needMinPSample = 1;
}

TORCH_CHECK(tilingInfo_->opParamInfo.sampledRes.shape.size() == DIM_NUM_TWO,
"ApplyTopKTopPMinP: the dimNum of sampled_res should be ", DIM_NUM_TWO, ", but now is ",
tilingInfo_->opParamInfo.sampledRes.shape.size(), ".");
int64_t sampledResSize0 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ZERO];
int64_t sampledResSize1 = tilingInfo_->opParamInfo.sampledRes.shape[DIM_IDX_ONE];
TORCH_CHECK(sampledResSize0 == tilingData_.batchSize && sampledResSize1 == tilingData_.vocabSize,
"ApplyTopKTopPMinP: the size of sampledRes should be [", tilingData_.batchSize, ", ",
tilingData_.vocabSize, "], but now is [", sampledResSize0, ", ", sampledResSize1, "].");
return ge::GRAPH_SUCCESS;
}

void ApplyTopKTopPMinPTiling::SplitTask()
{
tilingData_.loopDataNum = tilingData_.ubSize / BYTES_B32 / LOCAL_TENSOR_NUM / BYTES_PER_REPEAT * BYTES_PER_REPEAT;
tilingData_.coreNum = tilingData_.batchSize > tilingData_.coreNum ? tilingData_.coreNum : tilingData_.batchSize;
tilingData_.batchPerCore = tilingData_.batchSize / std::max(tilingData_.coreNum, static_cast<int64_t>(1));
tilingData_.batchTailCore = tilingData_.batchSize - tilingData_.batchPerCore * tilingData_.coreNum;
}

ge::graphStatus ApplyTopKTopPMinPTiling::DoTiling()
{
if (CheckDtype() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}
if (CheckShape() != ge::GRAPH_SUCCESS) {
return ge::GRAPH_FAILED;
}

auto ascendcPlatform = *platform_ascendc::PlatformAscendCManager::GetInstance();
uint32_t aivNum = ascendcPlatform.GetCoreNumAiv();
uint32_t aicNum = ascendcPlatform.GetCoreNumAic();
TORCH_CHECK(aivNum != 0 && aivNum != 0, "num of core obtained is 0");
tilingData_.coreNum = static_cast<int64_t>(aivNum);

uint64_t ubSize = 0;
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
tilingData_.ubSize = static_cast<int64_t>(ubSize) - SELECT_MODE_BYTES;

auto socVersion = ascendcPlatform.GetSocVersion();
TORCH_CHECK(socVersion == platform_ascendc::SocVersion::ASCEND910B ||
socVersion == platform_ascendc::SocVersion::ASCEND910_93,
"soc version does not support ", (int32_t)socVersion);

SplitTask();

// -------------set workspacesize-----------------
tilingInfo_->workspaceSize = static_cast<int64_t>(ascendcPlatform.GetLibApiWorkSpaceSize()) +
tilingData_.batchSize * tilingData_.vocabSize * BYTES_B32;

// -------------set tilingkey-----------------
tilingData_.tilingKey =
G_DTYPE_MAP.at(tilingInfo_->opParamInfo.probs.dtype) * COEF_TEN + tilingInfo_->needMinPSample;

return ge::GRAPH_SUCCESS;
}

const ApplyTopKTopPMinPTilingData &ApplyTopKTopPMinPTiling::GetTilingData() const
{
return tilingData_;
}
} // namespace sglang::ATKTPMPHost
Loading