diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 4ee7e3f2d..273cfd6cb 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -55,7 +55,7 @@ jobs: - name: Run test intranode timeout-minutes: 10 env: - HCCL_BUFFSIZE: 2300 + HCCL_BUFFSIZE: 3000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_intranode.py @@ -71,7 +71,8 @@ jobs: - name: Run test low latency timeout-minutes: 10 env: - HCCL_BUFFSIZE: 1913 + HCCL_BUFFSIZE: 3000 + MOE_ENABLE_TOPK_NEG_ONE: 1 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py --num-tokens=1 @@ -114,6 +115,7 @@ jobs: timeout-minutes: 10 env: HCCL_BUFFSIZE: 3000 + MOE_ENABLE_TOPK_NEG_ONE: 1 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_normal_and_low_latency.py @@ -160,14 +162,15 @@ jobs: - name: Run test intranode timeout-minutes: 10 env: - HCCL_BUFFSIZE: 2300 + HCCL_BUFFSIZE: 3000 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_intranode.py - name: Run test low latency timeout-minutes: 10 env: - HCCL_BUFFSIZE: 1913 + HCCL_BUFFSIZE: 3000 + MOE_ENABLE_TOPK_NEG_ONE: 1 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py --num-tokens=1 @@ -210,6 +213,7 @@ jobs: timeout-minutes: 10 env: HCCL_BUFFSIZE: 3000 + MOE_ENABLE_TOPK_NEG_ONE: 1 run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_normal_and_low_latency.py diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index a3bbc6699..e44d0f3b2 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -1,5 +1,6 @@ import argparse import os +import random import time from typing import Optional @@ -34,13 +35,28 @@ def test_main( group: dist.ProcessGroup, ): # Settings - num_tokens, hidden = args.num_tokens, args.hidden + base_num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts enable_diagnose = args.enable_diagnose num_servers = num_ranks // num_local_ranks num_nodes = num_servers expert_token_nums_type = int(os.getenv("MOE_EXPERT_TOKEN_NUMS_TYPE", 1)) + fluctuation_percentage = 0.1 + min_fluctuation = 2 + + if base_num_tokens < 10: + fluctuation = random.randint(-min_fluctuation, min_fluctuation) + num_tokens = base_num_tokens + fluctuation + else: + fluctuation = random.uniform( + 1 - fluctuation_percentage, 1 + fluctuation_percentage + ) + num_tokens = int(base_num_tokens * fluctuation) + + # Ensure num_tokens is at least 1 + num_tokens = max(num_tokens, 1) + assert num_experts % num_ranks == 0 and num_nodes >= 2 assert num_tokens <= MAX_BATCH_SIZE if local_rank == 0: diff --git a/tests/python/deepep/test_intranode.py b/tests/python/deepep/test_intranode.py index b59dc823a..bf501e820 100644 --- a/tests/python/deepep/test_intranode.py +++ b/tests/python/deepep/test_intranode.py @@ -1,5 +1,6 @@ import argparse import os +import random import time from typing import Optional @@ -30,12 +31,27 @@ def test_main( group: dist.ProcessGroup, ): # Settings - num_tokens, hidden = args.num_tokens, args.hidden + base_num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts enable_diagnose = args.enable_diagnose num_servers = num_ranks // num_local_ranks expert_token_nums_type = int(os.getenv("MOE_EXPERT_TOKEN_NUMS_TYPE", 1)) + fluctuation_percentage = 0.1 + min_fluctuation = 2 + + if base_num_tokens < 10: + fluctuation = random.randint(-min_fluctuation, min_fluctuation) + num_tokens = base_num_tokens + fluctuation + else: + fluctuation = random.uniform( + 1 - fluctuation_percentage, 1 + fluctuation_percentage + ) + num_tokens = int(base_num_tokens * fluctuation) + + # Ensure num_tokens is at least 1 + num_tokens = max(num_tokens, 1) + assert num_experts % num_ranks == 0 if local_rank == 0: print( diff --git a/tests/python/deepep/test_low_latency.py b/tests/python/deepep/test_low_latency.py index 6bebad3f5..b2403cfa9 100644 --- a/tests/python/deepep/test_low_latency.py +++ b/tests/python/deepep/test_low_latency.py @@ -20,7 +20,8 @@ def test( - num_tokens: int, + aligned_num_tokens: int, # 对齐后的最大token数 + actual_num_tokens: int, # 当前rank的实际token数,有效token数 hidden: int, num_experts: int, num_topk: int, @@ -43,28 +44,57 @@ def test( num_ranks - rank_offset < 257 ), "Too many ranks (exceeding test precision limit)" - x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * ( - rank - rank_offset - ) - x[:, -128:] = torch.arange(num_tokens, device="npu").to(torch.bfloat16).view(-1, 1) + x = torch.zeros((aligned_num_tokens, hidden), dtype=torch.bfloat16, device="npu") + + if actual_num_tokens > 0: + x[:actual_num_tokens] = torch.ones( + (actual_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) * (rank - rank_offset) + x[:actual_num_tokens, -128:] = ( + torch.arange(actual_num_tokens, device="npu").to(torch.bfloat16).view(-1, 1) + ) + scores = ( - torch.randn((num_tokens, num_experts), dtype=torch.float32, device="npu").abs() + torch.randn( + (aligned_num_tokens, num_experts), dtype=torch.float32, device="npu" + ).abs() + 1 ) - topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] - if drop_percent > 0: - enable_neg_one = int(os.getenv("MOE_ENABLE_TOPK_NEG_ONE", 0)) - if enable_neg_one == 0: - print( - "[ERROR] The kernel can't support drop_percent larger than 0 when MOE_ENABLE_TOPK_NEG_ONE was" - "unset or 0. Please set to 1 and try again" + topk_idx = torch.full( + (aligned_num_tokens, num_topk), -1, dtype=torch.long, device="npu" + ) + + if actual_num_tokens > 0: + actual_scores = scores[:actual_num_tokens] + actual_topk_idx = torch.topk( + actual_scores, num_topk, dim=-1, largest=True, sorted=True + )[1] + topk_idx[:actual_num_tokens] = actual_topk_idx + + if drop_percent > 0: + enable_neg_one = int(os.getenv("MOE_ENABLE_TOPK_NEG_ONE", 0)) + if enable_neg_one == 0: + print( + "[ERROR] The kernel can't support drop_percent larger than 0 when MOE_ENABLE_TOPK_NEG_ONE was" + "unset or 0. Please set to 1 and try again" + ) + assert enable_neg_one == 1 + drop_mask = ( + torch.rand( + (actual_num_tokens, num_topk), dtype=torch.float32, device="npu" + ) + < drop_percent ) - assert enable_neg_one == 1 - drop_mask = torch.rand_like(topk_idx, dtype=torch.float32) < drop_percent - topk_idx = topk_idx.masked_fill(drop_mask, -1) - topk_weights = torch.randn( - (num_tokens, num_topk), dtype=torch.float32, device="npu" - ).abs() + topk_idx[:actual_num_tokens] = topk_idx[:actual_num_tokens].masked_fill( + drop_mask, -1 + ) + topk_weights = torch.zeros( + (aligned_num_tokens, num_topk), dtype=torch.float32, device="npu" + ) + if actual_num_tokens > 0: + topk_weights[:actual_num_tokens] = torch.randn( + (actual_num_tokens, num_topk), dtype=torch.float32, device="npu" + ).abs() # Check dispatch correctness do_check = True @@ -79,7 +109,7 @@ def test( buffer.low_latency_dispatch( x, topk_idx, - num_tokens, + aligned_num_tokens, num_experts, use_fp8=dispatch_use_fp8, round_scale=False, @@ -94,13 +124,21 @@ def test( ) all_topk_idx = torch.empty( - (num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="npu" + (num_ranks, aligned_num_tokens, num_topk), + dtype=topk_idx.dtype, + device="npu", ) dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + local_actual_tokens = torch.tensor( + [actual_num_tokens], dtype=torch.int32, device="npu" + ) + all_actual_tokens = torch.empty(num_ranks, dtype=torch.int32, device="npu") + dist.all_gather_into_tensor(all_actual_tokens, local_actual_tokens, group=group) + for i in range(num_local_experts if do_check else 0): expert_id = rank * num_local_experts + i - temp = num_tokens / num_local_experts + temp = aligned_num_tokens / num_local_experts recv_count = packed_recv_count[i] recv_x = ( per_token_cast_back( @@ -120,19 +158,30 @@ def test( # Check expert indices int_mask = (2**32) - 1 num_valid_tokens = recv_count.item() + + expected_valid_tokens = 0 + for r in range(num_ranks): + # 获取rank r的实际token数 + r_actual_tokens = all_actual_tokens[r].item() + # 获取rank r发送给这个专家的token数 + r_topk_idx = all_topk_idx[r, :r_actual_tokens] + expected_valid_tokens += (r_topk_idx == expert_id).sum().item() + assert ( num_valid_tokens == (recv_layout_range & int_mask).item() ), f"{num_valid_tokens} != {recv_layout_range & int_mask}.item()" assert ( - num_valid_tokens == (all_topk_idx == expert_id).sum().item() - ), f"{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}" + num_valid_tokens == expected_valid_tokens + ), f"{num_valid_tokens} != {expected_valid_tokens}" if num_valid_tokens == 0: continue # Check received data recv_x = recv_x[:num_valid_tokens] recv_x_amin = recv_x[:, :-128].amin(dim=-1) - assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) + assert torch.allclose( + recv_x_amin, recv_x[:, :-128].amax(dim=-1), equal_nan=True + ) if dispatch_use_fp8: hash_value ^= hash_tensor( packed_recv_x[0][int(i * temp) : int(i * temp + num_valid_tokens)] @@ -155,7 +204,9 @@ def test( packed_recv_count, ) = handle - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="npu") + out = torch.empty( + (aligned_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) combined_x, event, hook = buffer.low_latency_combine( simulated_gemm_x, topk_idx, @@ -168,25 +219,47 @@ def test( ) if do_check: - diff = calc_diff( - x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), - combined_x, + if actual_num_tokens > 0: + # 只考虑有效token + expected_x = torch.zeros( + (aligned_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) + expected_x[:actual_num_tokens] = torch.ones( + (actual_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) * (rank - rank_offset) + expected_x[:actual_num_tokens, -128:] = ( + torch.arange(actual_num_tokens, device="npu") + .to(torch.bfloat16) + .view(-1, 1) + ) + + diff = calc_diff( + expected_x[:actual_num_tokens] + * topk_weights[:actual_num_tokens] + .masked_fill(topk_idx[:actual_num_tokens] == -1, 0) + .sum(dim=1) + .view(-1, 1), + combined_x[:actual_num_tokens], + ) + assert torch.isnan(combined_x[:actual_num_tokens]).sum().item() == 0 + if dispatch_use_fp8: + assert diff < 1e-4, f"Error: {diff=}" + else: + assert diff < 1e-5, f"Error: {diff=}" + hash_value ^= hash_tensor( + combined_x[:actual_num_tokens] + if actual_num_tokens > 0 + else torch.tensor([], device="npu") ) - assert torch.isnan(combined_x).sum().item() == 0 - if dispatch_use_fp8: - assert diff < 1e-4, f"Error: {diff=}" - else: - assert diff < 1e-5, f"Error: {diff=}" - hash_value ^= hash_tensor(combined_x) - print(f"rank {rank} PASSED") + print(f"rank {rank} PASSED (actual tokens: {actual_num_tokens})") # noinspection PyShadowingNames def test_func(zero_copy: bool, return_recv_hook: bool): recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch( x, topk_idx, - num_tokens, + aligned_num_tokens, num_experts, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, use_fp8=dispatch_use_fp8, @@ -205,7 +278,7 @@ def test_func(zero_copy: bool, return_recv_hook: bool): # Calculate bandwidth num_fp8_bytes, num_bf16_bytes = (hidden + hidden // 128 * 4 + 16), hidden * 2 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 - for i in range(num_tokens): + for i in range(actual_num_tokens): num_selections = (topk_idx[i] != -1).sum().item() num_dispatch_comm_bytes += num_fp8_bytes * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections @@ -216,12 +289,12 @@ def test_func(zero_copy: bool, return_recv_hook: bool): ) print( f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, " - f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us", + f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us, " + f"actual_tokens={actual_num_tokens}", flush=True, ) # Separate profiling - # return_recv_hook=True is not supported now for return_recv_hook in (False,): enable_neg_one = int(os.getenv("MOE_ENABLE_TOPK_NEG_ONE", 0)) dist.barrier() @@ -262,13 +335,36 @@ def test_func(zero_copy: bool, return_recv_hook: bool): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) shared_expert_rank_num = int(os.getenv("MOE_SHARED_EXPERT_RANK_NUM", 0)) - num_tokens, hidden = args.num_tokens, args.hidden + base_num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts use_experts = num_experts if shared_expert_rank_num == 0 else (num_experts - 1) use_ranks = num_ranks - shared_expert_rank_num drop_percent = args.drop_percent + + fluctuation_percentage = 0.1 + min_fluctuation = 2 + + if base_num_tokens < 10: + fluctuation = random.randint(-min_fluctuation, min_fluctuation) + num_tokens = base_num_tokens + fluctuation + else: + fluctuation = random.uniform( + 1 - fluctuation_percentage, 1 + fluctuation_percentage + ) + num_tokens = int(base_num_tokens * fluctuation) + + raw_num_tokens = max(num_tokens, 1) + + local_tokens_tensor = torch.tensor([num_tokens], dtype=torch.int32, device="npu") + dist.all_reduce(local_tokens_tensor, op=dist.ReduceOp.MAX) + aligned_num_tokens = local_tokens_tensor.item() + + print( + f"[rank {rank}] raw_num_tokens: {raw_num_tokens}, aligned_num_tokens: {aligned_num_tokens}" + ) + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint( - num_tokens, hidden, num_ranks, num_experts + aligned_num_tokens, hidden, num_ranks, num_experts ) buffer = Buffer( group, @@ -278,7 +374,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ) test( - num_tokens, + aligned_num_tokens, + raw_num_tokens, hidden, use_experts, num_topk, @@ -295,7 +392,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if rank == 0: print(f"Testing with seed {seed} ...", flush=True) ref_hash = test( - num_tokens, + aligned_num_tokens, + raw_num_tokens, hidden, use_experts, num_topk, @@ -309,7 +407,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): for i in range(20): assert ( test( - num_tokens, + aligned_num_tokens, + raw_num_tokens, hidden, use_experts, num_topk, diff --git a/tests/python/deepep/test_normal_and_low_latency.py b/tests/python/deepep/test_normal_and_low_latency.py index 3bfc9f0ce..3f42444f5 100644 --- a/tests/python/deepep/test_normal_and_low_latency.py +++ b/tests/python/deepep/test_normal_and_low_latency.py @@ -1,4 +1,5 @@ import argparse +import random import deep_ep import torch @@ -10,7 +11,8 @@ def low_latency_test( - num_tokens: int, + aligned_num_tokens: int, + actual_num_tokens: int, hidden: int, num_experts: int, num_topk: int, @@ -25,20 +27,41 @@ def low_latency_test( num_ranks - rank_offset < 257 ), "Too many ranks (exceeding test precision limit)" - x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="npu") * ( - rank - rank_offset - ) - x[:, -128:] = torch.arange(num_tokens, device="npu").to(torch.bfloat16).view(-1, 1) + x = torch.zeros((aligned_num_tokens, hidden), dtype=torch.bfloat16, device="npu") + + if actual_num_tokens > 0: + x[:actual_num_tokens] = torch.ones( + (actual_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) * (rank - rank_offset) + x[:actual_num_tokens, -128:] = ( + torch.arange(actual_num_tokens, device="npu").to(torch.bfloat16).view(-1, 1) + ) + scores = ( - torch.randn((num_tokens, num_experts), dtype=torch.float32, device="npu").abs() + torch.randn( + (aligned_num_tokens, num_experts), dtype=torch.float32, device="npu" + ).abs() + 1 ) - topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_idx = torch.full( + (aligned_num_tokens, num_topk), -1, dtype=torch.long, device="npu" + ) + + if actual_num_tokens > 0: + actual_scores = scores[:actual_num_tokens] + actual_topk_idx = torch.topk( + actual_scores, num_topk, dim=-1, largest=True, sorted=True + )[1] + topk_idx[:actual_num_tokens] = actual_topk_idx - topk_weights = torch.randn( - (num_tokens, num_topk), dtype=torch.float32, device="npu" - ).abs() + topk_weights = torch.zeros( + (aligned_num_tokens, num_topk), dtype=torch.float32, device="npu" + ) + if actual_num_tokens > 0: + topk_weights[:actual_num_tokens] = torch.randn( + (actual_num_tokens, num_topk), dtype=torch.float32, device="npu" + ).abs() return_recv_hook = False cumulative_local_expert_recv_stats = torch.zeros( @@ -48,7 +71,7 @@ def low_latency_test( packed_recv_x, packed_recv_count, handle, event, hook = buffer.low_latency_dispatch( x, topk_idx, - num_tokens, + aligned_num_tokens, num_experts, use_fp8=dispatch_use_fp8, round_scale=False, @@ -70,7 +93,7 @@ def low_latency_test( _, ) = handle - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="npu") + out = torch.empty((aligned_num_tokens, hidden), dtype=torch.bfloat16, device="npu") combined_x, event, hook = buffer.low_latency_combine( simulated_gemm_x, topk_idx, @@ -82,15 +105,31 @@ def low_latency_test( out=out, ) - diff = calc_diff( - x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), - combined_x, - ) - assert torch.isnan(combined_x).sum().item() == 0 - if dispatch_use_fp8: - assert diff < 1e-4, f"Error: {diff=}" - else: - assert diff < 1e-5, f"Error: {diff=}" + if actual_num_tokens > 0: + # 计算期望的输出(只考虑有效token) + expected_x = torch.zeros( + (aligned_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) + expected_x[:actual_num_tokens] = torch.ones( + (actual_num_tokens, hidden), dtype=torch.bfloat16, device="npu" + ) * (rank - rank_offset) + expected_x[:actual_num_tokens, -128:] = ( + torch.arange(actual_num_tokens, device="npu").to(torch.bfloat16).view(-1, 1) + ) + + diff = calc_diff( + expected_x[:actual_num_tokens] + * topk_weights[:actual_num_tokens] + .masked_fill(topk_idx[:actual_num_tokens] == -1, 0) + .sum(dim=1) + .view(-1, 1), + combined_x[:actual_num_tokens], + ) + assert torch.isnan(combined_x).sum().item() == 0 + if dispatch_use_fp8: + assert diff < 1e-4, f"Error: {diff=}" + else: + assert diff < 1e-5, f"Error: {diff=}" def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): @@ -98,12 +137,27 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_topk, num_experts, hidden = args.num_topk, args.num_experts, args.hidden assert num_experts % num_ranks == 0 torch.manual_seed(rank) - buffer = deep_ep.Buffer( - group, int(2e9), 0, low_latency_mode=True, num_qps_per_rank=1 - ) for i in range(args.test_loop): - normal_num_tokens = args.normal_num_tokens + buffer = deep_ep.Buffer( + group, int(2e9), 0, low_latency_mode=True, num_qps_per_rank=1 + ) + base_normal_num_tokens = args.normal_num_tokens + fluctuation_percentage = 0.1 + min_fluctuation = 2 + + if base_normal_num_tokens < 10: + fluctuation = random.randint(-min_fluctuation, min_fluctuation) + normal_num_tokens = base_normal_num_tokens + fluctuation + else: + fluctuation = random.uniform( + 1 - fluctuation_percentage, 1 + fluctuation_percentage + ) + normal_num_tokens = int(base_normal_num_tokens * fluctuation) + + # Ensure normal_num_tokens is at least 1 + normal_num_tokens = max(normal_num_tokens, 1) + if local_rank == 0: print(f"Start executing normal test loop {i} ...", flush=True) normal_test( @@ -116,10 +170,30 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if local_rank == 0: print(f"End executing normal test loop {i} ...", flush=True) - low_latency_num_tokens = args.low_latency_num_tokens + base_low_latency_num_tokens = args.low_latency_num_tokens + + if base_low_latency_num_tokens < 10: + fluctuation = random.randint(-min_fluctuation, min_fluctuation) + low_latency_num_tokens = base_low_latency_num_tokens + fluctuation + else: + fluctuation = random.uniform( + 1 - fluctuation_percentage, 1 + fluctuation_percentage + ) + low_latency_num_tokens = int(base_low_latency_num_tokens * fluctuation) + + # Ensure low_latency_num_tokens is at least 1 + low_latency_num_tokens = max(low_latency_num_tokens, 1) + + local_tokens_tensor = torch.tensor( + [low_latency_num_tokens], dtype=torch.int32, device="npu" + ) + dist.all_reduce(local_tokens_tensor, op=dist.ReduceOp.MAX) + aligned_num_tokens = local_tokens_tensor.item() + if local_rank == 0: print(f"Start executing low latency test loop {i} ...", flush=True) low_latency_test( + aligned_num_tokens, low_latency_num_tokens, hidden, num_experts, @@ -130,7 +204,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ) if local_rank == 0: print(f"End executing low latency test loop {i} ...", flush=True) - dist.barrier() + del buffer + torch.npu.empty_cache() + dist.barrier() dist.destroy_process_group()