Skip to content

Commit 8fe7469

Browse files
zuje123zhengfaan
andauthored
[WIP] Fix bs zero for deepep (#385)
* low_latency a3 support bs=0 * normal a3 support bs=0 * low_latency combine add output param for bs=0 * low_latency a2 support bs=0 * internode normal a2 support bs=0 * intranode normal a2 support bs=0 * low_latency a2 dual support bs=0 * deepep adapter remove padding * fused_deep_moe support bs=0 * fix codelint * remove unuse variable --------- Co-authored-by: zhengfaan <zhengfaan@outlook.com>
1 parent 45aa6eb commit 8fe7469

28 files changed

Lines changed: 107 additions & 315 deletions

csrc/deepep/deep_ep.cpp

Lines changed: 41 additions & 263 deletions
Large diffs are not rendered by default.

csrc/deepep/deep_ep.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@ struct Buffer {
2727
bool combine_enable_long_seq = false; // Whether to enable the Combine Ant Migration feature
2828

2929
bool low_latency_mode = false;
30-
bool is_padding = false;
31-
int padding_cnt = 0;
32-
at::Tensor ori_x;
33-
at::Tensor new_topk_idx;
34-
at::Tensor new_scales;
3530
at::Tensor notify_send_data; // only for internode notify
3631
at::Tensor send_token_idx_small;
3732
int notify_send_data_size; // only for internode notify

csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ static bool CheckAttrs(gert::TilingContext *context, CamMoeCombineNormalTilingDa
435435
// 校验输入topkWeights的维度0并设bs
436436
const gert::StorageShape *topkWeightsStorageShape = context->GetInputShape(TOPK_WEIGHTS_INDEX);
437437
int64_t topkWeightsDim0 = topkWeightsStorageShape->GetStorageShape().GetDim(0);
438-
OP_TILING_CHECK((topkWeightsDim0 <= 0) || (topkWeightsDim0 > BS_UPPER_BOUND),
438+
OP_TILING_CHECK((topkWeightsDim0 < 0) || (topkWeightsDim0 > BS_UPPER_BOUND),
439439
OP_LOGE(nodeName, "Invalid topkWeights dims0(BS) %ld. Should be between [1, %ld].", topkWeightsDim0,
440440
BS_UPPER_BOUND),
441441
return false);

csrc/deepep/ops/op_host/cam_moe_dispatch_normal_tiling.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ static ge::graphStatus CheckAttrs(gert::TilingContext *context, const char *node
378378
// 校验输入x的dim 0并设bs
379379
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
380380
const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
381-
OP_TILING_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 <= 0),
381+
OP_TILING_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 < 0),
382382
OP_LOGE(nodeName, "xDim0(BS) is invalid. Should be between [1, %ld], but got xDim0=%ld.",
383383
BS_UPPER_BOUND, xDim0),
384384
return ge::GRAPH_FAILED);

csrc/deepep/ops/op_host/fused_deep_moe_tiling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5;
8282
constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6;
8383
constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7;
8484

85-
constexpr uint32_t MIN_BATCH_SIZE = 1;
85+
constexpr uint32_t MIN_BATCH_SIZE = 0;
8686
constexpr uint32_t MAX_BATCH_SIZE = 256;
8787
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
8888
constexpr uint32_t SUPPORT_TOP_K = 12;

csrc/deepep/ops/op_host/moe_distribute_combine_v2_def.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class MoeDistributeCombineV2 : public OpDef
114114
.DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16})
115115
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
116116
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
117+
this->Output("combine_send_cost_stats")
118+
.ParamType(OPTIONAL)
119+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
120+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
121+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
117122

118123
this->Attr("group_ep").AttrType(REQUIRED).String();
119124
this->Attr("ep_world_size").AttrType(REQUIRED).Int();

