diff --git a/python/sglang/compile_deep_gemm.py b/python/sglang/compile_deep_gemm.py index b86086e20b2f..1b570068ddb7 100644 --- a/python/sglang/compile_deep_gemm.py +++ b/python/sglang/compile_deep_gemm.py @@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request( headers = { "Content-Type": "application/json; charset=utf-8", } - response = requests.get(f"{base_url}/v1/models", headers=headers) + if server_args.node_rank == 0: + response = requests.get(f"{base_url}/v1/models", headers=headers) + else: + # This http api is created by launch_dummy_health_check_server for none-rank0 node. + response = requests.get(f"{base_url}/health", headers=headers) if response.status_code == 200: + # Rank-0 node send a request to sync with other node and then return. + if server_args.node_rank == 0: + response = requests.post( + f"{base_url}/generate", + json={ + "input_ids": [0, 1, 2, 3], + "sampling_params": { + "max_new_tokens": 8, + "temperature": 0, + }, + }, + timeout=600, + ) + if response.status_code != 200: + error = response.json() + raise RuntimeError(f"Sync request failed: {error}") + # Other nodes should wait for the exit signal from Rank-0 node. + else: + start_time_waiting = time.time() + while proc.is_alive(): + if time.time() - start_time_waiting < timeout: + time.sleep(10) + else: + raise TimeoutError("Waiting for main node timeout!") return proc except requests.RequestException: pass @@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs): proc = launch_server_process_and_send_one_request(server_args, compile_args) - kill_process_tree(proc.pid) - print("\nDeepGEMM Kernels compilation finished successfully.") + # Sleep for safety + time.sleep(10) + if proc.is_alive(): + # This is the rank0 node. + kill_process_tree(proc.pid) + else: + try: + kill_process_tree(proc.pid) + except Exception: + pass + if __name__ == "__main__": parser = argparse.ArgumentParser()