csrc/deepep/ops/op_host/moe_distribute_combine_v2_tiling.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,7 @@ static bool CheckTensorShape(const gert::TilingContext *context, MoeDistributeCo
849849
const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX);
850850
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
851851
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
852-
OP_TILING_CHECK(xDim0 != expertIdsDim0,
853-
OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", expertIdsDim0, xDim0),
854-
return false);
852+
855853
OP_TILING_CHECK(xDim1 != expandXDim1,
856854
OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, expandXDim1),
857855
return false);
@@ -984,8 +982,8 @@ static bool CheckAttrs(const gert::TilingContext *context, MoeDistributeCombineV
984982
// 校验输入expertIds的维度0并设bs
985983
const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(EXPERT_IDS_INDEX);
986984
int64_t expertIdsDim0 = expertIdsStorageShape->GetStorageShape().GetDim(0);
987-
OP_TILING_CHECK((expertIdsDim0 <= 0) || (expertIdsDim0 > BS_UPPER_BOUND),
988-
OP_LOGE(nodeName, "Invalid expertIds dims0(BS) %ld. Should be between [1, %ld].", expertIdsDim0,
985+
OP_TILING_CHECK((expertIdsDim0 < 0) || (expertIdsDim0 > BS_UPPER_BOUND),
986+
OP_LOGE(nodeName, "Invalid expertIds dims0(BS) %ld. Should be between [0, %ld].", expertIdsDim0,
989987
BS_UPPER_BOUND),
990988
return false);
991989
tilingData.moeDistributeCombineV2Info.bs = static_cast<uint32_t>(expertIdsDim0);

csrc/deepep/ops/op_host/moe_distribute_dispatch_v2_tiling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,8 @@ static ge::graphStatus CheckAttrs(const gert::TilingContext *context, const char
757757
// 校验输入x的dim 0并设bs
758758
const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX);
759759
const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
760-
OP_TILING_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 <= 0),
761-
OP_LOGE(nodeName, "xDim0(BS) is invalid. Should be between [1, %ld], but got xDim0=%ld.",
760+
OP_TILING_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 < 0),
761+
OP_LOGE(nodeName, "xDim0(BS) is invalid. Should be between [0, %ld], but got xDim0=%ld.",
762762
BS_UPPER_BOUND, xDim0),
763763
return ge::GRAPH_FAILED);
764764
tilingData.moeDistributeDispatchV2Info.bs = static_cast<uint32_t>(xDim0);

csrc/deepep/ops/op_host/notify_dispatch_tiling.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
131131
OP_LOGE(nodeName, "sendCount is invalid, only support > 0, but got sendCount=%d.", *sendCountPtr),
132132
return ge::GRAPH_FAILED);
133133
OP_TILING_CHECK(
134-
(*numTokenPtr <= 0),
135-
OP_LOGE(nodeName, "numTokenPtr is invalid, only support > 0, but got numTokenPtr=%d.", *numTokenPtr),
134+
(*numTokenPtr < 0),
135+
OP_LOGE(nodeName, "numTokenPtr is invalid, only support >= 0, but got numTokenPtr=%d.", *numTokenPtr),
136136
return ge::GRAPH_FAILED);
137137

138138
commGroup = std::string(commGroupPtr);

csrc/deepep/ops/op_host/op_api/aclnn_moe_distribute_combine_v2.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ extern aclnnStatus aclnnInnerMoeDistributeCombineV2GetWorkspaceSize(
2323
int64_t epRankId, int64_t moeExpertNum, char *groupTp, int64_t tpWorldSize, int64_t tpRankId,
2424
int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t globalBs, int64_t outDtype,
2525
int64_t commQuantMode, int64_t groupListType, char *commAlg, int64_t zeroExpertNum, int64_t copyExpertNum,
26-
int64_t constExpertNum, const aclTensor *x, uint64_t *workspaceSize, aclOpExecutor **executor);
26+
int64_t constExpertNum, const aclTensor *x, const aclTensor *sendCostStats, uint64_t *workspaceSize,
27+
aclOpExecutor **executor);
2728

2829
extern aclnnStatus aclnnInnerMoeDistributeCombineV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,
2930
aclrtStream stream);
@@ -38,14 +39,15 @@ aclnnStatus aclnnMoeDistributeCombineV2GetWorkspaceSize(
3839
const aclTensor *sharedExpertXOptional, char *groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum,
3940
char *groupTp, int64_t tpWorldSize, int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum,
4041
int64_t sharedExpertRankNum, int64_t globalBs, int64_t outDtype, int64_t commQuantMode, int64_t groupListType,
41-
char *commAlg, const aclTensor *xOut, uint64_t *workspaceSize, aclOpExecutor **executor)
42+
char *commAlg, const aclTensor *xOut, const aclTensor *sendCostStats, uint64_t *workspaceSize,
43+
aclOpExecutor **executor)
4244
{
4345
return aclnnInnerMoeDistributeCombineV2GetWorkspaceSize(
4446
expandX, expertIds, assistInfoForCombine, epSendCounts, expertScales, tpSendCountsOptional, xActiveMaskOptional,
4547
activationScaleOptional, weightScaleOptional, groupListOptional, expandScalesOptional, sharedExpertXOptional,
4648
nullptr, nullptr, nullptr, nullptr, nullptr, groupEp, epWorldSize, epRankId, moeExpertNum, groupTp, tpWorldSize,
4749
tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, globalBs, outDtype, commQuantMode,
48-
groupListType, commAlg, 0, 0, 0, xOut, workspaceSize, executor);
50+
groupListType, commAlg, 0, 0, 0, xOut, sendCostStats, workspaceSize, executor);
4951
}
5052

5153
aclnnStatus aclnnMoeDistributeCombineV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor,

0 commit comments

Comments
 (0)