diff --git a/sgl-model-gateway/bindings/golang/Makefile b/sgl-model-gateway/bindings/golang/Makefile index a1b73404681a..b7dee7b4ee22 100644 --- a/sgl-model-gateway/bindings/golang/Makefile +++ b/sgl-model-gateway/bindings/golang/Makefile @@ -1,10 +1,10 @@ -# Makefile for sglang-router golang bindings +# Makefile for sgl-model-gateway golang bindings # This builds the Rust FFI library and provides convenience targets for Go development # Configuration CARGO_BUILD_DIR ?= $(shell pwd)/target BUILD_MODE ?= release -LIB_NAME = libsglang_router_rs +LIB_NAME = libsgl_model_gateway_go # Detect OS UNAME_S := $(shell uname -s) @@ -30,7 +30,7 @@ PYTHON_LDFLAGS := $(shell python3-config --ldflags --embed 2>/dev/null || python # CGO flags - use exported lib directory if available, otherwise build directory LIB_DIR := $(if $(wildcard $(LIB_EXPORT_PATH)),$(LIB_EXPORT_DIR),$(LIB_BUILD_DIR)) -export CGO_LDFLAGS = -L$(LIB_DIR) -lsglang_router_rs $(PYTHON_LDFLAGS) -ldl +export CGO_LDFLAGS = -L$(LIB_DIR) -lsgl_model_gateway_go $(PYTHON_LDFLAGS) -ldl export $(LD_LIBRARY_PATH_VAR) := $(LIB_DIR):$($(LD_LIBRARY_PATH_VAR)) .PHONY: all build build-dev lib lib-clean clean test examples help run-simple run-streaming check-lib diff --git a/sgl-model-gateway/bindings/golang/README.md b/sgl-model-gateway/bindings/golang/README.md index 4a43e1f2e6d4..a5bb8e8dc5b7 100644 --- a/sgl-model-gateway/bindings/golang/README.md +++ b/sgl-model-gateway/bindings/golang/README.md @@ -40,15 +40,61 @@ A high-level Go SDK for interacting with SGLang gRPC API, designed with an OpenA go get github.com/sglang/sglang-go-grpc-sdk ``` +### Sync Dependencies + +```bash +cd sgl-model-gateway/bindings/golang +go mod tidy +``` + ### Build Requirements -- Go 1.21 or later -- Rust toolchain (for building the FFI library) -- Python 3.x (for Python bindings in Rust FFI) -- Tokio runtime for async operations +- Go 1.21+, Rust toolchain, Python 3.x ## Quick Start +### Benchmark + +Run the OpenAI-compatible server and benchmark: + +```bash +# Set environment variables +export SGL_TOKENIZER_PATH="/Users/yangyanbo/tokenizer" +export SGL_GRPC_ENDPOINT="grpc://10.109.185.20:8001" + +# Run server +cd examples/oai_server +bash run.sh + +# Run E2E benchmark +cd ../.. +make e2e E2E_MODEL=/work/models/qwencoder-3b E2E_TOKENIZER=/Users/yangyanbo/tokenizer E2E_INPUT_LEN=1024 E2E_OUTPUT_LEN=512 +``` + +## Examples + +The SDK includes several examples in the `examples/` directory: + +- **simple**: Basic non-streaming chat completion example +- **streaming**: Real-time streaming with performance metrics + +### Running Examples + +```bash +# Run simple example +cd bindings/golang/examples/simple +bash run.sh + +# Run streaming example +cd bindings/golang/examples/streaming +bash run.sh + +# Or use Makefile from bindings/golang directory +cd bindings/golang +make run-simple +make run-streaming +``` + ### Basic Usage (Non-streaming) ```go @@ -162,29 +208,7 @@ func float32Ptr(f float32) *float32 { } ``` -## Examples - -The SDK includes several examples in the `examples/` directory: - -- **simple**: Basic non-streaming chat completion example -- **streaming**: Real-time streaming with performance metrics - -### Running Examples - -```bash -# Run simple example -cd bindings/golang/examples/simple -bash run.sh - -# Run streaming example -cd bindings/golang/examples/streaming -bash run.sh -# Or use Makefile from bindings/golang directory -cd bindings/golang -make run-simple -make run-streaming -``` Examples automatically detect the server endpoint and tokenizer path via environment variables or defaults. @@ -286,15 +310,8 @@ go tool cover -html=coverage.out -o coverage.html #### Unit Test Coverage -- **Configuration validation** (`TestClientConfig`) - Validates ClientConfig requirements -- **Type structures** - Verifying all struct types work correctly -- **Response handling** - Testing response parsing and validation -- **Concurrent operations** (`TestConcurrentClientOperations`) - Thread-safety verification -- **Benchmarks** (`BenchmarkChatCompletionRequest`) - Performance measurement - -**Test Files**: +- Configuration validation, type structures, response handling, concurrent operations, and benchmarks - `client_test.go` - 10 unit tests covering core functionality -- Tests cover: config validation, message types, request validation, close operations, response types, streaming, tools, concurrency, and context cancellation ### Integration Tests @@ -302,28 +319,12 @@ Integration tests require a running SGLang server and test the full client-serve #### Prerequisites -1. Start an SGLang server: - -```bash -# Using Python (requires sglang package installed) -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-hf - -# Or using pre-built Docker image -docker run -p 20000:20000 lmsys/sglang:latest - -# Or build your own -sglang launch_server --model-path -``` - -2. Set required environment variables: - -```bash -# Set the gRPC endpoint (default: grpc://localhost:20000) -export SGL_GRPC_ENDPOINT=grpc://localhost:20000 - -# Set the tokenizer path (required) -export SGL_TOKENIZER_PATH=/path/to/tokenizer -``` +1. Start SGLang server: `python -m sglang.launch_server --model-path ` +2. Set environment variables: + ```bash + export SGL_GRPC_ENDPOINT=grpc://localhost:20000 + export SGL_TOKENIZER_PATH=/path/to/tokenizer + ``` #### Running Integration Tests @@ -352,54 +353,13 @@ go test -tags=integration -race ./... ### Benchmarks -Measure performance of SDK operations: - ```bash -# Run all benchmarks go test -bench=. -benchmem ./... - -# Run specific benchmark -go test -bench=BenchmarkChatCompletionRequest -benchmem - -# Run for longer duration -go test -bench=. -benchtime=10s ./... -``` - -Current benchmarks: -- `BenchmarkChatCompletionRequest` - Measures request creation performance - -### CI/CD Integration - -Add to your GitHub Actions workflow: - -```yaml -- name: Run Go tests - run: | - go test -race -cover ./... - -- name: Run integration tests (on main branch) - if: github.ref == 'refs/heads/main' - env: - SGL_GRPC_ENDPOINT: grpc://localhost:20000 - SGL_TOKENIZER_PATH: /path/to/tokenizer - run: go test -tags=integration ./... ``` ## Documentation -### Code Documentation - -All public types and functions include comprehensive documentation: - -1. **Package-level documentation** in `client.go` with usage examples -2. **Type documentation** for all structs with field descriptions -3. **Function documentation** with: - - Purpose and behavior description - - Parameter documentation with types and constraints - - Return value documentation - - Error cases and handling - - Safety notes (for FFI functions) - - Usage examples +All public types and functions include comprehensive documentation with usage examples. ### Key Documented Components @@ -413,45 +373,19 @@ All public types and functions include comprehensive documentation: ### Viewing Documentation -Generate and view HTML documentation: - ```bash -# Install godoc (if not already installed) -go install golang.org/x/tools/cmd/godoc@latest - -# Generate and serve documentation godoc -http=:6060 - # Visit: http://localhost:6060/pkg/github.com/sglang/sglang-go-grpc-sdk/ ``` ## Development -### Building - ```bash cd bindings/golang - -# Build the Go bindings (compiles Rust FFI library) -make build - -# Clean build -make clean && make build -``` - -### Code Quality - -Ensure code quality before committing: - -```bash -# Run Go vet (check for potential bugs) -go vet ./... - -# Format code -go fmt ./... - -# Run all tests with race detection -go test -race ./... +make build # Build Go bindings +go vet ./... # Check code quality +go fmt ./... # Format code +go test -race ./... # Run tests ``` ### Project Structure @@ -478,27 +412,23 @@ bindings/golang/ ## Troubleshooting -### Connection Errors +### Missing Dependencies -**Error**: `connection refused` or `failed to dial` +Run `go mod tidy` to sync dependencies. -**Solution**: -1. Ensure SGLang server is running: `python -m sglang.launch_server` -2. Check endpoint: `echo $SGL_GRPC_ENDPOINT` -3. Verify port is not blocked: `nc -zv localhost 20000` +### Connection Errors -### Tokenizer Not Found +Ensure SGLang server is running and check `SGL_GRPC_ENDPOINT`. -**Error**: `tokenizer path not found` or `tokenizer configuration missing` +### Tokenizer Not Found -**Solution**: -1. Set `SGL_TOKENIZER_PATH` environment variable +Set `SGL_TOKENIZER_PATH` environment variable. 2. Verify path contains required files: `ls $SGL_TOKENIZER_PATH` 3. Files should include: `tokenizer.json`, `vocab.json`, `config.json` ### Build Failures -**Error**: `library 'sglang_router_rs' not found` +**Error**: `library 'sgl_model_gateway_go' not found` **Solution**: 1. Rebuild Rust library: `cd sgl-model-gateway/bindings/golang && make build` diff --git a/sgl-model-gateway/bindings/golang/client.go b/sgl-model-gateway/bindings/golang/client.go index 110135da58a2..4bd2cad0da18 100644 --- a/sgl-model-gateway/bindings/golang/client.go +++ b/sgl-model-gateway/bindings/golang/client.go @@ -32,8 +32,9 @@ import ( "io" "strings" "sync" + "time" - "github.com/sglang/sglang-go-grpc-sdk/internal/ffi" + grpcclient "github.com/sglang/sglang-go-grpc-sdk/internal/grpc" ) // Client is the main client for interacting with SGLang gRPC API. @@ -44,7 +45,7 @@ import ( type Client struct { endpoint string tokenizerPath string - clientHandle *ffi.SglangClientHandle + grpcClient *grpcclient.GrpcClient // gRPC-based client mu sync.RWMutex } @@ -58,6 +59,41 @@ type ClientConfig struct { // tokenizer configuration files (e.g., tokenizer.json, vocab.json). // Required field. TokenizerPath string + + // ChannelBufferSizes configures buffer sizes for internal channels. + // If nil, default values will be used (optimized for high concurrency). + ChannelBufferSizes *ChannelBufferSizes + + // Timeouts configures timeout values for various operations. + // If nil, default values will be used. + Timeouts *Timeouts +} + +// ChannelBufferSizes configures buffer sizes for internal channels. +// These affect concurrency and memory usage. Larger buffers allow more +// concurrent operations but use more memory. +type ChannelBufferSizes = grpcclient.ChannelBufferSizes + +// Timeouts configures timeout values for various operations. +type Timeouts = grpcclient.Timeouts + +// defaultChannelBufferSizes returns default channel buffer sizes optimized for high concurrency (10k+). +// These values are designed to handle thousands of concurrent requests without blocking. +func defaultChannelBufferSizes() ChannelBufferSizes { + return ChannelBufferSizes{ + ResultJSONChan: 10000, // Increased for high concurrency: each request may produce 200-500 chunks + ErrChan: 100, // Errors are rare, 100 is sufficient + RecvChan: 2000, // Increased for high concurrency: more gRPC responses to buffer + } +} + +// defaultTimeouts returns default timeout values. +func defaultTimeouts() Timeouts { + return Timeouts{ + KeepaliveTime: 300 * time.Second, // Increased to reduce ping frequency and avoid "too many pings" errors + KeepaliveTimeout: 20 * time.Second, + CloseTimeout: 5 * time.Second, + } } // NewClient creates a new SGLang client with the given configuration. @@ -77,15 +113,41 @@ func NewClient(config ClientConfig) (*Client, error) { return nil, errors.New("tokenizer path is required") } - clientHandle, err := ffi.NewClient(config.Endpoint, config.TokenizerPath) + bufferSizes := defaultChannelBufferSizes() + if config.ChannelBufferSizes != nil { + if config.ChannelBufferSizes.ResultJSONChan > 0 { + bufferSizes.ResultJSONChan = config.ChannelBufferSizes.ResultJSONChan + } + if config.ChannelBufferSizes.ErrChan > 0 { + bufferSizes.ErrChan = config.ChannelBufferSizes.ErrChan + } + if config.ChannelBufferSizes.RecvChan > 0 { + bufferSizes.RecvChan = config.ChannelBufferSizes.RecvChan + } + } + + timeouts := defaultTimeouts() + if config.Timeouts != nil { + if config.Timeouts.KeepaliveTime > 0 { + timeouts.KeepaliveTime = config.Timeouts.KeepaliveTime + } + if config.Timeouts.KeepaliveTimeout > 0 { + timeouts.KeepaliveTimeout = config.Timeouts.KeepaliveTimeout + } + if config.Timeouts.CloseTimeout > 0 { + timeouts.CloseTimeout = config.Timeouts.CloseTimeout + } + } + + grpcClient, err := grpcclient.NewGrpcClient(config.Endpoint, config.TokenizerPath, bufferSizes, timeouts) if err != nil { - return nil, fmt.Errorf("failed to create client: %w", err) + return nil, fmt.Errorf("failed to create gRPC client: %w", err) } return &Client{ endpoint: config.Endpoint, tokenizerPath: config.TokenizerPath, - clientHandle: clientHandle, + grpcClient: grpcClient, }, nil } @@ -97,9 +159,11 @@ func (c *Client) Close() error { c.mu.Lock() defer c.mu.Unlock() - if c.clientHandle != nil { - c.clientHandle.Free() - c.clientHandle = nil + if c.grpcClient != nil { + if err := c.grpcClient.Close(); err != nil { + return err + } + c.grpcClient = nil } return nil } @@ -229,7 +293,7 @@ type MessageDelta struct { // // Context Support: // The ctx parameter is fully supported for cancellation and timeouts: -// - If ctx is cancelled, the request will be interrupted on the next stream.Recv() call +// - If ctx is cancelled, the request will be interrupted on the next stream.RecvJSON() call // - If ctx times out, the request will return context.DeadlineExceeded // // Example with timeout: @@ -244,7 +308,6 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq // For non-streaming, we'll collect all chunks and return the final response req.Stream = true // We still use streaming internally, but collect all chunks - // Prepare request: if Tools is empty, set to nil for proper JSON serialization if len(req.Tools) == 0 { req.Tools = nil } @@ -265,7 +328,7 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq var systemFingerprint string for { - chunk, err := stream.Recv() + chunkJSON, err := stream.RecvJSON() if err == io.EOF { break } @@ -273,6 +336,11 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq return nil, err } + var chunk ChatCompletionStreamResponse + if err := json.Unmarshal([]byte(chunkJSON), &chunk); err != nil { + return nil, fmt.Errorf("failed to parse chunk: %w", err) + } + if chunk.ID != "" { responseID = chunk.ID } @@ -293,21 +361,16 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq if len(choice.Delta.ToolCalls) > 0 { fullToolCalls = append(fullToolCalls, choice.Delta.ToolCalls...) } - // Always update finish_reason if present (even if empty string, but should not be empty) - // The last chunk (Complete message) should have finish_reason set if choice.FinishReason != "" { finishReason = choice.FinishReason } } - // Extract usage from chunk if available (usually in the last chunk) - // Always update usage if present, as the last chunk should have the final usage if chunk.Usage != nil { usage = *chunk.Usage } } - // Build final response message := Message{ Role: "assistant", Content: fullContent.String(), @@ -316,8 +379,6 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq message.ToolCalls = fullToolCalls } - // Ensure finish_reason is set (defensive check) - // If finish_reason is still empty, default to "stop" if finishReason == "" { finishReason = "stop" } @@ -341,94 +402,22 @@ func (c *Client) CreateChatCompletion(ctx context.Context, req ChatCompletionReq // ChatCompletionStream represents a streaming chat completion type ChatCompletionStream struct { - stream *ffi.SglangStreamHandle - mu sync.Mutex - done bool // Track if stream has been marked as done - ctx context.Context // Context for cancellation support - cancel context.CancelFunc // Cancel function to stop monitoring goroutine - closed chan struct{} // Signal when stream is closed + grpcStream *grpcclient.GrpcChatCompletionStream + ctx context.Context + cancel context.CancelFunc } -// Recv receives the next chunk from the stream. -// -// Supports context cancellation: if the context passed to CreateChatCompletionStream -// is cancelled, Recv will return context.Canceled error on the next call. -func (s *ChatCompletionStream) Recv() (*ChatCompletionStreamResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - // Check if context was cancelled - select { - case <-s.ctx.Done(): - return nil, s.ctx.Err() // Returns context.Canceled or context.DeadlineExceeded - default: - } - - if s.stream == nil { - return nil, io.EOF - } - - // If stream was already marked as done, immediately return EOF - // This prevents calling ReadNext() again after isDone=1 - if s.done { - return nil, io.EOF - } - - // Loop to handle empty responses (Ok(None) from Rust) - // Keep reading until we get actual data or stream ends - for { - responseJSON, isDone, err := s.stream.ReadNext() - if err != nil { - return nil, err - } - - // Mark stream as done if ReadNext indicates completion - if isDone { - s.done = true - } - - // If we have a response, parse and return it - if responseJSON != "" { - var response ChatCompletionStreamResponse - if err := json.Unmarshal([]byte(responseJSON), &response); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - return &response, nil - } - - // If stream is done but no response, return EOF - if isDone { - return nil, io.EOF - } - - // Empty response and stream not done - loop to read next chunk - // This handles Ok(None) cases where Rust returns no data but stream continues - } +func (s *ChatCompletionStream) RecvJSON() (string, error) { + return s.grpcStream.RecvJSON() } // Close closes the stream and cancels any pending operations. func (s *ChatCompletionStream) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - - // Cancel the context to signal the monitoring goroutine to stop if s.cancel != nil { s.cancel() } - - // Signal that stream is closed - select { - case <-s.closed: - // Already closed - default: - close(s.closed) - } - - // Free the stream to mark it as completed - // This prevents AbortOnDropStream from sending abort when dropped - if s.stream != nil { - s.stream.Free() - s.stream = nil + if s.grpcStream != nil { + return s.grpcStream.Close() } return nil } @@ -437,8 +426,8 @@ func (s *ChatCompletionStream) Close() error { // // Context Support: // The ctx parameter is now fully supported for cancellation and timeouts: -// - If ctx is cancelled, stream.Recv() will return context.Canceled on the next call -// - If ctx times out (WithTimeout), stream.Recv() will return context.DeadlineExceeded +// - If ctx is cancelled, stream.RecvJSON() will return context.Canceled on the next call +// - If ctx times out (WithTimeout), stream.RecvJSON() will return context.DeadlineExceeded // - Calling stream.Close() also cancels the context // // Example with timeout: @@ -457,54 +446,38 @@ func (s *ChatCompletionStream) Close() error { // cancel() // Cancel after 5 seconds // }() func (c *Client) CreateChatCompletionStream(ctx context.Context, req ChatCompletionRequest) (*ChatCompletionStream, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - if c.clientHandle == nil { - return nil, errors.New("client is closed") - } - - // Marshal request to JSON, then ensure tools field is always present. - // Due to omitempty tag, empty Tools slice will be omitted from JSON. - // We need to ensure tools field is always present as [] when empty (not omitted), - // matching the behavior of complete_sdk example. reqJSON, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Unmarshal into map and ensure tools field is present var reqMap map[string]interface{} if err := json.Unmarshal(reqJSON, &reqMap); err != nil { return nil, fmt.Errorf("failed to unmarshal request to map: %w", err) } - // Add empty tools array if not present if _, exists := reqMap["tools"]; !exists { reqMap["tools"] = []interface{}{} } - // Marshal back to JSON reqJSON, err = json.Marshal(reqMap) if err != nil { return nil, fmt.Errorf("failed to marshal request map to JSON: %w", err) } - // Create stream - streamHandle, err := c.clientHandle.ChatCompletionStream(string(reqJSON)) - if err != nil { - return nil, fmt.Errorf("failed to create stream: %w", err) + if c.grpcClient == nil { + return nil, errors.New("gRPC client is closed") } - // Create a child context from the provided context for cancellation support - streamCtx, cancel := context.WithCancel(ctx) - - stream := &ChatCompletionStream{ - stream: streamHandle, - ctx: streamCtx, - cancel: cancel, - closed: make(chan struct{}), + grpcStream, err := c.grpcClient.CreateChatCompletionStream(ctx, string(reqJSON)) + if err != nil { + return nil, fmt.Errorf("failed to create gRPC stream: %w", err) } - return stream, nil + streamCtx, cancel := context.WithCancel(ctx) + return &ChatCompletionStream{ + grpcStream: grpcStream, + ctx: streamCtx, + cancel: cancel, + }, nil } diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/Makefile b/sgl-model-gateway/bindings/golang/examples/oai_server/Makefile new file mode 100644 index 000000000000..a134298438d4 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/Makefile @@ -0,0 +1,239 @@ +# Makefile for OAI Server +# Builds binary, runs tests, and provides basic targets + +# Configuration +APP_NAME = oai_server +VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S') +GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") + +# Paths +ROOT_DIR := $(shell pwd) +BINDINGS_DIR := $(shell cd $(ROOT_DIR)/../.. && pwd) +BUILD_DIR := $(ROOT_DIR)/build +BINARY := $(BUILD_DIR)/$(APP_NAME) + +# Rust FFI library paths +LIB_DIR := $(BINDINGS_DIR)/lib +LIB_NAME = libsgl_model_gateway_go + +# Detect OS +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Linux) + LIB_EXT = .so + LD_LIBRARY_PATH_VAR = LD_LIBRARY_PATH + ARCH := $(shell uname -m) + ifeq ($(ARCH),x86_64) + GOARCH = amd64 + else ifeq ($(ARCH),aarch64) + GOARCH = arm64 + endif +endif +ifeq ($(UNAME_S),Darwin) + LIB_EXT = .dylib + LD_LIBRARY_PATH_VAR = DYLD_LIBRARY_PATH + ARCH := $(shell uname -m) + ifeq ($(ARCH),x86_64) + GOARCH = amd64 + else ifeq ($(ARCH),arm64) + GOARCH = arm64 + endif +endif + +# Build flags +LDFLAGS = -X main.Version=$(VERSION) -X main.BuildTime=$(BUILD_TIME) -X main.GitCommit=$(GIT_COMMIT) +GO_BUILD_FLAGS = -ldflags "$(LDFLAGS)" + +# Python LDFLAGS (needed for Rust FFI that depends on Python) +PYTHON_LDFLAGS := $(shell python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || python-config --ldflags --embed 2>/dev/null || python-config --ldflags 2>/dev/null || echo "") + +# CGO flags +CGO_LDFLAGS = -L$(LIB_DIR) $(PYTHON_LDFLAGS) + +.PHONY: all build build-dev test e2e clean help lib run stream check-rust-lib check-server + +# E2E test configuration +E2E_HOST ?= localhost +E2E_PORT ?= 8080 +E2E_MODEL ?= default +E2E_TOKENIZER ?= $(shell echo $$SGL_TOKENIZER_PATH || echo "./examples/tokenizer") +E2E_NUM_PROMPTS ?= 100 +E2E_INPUT_LEN ?= 1024 +E2E_OUTPUT_LEN ?= 512 +E2E_REQUEST_RATE ?= 20 +E2E_MAX_CONCURRENCY ?= 20 +E2E_BASE_URL ?= http://$(E2E_HOST):$(E2E_PORT) + +help: + @echo "OAI Server Makefile" + @echo "" + @echo "Available targets:" + @echo " lib - Build Rust FFI library" + @echo " build - Build binary (release mode)" + @echo " build-dev - Build binary (debug mode)" + @echo " test - Run tests" + @echo " e2e - Run end-to-end test with bench_serving.py" + @echo " run - Run the server (development)" + @echo " stream - Run streaming example" + @echo " clean - Clean build artifacts" + @echo "" + @echo "E2E test variables:" + @echo " E2E_HOST - OAI Server host (default: localhost)" + @echo " E2E_PORT - OAI Server port (default: 8080)" + @echo " E2E_MODEL - Model name (default: default)" + @echo " E2E_TOKENIZER - Tokenizer path" + @echo " E2E_NUM_PROMPTS - Number of prompts (default: 100)" + @echo " E2E_INPUT_LEN - Input token length (default: 1024)" + @echo " E2E_OUTPUT_LEN - Output token length (default: 512)" + @echo " E2E_REQUEST_RATE - Request rate per second (default: 20)" + @echo " E2E_MAX_CONCURRENCY - Max concurrent requests (default: 20)" + +all: build + +# Build Rust FFI library +lib: + @echo "Building Rust FFI library..." + @cd $(BINDINGS_DIR) && $(MAKE) lib + @echo "✓ Rust FFI library built" + +# Check if Rust FFI library exists +check-rust-lib: + @if [ ! -f "$(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)" ]; then \ + echo "Error: Rust FFI library not found at $(LIB_DIR)/$(LIB_NAME)$(LIB_EXT)"; \ + echo "Building Rust library..."; \ + cd $(BINDINGS_DIR) && $(MAKE) lib; \ + fi + @echo "✓ Rust FFI library found" + +# Build binary (release) +build: check-rust-lib + @echo "Building $(APP_NAME) (release mode)..." + @mkdir -p $(BUILD_DIR) + @CGO_ENABLED=1 \ + CGO_LDFLAGS="$(CGO_LDFLAGS)" \ + GOOS=$(shell go env GOOS) \ + GOARCH=$(GOARCH) \ + go build $(GO_BUILD_FLAGS) -o $(BINARY) . + @echo "✓ Binary built: $(BINARY)" + +# Build binary (debug) +build-dev: check-rust-lib + @echo "Building $(APP_NAME) (debug mode)..." + @mkdir -p $(BUILD_DIR) + @CGO_ENABLED=1 \ + CGO_LDFLAGS="$(CGO_LDFLAGS)" \ + go build -o $(BINARY) . + @echo "✓ Binary built (debug): $(BINARY)" + +# Run tests +test: check-rust-lib + @echo "Running tests..." + @CGO_ENABLED=1 \ + CGO_LDFLAGS="$(CGO_LDFLAGS)" \ + export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \ + go test -v ./... + @echo "✓ Tests completed" + +# Check if OAI Server is running +check-server: + @echo "Checking if OAI Server is running at $(E2E_BASE_URL)..." + @if curl -s -f $(E2E_BASE_URL)/health > /dev/null 2>&1; then \ + echo "✓ OAI Server is running"; \ + exit 0; \ + else \ + echo "✗ OAI Server is not running at $(E2E_BASE_URL)"; \ + echo " Start it with: make run"; \ + exit 1; \ + fi + +# Find sglang project root (4 levels up from oai_server) +SGLANG_ROOT := $(shell cd $(ROOT_DIR)/../../../../.. && pwd) + +# Run end-to-end test with bench_serving.py +e2e: check-server + @echo "Checking if bench_serving.py is available..." + @if python -m sglang.bench_serving --help > /dev/null 2>&1; then \ + echo "✓ Using installed bench_serving.py module"; \ + USE_SGLANG_ROOT=false; \ + elif [ -f "$(SGLANG_ROOT)/python/sglang/bench_serving.py" ]; then \ + echo "✓ Using bench_serving.py from $(SGLANG_ROOT)"; \ + USE_SGLANG_ROOT=true; \ + else \ + echo "✗ bench_serving.py is not available"; \ + echo " Install dependencies: pip install aiohttp numpy datasets transformers tqdm pillow pybase64"; \ + exit 1; \ + fi + @echo "Running end-to-end test with bench_serving.py..." + @echo "Configuration:" + @echo " Server: $(E2E_BASE_URL)" + @if [ "$(E2E_MODEL)" != "default" ]; then \ + echo " Model: $(E2E_MODEL)"; \ + fi + @if [ -n "$(E2E_TOKENIZER)" ]; then \ + echo " Tokenizer: $(E2E_TOKENIZER)"; \ + fi + @echo " Prompts: $(E2E_NUM_PROMPTS)" + @echo " Input/Output: $(E2E_INPUT_LEN)/$(E2E_OUTPUT_LEN) tokens" + @echo " Request rate: $(E2E_REQUEST_RATE) req/s" + @echo " Max concurrency: $(E2E_MAX_CONCURRENCY)" + @echo "" + @TOKENIZER_ABS=$$(cd $(ROOT_DIR) && python3 -c "import os; path='$(E2E_TOKENIZER)'; print(os.path.abspath(path) if not os.path.isabs(path) else path)" 2>/dev/null || echo "$(E2E_TOKENIZER)"); \ + if [ -n "$(E2E_TOKENIZER)" ]; then \ + if [ -n "$$TOKENIZER_ABS" ] && ([ -d "$$TOKENIZER_ABS" ] || [ -f "$$TOKENIZER_ABS" ]); then \ + TOKENIZER_ARG="--tokenizer $$TOKENIZER_ABS"; \ + else \ + TOKENIZER_ARG="--tokenizer $(E2E_TOKENIZER)"; \ + fi; \ + else \ + TOKENIZER_ARG=""; \ + fi; \ + if [ "$$USE_SGLANG_ROOT" = "true" ]; then \ + cd $(SGLANG_ROOT) && PYTHONPATH=$(SGLANG_ROOT)/python:$$PYTHONPATH python python/sglang/bench_serving.py \ + --backend sglang-oai-chat \ + --base-url $(E2E_BASE_URL) \ + $$([ "$(E2E_MODEL)" != "default" ] && echo "--model $(E2E_MODEL)") \ + $$TOKENIZER_ARG \ + --dataset-name random \ + --num-prompts $(E2E_NUM_PROMPTS) \ + --random-input-len $(E2E_INPUT_LEN) \ + --random-output-len $(E2E_OUTPUT_LEN) \ + --request-rate $(E2E_REQUEST_RATE) \ + --max-concurrency $(E2E_MAX_CONCURRENCY) \ + --warmup-requests 5 \ + --disable-tqdm || (echo "✗ E2E test failed"; exit 1); \ + else \ + python -m sglang.bench_serving \ + --backend sglang-oai-chat \ + --base-url $(E2E_BASE_URL) \ + $$([ "$(E2E_MODEL)" != "default" ] && echo "--model $(E2E_MODEL)") \ + $$TOKENIZER_ARG \ + --dataset-name random \ + --num-prompts $(E2E_NUM_PROMPTS) \ + --random-input-len $(E2E_INPUT_LEN) \ + --random-output-len $(E2E_OUTPUT_LEN) \ + --request-rate $(E2E_REQUEST_RATE) \ + --max-concurrency $(E2E_MAX_CONCURRENCY) \ + --warmup-requests 5 \ + --disable-tqdm || (echo "✗ E2E test failed"; exit 1); \ + fi + @echo "" + @echo "✓ E2E test completed" + +# Run the server (development) +run: build-dev + @echo "Running server..." + @export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \ + $(BINARY) + +# Run streaming example +stream: check-rust-lib + @echo "Running streaming example..." + @cd $(BINDINGS_DIR)/examples/streaming && \ + export $(LD_LIBRARY_PATH_VAR)="$(LIB_DIR):$$$(LD_LIBRARY_PATH_VAR)" && \ + bash run.sh + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + @rm -rf $(BUILD_DIR) + @echo "✓ Clean complete" diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/README.md b/sgl-model-gateway/bindings/golang/examples/oai_server/README.md new file mode 100644 index 000000000000..46b3cf0817ce --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/README.md @@ -0,0 +1,305 @@ +# Go SGLang Router - OpenAI Compatible API Server + +Go SGLang Router is a high-performance OpenAI-compatible API server that communicates with the SGLang backend via gRPC and performs efficient preprocessing and postprocessing through Rust FFI. + +## Features + +- ✅ **OpenAI API Compatible**: Fully compatible with OpenAI Chat Completions API +- ✅ **High Performance**: Low latency and high throughput using gRPC and Rust FFI +- ✅ **Streaming Support**: Server-Sent Events (SSE) streaming responses +- ✅ **Thread-Safe**: Pre-created tokenizer handle, lock-free concurrency +- ✅ **Graceful Shutdown**: Context cancellation mechanism to avoid resource leaks and panics +- ✅ **Configurable**: Supports configuring channel buffer sizes and timeout durations + +## Architecture Overview + +**Important Note**: gRPC mode **still calls FFI**, which is used for: +- **Preprocessing**: chat_template and tokenization (request phase) +- **Postprocessing**: token decoding and tool parsing (response phase) + +gRPC is only used for communication with the SGLang backend, while input/output processing completely relies on Rust FFI. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ HTTP Client │ +│ (OpenAI API Format) │ +└────────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ FastHTTP Server │ +│ handlers/chat.go:HandleChatCompletion │ +│ - Parse request JSON │ +│ - SetBodyStreamWriter (SSE) │ +└────────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ SGLang Client (client.go) │ +│ CreateChatCompletionStream(ctx, req) │ +│ - Wraps gRPC client │ +└────────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ gRPC Client (internal/grpc/client_grpc.go) │ +│ CreateChatCompletionStream(ctx, reqJSON) │ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Step 1: FFI Preprocess (Rust FFI) │ │ +│ │ - ffi.PreprocessChatRequestWithTokenizer() │ │ +│ │ - chat_template application │ │ +│ │ - tokenization │ │ +│ │ - tool constraints generation │ │ +│ │ Returns: PromptText, TokenIDs, ToolConstraintsJSON, │ │ +│ │ PromptTokens │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Step 2: Build gRPC Request │ │ +│ │ - Parse request JSON (model, temperature, etc.) │ │ +│ │ - Create proto.GenerateRequest │ │ +│ │ - Set TokenizedInput (PromptText, TokenIDs) │ │ +│ │ - Set SamplingParams (temperature, top_p, top_k, etc.) │ │ +│ │ - Set Constraints (from ToolConstraintsJSON) │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Step 3: Create gRPC Stream │ │ +│ │ - client.Generate(generateReq) → gRPC stream │ │ +│ │ - Connects to SGLang Backend (Rust) │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Step 4: Create Converter & BatchPostprocessor │ │ +│ │ - ffi.CreateGrpcResponseConverterWithTokenizer() │ │ +│ │ - Uses preprocessed.PromptTokens for initial count │ │ +│ │ - ffi.NewBatchPostprocessor(batchSize=1, immediate) │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Step 5: Start readLoop (Background Goroutine) │ │ +│ │ - go grpcStream.readLoop() │ │ +│ │ - Returns GrpcChatCompletionStream immediately │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +└───────────────────────┼────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ GrpcChatCompletionStream.readLoop() │ +│ (Background Goroutine) │ +│ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Recv() Goroutine (Dedicated) │ │ +│ │ - Continuously calls stream.Recv() │ │ +│ │ - Sends results to recvChan (buffered, 2000) │ │ +│ │ - Exits on ctx.Done() or error │ │ +│ │ - Calls stream.CloseSend() on ctx.Done() │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Main Loop │ │ +│ │ - Reads from recvChan │ │ +│ │ - For each proto.GenerateResponse: │ │ +│ │ → go processAndSendResponse() (async) │ │ +│ │ - protoToJSON() converts proto to JSON string │ │ +│ │ - batchPostprocessor.AddChunk(protoJSON) │ │ +│ │ → FFI postprocessing (token decoding, tool parsing)│ │ +│ │ → Returns OpenAI-format JSON strings │ │ +│ │ - Sends JSON to resultJSONChan (buffered, 10000) │ │ +│ │ - All operations check ctx.Done() for cancellation │ │ +│ │ - On EOF: flush batch, send remaining results, return │ │ +│ │ - On error: send to errChan (buffered, 100) │ │ +│ │ - defer: cancel ctx, wait goroutines, close channels │ │ +│ └────────────────────┬─────────────────────────────────────┘ │ +└───────────────────────┼────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ resultJSONChan (Buffered Channel, 10000) │ +│ - Contains OpenAI-format JSON strings │ +│ - Ready for consumption │ +└────────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ ChatCompletionStream.RecvJSON() │ +│ (client.go:410) │ +│ - Direct wrapper: return grpcStream.RecvJSON() │ +│ - No intermediate processing │ +└────────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ FastHTTP SetBodyStreamWriter │ +│ (handlers/chat.go:159) │ +│ - Loop: stream.RecvJSON() → format SSE → flush │ +│ - Format: "data: {json}\n\n" │ +│ - Final: "data: [DONE]\n\n" │ +│ - Immediate flush after each chunk │ +└────────────────────────────┬────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ HTTP Client │ +│ (SSE Stream) │ +│ Receives: data: {...}\n\n │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Quick Start + +### Start Server + +```bash +./run.sh +``` + +The server will start on port `:8080`. + +### Usage Example + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/path/to/model", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": true + }' +``` + +## Key Design + +### 1. Thread-Safe Tokenizer +- Pre-create `TokenizerHandle` at startup +- Rust side uses `Arc`, thread-safe +- Lock-free concurrency, eliminating lock contention + +### 2. Context Cancellation Mechanism (Graceful Shutdown) +- Use `context.Context` cancellation mechanism +- In `readLoop`'s `defer`: cancel context first, then wait for all goroutines to complete, finally close channels +- `processAndSendResponse` checks `ctx.Done()` at function start, all `select` statements include `case <-s.ctx.Done()` +- Avoids "send on closed channel" panic + +### 3. Cancellable Recv() +- Use dedicated goroutine to execute `Recv()` +- Pass results through `recvChan` +- Call `CloseSend()` when context is cancelled to make `Recv()` return error + +### 4. Simplified Channel Design +- `resultJSONChan`: Main data channel (gRPC layer) +- `errChan`: Error channel (gRPC layer) +- `recvChan`: Internal communication channel (gRPC layer) +- Removed redundant channels and duplicate reads + +## Configuration + +### Channel Buffer Sizes + +```go +type ChannelBufferSizes struct { + ResultJSONChan int // Default: 10000 + ErrChan int // Default: 100 + RecvChan int // Default: 2000 +} +``` + +### Timeout Configuration + +```go +type Timeouts struct { + KeepaliveTime time.Duration // Default: 300s + KeepaliveTimeout time.Duration // Default: 20s + CloseTimeout time.Duration // Default: 5s +} +``` + +## Performance Optimizations + +1. **Pre-create Tokenizer**: Created at startup to avoid first request latency +2. **Lock-Free Concurrency**: Tokenizer is thread-safe, no locks needed +3. **Lazy Parsing**: JSON parsing deferred until needed +4. **Direct JSON Passing**: `RecvJSON()` avoids parse/serialize overhead +5. **Immediate Batching**: batchSize=1, no delay +6. **Async Processing**: `readLoop` processes in background, doesn't block request handling +7. **Configurable Buffers**: Adjust channel sizes based on concurrency needs + +## File Structure + +``` +sgl-model-gateway/bindings/golang/ +├── client.go # High-level client API +├── internal/ +│ ├── grpc/ +│ │ └── client_grpc.go # gRPC client implementation +│ ├── ffi/ # FFI bindings (Rust) +│ └── proto/ # Protobuf definitions +└── examples/ + └── oai_server/ + ├── handlers/ + │ └── chat.go # HTTP request handling + ├── models/ + │ └── chat.go # Request/response models + └── service/ + └── sglang_service.go # Service layer +``` + +## Error Handling + +### Context Cancellation Mechanism +1. **Client disconnects** → `SetBodyStreamWriter` detects flush error +2. **Cancel streamCtx** → `readLoop` detects `ctx.Done()` +3. **Call stream.CloseSend()** → `Recv()` goroutine returns error +4. **readLoop defer executes**: + - Set `closed` flag + - Cancel context (if not already cancelled) + - Wait for all `processAndSendResponse` goroutines to complete (`processWg.Wait()`) + - Close all channels (`resultJSONChan`, `errChan`, `readLoopDone`) +5. **Clean up resources and exit** + +### Channel Blocking and Race Condition Prevention +- **Context cancellation mechanism**: All channel sends use `select` statements with `case <-s.ctx.Done()` +- **Graceful exit**: When context is cancelled, all blocking send operations can return immediately +- **WaitGroup synchronization**: `readLoop`'s `defer` uses `processWg.Wait()` to ensure all goroutines complete before closing channels +- **Avoid panic**: Through context cancellation and WaitGroup synchronization, avoids "send on closed channel" panic + +## Key Functions + +### CreateChatCompletionStream +**Location**: `internal/grpc/client_grpc.go:108` +- Preprocess request (FFI) +- Build gRPC request +- Create converter and batch processor +- Start `readLoop` + +### readLoop +**Location**: `internal/grpc/client_grpc.go:290` +- Start Recv() goroutine (continuously calls `stream.Recv()`) +- Process proto responses +- Asynchronously call `processAndSendResponse` (tracked with `processWg`) +- **Graceful shutdown in defer**: + - Set `closed` flag + - Cancel context (if not already cancelled) + - Wait for all `processAndSendResponse` goroutines to complete (`processWg.Wait()`) + - Close all channels (`resultJSONChan`, `errChan`, `readLoopDone`) + +### processAndSendResponse +**Location**: `internal/grpc/client_grpc.go:379` +- Check `ctx.Done()` at function start, return immediately if cancelled +- Convert proto to JSON +- Call FFI batch processor +- All `select` statements include `case <-s.ctx.Done()` for graceful shutdown handling +- Send JSON to channel + +### RecvJSON +**Location**: +- `internal/grpc/client_grpc.go:412`: gRPC layer implementation +- `client.go:410`: Client wrapper layer +- Read from `resultJSONChan` +- Directly return JSON string, no parsing needed diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/config/config.go b/sgl-model-gateway/bindings/golang/examples/oai_server/config/config.go new file mode 100644 index 000000000000..5442ac34ae5c --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/config/config.go @@ -0,0 +1,55 @@ +package config + +import ( + "os" +) + +// Config holds the application configuration +type Config struct { + Endpoint string + TokenizerPath string + Port string + LogDir string + LogLevel string +} + +// Load loads configuration from environment variables with defaults +func Load() *Config { + // Get tokenizer path from environment or use default + tokenizerPath := os.Getenv("SGL_TOKENIZER_PATH") + if tokenizerPath == "" { + tokenizerPath = "../tokenizer" + } + + // Get endpoint from environment or use default + endpoint := os.Getenv("SGL_GRPC_ENDPOINT") + if endpoint == "" { + endpoint = "grpc://localhost:20000" + } + + // Get port from environment or use default + port := os.Getenv("PORT") + if port == "" { + port = "8080" + } + + // Get log directory from environment or use default + logDir := os.Getenv("LOG_DIR") + if logDir == "" { + logDir = "./logs" + } + + // Get log level from environment or use default + logLevel := os.Getenv("LOG_LEVEL") + if logLevel == "" { + logLevel = "info" + } + + return &Config{ + Endpoint: endpoint, + TokenizerPath: tokenizerPath, + Port: port, + LogDir: logDir, + LogLevel: logLevel, + } +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/docs/benchmark_result.md b/sgl-model-gateway/bindings/golang/examples/oai_server/docs/benchmark_result.md new file mode 100644 index 000000000000..4c3fd19fc6ad --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/docs/benchmark_result.md @@ -0,0 +1,121 @@ +/tmp/ShareGPT_V3_unfiltered_cleaned_split.json: 100%|████████████████████| 642M/642M [10:02<00:00, 1.12MB/s] +#Input tokens: 50561 +#Output tokens: 25883 +Starting warmup with 5 sequences... +Warmup completed with 5 sequences. Starting main benchmark run... + +============ Serving Benchmark Result ============ +Backend: sglang-oai-chat +Traffic request rate: 20.0 +Max request concurrency: 20 +Successful requests: 100 +Benchmark duration (s): 107.24 +Total input tokens: 50561 +Total input text tokens: 50561 +Total input vision tokens: 0 +Total generated tokens: 25883 +Total generated tokens (retokenized): 129591 +Request throughput (req/s): 0.93 +Input token throughput (tok/s): 471.48 +Output token throughput (tok/s): 241.36 +Total token throughput (tok/s): 712.84 +Concurrency: 16.42 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 17609.46 +Median E2E Latency (ms): 12343.82 +---------------Time to First Token---------------- +Mean TTFT (ms): 190.71 +Median TTFT (ms): 164.86 +P99 TTFT (ms): 397.72 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 162.55 +Median TPOT (ms): 63.51 +P99 TPOT (ms): 1337.20 +---------------Inter-Token Latency---------------- +Mean ITL (ms): 25.85 +Median ITL (ms): 24.26 +P95 ITL (ms): 48.26 +P99 ITL (ms): 119.04 +Max ITL (ms): 194.58 +================================================== + +✓ E2E test completed + + +## Rust +============ Serving Benchmark Result ============ +Backend: sglang-oai-chat +Traffic request rate: 20.0 +Max request concurrency: 20 +Successful requests: 100 +Benchmark duration (s): 37.71 +Total input tokens: 50561 +Total input text tokens: 50561 +Total input vision tokens: 0 +Total generated tokens: 25883 +Total generated tokens (retokenized): 25599 +Request throughput (req/s): 2.65 +Input token throughput (tok/s): 1340.75 +Output token throughput (tok/s): 686.35 +Total token throughput (tok/s): 2027.10 +Concurrency: 18.58 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 7008.05 +Median E2E Latency (ms): 7061.24 +---------------Time to First Token---------------- +Mean TTFT (ms): 156.09 +Median TTFT (ms): 133.81 +P99 TTFT (ms): 318.53 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 26.59 +Median TPOT (ms): 26.75 +P99 TPOT (ms): 29.18 +---------------Inter-Token Latency---------------- +Mean ITL (ms): 26.71 +Median ITL (ms): 23.61 +P95 ITL (ms): 66.11 +P99 ITL (ms): 115.30 +Max ITL (ms): 201.08 +================================================== + + +## golang +#Input tokens: 50561 +#Output tokens: 25883 +Starting warmup with 5 sequences... +Warmup completed with 5 sequences. Starting main benchmark run... + +============ Serving Benchmark Result ============ +Backend: sglang-oai-chat +Traffic request rate: 20.0 +Max request concurrency: 20 +Successful requests: 100 +Benchmark duration (s): 34.22 +Total input tokens: 50561 +Total input text tokens: 50561 +Total input vision tokens: 0 +Total generated tokens: 22970 +Total generated tokens (retokenized): 31740 +Request throughput (req/s): 2.92 +Input token throughput (tok/s): 1477.70 +Output token throughput (tok/s): 671.32 +Total token throughput (tok/s): 2149.03 +Concurrency: 18.42 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 6303.33 +Median E2E Latency (ms): 6294.46 +---------------Time to First Token---------------- +Mean TTFT (ms): 157.10 +Median TTFT (ms): 149.16 +P99 TTFT (ms): 251.98 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 26.49 +Median TPOT (ms): 27.15 +P99 TPOT (ms): 28.73 +---------------Inter-Token Latency---------------- +Mean ITL (ms): 26.97 +Median ITL (ms): 24.61 +P95 ITL (ms): 52.39 +P99 ITL (ms): 86.52 +Max ITL (ms): 194.55 +================================================== diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/go.mod b/sgl-model-gateway/bindings/golang/examples/oai_server/go.mod new file mode 100644 index 000000000000..9fec2c981acc --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/go.mod @@ -0,0 +1,28 @@ +module oai_server + +go 1.24.0 + +toolchain go1.24.10 + +replace github.com/sglang/sglang-go-grpc-sdk => ../.. + +require ( + github.com/sglang/sglang-go-grpc-sdk v0.0.0-00010101000000-000000000000 + github.com/valyala/fasthttp v1.52.0 + go.uber.org/zap v1.27.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 +) + +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/klauspost/compress v1.17.9 // indirect + github.com/stretchr/testify v1.10.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/grpc v1.77.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect +) diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/go.sum b/sgl-model-gateway/bindings/golang/examples/oai_server/go.sum new file mode 100644 index 000000000000..7db4221e3eb3 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/go.sum @@ -0,0 +1,60 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0= +github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/chat.go b/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/chat.go new file mode 100644 index 000000000000..e7ad8a57a50a --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/chat.go @@ -0,0 +1,556 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strings" + "time" + + sglang "github.com/sglang/sglang-go-grpc-sdk" + "github.com/valyala/fasthttp" + "go.uber.org/zap" + + "oai_server/models" + "oai_server/service" + "oai_server/utils" +) + +// ChatHandler handles chat completion requests +type ChatHandler struct { + logger *zap.Logger + service *service.SGLangService +} + +// NewChatHandler creates a new chat handler +func NewChatHandler(logger *zap.Logger, svc *service.SGLangService) *ChatHandler { + return &ChatHandler{ + logger: logger, + service: svc, + } +} + +// recvResult holds the result of a RecvJSON() call +type recvResult struct { + chunkJSON string + err error +} + +// HandleChatCompletion handles POST /v1/chat/completions +func (h *ChatHandler) HandleChatCompletion(ctx *fasthttp.RequestCtx) { + var req models.ChatRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + h.logger.Warn("Invalid chat completion request", zap.Error(err)) + utils.RespondError(ctx, 400, fmt.Sprintf("Invalid request: %v", err), "invalid_request_error") + return + } + + path := string(ctx.Path()) + + defer func() { + statusCode := ctx.Response.StatusCode() + if statusCode == 0 { + statusCode = 200 + } + h.logHTTPResponse(statusCode, path) + }() + + // Convert to SGLang format + messages := make([]sglang.ChatMessage, len(req.Messages)) + for i, msg := range req.Messages { + role, roleOk := msg["role"] + content, contentOk := msg["content"] + + // Validate role + if !roleOk || role == "" { + h.logger.Warn("Missing or empty role in message", zap.Int("message_index", i)) + utils.RespondError(ctx, 400, "Message role is required and cannot be empty", "invalid_request_error") + return + } + + // Ensure content is always a string (not null) + // Chat template requires content field to be present, even if empty + // If content is missing or null, use empty string + contentStr := "" + if contentOk && content != "" { + contentStr = content + } + + messages[i] = sglang.ChatMessage{ + Role: role, + Content: contentStr, + } + } + + sglReq := sglang.ChatCompletionRequest{ + Model: req.Model, + Messages: messages, + Stream: req.Stream, + } + + if req.Temperature != nil { + temp := float32(*req.Temperature) + sglReq.Temperature = &temp + } + if req.TopP != nil { + topP := float32(*req.TopP) + sglReq.TopP = &topP + } + if req.MaxCompletionTokens != nil { + sglReq.MaxCompletionTokens = req.MaxCompletionTokens + } else if req.MaxTokens != nil { + sglReq.MaxCompletionTokens = req.MaxTokens + } + + requestCtx := context.Background() + + if req.Stream { + h.handleStreamingCompletion(ctx, requestCtx, sglReq) + } else { + h.handleNonStreamingCompletion(ctx, requestCtx, sglReq) + } +} + +// isBrokenPipeError checks if the error is a broken pipe error (client disconnected) +func isBrokenPipeError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "connection reset by peer") || + strings.Contains(errStr, "connection closed") || + strings.Contains(errStr, "write: connection closed") +} + +// logHTTPResponse logs HTTP response with colored output +func (h *ChatHandler) logHTTPResponse(statusCode int, path string) { + var statusText string + var colorCode string + + switch { + case statusCode >= 200 && statusCode < 300: + colorCode = "\033[32m" // Green + statusText = "OK" + case statusCode >= 300 && statusCode < 400: + colorCode = "\033[33m" // Yellow + statusText = "Redirect" + case statusCode >= 400 && statusCode < 500: + colorCode = "\033[33m" // Yellow + statusText = "Client Error" + case statusCode >= 500: + colorCode = "\033[31m" // Red + statusText = "Server Error" + default: + colorCode = "\033[37m" // White + statusText = "Unknown" + } + + resetCode := "\033[0m" + msg := fmt.Sprintf("%s[%d %s]%s %s", colorCode, statusCode, statusText, resetCode, path) + h.logger.Info(msg) +} + +func (h *ChatHandler) handleStreamingCompletion(ctx *fasthttp.RequestCtx, requestCtx context.Context, req sglang.ChatCompletionRequest) { + + ctx.SetContentType("text/event-stream") + ctx.Response.Header.Set("Cache-Control", "no-cache") + ctx.Response.Header.Set("Connection", "keep-alive") + ctx.Response.Header.Set("X-Accel-Buffering", "no") + ctx.SetStatusCode(200) + + var clientDisconnected bool + // Flush timeout: prevent deadlock if client is slow or disconnected + // This timeout should be longer than typical network latency but shorter than client timeout + const flushTimeout = 5 * time.Second + + ctx.SetBodyStreamWriter(func(w *bufio.Writer) { + streamCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream, err := h.service.Client().CreateChatCompletionStream(streamCtx, req) + if err != nil { + h.logger.Error("Failed to create chat completion stream", + zap.Error(err), + zap.String("model", req.Model), + ) + // Use sendSSEError to send error in consistent format + errInfo, sendErr := h.sendSSEError(w, err) + if sendErr != nil { + h.logger.Warn("Failed to send SSE error", zap.Error(sendErr)) + } else if errInfo.IsTimeout { + h.logger.Error("Stream creation timeout", zap.Error(err)) + } + return + } + defer func() { + if closeErr := stream.Close(); closeErr != nil { + h.logger.Warn("Failed to close stream", zap.Error(closeErr)) + } + }() + + // Use a single dedicated goroutine to continuously call RecvJSON() and send results via channel + recvChan := make(chan recvResult, 20) + recvGoroutineDone := make(chan struct{}) + go func() { + defer func() { + close(recvChan) + close(recvGoroutineDone) + }() + for { + // Check context before calling RecvJSON() to avoid blocking if context is cancelled + select { + case <-streamCtx.Done(): + return + default: + } + + // Call RecvJSON() - this may block, but stream.Close() will unblock it + // when context is cancelled (called from main loop) + chunkJSON, err := stream.RecvJSON() + + // Check context again after RecvJSON() returns + select { + case <-streamCtx.Done(): + return + default: + } + + // Send to channel (may block if channel is full) + // If channel is full, this will block until main loop reads from it + // This is acceptable because main loop should be actively reading + select { + case recvChan <- recvResult{chunkJSON: chunkJSON, err: err}: + if err != nil { + // EOF or other error, stop the goroutine + return + } + case <-streamCtx.Done(): + // Context cancelled while sending, stop the goroutine + return + } + } + }() + + for { + if clientDisconnected { + cancel() + // Close stream immediately to unblock RecvJSON() calls + stream.Close() + return + } + + select { + case <-streamCtx.Done(): + // Close stream to ensure RecvJSON() goroutine can exit + stream.Close() + return + case result, ok := <-recvChan: + if !ok { + // Channel closed, stream ended + return + } + if result.err == io.EOF { + if !clientDisconnected { + w.WriteString("data: [DONE]\n\n") + // Flush with timeout to prevent deadlock + flushDone := make(chan error, 1) + go func() { + flushDone <- w.Flush() + }() + flushCtx, flushCancel := context.WithTimeout(streamCtx, flushTimeout) + defer flushCancel() + select { + case flushErr := <-flushDone: + if flushErr != nil && !isBrokenPipeError(flushErr) { + h.logger.Warn("Final flush error", zap.Error(flushErr)) + } + case <-flushCtx.Done(): + if flushCtx.Err() == context.DeadlineExceeded { + h.logger.Warn("Final flush timeout", zap.Duration("timeout", flushTimeout)) + } + case <-streamCtx.Done(): + // Context cancelled, skip flush + } + } + return + } + if result.err != nil { + if result.err == context.Canceled || result.err == context.DeadlineExceeded { + return + } + // Send error to client before closing + errInfo, sendErr := h.sendSSEError(w, result.err) + if sendErr != nil { + h.logger.Warn("Failed to send SSE error", zap.Error(sendErr)) + } + if errInfo.IsTimeout { + h.logger.Error("Stream timeout error", zap.Error(result.err)) + } else { + h.logger.Error("Stream error", zap.Error(result.err)) + } + return + } + if result.chunkJSON == "" { + continue + } + + w.WriteString("data: ") + w.WriteString(result.chunkJSON) + w.WriteString("\n\n") + + // Flush with timeout to prevent deadlock: + // If Flush blocks indefinitely (slow client), RecvJSON goroutine may fill recvChan + // and then block trying to send, causing deadlock + // Note: bufio.Writer.Flush() doesn't have a timeout parameter, so we use + // a goroutine + select pattern to implement timeout behavior + flushDone := make(chan error, 1) + go func() { + flushDone <- w.Flush() + }() + + flushCtx, flushCancel := context.WithTimeout(streamCtx, flushTimeout) + defer flushCancel() + + select { + case err := <-flushDone: + if err != nil { + if isBrokenPipeError(err) { + clientDisconnected = true + cancel() + // Close stream immediately to unblock RecvJSON() calls + stream.Close() + return + } + h.logger.Warn("Flush error", zap.Error(err)) + } + case <-flushCtx.Done(): + // Flush timeout: client may be slow or disconnected + // Continue processing to avoid deadlock, but mark as disconnected + if flushCtx.Err() == context.DeadlineExceeded { + h.logger.Warn("Flush timeout, client may be slow or disconnected", zap.Duration("timeout", flushTimeout)) + } + clientDisconnected = true + cancel() + stream.Close() + return + case <-streamCtx.Done(): + // Context cancelled, stop flushing + return + } + } + } + }) +} + +func (h *ChatHandler) handleNonStreamingCompletion(ctx *fasthttp.RequestCtx, requestCtx context.Context, req sglang.ChatCompletionRequest) { + resp, err := h.service.Client().CreateChatCompletion(requestCtx, req) + if err != nil { + h.logger.Error("Failed to create chat completion", + zap.Error(err), + zap.String("model", req.Model), + ) + utils.RespondError(ctx, 500, fmt.Sprintf("Failed to create completion: %v", err), "server_error") + return + } + + // Convert to OpenAI format + response := utils.BuildResponseBase(resp.ID, resp.Created, resp.Model) + response["object"] = "chat.completion" + + choices := make([]map[string]interface{}, len(resp.Choices)) + for i, choice := range resp.Choices { + choiceMap := map[string]interface{}{ + "index": choice.Index, + "message": map[string]interface{}{ + "role": choice.Message.Role, + "content": choice.Message.Content, + }, + "finish_reason": choice.FinishReason, + } + if len(choice.Message.ToolCalls) > 0 { + toolCalls := make([]map[string]interface{}, len(choice.Message.ToolCalls)) + for j, tc := range choice.Message.ToolCalls { + toolCalls[j] = map[string]interface{}{ + "id": tc.ID, + "type": tc.Type, + "function": map[string]interface{}{"name": tc.Function.Name, "arguments": tc.Function.Arguments}, + } + } + choiceMap["message"].(map[string]interface{})["tool_calls"] = toolCalls + } + choices[i] = choiceMap + } + response["choices"] = choices + + // Usage is always present (not a pointer) + response["usage"] = map[string]interface{}{ + "prompt_tokens": resp.Usage.PromptTokens, + "completion_tokens": resp.Usage.CompletionTokens, + "total_tokens": resp.Usage.TotalTokens, + } + + ctx.SetStatusCode(200) + ctx.SetContentType("application/json") + jsonData, _ := json.Marshal(response) + ctx.Write(jsonData) +} + +// StreamErrorInfo holds parsed error information +type StreamErrorInfo struct { + Message string + Type string + Code int + IsTimeout bool +} + +// parseStreamError parses error type and code +func parseStreamError(err error) StreamErrorInfo { + if err == nil { + return StreamErrorInfo{} + } + + errorMsg := err.Error() + // Check timeout error by message prefix + isTimeout := strings.HasPrefix(errorMsg, "stream.Recv() timeout") || strings.Contains(errorMsg, "timeout after") + + errorType := "server_error" + errorCode := 500 + if isTimeout { + errorType = "timeout_error" + errorCode = 504 + } + + return StreamErrorInfo{ + Message: errorMsg, + Type: errorType, + Code: errorCode, + IsTimeout: isTimeout, + } +} + +// formatErrorJSON formats error as OpenAI JSON +func formatErrorJSON(errInfo StreamErrorInfo) string { + errorObj := map[string]interface{}{ + "error": map[string]interface{}{ + "message": errInfo.Message, + "type": errInfo.Type, + "code": errInfo.Code, + }, + } + jsonBytes, _ := json.Marshal(errorObj) + return string(jsonBytes) +} + +// sendSSEError sends SSE error response. Callers should log errors. +func (h *ChatHandler) sendSSEError(w *bufio.Writer, err error) (StreamErrorInfo, error) { + errInfo := parseStreamError(err) + errorJSON := formatErrorJSON(errInfo) + + w.WriteString("data: ") + w.WriteString(errorJSON) + w.WriteString("\n\n") + + if flushErr := w.Flush(); flushErr != nil && !isBrokenPipeError(flushErr) { + h.logger.Warn("Failed to flush error response", zap.Error(flushErr)) + return errInfo, flushErr + } + + return errInfo, nil +} + +// HandleGenerate handles POST /generate (SGLang native API) +func (h *ChatHandler) HandleGenerate(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + + defer func() { + statusCode := ctx.Response.StatusCode() + if statusCode == 0 { + statusCode = 200 + } + h.logHTTPResponse(statusCode, path) + }() + + // Parse request body + var req map[string]interface{} + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + h.logger.Warn("Invalid generate request", zap.Error(err)) + utils.RespondError(ctx, 400, fmt.Sprintf("Invalid request: %v", err), "invalid_request_error") + return + } + + // Extract text and sampling_params + text, ok := req["text"].(string) + if !ok || text == "" { + utils.RespondError(ctx, 400, "Missing or invalid 'text' field", "invalid_request_error") + return + } + + samplingParams, _ := req["sampling_params"].(map[string]interface{}) + if samplingParams == nil { + samplingParams = make(map[string]interface{}) + } + + // Convert to chat completion format for processing + chatReq := sglang.ChatCompletionRequest{ + Model: "default", + Messages: []sglang.ChatMessage{{Role: "user", Content: text}}, + Stream: false, + } + + // Copy sampling params + if maxNewTokens, ok := samplingParams["max_new_tokens"].(float64); ok { + tokens := int(maxNewTokens) + chatReq.MaxCompletionTokens = &tokens + } + if temp, ok := samplingParams["temperature"].(float64); ok { + temp32 := float32(temp) + chatReq.Temperature = &temp32 + } + if topP, ok := samplingParams["top_p"].(float64); ok { + topP32 := float32(topP) + chatReq.TopP = &topP32 + } + if topK, ok := samplingParams["top_k"].(float64); ok { + topKInt := int(topK) + chatReq.TopK = &topKInt + } + + requestCtx := context.Background() + + // Use non-streaming completion for /generate endpoint + resp, err := h.service.Client().CreateChatCompletion(requestCtx, chatReq) + if err != nil { + h.logger.Error("Failed to create completion", + zap.Error(err), + ) + utils.RespondError(ctx, 500, fmt.Sprintf("Failed to create completion: %v", err), "server_error") + return + } + + // Convert to SGLang /generate response format + // meta_info must match SGLang's expected format with completion_tokens at top level + finishReason := resp.Choices[0].FinishReason + if finishReason == "" { + finishReason = "stop" + } + + response := map[string]interface{}{ + "text": resp.Choices[0].Message.Content, + "meta_info": map[string]interface{}{ + "id": resp.ID, + "finish_reason": finishReason, + "prompt_tokens": resp.Usage.PromptTokens, + "completion_tokens": resp.Usage.CompletionTokens, + "cached_tokens": 0, // Not available from chat completion API + "weight_version": "", // Not available from chat completion API + }, + } + + ctx.SetStatusCode(200) + ctx.SetContentType("application/json") + jsonData, _ := json.Marshal(response) + ctx.Write(jsonData) +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/health.go b/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/health.go new file mode 100644 index 000000000000..f0418154590d --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/health.go @@ -0,0 +1,33 @@ +package handlers + +import ( + "encoding/json" + + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +// HealthHandler handles health check requests +type HealthHandler struct { + logger *zap.Logger +} + +// NewHealthHandler creates a new health handler +func NewHealthHandler(logger *zap.Logger) *HealthHandler { + return &HealthHandler{ + logger: logger, + } +} + +// Check handles GET /health +func (h *HealthHandler) Check(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetContentType("application/json") + + response := map[string]string{ + "status": "ok", + } + + jsonData, _ := json.Marshal(response) + ctx.Write(jsonData) +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/models.go b/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/models.go new file mode 100644 index 000000000000..a03474a6575d --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/handlers/models.go @@ -0,0 +1,67 @@ +package handlers + +import ( + "encoding/json" + + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +// ModelsHandler handles model list requests +type ModelsHandler struct { + logger *zap.Logger + tokenizerPath string +} + +// NewModelsHandler creates a new models handler +func NewModelsHandler(logger *zap.Logger, tokenizerPath string) *ModelsHandler { + return &ModelsHandler{ + logger: logger, + tokenizerPath: tokenizerPath, + } +} + +// List handles GET /v1/models +func (h *ModelsHandler) List(ctx *fasthttp.RequestCtx) { + // Return a default model for OpenAI compatibility + ctx.SetStatusCode(200) + ctx.SetContentType("application/json") + + response := map[string]interface{}{ + "object": "list", + "data": []map[string]interface{}{ + { + "id": "default", + "object": "model", + "created": 1677610602, + "owned_by": "sglang", + }, + }, + } + + jsonData, _ := json.Marshal(response) + ctx.Write(jsonData) +} + +// GetModelInfo handles GET /get_model_info +// Returns model information compatible with SGLang RuntimeEndpoint +func (h *ModelsHandler) GetModelInfo(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetContentType("application/json") + + // Return model info compatible with SGLang RuntimeEndpoint expectations + response := map[string]interface{}{ + "model_path": h.tokenizerPath, // Use tokenizer path as model path + "tokenizer_path": h.tokenizerPath, + "is_generation": true, + "preferred_sampling_params": "", + "weight_version": "", + "has_image_understanding": false, + "has_audio_understanding": false, + "model_type": "", + "architectures": nil, + } + + jsonData, _ := json.Marshal(response) + ctx.Write(jsonData) +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/logger/logger.go b/sgl-model-gateway/bindings/golang/examples/oai_server/logger/logger.go new file mode 100644 index 000000000000..3e07cd007645 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/logger/logger.go @@ -0,0 +1,67 @@ +package logger + +import ( + "os" + "path/filepath" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +// Init initializes the logger with file and console output +func Init(logDir, logLevel string) (*zap.Logger, error) { + // Ensure log directory exists + if err := os.MkdirAll(logDir, 0755); err != nil { + return nil, err + } + + // Parse log level + var level zapcore.Level + if err := level.UnmarshalText([]byte(logLevel)); err != nil { + level = zapcore.InfoLevel + } + + // Create log file path with date + logFile := filepath.Join(logDir, "oai_server-"+time.Now().Format("2006-01-02")+".log") + + // File writer with rotation + fileWriter := zapcore.AddSync(&lumberjack.Logger{ + Filename: logFile, + MaxSize: 100, // megabytes + MaxBackups: 10, + MaxAge: 30, // days + Compress: true, + }) + + // Console writer + consoleWriter := zapcore.AddSync(os.Stdout) + + // Encoder config + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.TimeKey = "timestamp" + encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder + + // Create cores + fileCore := zapcore.NewCore( + zapcore.NewJSONEncoder(encoderConfig), + fileWriter, + level, + ) + + consoleCore := zapcore.NewCore( + zapcore.NewConsoleEncoder(encoderConfig), + consoleWriter, + level, + ) + + // Combine cores + core := zapcore.NewTee(fileCore, consoleCore) + + // Create logger + logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)) + + return logger, nil +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/main.go b/sgl-model-gateway/bindings/golang/examples/oai_server/main.go new file mode 100644 index 000000000000..5ba7ecfc0bb2 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/main.go @@ -0,0 +1,116 @@ +// OpenAI-compatible chat server using SGLang Go SDK and fasthttp framework +package main + +import ( + "fmt" + "net/http" + "os" + + _ "net/http/pprof" // Enable pprof endpoints + + "github.com/valyala/fasthttp" + "go.uber.org/zap" + + "oai_server/config" + "oai_server/handlers" + "oai_server/logger" + "oai_server/service" +) + +// Version information (set at build time via ldflags) +var ( + Version = "dev" + BuildTime = "unknown" + GitCommit = "unknown" +) + +func main() { + // Load configuration + cfg := config.Load() + + // Initialize logger + appLogger, err := logger.Init(cfg.LogDir, cfg.LogLevel) + if err != nil { + panic(fmt.Sprintf("Failed to initialize logger: %v", err)) + } + defer appLogger.Sync() + + appLogger.Info("Starting OpenAI-compatible server", + zap.String("endpoint", cfg.Endpoint), + zap.String("tokenizer", cfg.TokenizerPath), + zap.String("port", cfg.Port), + ) + + // Initialize SGLang service + sglangService, err := service.NewSGLangService(cfg.Endpoint, cfg.TokenizerPath) + if err != nil { + appLogger.Fatal("Failed to create SGLang client", zap.Error(err)) + } + defer sglangService.Close() + + appLogger.Info("SGLang client created successfully") + + // Enable pprof if requested + if os.Getenv("PPROF_ENABLED") == "true" { + pprofPort := os.Getenv("PPROF_PORT") + if pprofPort == "" { + pprofPort = "6060" + } + go func() { + pprofAddr := ":" + pprofPort + appLogger.Info("Starting pprof server", zap.String("address", pprofAddr)) + if err := http.ListenAndServe(pprofAddr, nil); err != nil { + appLogger.Error("pprof server failed", zap.Error(err)) + } + }() + appLogger.Info("pprof enabled", zap.String("port", pprofPort), zap.String("endpoint", fmt.Sprintf("http://localhost:%s/debug/pprof/", pprofPort))) + } + + // Initialize handlers + healthHandler := handlers.NewHealthHandler(appLogger) + modelsHandler := handlers.NewModelsHandler(appLogger, cfg.TokenizerPath) + chatHandler := handlers.NewChatHandler(appLogger, sglangService) + + // Setup fasthttp router + router := func(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + method := string(ctx.Method()) + + switch { + case method == "GET" && path == "/health": + healthHandler.Check(ctx) + case method == "GET" && path == "/v1/models": + modelsHandler.List(ctx) + case method == "GET" && path == "/get_model_info": + modelsHandler.GetModelInfo(ctx) + case method == "POST" && path == "/v1/chat/completions": + chatHandler.HandleChatCompletion(ctx) + case (method == "POST" || method == "PUT") && path == "/generate": + chatHandler.HandleGenerate(ctx) + default: + ctx.Error("Not Found", fasthttp.StatusNotFound) + } + } + + // Start server + serverAddr := ":" + cfg.Port + baseURL := fmt.Sprintf("http://localhost:%s", cfg.Port) + + appLogger.Info("Server starting", + zap.String("address", serverAddr), + zap.String("base_url", baseURL), + ) + + // Print available HTTP endpoints (similar to FastAPI startup) + appLogger.Info("Available HTTP endpoints:") + appLogger.Info(fmt.Sprintf(" GET %s/health", baseURL)) + appLogger.Info(fmt.Sprintf(" GET %s/v1/models", baseURL)) + appLogger.Info(fmt.Sprintf(" GET %s/get_model_info", baseURL)) + appLogger.Info(fmt.Sprintf(" POST %s/v1/chat/completions", baseURL)) + appLogger.Info(fmt.Sprintf(" POST %s/generate", baseURL)) + appLogger.Info(fmt.Sprintf("Application startup complete. Listening on %s", baseURL)) + + if err := fasthttp.ListenAndServe(serverAddr, router); err != nil { + appLogger.Fatal("Server failed", zap.Error(err)) + } +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/models/chat.go b/sgl-model-gateway/bindings/golang/examples/oai_server/models/chat.go new file mode 100644 index 000000000000..a6034664cd7a --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/models/chat.go @@ -0,0 +1,14 @@ +package models + +// ChatRequest represents an OpenAI-compatible chat completion request +type ChatRequest struct { + Model string `json:"model" binding:"required"` + Messages []map[string]string `json:"messages" binding:"required"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` // OpenAI API standard field + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // SGLang-specific field (used by bench_serving.py) + Tools []map[string]interface{} `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/run.sh b/sgl-model-gateway/bindings/golang/examples/oai_server/run.sh new file mode 100755 index 000000000000..f62451630239 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/run.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +# OpenAI-compatible server runner +# Usage: ./run.sh [tokenizer_path] [endpoint] [port] [--profile] [--pprof-port PORT] +# +# Options: +# --profile Enable pprof profiling (default port: 6060) +# --pprof-port PORT Set pprof port (default: 6060, requires --profile) + +# Set library path for Rust FFI library +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BINDINGS_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" +LIB_DIR="${BINDINGS_DIR}/lib" + +if [ ! -d "$LIB_DIR" ]; then + echo "Error: Library directory not found at $LIB_DIR" + echo "Please run 'make lib' first to build and export the library" + exit 1 +fi + +# Get Python LDFLAGS (needed for Rust FFI that depends on Python) +PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "") + +# Set CGO_LDFLAGS to link with the Rust library +# Note: -lsgl_model_gateway_go and -ldl are already in the #cgo directive in internal/ffi/client.go +# We only need to add the library path (-L) and Python flags +export CGO_LDFLAGS="-L${LIB_DIR} ${PYTHON_LDFLAGS}" + +# macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH +if [[ "$OSTYPE" == "darwin"* ]]; then + export DYLD_LIBRARY_PATH="${LIB_DIR}:${DYLD_LIBRARY_PATH}" +else + export LD_LIBRARY_PATH="${LIB_DIR}:${LD_LIBRARY_PATH}" +fi + +# Parse arguments +ENABLE_PROFILE=false +PPROF_PORT="6060" +TOKENIZER_PATH="" +ENDPOINT="" +PORT="" + +while [[ $# -gt 0 ]]; do + case $1 in + --profile) + ENABLE_PROFILE=true + shift + ;; + --pprof-port) + ENABLE_PROFILE=true + PPROF_PORT="$2" + shift 2 + ;; + *) + if [[ -z "$TOKENIZER_PATH" ]]; then + TOKENIZER_PATH="$1" + elif [[ -z "$ENDPOINT" ]]; then + ENDPOINT="$1" + elif [[ -z "$PORT" ]]; then + PORT="$1" + fi + shift + ;; + esac +done + +# Default configuration +DEFAULT_TOKENIZER_PATH="${SGL_TOKENIZER_PATH:-../tokenizer}" +DEFAULT_ENDPOINT="${SGL_GRPC_ENDPOINT:-grpc://localhost:20000}" +DEFAULT_PORT="${PORT:-8080}" + +TOKENIZER_PATH="${TOKENIZER_PATH:-${DEFAULT_TOKENIZER_PATH}}" +ENDPOINT="${ENDPOINT:-${DEFAULT_ENDPOINT}}" +PORT="${PORT:-${DEFAULT_PORT}}" + +echo "Running OpenAI-compatible server..." +echo "Library path: ${LIB_DIR}" +echo "Tokenizer: $TOKENIZER_PATH" +echo "Endpoint: $ENDPOINT" +echo "Port: $PORT" +echo "Client Mode: gRPC (default)" +echo "FFI Postprocessing: ENABLED (normal mode)" +echo "FFI Preprocessing: ENABLED (normal mode)" +if [[ "$ENABLE_PROFILE" == "true" ]]; then + echo "Profiling: enabled (port: $PPROF_PORT)" + echo " pprof endpoint: http://localhost:$PPROF_PORT/debug/pprof/" + export PPROF_ENABLED=true + export PPROF_PORT="$PPROF_PORT" +else + echo "Profiling: disabled" +fi +echo "" + +# Change to script directory +cd "$(dirname "${BASH_SOURCE[0]}")" + +# Ensure Go module is properly initialized +if [ ! -f "go.mod" ]; then + echo "Error: go.mod not found in $(pwd)" + exit 1 +fi + +# Ensure Go modules are enabled +export GO111MODULE=on + +# Sync Go module dependencies +echo "Syncing Go module dependencies..." +go mod tidy + +# Run the server (use ./main.go to ensure module context is correct) +SGL_TOKENIZER_PATH="$TOKENIZER_PATH" SGL_GRPC_ENDPOINT="$ENDPOINT" PORT="$PORT" go run ./main.go diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/analyze_tpot.sh b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/analyze_tpot.sh new file mode 100755 index 000000000000..e549840a95be --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/analyze_tpot.sh @@ -0,0 +1,554 @@ +#!/bin/bash + +# TPOT performance bottleneck analysis script +# Specifically designed to analyze why Go Router is twice as slow as Rust Router +# +# Usage: +# ./scripts/analyze_tpot.sh [options] +# +# Options: +# --duration SECONDS CPU profile duration (default: 60) +# --requests NUM Number of requests (default: 100) +# --concurrency NUM Concurrency level (default: 20) +# --pprof-port PORT pprof port (default: 6060) +# --server-url URL Server URL (default: http://localhost:8080) + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +PROFILE_DIR="${PROJECT_ROOT}/profiles" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUTPUT_DIR="${PROFILE_DIR}/tpot_analysis_${TIMESTAMP}" + +# Colors +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +BLUE='\033[0;34m' +NC='\033[0m' + +# Default values +DURATION=${DURATION:-60} +NUM_REQUESTS=${NUM_REQUESTS:-100} +CONCURRENCY=${CONCURRENCY:-20} +PPROF_PORT=${PPROF_PORT:-6060} +SERVER_URL=${SERVER_URL:-http://localhost:8080} + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --duration) + DURATION="$2" + shift 2 + ;; + --requests) + NUM_REQUESTS="$2" + shift 2 + ;; + --concurrency) + CONCURRENCY="$2" + shift 2 + ;; + --pprof-port) + PPROF_PORT="$2" + shift 2 + ;; + --server-url) + SERVER_URL="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +mkdir -p "$OUTPUT_DIR" + +# Check for graphviz (optional, needed for some pprof visualizations) +HAS_GRAPHVIZ=false +if command -v dot >/dev/null 2>&1; then + HAS_GRAPHVIZ=true +fi + +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE}TPOT Performance Bottleneck Analysis${NC}" +echo -e "${BLUE}========================================${NC}" +echo "" +echo "Configuration:" +echo " Duration: ${DURATION}s" +echo " Requests: $NUM_REQUESTS" +echo " Concurrency: $CONCURRENCY" +echo " Server URL: $SERVER_URL" +echo " pprof Port: $PPROF_PORT" +echo " Output Dir: $OUTPUT_DIR" +if [ "$HAS_GRAPHVIZ" = "false" ]; then + echo "" + echo -e "${YELLOW}Note: graphviz not found. Some pprof visualizations may not work.${NC}" + echo -e "${YELLOW}To install graphviz:${NC}" + echo -e "${YELLOW} macOS: brew install graphviz${NC}" + echo -e "${YELLOW} Ubuntu: sudo apt-get install graphviz${NC}" + echo -e "${YELLOW} CentOS: sudo yum install graphviz${NC}" + echo -e "${YELLOW}Text reports will still be generated without graphviz.${NC}" +fi +echo "" + +# Check if server is running +echo -e "${YELLOW}[Check] Verifying server is running...${NC}" +if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then + echo -e "${RED}Error: Server not responding at ${SERVER_URL}${NC}" + echo "" + echo "Please start the server first with profiling enabled:" + echo " ./run.sh --profile --pprof-port $PPROF_PORT" + echo " or" + echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT make run" + exit 1 +fi +echo -e "${GREEN}✓ Server is running${NC}" +echo "" + +# Check if pprof is enabled +echo -e "${YELLOW}[Check] Verifying pprof is enabled...${NC}" +if ! curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then + echo -e "${RED}Error: pprof not accessible at http://localhost:${PPROF_PORT}/debug/pprof/${NC}" + echo "" + echo "Please start the server with profiling enabled:" + echo " ./run.sh --profile --pprof-port $PPROF_PORT" + exit 1 +fi +echo -e "${GREEN}✓ pprof is enabled${NC}" +echo "" + +# ============================================ +# Step 1: Collect baseline profiles +# ============================================ +echo -e "${GREEN}[Step 1/8] Collecting baseline profiles...${NC}" + +# Baseline memory +go tool pprof -proto -output="${OUTPUT_DIR}/heap_before.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true + +# Baseline goroutine +go tool pprof -proto -output="${OUTPUT_DIR}/goroutine_before.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/goroutine" > /dev/null 2>&1 || true + +echo -e "${GREEN}✓ Baseline profiles collected${NC}" +echo "" + +# ============================================ +# Step 2: Start CPU profile collection +# ============================================ +echo -e "${GREEN}[Step 2/8] Starting CPU profile collection (${DURATION}s)...${NC}" +go tool pprof -proto -output="${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" & +CPU_PID=$! +sleep 2 +echo -e "${GREEN}✓ CPU profile collection started${NC}" +echo "" + +# ============================================ +# Step 3: Run load test with streaming requests +# ============================================ +echo -e "${GREEN}[Step 3/8] Running load test ($NUM_REQUESTS streaming requests, concurrency=$CONCURRENCY)...${NC}" + +# Function to run a single streaming request +run_streaming_request() { + local request_id=$1 + local start_time=$(date +%s) + local start_nanos=$(date +%N 2>/dev/null || echo "000000000") + + curl -N -s -X POST "${SERVER_URL}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"default\", + \"messages\": [{\"role\": \"user\", \"content\": \"Write a 500-word story with character dialogue and scene descriptions\"}], + \"stream\": true, + \"max_tokens\": 300, + \"temperature\": 0.7 + }" > /dev/null + + local end_time=$(date +%s) + local end_nanos=$(date +%N 2>/dev/null || echo "000000000") + local duration=$((end_time - start_time)) + echo "$duration" >> "${OUTPUT_DIR}/request_times.txt" +} + +# Run requests with controlled concurrency +# Use a temporary file to track job PIDs to avoid conflicts with CPU_PID +JOB_PIDS_FILE="${OUTPUT_DIR}/.job_pids_$$" +> "$JOB_PIDS_FILE" + +for i in $(seq 1 $NUM_REQUESTS); do + # Wait if we've reached concurrency limit + while [ $(wc -l < "$JOB_PIDS_FILE" 2>/dev/null || echo 0) -ge $CONCURRENCY ]; do + # Check and remove completed jobs + while IFS= read -r pid; do + if [ -n "$pid" ] && ! kill -0 "$pid" 2>/dev/null; then + # Process completed, remove from file + grep -v "^${pid}$" "$JOB_PIDS_FILE" > "${JOB_PIDS_FILE}.tmp" && \ + mv "${JOB_PIDS_FILE}.tmp" "$JOB_PIDS_FILE" || true + fi + done < "$JOB_PIDS_FILE" + sleep 0.1 + done + + # Start new request + run_streaming_request $i & + echo $! >> "$JOB_PIDS_FILE" + + # Progress indicator + if [ $((i % 10)) -eq 0 ]; then + echo " Progress: $i/$NUM_REQUESTS requests sent..." + fi +done + +# Wait for all remaining jobs (excluding CPU_PID) +while IFS= read -r pid; do + if [ -n "$pid" ] && [ "$pid" != "$CPU_PID" ]; then + wait "$pid" 2>/dev/null || true + fi +done < "$JOB_PIDS_FILE" + +# Clean up +rm -f "$JOB_PIDS_FILE" "${JOB_PIDS_FILE}.tmp" 2>/dev/null || true + +echo -e "${GREEN}✓ Load test completed${NC}" +echo "" + +# ============================================ +# Step 4: Wait for CPU profile to complete +# ============================================ +echo -e "${GREEN}[Step 4/8] Waiting for CPU profile to complete...${NC}" +# Wait for the process, but handle the case where it might have already completed +if kill -0 $CPU_PID 2>/dev/null; then + wait $CPU_PID 2>/dev/null || true +else + # Process already completed, just wait a bit to ensure file is written + sleep 1 +fi +echo -e "${GREEN}✓ CPU profile collection completed${NC}" +echo "" + +# ============================================ +# Step 5: Collect final profiles +# ============================================ +echo -e "${GREEN}[Step 5/8] Collecting final profiles...${NC}" + +# Final memory +go tool pprof -proto -output="${OUTPUT_DIR}/heap_after.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true + +# Final goroutine +go tool pprof -proto -output="${OUTPUT_DIR}/goroutine_after.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/goroutine" > /dev/null 2>&1 || true + +# Mutex profile +go tool pprof -proto -output="${OUTPUT_DIR}/mutex.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/mutex" > /dev/null 2>&1 || true + +# Block profile +go tool pprof -proto -output="${OUTPUT_DIR}/block.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/block" > /dev/null 2>&1 || true + +echo -e "${GREEN}✓ Final profiles collected${NC}" +echo "" + +# ============================================ +# Step 6: Generate analysis reports +# ============================================ +echo -e "${GREEN}[Step 6/8] Generating analysis reports...${NC}" + +# CPU analysis +echo " Generating CPU reports..." +go tool pprof -top -cum "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/01_cpu_top_cum.txt" 2>&1 || true +go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/02_cpu_top_flat.txt" 2>&1 || true + +# Memory analysis +echo " Generating memory reports..." +if [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then + go tool pprof -top -alloc_space "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/03_memory_alloc_space.txt" 2>&1 || true + go tool pprof -top -alloc_objects "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/04_memory_alloc_objects.txt" 2>&1 || true + go tool pprof -top -inuse_space "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/05_memory_inuse_space.txt" 2>&1 || true +fi + +# Memory growth +if [ -f "${OUTPUT_DIR}/heap_before.pb.gz" ] && [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then + go tool pprof -top -base="${OUTPUT_DIR}/heap_before.pb.gz" \ + "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/06_memory_growth.txt" 2>&1 || true +fi + +# FFI/CGO analysis +echo " Analyzing FFI/CGO calls..." +go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \ + grep -iE "(block_on|CGO|FFI|ffi|runtime\.cgo|_Cfunc)" > "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt" || \ + echo "No FFI/CGO related functions found" > "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt" + +# JSON serialization analysis +echo " Analyzing JSON serialization..." +go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \ + grep -iE "(json|Marshal|Unmarshal|Encode|Decode|sonic|jsoniter)" > "${OUTPUT_DIR}/08_json_analysis.txt" || \ + echo "No JSON related functions found" > "${OUTPUT_DIR}/08_json_analysis.txt" + +# Goroutine analysis +if [ -f "${OUTPUT_DIR}/goroutine_after.pb.gz" ]; then + echo " Analyzing goroutines..." + go tool pprof -top "${OUTPUT_DIR}/goroutine_after.pb.gz" > "${OUTPUT_DIR}/09_goroutine_analysis.txt" 2>&1 || true +fi + +# Mutex analysis +if [ -f "${OUTPUT_DIR}/mutex.pb.gz" ]; then + echo " Analyzing mutex contention..." + go tool pprof -top "${OUTPUT_DIR}/mutex.pb.gz" > "${OUTPUT_DIR}/10_mutex_analysis.txt" 2>&1 || true +fi + +# Block analysis +if [ -f "${OUTPUT_DIR}/block.pb.gz" ]; then + echo " Analyzing blocking operations..." + go tool pprof -top "${OUTPUT_DIR}/block.pb.gz" > "${OUTPUT_DIR}/11_block_analysis.txt" 2>&1 || true +fi + +# Request timing statistics +if [ -f "${OUTPUT_DIR}/request_times.txt" ] && [ -s "${OUTPUT_DIR}/request_times.txt" ]; then + echo " Calculating request timing statistics..." + { + echo "Request Timing Statistics" + echo "========================" + echo "" + echo "Total requests: $(wc -l < "${OUTPUT_DIR}/request_times.txt" | tr -d ' ')" + echo "" + awk '{ + sum+=$1 + sumsq+=$1*$1 + if(NR==1 || $1max) max=$1 + } END { + if(NR > 0) { + mean=sum/NR + variance=(sumsq/NR - mean*mean) + stddev=sqrt(variance) + print "Min: " min "s" + print "Max: " max "s" + print "Mean: " mean "s" + print "StdDev: " stddev "s" + } + }' "${OUTPUT_DIR}/request_times.txt" + } > "${OUTPUT_DIR}/12_request_timing.txt" +fi + +echo -e "${GREEN}✓ Analysis reports generated${NC}" +echo "" + +# ============================================ +# Step 7: Generate summary report +# ============================================ +echo -e "${GREEN}[Step 7/8] Generating summary report...${NC}" + +SUMMARY_FILE="${OUTPUT_DIR}/00_SUMMARY.md" +cat > "$SUMMARY_FILE" </dev/null || echo "No significant mutex contention detected") +\`\`\` + +### 8. Blocking Operations + +\`\`\` +$(head -15 "${OUTPUT_DIR}/11_block_analysis.txt" | tail -10 2>/dev/null || echo "No significant blocking detected") +\`\`\` + +## Performance Bottlenecks Identified + +### High Priority Issues + +1. **FFI/CGO Overhead** + - Check: \`cat ${OUTPUT_DIR}/07_ffi_cgo_analysis.txt\` + - Impact: FFI calls add overhead compared to native Rust code + - Recommendation: Minimize FFI calls, batch operations + +2. **JSON Serialization** + - Check: \`cat ${OUTPUT_DIR}/08_json_analysis.txt\` + - Impact: JSON marshaling/unmarshaling can be expensive + - Recommendation: Use faster JSON library (jsoniter), reduce serialization frequency + +3. **Memory Allocations** + - Check: \`cat ${OUTPUT_DIR}/03_memory_alloc_space.txt\` + - Impact: Frequent allocations cause GC pressure + - Recommendation: Use object pools, pre-allocate buffers + +### Medium Priority Issues + +4. **Goroutine Overhead** + - Check: \`cat ${OUTPUT_DIR}/09_goroutine_analysis.txt\` + - Impact: Too many goroutines can cause scheduling overhead + - Recommendation: Limit goroutine count, use worker pools + +5. **Lock Contention** + - Check: \`cat ${OUTPUT_DIR}/10_mutex_analysis.txt\` + - Impact: Lock contention reduces parallelism + - Recommendation: Reduce lock granularity, use lock-free structures + +## Comparison with Rust Router + +### Expected Differences + +1. **FFI Overhead**: Go → Rust FFI calls add ~100-500ns per call +2. **GC Overhead**: Go's GC can cause pauses (usually <1ms) +3. **JSON Library**: Go's standard library is slower than Rust's serde +4. **Memory Layout**: Go's GC affects cache locality + +### Optimization Opportunities + +1. **Reduce FFI Calls** + - Batch token processing + - Use async FFI (if possible) + - Cache frequently used FFI results + +2. **Optimize JSON** + - Use jsoniter (already implemented) + - Pre-allocate JSON buffers + - Reduce serialization frequency + +3. **Memory Management** + - Use sync.Pool for frequently allocated objects + - Pre-allocate slices with known capacity + - Avoid unnecessary string copies + +4. **Concurrency** + - Use worker pools instead of spawning goroutines per request + - Limit concurrent FFI calls + - Use channels efficiently + +## Next Steps + +1. Review detailed reports in this directory +2. Use interactive pprof: \`go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz\` +3. Compare with Rust router profiles (if available) +4. Implement optimizations based on findings +5. Re-run analysis to measure improvements + +## Files Generated + +- \`00_SUMMARY.md\` - This summary +- \`01_cpu_top_cum.txt\` - CPU top functions (cumulative) +- \`02_cpu_top_flat.txt\` - CPU top functions (flat) +- \`03_memory_alloc_space.txt\` - Memory allocation by space +- \`04_memory_alloc_objects.txt\` - Memory allocation by objects +- \`05_memory_inuse_space.txt\` - Memory in use by space +- \`06_memory_growth.txt\` - Memory growth during test +- \`07_ffi_cgo_analysis.txt\` - FFI/CGO overhead analysis +- \`08_json_analysis.txt\` - JSON serialization analysis +- \`09_goroutine_analysis.txt\` - Goroutine analysis +- \`10_mutex_analysis.txt\` - Mutex contention analysis +- \`11_block_analysis.txt\` - Blocking operations analysis +- \`12_request_timing.txt\` - Request timing statistics +- \`*.pb.gz\` - Raw profile files for interactive analysis + +EOF + +echo -e "${GREEN}✓ Summary report generated${NC}" +echo "" + +# ============================================ +# Step 8: Display summary +# ============================================ +echo -e "${GREEN}[Step 8/8] Analysis Complete!${NC}" +echo "" +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE}Summary${NC}" +echo -e "${BLUE}========================================${NC}" +echo "" +echo -e "${YELLOW}Top CPU Hotspots (Cumulative):${NC}" +head -12 "${OUTPUT_DIR}/01_cpu_top_cum.txt" | tail -10 +echo "" +echo -e "${YELLOW}FFI/CGO Overhead:${NC}" +cat "${OUTPUT_DIR}/07_ffi_cgo_analysis.txt" +echo "" +echo -e "${YELLOW}JSON Serialization Overhead:${NC}" +cat "${OUTPUT_DIR}/08_json_analysis.txt" +echo "" +echo -e "${YELLOW}Top Memory Allocations:${NC}" +head -12 "${OUTPUT_DIR}/03_memory_alloc_space.txt" | tail -10 +echo "" +if [ -f "${OUTPUT_DIR}/12_request_timing.txt" ]; then + echo -e "${YELLOW}Request Timing:${NC}" + cat "${OUTPUT_DIR}/12_request_timing.txt" + echo "" +fi +echo -e "${GREEN}========================================${NC}" +echo "" +echo -e "${BLUE}Detailed Reports:${NC}" +echo " Summary: cat ${OUTPUT_DIR}/00_SUMMARY.md" +echo " CPU (cum): cat ${OUTPUT_DIR}/01_cpu_top_cum.txt" +echo " CPU (flat): cat ${OUTPUT_DIR}/02_cpu_top_flat.txt" +echo " FFI/CGO: cat ${OUTPUT_DIR}/07_ffi_cgo_analysis.txt" +echo " JSON: cat ${OUTPUT_DIR}/08_json_analysis.txt" +echo " Memory: cat ${OUTPUT_DIR}/03_memory_alloc_space.txt" +echo "" +echo -e "${BLUE}Interactive Analysis:${NC}" +echo " Run: go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" +echo " Then visit:" +echo " - http://localhost:8081/ui/flamegraph (Flame Graph - no graphviz needed)" +echo " - http://localhost:8081/ui/top (Top Functions - no graphviz needed)" +if [ "$HAS_GRAPHVIZ" = "true" ]; then + echo " - http://localhost:8081/ui/graph (Call Graph - requires graphviz)" +else + echo " - http://localhost:8081/ui/graph (Call Graph - requires graphviz, not available)" +fi +echo "" +if [ "$HAS_GRAPHVIZ" = "false" ]; then + echo -e "${YELLOW}Note: Install graphviz to enable call graph visualization:${NC}" + echo -e "${YELLOW} macOS: brew install graphviz${NC}" + echo -e "${YELLOW} Ubuntu: sudo apt-get install graphviz${NC}" + echo -e "${YELLOW} CentOS: sudo yum install graphviz${NC}" + echo "" +fi +echo -e "${GREEN}All files saved to: ${OUTPUT_DIR}${NC}" +echo "" diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_analysis.sh b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_analysis.sh new file mode 100755 index 000000000000..7c5eca6f4d10 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_analysis.sh @@ -0,0 +1,215 @@ +#!/bin/bash + +# pprof performance analysis script +# Used to analyze performance bottlenecks of Go OpenAI server + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Configuration +PPROF_PORT=${PPROF_PORT:-6060} +SERVER_PORT=${SERVER_PORT:-8080} +DURATION=${DURATION:-60} # Performance test duration (seconds) +OUTPUT_DIR="./pprof_results" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# Create output directory +mkdir -p "$OUTPUT_DIR" + +echo "==========================================" +echo "pprof Performance Analysis Tool" +echo "==========================================" +echo "PPROF_PORT: $PPROF_PORT" +echo "SERVER_PORT: $SERVER_PORT" +echo "DURATION: ${DURATION}s" +echo "OUTPUT_DIR: $OUTPUT_DIR" +echo "" + +# Check if go tool pprof is available +if ! command -v go &> /dev/null; then + echo "Error: go command not found" + exit 1 +fi + +# Check if server is running +check_server() { + if curl -s "http://localhost:${SERVER_PORT}/health" > /dev/null 2>&1; then + return 0 + else + return 1 + fi +} + +# Check if pprof is available +check_pprof() { + if curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then + return 0 + else + return 1 + fi +} + +# Start server (if not running) +if ! check_server; then + echo "Server not running, please start the server first:" + echo " export PPROF_ENABLED=true" + echo " export PPROF_PORT=$PPROF_PORT" + echo " ./oai_server" + echo "" + echo "Or use the following command to start:" + echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT ./oai_server" + echo "" + read -p "Start server now? (y/n) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + echo "Starting server..." + PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT ./oai_server & + SERVER_PID=$! + echo "Server PID: $SERVER_PID" + + # Wait for server to start + echo "Waiting for server to start..." + for i in {1..30}; do + if check_server; then + echo "Server started" + break + fi + sleep 1 + done + + if ! check_server; then + echo "Error: Server failed to start" + kill $SERVER_PID 2>/dev/null || true + exit 1 + fi + else + exit 1 + fi +fi + +# Check if pprof is available +if ! check_pprof; then + echo "Error: pprof not enabled. Please set environment variables:" + echo " export PPROF_ENABLED=true" + echo " export PPROF_PORT=$PPROF_PORT" + exit 1 +fi + +echo "Starting to collect performance data..." +echo "" + +# 1. CPU Profile (30 seconds) +echo "[1/6] Collecting CPU Profile (30 seconds)..." +go tool pprof -proto -output="$OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30" & +CPU_PID=$! + +# 2. Collect Heap Profile simultaneously +echo "[2/6] Collecting Heap Profile..." +go tool pprof -proto -output="$OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/heap" & +HEAP_PID=$! + +# 3. Collect Goroutine Profile +echo "[3/6] Collecting Goroutine Profile..." +go tool pprof -proto -output="$OUTPUT_DIR/goroutine_${TIMESTAMP}.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/goroutine" & +GOROUTINE_PID=$! + +# 4. Collect Mutex Profile +echo "[4/6] Collecting Mutex Profile..." +go tool pprof -proto -output="$OUTPUT_DIR/mutex_${TIMESTAMP}.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/mutex" & +MUTEX_PID=$! + +# 5. Collect Block Profile +echo "[5/6] Collecting Block Profile..." +go tool pprof -proto -output="$OUTPUT_DIR/block_${TIMESTAMP}.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/block" & +BLOCK_PID=$! + +# 6. Run performance test (during CPU profile collection) +echo "[6/6] Running performance test..." +echo "Tip: Please use your performance testing tool (curl, ab, wrk, etc.) to send requests to the server" +echo " CPU profile will collect 30 seconds of performance data" +echo "" + +# Wait for CPU profile to complete +wait $CPU_PID +echo "CPU Profile collection completed" + +# Wait for other profiles +wait $HEAP_PID +wait $GOROUTINE_PID +wait $MUTEX_PID +wait $BLOCK_PID + +echo "" +echo "==========================================" +echo "Performance data collection completed!" +echo "==========================================" +echo "" +echo "Generated analysis files:" +ls -lh "$OUTPUT_DIR"/*_${TIMESTAMP}.* 2>/dev/null || true +echo "" + +# Generate analysis report +echo "Generating analysis report..." +echo "" + +# CPU Top 20 +echo "=== CPU Top 20 (sorted by flat time) ===" > "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +go tool pprof -top -cum "$OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true +echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" + +# Heap Top 20 +echo "=== Heap Top 20 (sorted by allocation size) ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +go tool pprof -top "$OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true +echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" + +# Goroutine statistics +echo "=== Goroutine Statistics ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +go tool pprof -top "$OUTPUT_DIR/goroutine_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true +echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" + +# Mutex statistics +echo "=== Mutex Wait Time ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +go tool pprof -top "$OUTPUT_DIR/mutex_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true +echo "" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" + +# Block statistics +echo "=== Block Wait Time ===" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +go tool pprof -top "$OUTPUT_DIR/block_${TIMESTAMP}.pb.gz" >> "$OUTPUT_DIR/analysis_${TIMESTAMP}.txt" 2>&1 || true + +echo "Analysis report saved to: $OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +echo "" + +# Display key information +echo "==========================================" +echo "Key Performance Metrics Summary" +echo "==========================================" +echo "" +echo "View detailed report:" +echo " cat $OUTPUT_DIR/analysis_${TIMESTAMP}.txt" +echo "" +echo "Interactive CPU Profile view:" +echo " go tool pprof $OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" +echo "" +echo "Interactive Heap Profile view:" +echo " go tool pprof $OUTPUT_DIR/heap_${TIMESTAMP}.pb.gz" +echo "" +echo "Generate flame graph (requires go-torch or pprof):" +echo " go tool pprof -http=:8080 $OUTPUT_DIR/cpu_${TIMESTAMP}.pb.gz" +echo "" + +# If server was started, ask if it should be closed +if [ -n "$SERVER_PID" ]; then + read -p "Close server? (y/n) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + kill $SERVER_PID 2>/dev/null || true + echo "Server closed" + fi +fi diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_quick.sh b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_quick.sh new file mode 100755 index 000000000000..c8b21a25174e --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_quick.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Quick pprof analysis script +# Collects 30-second CPU profile and immediately displays top results + +set -e + +PPROF_PORT=${PPROF_PORT:-6060} +DURATION=${DURATION:-30} + +echo "==========================================" +echo "Quick pprof Analysis" +echo "==========================================" +echo "PPROF_PORT: $PPROF_PORT" +echo "DURATION: ${DURATION}s" +echo "" +echo "Tip: During data collection, please send requests to the server" +echo " You can use: ./pprof_test.sh" +echo "" + +# Check if pprof is available +if ! curl -s "http://localhost:${PPROF_PORT}/debug/pprof/" > /dev/null 2>&1; then + echo "Error: pprof not enabled. Please set environment variables:" + echo " export PPROF_ENABLED=true" + echo " export PPROF_PORT=$PPROF_PORT" + exit 1 +fi + +echo "Starting to collect CPU Profile (${DURATION} seconds)..." +echo "" + +# Collect CPU profile and directly display top results +go tool pprof -top -cum "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" + +echo "" +echo "==========================================" +echo "Analysis Complete" +echo "==========================================" +echo "" +echo "More analysis options:" +echo " # Interactive view" +echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30" +echo "" +echo " # View heap memory" +echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/heap" +echo "" +echo " # View goroutines" +echo " go tool pprof http://localhost:${PPROF_PORT}/debug/pprof/goroutine" +echo "" +echo " # Generate Web UI" +echo " go tool pprof -http=:8080 http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=30" +echo "" diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_test.sh b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_test.sh new file mode 100755 index 000000000000..e5a69f60a130 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/pprof_test.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Simple performance test script for sending requests while collecting pprof data + +set -e + +SERVER_URL=${SERVER_URL:-"http://localhost:8080"} +DURATION=${DURATION:-30} # Test duration (seconds) +CONCURRENT=${CONCURRENT:-1} # Number of concurrent requests + +echo "==========================================" +echo "Performance Test Script" +echo "==========================================" +echo "SERVER_URL: $SERVER_URL" +echo "DURATION: ${DURATION}s" +echo "CONCURRENT: $CONCURRENT" +echo "" + +# Test request JSON +TEST_REQUEST='{ + "model": "default", + "messages": [ + {"role": "user", "content": "Hello, how are you?"} + ], + "stream": true, + "max_tokens": 100 +}' + +# Check if server is available +if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then + echo "Error: Server not available (${SERVER_URL}/health)" + exit 1 +fi + +echo "Starting to send test requests..." +echo "" + +# Function to send streaming request +send_stream_request() { + local request_num=$1 + local start_time=$(date +%s.%N) + + curl -s -N -X POST "${SERVER_URL}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "$TEST_REQUEST" \ + > /dev/null 2>&1 + + local end_time=$(date +%s.%N) + local duration=$(echo "$end_time - $start_time" | bc) + echo "Request $request_num completed, duration: ${duration}s" +} + +# Send requests concurrently +if [ "$CONCURRENT" -eq 1 ]; then + # Single-threaded mode: continuously send requests + end_time=$(($(date +%s) + DURATION)) + request_count=0 + + while [ $(date +%s) -lt $end_time ]; do + request_count=$((request_count + 1)) + send_stream_request $request_count + done + + echo "" + echo "Test completed, sent $request_count requests" +else + # Multi-threaded mode: send requests concurrently + end_time=$(($(date +%s) + DURATION)) + request_count=0 + + while [ $(date +%s) -lt $end_time ]; do + # Start concurrent requests + for i in $(seq 1 $CONCURRENT); do + request_count=$((request_count + 1)) + send_stream_request $request_count & + done + + # Wait for all requests to complete + wait + + # Brief rest to avoid overload + sleep 0.1 + done + + echo "" + echo "Test completed, sent $request_count requests" +fi diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/profile_tpot.sh b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/profile_tpot.sh new file mode 100755 index 000000000000..ced5c18051f5 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/scripts/profile_tpot.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +# TPOT performance analysis script +# Quickly collect and analyze TPOT-related performance data + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +PROFILE_DIR="${PROJECT_ROOT}/profiles" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUTPUT_DIR="${PROFILE_DIR}/${TIMESTAMP}" + +# Colors +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +# Default values +PPROF_PORT=${PPROF_PORT:-6060} +SERVER_URL=${SERVER_URL:-http://localhost:8080} +DURATION=${DURATION:-30} +NUM_REQUESTS=${NUM_REQUESTS:-20} + +mkdir -p "$OUTPUT_DIR" + +echo -e "${GREEN}TPOT Performance Analysis${NC}" +echo "==========================" +echo "Profile directory: $OUTPUT_DIR" +echo "Duration: ${DURATION}s" +echo "Requests: $NUM_REQUESTS" +echo "" + +# Check if server is running +if ! curl -s "${SERVER_URL}/health" > /dev/null 2>&1; then + echo -e "${YELLOW}Warning: Server not responding at ${SERVER_URL}${NC}" + echo "Please start the server first with profiling enabled:" + echo " PPROF_ENABLED=true PPROF_PORT=$PPROF_PORT make run" + exit 1 +fi + +# Collect baseline memory +echo -e "${GREEN}[1/5] Collecting baseline memory profile...${NC}" +go tool pprof -proto -output="${OUTPUT_DIR}/heap_before.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true + +# Start CPU profile collection in background +echo -e "${GREEN}[2/5] Starting CPU profile collection (${DURATION}s)...${NC}" +go tool pprof -proto -output="${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/profile?seconds=${DURATION}" & +CPU_PID=$! + +# Wait a bit for profile to start +sleep 2 + +# Run load test +echo -e "${GREEN}[3/5] Running load test ($NUM_REQUESTS requests)...${NC}" +for i in $(seq 1 $NUM_REQUESTS); do + curl -N -s -X POST "${SERVER_URL}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{ + \"model\": \"default\", + \"messages\": [{\"role\": \"user\", \"content\": \"Write a story\"}], + \"stream\": true, + \"max_tokens\": 200 + }" > /dev/null & + + # Limit concurrency + if [ $((i % 5)) -eq 0 ]; then + wait + fi +done +wait + +# Wait for CPU profile to complete +echo -e "${GREEN}[4/5] Waiting for CPU profile to complete...${NC}" +# Wait for the CPU profile process, but handle the case where it's not a child process +if kill -0 $CPU_PID 2>/dev/null; then + # Process is still running, wait for it + while kill -0 $CPU_PID 2>/dev/null; do + sleep 1 + done +else + # Process already completed or not found, just wait a bit + sleep 2 +fi + +# Collect final memory +echo -e "${GREEN}[5/5] Collecting final memory profile...${NC}" +go tool pprof -proto -output="${OUTPUT_DIR}/heap_after.pb.gz" \ + "http://localhost:${PPROF_PORT}/debug/pprof/heap" > /dev/null 2>&1 || true + +# Generate reports +echo "" +echo -e "${GREEN}Generating reports...${NC}" + +# CPU top (cumulative) +go tool pprof -top -cum "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/cpu_top_cum.txt" 2>&1 || true + +# CPU top (flat) +go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" > "${OUTPUT_DIR}/cpu_top_flat.txt" 2>&1 || true + +# Memory growth +if [ -f "${OUTPUT_DIR}/heap_before.pb.gz" ] && [ -f "${OUTPUT_DIR}/heap_after.pb.gz" ]; then + go tool pprof -top -base="${OUTPUT_DIR}/heap_before.pb.gz" \ + "${OUTPUT_DIR}/heap_after.pb.gz" > "${OUTPUT_DIR}/heap_growth.txt" 2>&1 || true +fi + +# FFI/CGO related +go tool pprof -top "${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" 2>&1 | \ + grep -E "(block_on|CGO|FFI|json|Marshal|Unmarshal)" > "${OUTPUT_DIR}/ffi_related.txt" || \ + echo "No FFI/CGO related functions found" > "${OUTPUT_DIR}/ffi_related.txt" + +# Summary +echo "" +echo -e "${GREEN}=== Analysis Summary ===${NC}" +echo "" +echo -e "${YELLOW}CPU Top (Cumulative) - Top 10:${NC}" +head -12 "${OUTPUT_DIR}/cpu_top_cum.txt" | tail -10 || true + +echo "" +echo -e "${YELLOW}CPU Top (Flat) - Top 10:${NC}" +head -12 "${OUTPUT_DIR}/cpu_top_flat.txt" | tail -10 || true + +echo "" +echo -e "${YELLOW}FFI/CGO Related Functions:${NC}" +cat "${OUTPUT_DIR}/ffi_related.txt" || true + +echo "" +echo -e "${GREEN}=== Detailed Reports ===${NC}" +echo "CPU (cumulative): cat ${OUTPUT_DIR}/cpu_top_cum.txt" +echo "CPU (flat): cat ${OUTPUT_DIR}/cpu_top_flat.txt" +echo "Memory growth: cat ${OUTPUT_DIR}/heap_growth.txt" +echo "FFI related: cat ${OUTPUT_DIR}/ffi_related.txt" +echo "" +echo -e "${GREEN}=== Interactive Analysis ===${NC}" +echo "Run: go tool pprof -http=:8081 ${OUTPUT_DIR}/cpu_${DURATION}s.pb.gz" +echo "Then visit: http://localhost:8081/ui/flamegraph" +echo "" +echo "Profile files saved to: ${OUTPUT_DIR}" diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/service/sglang.go b/sgl-model-gateway/bindings/golang/examples/oai_server/service/sglang.go new file mode 100644 index 000000000000..a76c874fe698 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/service/sglang.go @@ -0,0 +1,37 @@ +package service + +import ( + sglang "github.com/sglang/sglang-go-grpc-sdk" +) + +// SGLangService wraps SGLang client +type SGLangService struct { + client *sglang.Client +} + +func NewSGLangService(endpoint, tokenizerPath string) (*SGLangService, error) { + client, err := sglang.NewClient(sglang.ClientConfig{ + Endpoint: endpoint, + TokenizerPath: tokenizerPath, + }) + if err != nil { + return nil, err + } + + return &SGLangService{ + client: client, + }, nil +} + +// Client returns the underlying SGLang client +func (s *SGLangService) Client() *sglang.Client { + return s.client +} + +// Close closes the SGLang client +func (s *SGLangService) Close() error { + if s.client != nil { + return s.client.Close() + } + return nil +} diff --git a/sgl-model-gateway/bindings/golang/examples/oai_server/utils/utils.go b/sgl-model-gateway/bindings/golang/examples/oai_server/utils/utils.go new file mode 100644 index 000000000000..a66629acb8f3 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/examples/oai_server/utils/utils.go @@ -0,0 +1,34 @@ +package utils + +import ( + "encoding/json" + + "github.com/valyala/fasthttp" +) + +// RespondError sends an error response in OpenAI format +func RespondError(ctx *fasthttp.RequestCtx, statusCode int, message, errorType string) { + ctx.SetStatusCode(statusCode) + ctx.SetContentType("application/json") + + response := map[string]interface{}{ + "error": map[string]interface{}{ + "message": message, + "type": errorType, + "code": statusCode, + }, + } + + jsonData, _ := json.Marshal(response) + ctx.Write(jsonData) +} + +// BuildResponseBase builds the base response structure for OpenAI-compatible responses +func BuildResponseBase(id string, created int64, model string) map[string]interface{} { + return map[string]interface{}{ + "id": id, + "object": "chat.completion", + "created": created, + "model": model, + } +} diff --git a/sgl-model-gateway/bindings/golang/examples/simple/run.sh b/sgl-model-gateway/bindings/golang/examples/simple/run.sh index 9153f2e2cf7a..400950e74fe4 100755 --- a/sgl-model-gateway/bindings/golang/examples/simple/run.sh +++ b/sgl-model-gateway/bindings/golang/examples/simple/run.sh @@ -19,7 +19,7 @@ fi PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "") # Set CGO_LDFLAGS to link with the Rust library -export CGO_LDFLAGS="-L${LIB_DIR} -lsglang_router_rs ${PYTHON_LDFLAGS} -ldl" +export CGO_LDFLAGS="-L${LIB_DIR} -lsgl_model_gateway_go ${PYTHON_LDFLAGS} -ldl" # macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH if [[ "$OSTYPE" == "darwin"* ]]; then diff --git a/sgl-model-gateway/bindings/golang/examples/streaming/run.sh b/sgl-model-gateway/bindings/golang/examples/streaming/run.sh index 49911cc0c762..28678dd67225 100755 --- a/sgl-model-gateway/bindings/golang/examples/streaming/run.sh +++ b/sgl-model-gateway/bindings/golang/examples/streaming/run.sh @@ -19,7 +19,7 @@ fi PYTHON_LDFLAGS=$(python3-config --ldflags --embed 2>/dev/null || python3-config --ldflags 2>/dev/null || echo "") # Set CGO_LDFLAGS to link with the Rust library -export CGO_LDFLAGS="-L${LIB_DIR} -lsglang_router_rs ${PYTHON_LDFLAGS} -ldl" +export CGO_LDFLAGS="-L${LIB_DIR} -lsgl_model_gateway_go ${PYTHON_LDFLAGS} -ldl" # macOS uses DYLD_LIBRARY_PATH, Linux uses LD_LIBRARY_PATH if [[ "$OSTYPE" == "darwin"* ]]; then diff --git a/sgl-model-gateway/bindings/golang/go.mod b/sgl-model-gateway/bindings/golang/go.mod new file mode 100644 index 000000000000..bb6823c342d1 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/go.mod @@ -0,0 +1,17 @@ +module github.com/sglang/sglang-go-grpc-sdk + +go 1.24.0 + +toolchain go1.24.10 + +require ( + google.golang.org/grpc v1.77.0 + google.golang.org/protobuf v1.36.10 +) + +require ( + golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect +) diff --git a/sgl-model-gateway/bindings/golang/go.sum b/sgl-model-gateway/bindings/golang/go.sum new file mode 100644 index 000000000000..6d6408a5bb84 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/go.sum @@ -0,0 +1,36 @@ +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82 h1:6/3JGEh1C88g7m+qzzTbl3A0FtsLguXieqofVLU/JAo= +golang.org/x/net v0.46.1-0.20251013234738-63d1a5100f82/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= diff --git a/sgl-model-gateway/bindings/golang/internal/ffi/batch_postprocessor.go b/sgl-model-gateway/bindings/golang/internal/ffi/batch_postprocessor.go new file mode 100644 index 000000000000..4d14155817cd --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/ffi/batch_postprocessor.go @@ -0,0 +1,126 @@ +// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface). +package ffi + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// BatchPostprocessor handles batch postprocessing of stream chunks to reduce FFI overhead +type BatchPostprocessor struct { + converter *GrpcResponseConverterHandle + buffer []string + batchSize int + flushInterval time.Duration + lastFlush time.Time + timer *time.Timer +} + +// NewBatchPostprocessor creates a new batch postprocessor +func NewBatchPostprocessor(converter *GrpcResponseConverterHandle, batchSize int, flushInterval time.Duration) *BatchPostprocessor { + if batchSize <= 0 { + batchSize = 1 + } + if flushInterval < 0 { + flushInterval = 0 + } + + return &BatchPostprocessor{ + converter: converter, + buffer: make([]string, 0, batchSize), + batchSize: batchSize, + flushInterval: flushInterval, + lastFlush: time.Now(), + } +} + +// AddChunk adds a chunk to the buffer and processes if batch is full +func (b *BatchPostprocessor) AddChunk(chunkJSON string) (results []string, shouldFlush bool, err error) { + if b.batchSize == 1 { + openaiJSON, _, err := PostprocessStreamChunk(b.converter, chunkJSON) + if err != nil { + return nil, false, err + } + return []string{openaiJSON}, false, nil + } + + b.buffer = append(b.buffer, chunkJSON) + shouldProcess := len(b.buffer) >= b.batchSize + shouldFlushTimeout := b.flushInterval > 0 && time.Since(b.lastFlush) >= b.flushInterval + + if shouldProcess || shouldFlushTimeout { + return b.processBatch() + } + + return nil, false, nil +} + +// Flush processes any remaining chunks in the buffer +func (b *BatchPostprocessor) Flush() (results []string, err error) { + if len(b.buffer) == 0 { + return nil, nil + } + + res, _, err := b.processBatch() + return res, err +} + +// processBatch processes the current buffer and returns results +func (b *BatchPostprocessor) processBatch() (results []string, shouldFlush bool, err error) { + if len(b.buffer) == 0 { + return nil, false, nil + } + + var sb strings.Builder + sb.Grow(len(b.buffer) * 200) + sb.WriteString(`[`) + for i, chunkJSONStr := range b.buffer { + if i > 0 { + sb.WriteString(`,`) + } + sb.WriteString(chunkJSONStr) + } + sb.WriteString(`]`) + bufferJSON := sb.String() + + resultJSON, _, err := PostprocessStreamChunksBatch( + b.converter, + bufferJSON, + b.batchSize*2, + ) + if err != nil { + return nil, false, fmt.Errorf("batch postprocessing failed: %w", err) + } + + var resultArray []json.RawMessage + if err := json.Unmarshal([]byte(resultJSON), &resultArray); err != nil { + return nil, false, fmt.Errorf("failed to unmarshal results array: %w", err) + } + + resultStrings := make([]string, 0, len(resultArray)) + for _, rawMsg := range resultArray { + resultStrings = append(resultStrings, string(rawMsg)) + } + + b.buffer = b.buffer[:0] + b.lastFlush = time.Now() + + if b.timer != nil { + b.timer.Stop() + b.timer = nil + } + + return resultStrings, false, nil +} + +// Reset clears the buffer and resets the postprocessor state +func (b *BatchPostprocessor) Reset() { + b.buffer = b.buffer[:0] + b.lastFlush = time.Now() + if b.timer != nil { + b.timer.Stop() + b.timer = nil + } +} diff --git a/sgl-model-gateway/bindings/golang/internal/ffi/client.go b/sgl-model-gateway/bindings/golang/internal/ffi/client.go index 13439e3c9660..b229c38fe7b0 100644 --- a/sgl-model-gateway/bindings/golang/internal/ffi/client.go +++ b/sgl-model-gateway/bindings/golang/internal/ffi/client.go @@ -11,7 +11,7 @@ package ffi /* -#cgo LDFLAGS: -lsglang_router_rs -ldl +#cgo LDFLAGS: -lsgl_model_gateway_go -ldl #include #include diff --git a/sgl-model-gateway/bindings/golang/internal/ffi/grpc_converter.go b/sgl-model-gateway/bindings/golang/internal/ffi/grpc_converter.go new file mode 100644 index 000000000000..38125d3ae72a --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/ffi/grpc_converter.go @@ -0,0 +1,275 @@ +package ffi + +/* +#cgo LDFLAGS: -lsgl_model_gateway_go -ldl +#include +#include + +// Error codes (must match client.go) +typedef enum { + SGL_ERROR_SUCCESS = 0, + SGL_ERROR_INVALID_ARGUMENT = 1, + SGL_ERROR_TOKENIZATION_ERROR = 2, + SGL_ERROR_PARSING_ERROR = 3, + SGL_ERROR_MEMORY_ERROR = 4, + SGL_ERROR_UNKNOWN = 99 +} SglErrorCode; + +// Opaque handles +typedef void* TokenizerHandle; +typedef void* GrpcResponseConverterHandle; + +// Converter functions +GrpcResponseConverterHandle* sgl_grpc_response_converter_create( + TokenizerHandle* tokenizer_handle, + const char* model, + const char* request_id, + const char* tools_json, + const char* tool_choice_json, + const char* stop, + const char* stop_token_ids, + int skip_special_tokens, + int initial_prompt_tokens, + char** error_out +); + +void sgl_grpc_response_converter_free(GrpcResponseConverterHandle* handle); + +// Tokenizer functions +TokenizerHandle* sgl_tokenizer_create_from_file(const char* tokenizer_path, char** error_out); +void sgl_tokenizer_free(TokenizerHandle* handle); + +// Memory management +void sgl_free_string(char* s); +*/ +import "C" + +import ( + "fmt" + "unsafe" +) + +// CreateGrpcResponseConverter creates a gRPC response converter handle +// This function creates a new tokenizer handle each time (for backward compatibility) +// For better performance, use CreateGrpcResponseConverterWithTokenizer with a cached tokenizer +func CreateGrpcResponseConverter( + tokenizerPath string, + model string, + requestID string, + toolsJSON string, + toolChoiceJSON string, + stopJSON string, + stopTokenIDs []uint32, + skipSpecialTokens bool, + initialPromptTokens int32, +) (*GrpcResponseConverterHandle, error) { + // Create tokenizer handle + tokenizerHandle, err := createTokenizerHandle(tokenizerPath) + if err != nil { + return nil, fmt.Errorf("failed to create tokenizer handle: %w", err) + } + defer C.sgl_tokenizer_free(tokenizerHandle) + + return createGrpcResponseConverterWithTokenizerHandle( + tokenizerHandle, + model, + requestID, + toolsJSON, + toolChoiceJSON, + stopJSON, + stopTokenIDs, + skipSpecialTokens, + initialPromptTokens, + ) +} + +// CreateGrpcResponseConverterWithTokenizer creates a gRPC response converter handle using a cached tokenizer +// This is more efficient as it reuses the tokenizer instead of creating a new one each time +func CreateGrpcResponseConverterWithTokenizer( + tokenizerHandle *TokenizerHandle, + model string, + requestID string, + toolsJSON string, + toolChoiceJSON string, + stopJSON string, + stopTokenIDs []uint32, + skipSpecialTokens bool, + initialPromptTokens int32, +) (*GrpcResponseConverterHandle, error) { + if tokenizerHandle == nil || tokenizerHandle.handle == nil { + return nil, fmt.Errorf("invalid tokenizer handle") + } + + return createGrpcResponseConverterWithTokenizerHandle( + tokenizerHandle.handle, + model, + requestID, + toolsJSON, + toolChoiceJSON, + stopJSON, + stopTokenIDs, + skipSpecialTokens, + initialPromptTokens, + ) +} + +// createGrpcResponseConverterWithTokenizerHandle is the internal implementation +func createGrpcResponseConverterWithTokenizerHandle( + tokenizerHandle *C.TokenizerHandle, + model string, + requestID string, + toolsJSON string, + toolChoiceJSON string, + stopJSON string, + stopTokenIDs []uint32, + skipSpecialTokens bool, + initialPromptTokens int32, +) (*GrpcResponseConverterHandle, error) { + + // Convert strings to C strings + modelC := C.CString(model) + defer C.free(unsafe.Pointer(modelC)) + + requestIDC := C.CString(requestID) + defer C.free(unsafe.Pointer(requestIDC)) + + var toolsJSONC *C.char + if toolsJSON != "" { + toolsJSONC = C.CString(toolsJSON) + defer C.free(unsafe.Pointer(toolsJSONC)) + } + + var toolChoiceJSONC *C.char + if toolChoiceJSON != "" { + toolChoiceJSONC = C.CString(toolChoiceJSON) + defer C.free(unsafe.Pointer(toolChoiceJSONC)) + } + + var stopJSONC *C.char + if stopJSON != "" { + stopJSONC = C.CString(stopJSON) + defer C.free(unsafe.Pointer(stopJSONC)) + } + + // Convert stop_token_ids to JSON string + stopTokenIDsJSON := "" + if len(stopTokenIDs) > 0 { + stopTokenIDsJSON = fmt.Sprintf("[%d", stopTokenIDs[0]) + for i := 1; i < len(stopTokenIDs); i++ { + stopTokenIDsJSON += fmt.Sprintf(",%d", stopTokenIDs[i]) + } + stopTokenIDsJSON += "]" + } + + var stopTokenIDsJSONC *C.char + if stopTokenIDsJSON != "" { + stopTokenIDsJSONC = C.CString(stopTokenIDsJSON) + defer C.free(unsafe.Pointer(stopTokenIDsJSONC)) + } + + var errorOut *C.char + skipSpecialTokensC := C.int(0) + if skipSpecialTokens { + skipSpecialTokensC = C.int(1) + } + + initialPromptTokensC := C.int(initialPromptTokens) + + converterHandle := C.sgl_grpc_response_converter_create( + tokenizerHandle, + modelC, + requestIDC, + toolsJSONC, + toolChoiceJSONC, + stopJSONC, + stopTokenIDsJSONC, + skipSpecialTokensC, + initialPromptTokensC, + &errorOut, + ) + + if converterHandle == nil { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + if errorMsg == "" { + errorMsg = "failed to create converter handle" + } + return nil, fmt.Errorf("%s", errorMsg) + } + + return &GrpcResponseConverterHandle{ + handle: converterHandle, + }, nil +} + +// FreeGrpcResponseConverter frees a gRPC response converter handle +func FreeGrpcResponseConverter(handle *GrpcResponseConverterHandle) { + if handle != nil && handle.handle != nil { + C.sgl_grpc_response_converter_free(handle.handle) + handle.handle = nil + } +} + +// TokenizerHandle wraps the Rust tokenizer FFI handle +type TokenizerHandle struct { + handle *C.TokenizerHandle +} + +// CreateTokenizerHandle creates a tokenizer handle (exported for caching) +func CreateTokenizerHandle(tokenizerPath string) (*TokenizerHandle, error) { + tokenizerPathC := C.CString(tokenizerPath) + defer C.free(unsafe.Pointer(tokenizerPathC)) + + var errorOut *C.char + tokenizerHandle := C.sgl_tokenizer_create_from_file(tokenizerPathC, &errorOut) + + if tokenizerHandle == nil { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + if errorMsg == "" { + errorMsg = "failed to create tokenizer handle" + } + return nil, fmt.Errorf("%s", errorMsg) + } + + return &TokenizerHandle{ + handle: tokenizerHandle, + }, nil +} + +// FreeTokenizerHandle frees a tokenizer handle +func FreeTokenizerHandle(handle *TokenizerHandle) { + if handle != nil && handle.handle != nil { + C.sgl_tokenizer_free(handle.handle) + handle.handle = nil + } +} + +// createTokenizerHandle creates a tokenizer handle (helper function, internal use) +func createTokenizerHandle(tokenizerPath string) (*C.TokenizerHandle, error) { + tokenizerPathC := C.CString(tokenizerPath) + defer C.free(unsafe.Pointer(tokenizerPathC)) + + var errorOut *C.char + tokenizerHandle := C.sgl_tokenizer_create_from_file(tokenizerPathC, &errorOut) + + if tokenizerHandle == nil { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + if errorMsg == "" { + errorMsg = "failed to create tokenizer handle" + } + return nil, fmt.Errorf("%s", errorMsg) + } + + return tokenizerHandle, nil +} diff --git a/sgl-model-gateway/bindings/golang/internal/ffi/postprocessor.go b/sgl-model-gateway/bindings/golang/internal/ffi/postprocessor.go new file mode 100644 index 000000000000..1b76ec2ac3f9 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/ffi/postprocessor.go @@ -0,0 +1,156 @@ +// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface). +package ffi + +/* +#cgo LDFLAGS: -lsgl_model_gateway_go -ldl +#include +#include + +// Error codes (must match client.go) +typedef enum { + SGL_ERROR_SUCCESS = 0, + SGL_ERROR_INVALID_ARGUMENT = 1, + SGL_ERROR_TOKENIZATION_ERROR = 2, + SGL_ERROR_PARSING_ERROR = 3, + SGL_ERROR_MEMORY_ERROR = 4, + SGL_ERROR_UNKNOWN = 99 +} SglErrorCode; + +// Opaque handle (must match grpc_converter.go) +typedef void* GrpcResponseConverterHandle; + +// Postprocessor functions +SglErrorCode sgl_postprocess_stream_chunk( + GrpcResponseConverterHandle* converter_handle, + const char* proto_chunk_json, + char** openai_json_out, + int* is_done_out, + char** error_out +); + +SglErrorCode sgl_postprocess_stream_chunks_batch( + GrpcResponseConverterHandle* converter_handle, + const char* proto_chunks_json_array, + int max_chunks, + char** openai_chunks_json_array_out, + int* chunks_count_out, + char** error_out +); + +// Memory management +void sgl_free_string(char* s); +*/ +import "C" + +import ( + "fmt" + "unsafe" +) + +// GrpcResponseConverterHandle wraps the Rust gRPC response converter FFI handle +type GrpcResponseConverterHandle struct { + handle *C.GrpcResponseConverterHandle +} + +// PostprocessStreamChunk postprocesses a gRPC stream chunk to OpenAI format +// +// This function: +// 1. Parses the proto chunk from JSON +// 2. Converts it to OpenAI format using the converter handle +// 3. Returns the OpenAI format JSON +// +// Returns the OpenAI format JSON, is_done flag, and any error. +func PostprocessStreamChunk(converterHandle *GrpcResponseConverterHandle, protoChunkJSON string) (openaiJSON string, isDone bool, err error) { + if converterHandle == nil || converterHandle.handle == nil { + return "", false, fmt.Errorf("invalid converter handle") + } + + protoChunkJSONC := C.CString(protoChunkJSON) + defer C.free(unsafe.Pointer(protoChunkJSONC)) + + var openaiJSONOut *C.char + var isDoneOut C.int + var errorOut *C.char + + errorCode := C.sgl_postprocess_stream_chunk( + converterHandle.handle, + protoChunkJSONC, + &openaiJSONOut, + &isDoneOut, + &errorOut, + ) + + if errorCode != C.SGL_ERROR_SUCCESS { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + return "", false, fmt.Errorf("postprocessing failed: %s", errorMsg) + } + + openaiJSON = C.GoString(openaiJSONOut) + isDone = isDoneOut != 0 + + // Free the C string allocated by Rust + if openaiJSONOut != nil { + C.sgl_free_string(openaiJSONOut) + } + + return openaiJSON, isDone, nil +} + +// PostprocessStreamChunksBatch postprocesses multiple gRPC stream chunks in batch +// +// This function processes multiple chunks in a single FFI call, significantly reducing +// FFI overhead in streaming scenarios. +// +// Arguments: +// - converterHandle: Converter handle +// - protoChunksJSONArray: JSON array string of proto chunks +// - maxChunks: Maximum number of chunks to process (for safety, typically 10-20) +// +// Returns: +// - openaiChunksJSONArray: JSON array of OpenAI format chunks +// - chunksCount: Number of processed chunks +// - error: Any error that occurred +func PostprocessStreamChunksBatch(converterHandle *GrpcResponseConverterHandle, protoChunksJSONArray string, maxChunks int) (openaiChunksJSONArray string, chunksCount int, err error) { + if converterHandle == nil || converterHandle.handle == nil { + return "", 0, fmt.Errorf("invalid converter handle") + } + + protoChunksJSONArrayC := C.CString(protoChunksJSONArray) + defer C.free(unsafe.Pointer(protoChunksJSONArrayC)) + + var openaiChunksJSONArrayOut *C.char + var chunksCountOut C.int + var errorOut *C.char + + errorCode := C.sgl_postprocess_stream_chunks_batch( + converterHandle.handle, + protoChunksJSONArrayC, + C.int(maxChunks), + &openaiChunksJSONArrayOut, + &chunksCountOut, + &errorOut, + ) + + if errorCode != C.SGL_ERROR_SUCCESS { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + return "", 0, fmt.Errorf("batch postprocessing failed: %s", errorMsg) + } + + openaiChunksJSONArray = C.GoString(openaiChunksJSONArrayOut) + chunksCount = int(chunksCountOut) + + // Free the C string allocated by Rust + if openaiChunksJSONArrayOut != nil { + C.sgl_free_string(openaiChunksJSONArrayOut) + } + + return openaiChunksJSONArray, chunksCount, nil +} diff --git a/sgl-model-gateway/bindings/golang/internal/ffi/preprocessor.go b/sgl-model-gateway/bindings/golang/internal/ffi/preprocessor.go new file mode 100644 index 000000000000..90c05173949e --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/ffi/preprocessor.go @@ -0,0 +1,246 @@ +// Package ffi provides Go bindings for SGLang's Rust FFI (Foreign Function Interface). +package ffi + +/* +#cgo LDFLAGS: -lsgl_model_gateway_go -ldl +#include +#include + +// Error codes (must match client.go) +typedef enum { + SGL_ERROR_SUCCESS = 0, + SGL_ERROR_INVALID_ARGUMENT = 1, + SGL_ERROR_TOKENIZATION_ERROR = 2, + SGL_ERROR_PARSING_ERROR = 3, + SGL_ERROR_MEMORY_ERROR = 4, + SGL_ERROR_UNKNOWN = 99 +} SglErrorCode; + +// Preprocessor functions +SglErrorCode sgl_preprocess_chat_request( + const char* request_json, + const char* tokenizer_path, + char** prompt_text_out, + uint32_t** token_ids_out, + size_t* token_ids_len_out, + char** tool_constraints_json_out, + int32_t* prompt_tokens_out, + char** error_out +); + +// Opaque handle (must match grpc_converter.go) +typedef void* TokenizerHandle; + +SglErrorCode sgl_preprocess_chat_request_with_tokenizer( + const char* request_json, + void* tokenizer_handle, + char** prompt_text_out, + uint32_t** token_ids_out, + size_t* token_ids_len_out, + char** tool_constraints_json_out, + int32_t* prompt_tokens_out, + char** error_out +); + +void sgl_preprocessed_request_free( + char* prompt_text, + uint32_t* token_ids, + size_t token_ids_len, + char* tool_constraints_json +); + +// Memory management +void sgl_free_string(char* s); +void sgl_free_token_ids(uint32_t* ptr, size_t count); +*/ +import "C" + +import ( + "fmt" + "unsafe" +) + +// PreprocessedRequest represents a preprocessed chat request +type PreprocessedRequest struct { + PromptText string + TokenIDs []uint32 + ToolConstraintsJSON string + PromptTokens int32 + // Internal pointers for memory management + promptTextPtr *C.char + tokenIDsPtr *C.uint32_t + tokenIDsLen uintptr + toolConstraintsJSONPtr *C.char +} + +// PreprocessChatRequest preprocesses a chat completion request +// +// This function: +// 1. Applies chat_template to messages +// 2. Tokenizes the processed text +// 3. Generates tool constraints (if tools are present) +// +// Returns the preprocessed request data and any error. +func PreprocessChatRequest(requestJSON, tokenizerPath string) (*PreprocessedRequest, error) { + requestJSONC := C.CString(requestJSON) + defer C.free(unsafe.Pointer(requestJSONC)) + + tokenizerPathC := C.CString(tokenizerPath) + defer C.free(unsafe.Pointer(tokenizerPathC)) + + var promptTextOut *C.char + var tokenIDsOut *C.uint32_t + var tokenIDsLenOut C.size_t + var toolConstraintsJSONOut *C.char + var promptTokensOut C.int32_t + var errorOut *C.char + + errorCode := C.sgl_preprocess_chat_request( + requestJSONC, + tokenizerPathC, + &promptTextOut, + &tokenIDsOut, + &tokenIDsLenOut, + &toolConstraintsJSONOut, + &promptTokensOut, + &errorOut, + ) + + if errorCode != C.SGL_ERROR_SUCCESS { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + return nil, fmt.Errorf("preprocessing failed: %s", errorMsg) + } + + result := &PreprocessedRequest{ + PromptText: C.GoString(promptTextOut), + TokenIDs: make([]uint32, tokenIDsLenOut), + ToolConstraintsJSON: "", + PromptTokens: int32(promptTokensOut), + } + + // Copy token IDs + if tokenIDsOut != nil && tokenIDsLenOut > 0 { + tokenIDsSlice := (*[1 << 30]C.uint32_t)(unsafe.Pointer(tokenIDsOut))[:tokenIDsLenOut:tokenIDsLenOut] + for i := range result.TokenIDs { + result.TokenIDs[i] = uint32(tokenIDsSlice[i]) + } + } + + // Copy tool constraints JSON if present + if toolConstraintsJSONOut != nil { + result.ToolConstraintsJSON = C.GoString(toolConstraintsJSONOut) + } + + // Store pointers for later cleanup + result.promptTextPtr = promptTextOut + result.tokenIDsPtr = tokenIDsOut + result.tokenIDsLen = uintptr(tokenIDsLenOut) + result.toolConstraintsJSONPtr = toolConstraintsJSONOut + + return result, nil +} + +// PreprocessChatRequestWithTokenizer preprocesses a chat completion request using an existing tokenizer handle +// +// This function is similar to PreprocessChatRequest, but accepts a TokenizerHandle +// instead of creating a new tokenizer. This allows reusing a cached tokenizer instance, +// significantly reducing initialization overhead in concurrent scenarios. +// +// Returns the preprocessed request data and any error. +func PreprocessChatRequestWithTokenizer(requestJSON string, tokenizerHandle *TokenizerHandle) (*PreprocessedRequest, error) { + requestJSONC := C.CString(requestJSON) + defer C.free(unsafe.Pointer(requestJSONC)) + + if tokenizerHandle == nil || tokenizerHandle.handle == nil { + return nil, fmt.Errorf("invalid tokenizer handle") + } + + var promptTextOut *C.char + var tokenIDsOut *C.uint32_t + var tokenIDsLenOut C.size_t + var toolConstraintsJSONOut *C.char + var promptTokensOut C.int32_t + var errorOut *C.char + + errorCode := C.sgl_preprocess_chat_request_with_tokenizer( + requestJSONC, + unsafe.Pointer(tokenizerHandle.handle), // Convert *C.TokenizerHandle to void* + &promptTextOut, + &tokenIDsOut, + &tokenIDsLenOut, + &toolConstraintsJSONOut, + &promptTokensOut, + &errorOut, + ) + + if errorCode != C.SGL_ERROR_SUCCESS { + errorMsg := "" + if errorOut != nil { + errorMsg = C.GoString(errorOut) + C.sgl_free_string(errorOut) + } + return nil, fmt.Errorf("preprocessing failed: %s", errorMsg) + } + + result := &PreprocessedRequest{ + PromptText: C.GoString(promptTextOut), + TokenIDs: make([]uint32, tokenIDsLenOut), + ToolConstraintsJSON: "", + PromptTokens: int32(promptTokensOut), + } + + // Copy token IDs + if tokenIDsOut != nil && tokenIDsLenOut > 0 { + tokenIDsSlice := (*[1 << 30]C.uint32_t)(unsafe.Pointer(tokenIDsOut))[:tokenIDsLenOut:tokenIDsLenOut] + for i := range result.TokenIDs { + result.TokenIDs[i] = uint32(tokenIDsSlice[i]) + } + } + + // Copy tool constraints JSON if present + if toolConstraintsJSONOut != nil { + result.ToolConstraintsJSON = C.GoString(toolConstraintsJSONOut) + } + + // Store pointers for later cleanup + result.promptTextPtr = promptTextOut + result.tokenIDsPtr = tokenIDsOut + result.tokenIDsLen = uintptr(tokenIDsLenOut) + result.toolConstraintsJSONPtr = toolConstraintsJSONOut + + return result, nil +} + +// Free frees the memory allocated for a preprocessed request +func (p *PreprocessedRequest) Free() { + if p.promptTextPtr != nil || p.tokenIDsPtr != nil || p.toolConstraintsJSONPtr != nil { + C.sgl_preprocessed_request_free( + p.promptTextPtr, + p.tokenIDsPtr, + C.size_t(p.tokenIDsLen), + p.toolConstraintsJSONPtr, + ) + // Clear pointers to prevent double-free + p.promptTextPtr = nil + p.tokenIDsPtr = nil + p.tokenIDsLen = 0 + p.toolConstraintsJSONPtr = nil + } +} + +// FreePreprocessedRequest frees the memory allocated for a preprocessed request +// This is a convenience function for direct pointer management +func FreePreprocessedRequest(promptTextPtr *C.char, tokenIDsPtr *C.uint32_t, tokenIDsLen uintptr, toolConstraintsJSONPtr *C.char) { + if promptTextPtr != nil || tokenIDsPtr != nil || toolConstraintsJSONPtr != nil { + C.sgl_preprocessed_request_free( + promptTextPtr, + tokenIDsPtr, + C.size_t(tokenIDsLen), + toolConstraintsJSONPtr, + ) + } +} diff --git a/sgl-model-gateway/bindings/golang/internal/grpc/client_grpc.go b/sgl-model-gateway/bindings/golang/internal/grpc/client_grpc.go new file mode 100644 index 000000000000..4213c467e090 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/grpc/client_grpc.go @@ -0,0 +1,684 @@ +// Package grpc provides gRPC client implementation for SGLang +package grpc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/sglang/sglang-go-grpc-sdk/internal/ffi" + "github.com/sglang/sglang-go-grpc-sdk/internal/proto" +) + +type grpcClientStream interface { + Recv() (*proto.GenerateResponse, error) + CloseSend() error +} + +// recvResult holds the result of a Recv() call +type recvResult struct { + resp *proto.GenerateResponse + err error +} + +type GrpcClient struct { + conn *grpc.ClientConn + client proto.SglangSchedulerClient + tokenizerPath string + tokenizerHandle *ffi.TokenizerHandle + bufferSizes ChannelBufferSizes + timeouts Timeouts + requestCounter uint64 // Atomic counter to ensure unique request IDs +} + +type ChannelBufferSizes struct { + ResultJSONChan int + ErrChan int + RecvChan int +} + +type Timeouts struct { + KeepaliveTime time.Duration + KeepaliveTimeout time.Duration + CloseTimeout time.Duration +} + +func NewGrpcClient(endpoint, tokenizerPath string, bufferSizes ChannelBufferSizes, timeouts Timeouts) (*GrpcClient, error) { + endpoint = strings.TrimPrefix(endpoint, "grpc://") + if !strings.Contains(endpoint, ":") { + return nil, fmt.Errorf("invalid endpoint format: %s (expected grpc://host:port)", endpoint) + } + + keepaliveParams := keepalive.ClientParameters{ + Time: timeouts.KeepaliveTime, + Timeout: timeouts.KeepaliveTimeout, + PermitWithoutStream: false, + } + + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithKeepaliveParams(keepaliveParams), + } + + conn, err := grpc.NewClient(endpoint, opts...) + if err != nil { + return nil, fmt.Errorf("failed to connect to gRPC server: %w", err) + } + + client := proto.NewSglangSchedulerClient(conn) + + tokenizerHandle, err := ffi.CreateTokenizerHandle(tokenizerPath) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to create tokenizer handle: %w", err) + } + + return &GrpcClient{ + conn: conn, + client: client, + tokenizerPath: tokenizerPath, + tokenizerHandle: tokenizerHandle, + bufferSizes: bufferSizes, + timeouts: timeouts, + }, nil +} + +func (c *GrpcClient) Close() error { + if c.tokenizerHandle != nil { + ffi.FreeTokenizerHandle(c.tokenizerHandle) + c.tokenizerHandle = nil + } + + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +func (c *GrpcClient) CreateChatCompletionStream(ctx context.Context, reqJSON string) (*GrpcChatCompletionStream, error) { + if c.tokenizerHandle == nil { + return nil, fmt.Errorf("tokenizer handle is nil (should be created at startup)") + } + + preprocessed, err := ffi.PreprocessChatRequestWithTokenizer(reqJSON, c.tokenizerHandle) + if err != nil { + return nil, fmt.Errorf("preprocessing failed: %w", err) + } + defer func() { + if preprocessed != nil { + preprocessed.Free() + } + }() + + // Parse request JSON to get parameters + var reqMap map[string]interface{} + if err := json.Unmarshal([]byte(reqJSON), &reqMap); err != nil { + return nil, fmt.Errorf("failed to parse request JSON: %w", err) + } + + model, _ := reqMap["model"].(string) + if model == "" { + model = "default" + } + + // Build GenerateRequest + // Generate unique request ID using timestamp + atomic counter to avoid collisions + // This matches Rust version's UUID-based approach for uniqueness + counter := atomic.AddUint64(&c.requestCounter, 1) + requestID := fmt.Sprintf("chatcmpl-%d-%d", time.Now().UnixNano(), counter) + generateReq := &proto.GenerateRequest{ + RequestId: requestID, + Tokenized: &proto.TokenizedInput{ + OriginalText: preprocessed.PromptText, + InputIds: preprocessed.TokenIDs, + }, + Stream: true, + } + + // Set sampling parameters + samplingParams := &proto.SamplingParams{ + Temperature: 1.0, + TopP: 1.0, + TopK: -1, + SkipSpecialTokens: true, + } + + if temp, ok := reqMap["temperature"].(float64); ok { + samplingParams.Temperature = float32(temp) + } + if topP, ok := reqMap["top_p"].(float64); ok { + samplingParams.TopP = float32(topP) + } + if topK, ok := reqMap["top_k"].(float64); ok { + samplingParams.TopK = int32(topK) + } + var maxTokensInt *int32 + if maxCompletionTokens, ok := reqMap["max_completion_tokens"].(float64); ok { + tokens := int32(maxCompletionTokens) + maxTokensInt = &tokens + } else if maxTokens, ok := reqMap["max_tokens"].(float64); ok { + tokens := int32(maxTokens) + maxTokensInt = &tokens + } + if maxTokensInt != nil { + samplingParams.MaxNewTokens = maxTokensInt + } + + // Parse tool constraints if available + if preprocessed.ToolConstraintsJSON != "" { + var toolConstraints map[string]interface{} + if err := json.Unmarshal([]byte(preprocessed.ToolConstraintsJSON), &toolConstraints); err == nil { + if regex, ok := toolConstraints["regex"].(string); ok { + samplingParams.Constraint = &proto.SamplingParams_Regex{Regex: regex} + } else if jsonSchema, ok := toolConstraints["json_schema"].(string); ok { + samplingParams.Constraint = &proto.SamplingParams_JsonSchema{JsonSchema: jsonSchema} + } + } + } + + generateReq.SamplingParams = samplingParams + generateReq.Timestamp = timestamppb.Now() + + stream, err := c.client.Generate(ctx, generateReq) + if err != nil { + return nil, fmt.Errorf("failed to create gRPC stream: %w", err) + } + toolsJSON := "" + if tools, ok := reqMap["tools"].([]interface{}); ok && len(tools) > 0 { + toolsBytes, _ := json.Marshal(tools) + toolsJSON = string(toolsBytes) + } + + toolChoiceJSON := "" + if toolChoice, ok := reqMap["tool_choice"]; ok { + toolChoiceBytes, _ := json.Marshal(toolChoice) + toolChoiceJSON = string(toolChoiceBytes) + } + + stopJSON := "" + if stop, ok := reqMap["stop"]; ok { + stopBytes, _ := json.Marshal(stop) + stopJSON = string(stopBytes) + } + + stopTokenIDs := []uint32{} + if stopTokenIDsVal, ok := reqMap["stop_token_ids"].([]interface{}); ok { + for _, id := range stopTokenIDsVal { + if idFloat, ok := id.(float64); ok { + stopTokenIDs = append(stopTokenIDs, uint32(idFloat)) + } + } + } + + skipSpecialTokens := true + if skipSpecialTokensVal, ok := reqMap["skip_special_tokens"].(bool); ok { + skipSpecialTokens = skipSpecialTokensVal + } + + if c.tokenizerHandle == nil { + stream.CloseSend() + return nil, fmt.Errorf("tokenizer handle is nil (should be created at startup)") + } + + converterHandle, err := ffi.CreateGrpcResponseConverterWithTokenizer( + c.tokenizerHandle, + model, + generateReq.RequestId, + toolsJSON, + toolChoiceJSON, + stopJSON, + stopTokenIDs, + skipSpecialTokens, + preprocessed.PromptTokens, // Pass initial prompt tokens from preprocessing + ) + if err != nil { + stream.CloseSend() + return nil, fmt.Errorf("failed to create converter handle: %w", err) + } + + batchSize := 1 + batchPostprocessor := ffi.NewBatchPostprocessor(converterHandle, batchSize, 0) + + streamCtx, cancel := context.WithCancel(ctx) + grpcStream := &GrpcChatCompletionStream{ + stream: stream, + converterHandle: converterHandle, + batchPostprocessor: batchPostprocessor, + batchSize: batchSize, + ctx: streamCtx, + cancel: cancel, + resultJSONChan: make(chan string, c.bufferSizes.ResultJSONChan), + errChan: make(chan error, c.bufferSizes.ErrChan), + readLoopDone: make(chan struct{}), + requestID: generateReq.RequestId, + model: model, + processWg: sync.WaitGroup{}, + closeTimeout: c.timeouts.CloseTimeout, + bufferSizes: c.bufferSizes, + } + + go grpcStream.readLoop() + + return grpcStream, nil +} + +// GrpcChatCompletionStream represents a streaming chat completion via gRPC +type GrpcChatCompletionStream struct { + stream grpcClientStream + converterHandle *ffi.GrpcResponseConverterHandle + batchPostprocessor *ffi.BatchPostprocessor + batchSize int + ctx context.Context + cancel context.CancelFunc + closed int32 + resultJSONChan chan string + errChan chan error + readLoopDone chan struct{} + requestID string + model string + processWg sync.WaitGroup + closeTimeout time.Duration + bufferSizes ChannelBufferSizes + clientDisconnected int32 // Atomic flag: 1 if client disconnected, 0 otherwise +} + +func (s *GrpcChatCompletionStream) readLoop() { + defer func() { + atomic.StoreInt32(&s.closed, 1) + s.processWg.Wait() + close(s.resultJSONChan) + close(s.errChan) + close(s.readLoopDone) + // Cancel context after channels are closed to ensure errors are read first + if s.cancel != nil { + s.cancel() + } + }() + + recvChan := make(chan recvResult, s.bufferSizes.RecvChan) + const firstRecvTimeout = 60 * time.Second + + go func() { + defer close(recvChan) + recvCount := 0 + for { + select { + case <-s.ctx.Done(): + // Skip CloseSend() if client disconnected + if atomic.LoadInt32(&s.clientDisconnected) == 0 { + _ = s.stream.CloseSend() + } + return + default: + } + + recvCount++ + var protoResp *proto.GenerateResponse + var err error + + // First Recv() with timeout + if recvCount == 1 { + recvDone := make(chan recvResult, 1) + go func() { + resp, recvErr := s.stream.Recv() + recvDone <- recvResult{resp: resp, err: recvErr} + }() + + select { + case result := <-recvDone: + protoResp = result.resp + err = result.err + case <-time.After(firstRecvTimeout): + timeoutErr := fmt.Errorf("stream.Recv() timeout after %v: backend may not be responding (request_id=%s)", firstRecvTimeout, s.requestID) + select { + case recvChan <- recvResult{resp: nil, err: timeoutErr}: + case <-s.ctx.Done(): + } + return + case <-s.ctx.Done(): + return + } + } else { + // Normal Recv() + protoResp, err = s.stream.Recv() + } + + if err != nil { + select { + case recvChan <- recvResult{resp: nil, err: err}: + case <-s.ctx.Done(): + return + } + return + } + + select { + case <-s.ctx.Done(): + // Skip CloseSend() if client disconnected + if atomic.LoadInt32(&s.clientDisconnected) == 0 { + _ = s.stream.CloseSend() + } + return + case recvChan <- recvResult{resp: protoResp, err: nil}: + } + } + }() + + for { + select { + case <-s.ctx.Done(): + // Skip CloseSend() if client disconnected + if atomic.LoadInt32(&s.clientDisconnected) == 0 { + _ = s.stream.CloseSend() + } + return + case result, ok := <-recvChan: + if !ok { + return + } + if result.err != nil { + if result.err == io.EOF { + results, flushErr := s.flushBatch() + if flushErr != nil { + select { + case s.errChan <- fmt.Errorf("failed to flush batch: %w", flushErr): + case <-s.ctx.Done(): + } + return + } + for _, resultJSON := range results { + select { + case s.resultJSONChan <- resultJSON: + case <-s.ctx.Done(): + return + } + } + return + } + select { + case s.errChan <- result.err: + case <-s.ctx.Done(): + } + return + } + + if result.resp != nil { + s.processWg.Add(1) + go func(resp *proto.GenerateResponse) { + defer s.processWg.Done() + s.processAndSendResponse(resp) + }(result.resp) + } + } + } +} + +func (s *GrpcChatCompletionStream) processAndSendResponse(protoResp *proto.GenerateResponse) { + select { + case <-s.ctx.Done(): + return + default: + } + + if protoResp == nil { + return + } + + protoJSON, err := protoToJSON(protoResp) + if err != nil { + select { + case s.errChan <- fmt.Errorf("failed to convert proto to JSON: %w", err): + case <-s.ctx.Done(): + } + return + } + + if s.batchPostprocessor == nil { + select { + case s.errChan <- fmt.Errorf("batch postprocessor is nil"): + case <-s.ctx.Done(): + } + return + } + + results, _, err := s.batchPostprocessor.AddChunk(protoJSON) + if err != nil { + select { + case s.errChan <- fmt.Errorf("batch postprocessing failed: %w", err): + case <-s.ctx.Done(): + } + return + } + + for _, resultJSON := range results { + select { + case s.resultJSONChan <- resultJSON: + case <-s.ctx.Done(): + return + } + } +} + +func (s *GrpcChatCompletionStream) RecvJSON() (string, error) { + // Use a loop instead of recursion to avoid stack overflow if there are many empty strings + for { + // Check errChan first to prioritize actual errors over context cancellation + select { + case err, ok := <-s.errChan: + if !ok { + return "", io.EOF + } + return "", err + default: + } + + select { + case resultJSON, ok := <-s.resultJSONChan: + if !ok { + return "", io.EOF + } + // Skip empty strings and continue loop instead of recursing + if resultJSON != "" { + return resultJSON, nil + } + // Empty string, continue loop to get next result + continue + case err, ok := <-s.errChan: + if !ok { + return "", io.EOF + } + return "", err + case <-s.ctx.Done(): + return "", s.ctx.Err() + } + } +} + +// SetClientDisconnected marks that the client has disconnected. +// When Close() is called, it will not call CloseSend() to avoid aborting the request on server side. +func (s *GrpcChatCompletionStream) SetClientDisconnected() { + atomic.StoreInt32(&s.clientDisconnected, 1) +} + +func (s *GrpcChatCompletionStream) Close() error { + if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) { + return nil + } + + if s.cancel != nil { + s.cancel() + } + + clientDisconnected := atomic.LoadInt32(&s.clientDisconnected) == 1 + + select { + case <-s.readLoopDone: + // readLoop completed + default: + if !clientDisconnected { + // Call CloseSend() if client didn't disconnect + _ = s.stream.CloseSend() + } + select { + case <-s.readLoopDone: + case <-time.After(s.closeTimeout): + } + } + + _, _ = s.flushBatch() + + if s.converterHandle != nil { + ffi.FreeGrpcResponseConverter(s.converterHandle) + } + + return nil +} + +func (s *GrpcChatCompletionStream) flushBatch() ([]string, error) { + if s.batchPostprocessor != nil { + results, err := s.batchPostprocessor.Flush() + if err != nil { + return nil, fmt.Errorf("batch flush failed: %w", err) + } + return results, nil + } + return nil, nil +} + +func protoToJSON(resp *proto.GenerateResponse) (string, error) { + var sb strings.Builder + sb.Grow(500) + + sb.WriteString(`{"request_id":`) + if resp.RequestId == "" { + sb.WriteString(`""`) + } else { + requestIDJSON, err := json.Marshal(resp.RequestId) + if err != nil { + return "", err + } + sb.Write(requestIDJSON) + } + + switch r := resp.Response.(type) { + case *proto.GenerateResponse_Chunk: + sb.WriteString(`,"chunk":{`) + sb.WriteString(`"token_ids":`) + tokenIDsJSON, err := json.Marshal(r.Chunk.TokenIds) + if err != nil { + return "", err + } + sb.Write(tokenIDsJSON) + sb.WriteString(`,"prompt_tokens":`) + sb.WriteString(strconv.FormatInt(int64(r.Chunk.PromptTokens), 10)) + sb.WriteString(`,"completion_tokens":`) + sb.WriteString(strconv.FormatInt(int64(r.Chunk.CompletionTokens), 10)) + sb.WriteString(`,"cached_tokens":`) + sb.WriteString(strconv.FormatInt(int64(r.Chunk.CachedTokens), 10)) + sb.WriteString(`,"index":`) + sb.WriteString(strconv.FormatInt(int64(r.Chunk.Index), 10)) + sb.WriteString(`}`) + case *proto.GenerateResponse_Complete: + sb.WriteString(`,"complete":{`) + sb.WriteString(`"output_ids":`) + outputIDsJSON, err := json.Marshal(r.Complete.OutputIds) + if err != nil { + return "", err + } + sb.Write(outputIDsJSON) + sb.WriteString(`,"finish_reason":`) + finishReasonJSON, err := json.Marshal(r.Complete.FinishReason) + if err != nil { + return "", err + } + sb.Write(finishReasonJSON) + sb.WriteString(`,"prompt_tokens":`) + sb.WriteString(strconv.FormatInt(int64(r.Complete.PromptTokens), 10)) + sb.WriteString(`,"completion_tokens":`) + sb.WriteString(strconv.FormatInt(int64(r.Complete.CompletionTokens), 10)) + sb.WriteString(`,"cached_tokens":`) + sb.WriteString(strconv.FormatInt(int64(r.Complete.CachedTokens), 10)) + sb.WriteString(`}`) + case *proto.GenerateResponse_Error: + sb.WriteString(`,"error":{`) + sb.WriteString(`"message":`) + messageJSON, err := json.Marshal(r.Error.Message) + if err != nil { + return "", err + } + sb.Write(messageJSON) + sb.WriteString(`,"http_status_code":`) + httpStatusCodeJSON, err := json.Marshal(r.Error.HttpStatusCode) + if err != nil { + return "", err + } + sb.Write(httpStatusCodeJSON) + if r.Error.Details != "" { + sb.WriteString(`,"details":`) + detailsJSON, err := json.Marshal(r.Error.Details) + if err != nil { + return "", err + } + sb.Write(detailsJSON) + } + sb.WriteString(`}`) + } + + sb.WriteString(`}`) + return sb.String(), nil +} + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + Choices []StreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +// StreamChoice represents a choice in a streaming response +type StreamChoice struct { + Index int `json:"index"` + Delta MessageDelta `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// MessageDelta represents incremental message updates +type MessageDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// ToolCall represents a tool call in the response +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function FunctionCall `json:"function"` +} + +// FunctionCall represents a function call +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// Usage represents token usage information +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} diff --git a/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler.pb.go b/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler.pb.go new file mode 100644 index 000000000000..732043042771 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler.pb.go @@ -0,0 +1,3325 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v3.21.12 +// source: sglang_scheduler.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + structpb "google.golang.org/protobuf/types/known/structpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Sampling parameters matching SGLang's SamplingParams +// +// IMPORTANT: Do not use SamplingParams::default() directly! +// The proto3 defaults (0 for numeric fields) do NOT match the semantic defaults +// (temperature=1.0, top_p=1.0, top_k=-1, etc.). Always construct with explicit values +// or use the conversion functions in sglang_scheduler.rs / grpc_server.py. +type SamplingParams struct { + state protoimpl.MessageState `protogen:"open.v1"` + Temperature float32 `protobuf:"fixed32,1,opt,name=temperature,proto3" json:"temperature,omitempty"` + TopP float32 `protobuf:"fixed32,2,opt,name=top_p,json=topP,proto3" json:"top_p,omitempty"` + TopK int32 `protobuf:"varint,3,opt,name=top_k,json=topK,proto3" json:"top_k,omitempty"` + MinP float32 `protobuf:"fixed32,4,opt,name=min_p,json=minP,proto3" json:"min_p,omitempty"` + FrequencyPenalty float32 `protobuf:"fixed32,5,opt,name=frequency_penalty,json=frequencyPenalty,proto3" json:"frequency_penalty,omitempty"` + PresencePenalty float32 `protobuf:"fixed32,6,opt,name=presence_penalty,json=presencePenalty,proto3" json:"presence_penalty,omitempty"` + RepetitionPenalty float32 `protobuf:"fixed32,7,opt,name=repetition_penalty,json=repetitionPenalty,proto3" json:"repetition_penalty,omitempty"` + MaxNewTokens *int32 `protobuf:"varint,8,opt,name=max_new_tokens,json=maxNewTokens,proto3,oneof" json:"max_new_tokens,omitempty"` + Stop []string `protobuf:"bytes,9,rep,name=stop,proto3" json:"stop,omitempty"` + StopTokenIds []uint32 `protobuf:"varint,10,rep,packed,name=stop_token_ids,json=stopTokenIds,proto3" json:"stop_token_ids,omitempty"` + SkipSpecialTokens bool `protobuf:"varint,11,opt,name=skip_special_tokens,json=skipSpecialTokens,proto3" json:"skip_special_tokens,omitempty"` + SpacesBetweenSpecialTokens bool `protobuf:"varint,12,opt,name=spaces_between_special_tokens,json=spacesBetweenSpecialTokens,proto3" json:"spaces_between_special_tokens,omitempty"` + // Structured generation + // + // Types that are valid to be assigned to Constraint: + // + // *SamplingParams_Regex + // *SamplingParams_JsonSchema + // *SamplingParams_EbnfGrammar + // *SamplingParams_StructuralTag + Constraint isSamplingParams_Constraint `protobuf_oneof:"constraint"` + // Speculative decoding + N int32 `protobuf:"varint,17,opt,name=n,proto3" json:"n,omitempty"` // Number of samples + // Additional parameters + MinNewTokens int32 `protobuf:"varint,18,opt,name=min_new_tokens,json=minNewTokens,proto3" json:"min_new_tokens,omitempty"` + IgnoreEos bool `protobuf:"varint,19,opt,name=ignore_eos,json=ignoreEos,proto3" json:"ignore_eos,omitempty"` + NoStopTrim bool `protobuf:"varint,20,opt,name=no_stop_trim,json=noStopTrim,proto3" json:"no_stop_trim,omitempty"` + StreamInterval *int32 `protobuf:"varint,21,opt,name=stream_interval,json=streamInterval,proto3,oneof" json:"stream_interval,omitempty"` + LogitBias map[string]float32 `protobuf:"bytes,22,rep,name=logit_bias,json=logitBias,proto3" json:"logit_bias,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"fixed32,2,opt,name=value"` + // Custom parameters for extensibility + CustomParams *structpb.Struct `protobuf:"bytes,23,opt,name=custom_params,json=customParams,proto3" json:"custom_params,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SamplingParams) Reset() { + *x = SamplingParams{} + mi := &file_sglang_scheduler_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SamplingParams) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SamplingParams) ProtoMessage() {} + +func (x *SamplingParams) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SamplingParams.ProtoReflect.Descriptor instead. +func (*SamplingParams) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{0} +} + +func (x *SamplingParams) GetTemperature() float32 { + if x != nil { + return x.Temperature + } + return 0 +} + +func (x *SamplingParams) GetTopP() float32 { + if x != nil { + return x.TopP + } + return 0 +} + +func (x *SamplingParams) GetTopK() int32 { + if x != nil { + return x.TopK + } + return 0 +} + +func (x *SamplingParams) GetMinP() float32 { + if x != nil { + return x.MinP + } + return 0 +} + +func (x *SamplingParams) GetFrequencyPenalty() float32 { + if x != nil { + return x.FrequencyPenalty + } + return 0 +} + +func (x *SamplingParams) GetPresencePenalty() float32 { + if x != nil { + return x.PresencePenalty + } + return 0 +} + +func (x *SamplingParams) GetRepetitionPenalty() float32 { + if x != nil { + return x.RepetitionPenalty + } + return 0 +} + +func (x *SamplingParams) GetMaxNewTokens() int32 { + if x != nil && x.MaxNewTokens != nil { + return *x.MaxNewTokens + } + return 0 +} + +func (x *SamplingParams) GetStop() []string { + if x != nil { + return x.Stop + } + return nil +} + +func (x *SamplingParams) GetStopTokenIds() []uint32 { + if x != nil { + return x.StopTokenIds + } + return nil +} + +func (x *SamplingParams) GetSkipSpecialTokens() bool { + if x != nil { + return x.SkipSpecialTokens + } + return false +} + +func (x *SamplingParams) GetSpacesBetweenSpecialTokens() bool { + if x != nil { + return x.SpacesBetweenSpecialTokens + } + return false +} + +func (x *SamplingParams) GetConstraint() isSamplingParams_Constraint { + if x != nil { + return x.Constraint + } + return nil +} + +func (x *SamplingParams) GetRegex() string { + if x != nil { + if x, ok := x.Constraint.(*SamplingParams_Regex); ok { + return x.Regex + } + } + return "" +} + +func (x *SamplingParams) GetJsonSchema() string { + if x != nil { + if x, ok := x.Constraint.(*SamplingParams_JsonSchema); ok { + return x.JsonSchema + } + } + return "" +} + +func (x *SamplingParams) GetEbnfGrammar() string { + if x != nil { + if x, ok := x.Constraint.(*SamplingParams_EbnfGrammar); ok { + return x.EbnfGrammar + } + } + return "" +} + +func (x *SamplingParams) GetStructuralTag() string { + if x != nil { + if x, ok := x.Constraint.(*SamplingParams_StructuralTag); ok { + return x.StructuralTag + } + } + return "" +} + +func (x *SamplingParams) GetN() int32 { + if x != nil { + return x.N + } + return 0 +} + +func (x *SamplingParams) GetMinNewTokens() int32 { + if x != nil { + return x.MinNewTokens + } + return 0 +} + +func (x *SamplingParams) GetIgnoreEos() bool { + if x != nil { + return x.IgnoreEos + } + return false +} + +func (x *SamplingParams) GetNoStopTrim() bool { + if x != nil { + return x.NoStopTrim + } + return false +} + +func (x *SamplingParams) GetStreamInterval() int32 { + if x != nil && x.StreamInterval != nil { + return *x.StreamInterval + } + return 0 +} + +func (x *SamplingParams) GetLogitBias() map[string]float32 { + if x != nil { + return x.LogitBias + } + return nil +} + +func (x *SamplingParams) GetCustomParams() *structpb.Struct { + if x != nil { + return x.CustomParams + } + return nil +} + +type isSamplingParams_Constraint interface { + isSamplingParams_Constraint() +} + +type SamplingParams_Regex struct { + Regex string `protobuf:"bytes,13,opt,name=regex,proto3,oneof"` +} + +type SamplingParams_JsonSchema struct { + JsonSchema string `protobuf:"bytes,14,opt,name=json_schema,json=jsonSchema,proto3,oneof"` +} + +type SamplingParams_EbnfGrammar struct { + EbnfGrammar string `protobuf:"bytes,15,opt,name=ebnf_grammar,json=ebnfGrammar,proto3,oneof"` +} + +type SamplingParams_StructuralTag struct { + StructuralTag string `protobuf:"bytes,16,opt,name=structural_tag,json=structuralTag,proto3,oneof"` +} + +func (*SamplingParams_Regex) isSamplingParams_Constraint() {} + +func (*SamplingParams_JsonSchema) isSamplingParams_Constraint() {} + +func (*SamplingParams_EbnfGrammar) isSamplingParams_Constraint() {} + +func (*SamplingParams_StructuralTag) isSamplingParams_Constraint() {} + +// Disaggregated serving parameters +type DisaggregatedParams struct { + state protoimpl.MessageState `protogen:"open.v1"` + BootstrapHost string `protobuf:"bytes,1,opt,name=bootstrap_host,json=bootstrapHost,proto3" json:"bootstrap_host,omitempty"` + BootstrapPort int32 `protobuf:"varint,2,opt,name=bootstrap_port,json=bootstrapPort,proto3" json:"bootstrap_port,omitempty"` + BootstrapRoom int32 `protobuf:"varint,3,opt,name=bootstrap_room,json=bootstrapRoom,proto3" json:"bootstrap_room,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DisaggregatedParams) Reset() { + *x = DisaggregatedParams{} + mi := &file_sglang_scheduler_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DisaggregatedParams) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DisaggregatedParams) ProtoMessage() {} + +func (x *DisaggregatedParams) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DisaggregatedParams.ProtoReflect.Descriptor instead. +func (*DisaggregatedParams) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{1} +} + +func (x *DisaggregatedParams) GetBootstrapHost() string { + if x != nil { + return x.BootstrapHost + } + return "" +} + +func (x *DisaggregatedParams) GetBootstrapPort() int32 { + if x != nil { + return x.BootstrapPort + } + return 0 +} + +func (x *DisaggregatedParams) GetBootstrapRoom() int32 { + if x != nil { + return x.BootstrapRoom + } + return 0 +} + +type GenerateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Input must be tokenized (no raw text) + Tokenized *TokenizedInput `protobuf:"bytes,2,opt,name=tokenized,proto3" json:"tokenized,omitempty"` + // Multimodal inputs + MmInputs *MultimodalInputs `protobuf:"bytes,3,opt,name=mm_inputs,json=mmInputs,proto3" json:"mm_inputs,omitempty"` + // Generation parameters + SamplingParams *SamplingParams `protobuf:"bytes,4,opt,name=sampling_params,json=samplingParams,proto3" json:"sampling_params,omitempty"` + // Return options + ReturnLogprob bool `protobuf:"varint,5,opt,name=return_logprob,json=returnLogprob,proto3" json:"return_logprob,omitempty"` + LogprobStartLen int32 `protobuf:"varint,6,opt,name=logprob_start_len,json=logprobStartLen,proto3" json:"logprob_start_len,omitempty"` + TopLogprobsNum int32 `protobuf:"varint,7,opt,name=top_logprobs_num,json=topLogprobsNum,proto3" json:"top_logprobs_num,omitempty"` + TokenIdsLogprob []uint32 `protobuf:"varint,8,rep,packed,name=token_ids_logprob,json=tokenIdsLogprob,proto3" json:"token_ids_logprob,omitempty"` + ReturnHiddenStates bool `protobuf:"varint,9,opt,name=return_hidden_states,json=returnHiddenStates,proto3" json:"return_hidden_states,omitempty"` + // For disaggregated serving + DisaggregatedParams *DisaggregatedParams `protobuf:"bytes,10,opt,name=disaggregated_params,json=disaggregatedParams,proto3" json:"disaggregated_params,omitempty"` + // Custom logit processor (serialized) + CustomLogitProcessor string `protobuf:"bytes,11,opt,name=custom_logit_processor,json=customLogitProcessor,proto3" json:"custom_logit_processor,omitempty"` + // Request metadata + Timestamp *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + LogMetrics bool `protobuf:"varint,13,opt,name=log_metrics,json=logMetrics,proto3" json:"log_metrics,omitempty"` + // Input embeddings (alternative to text/tokens) + InputEmbeds []float32 `protobuf:"fixed32,14,rep,packed,name=input_embeds,json=inputEmbeds,proto3" json:"input_embeds,omitempty"` + // LoRA adapter ID (if pre-loaded) + LoraId string `protobuf:"bytes,15,opt,name=lora_id,json=loraId,proto3" json:"lora_id,omitempty"` + // Data parallel routing + DataParallelRank int32 `protobuf:"varint,16,opt,name=data_parallel_rank,json=dataParallelRank,proto3" json:"data_parallel_rank,omitempty"` + // Whether client wants streaming response + Stream bool `protobuf:"varint,17,opt,name=stream,proto3" json:"stream,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GenerateRequest) Reset() { + *x = GenerateRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GenerateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateRequest) ProtoMessage() {} + +func (x *GenerateRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateRequest.ProtoReflect.Descriptor instead. +func (*GenerateRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{2} +} + +func (x *GenerateRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *GenerateRequest) GetTokenized() *TokenizedInput { + if x != nil { + return x.Tokenized + } + return nil +} + +func (x *GenerateRequest) GetMmInputs() *MultimodalInputs { + if x != nil { + return x.MmInputs + } + return nil +} + +func (x *GenerateRequest) GetSamplingParams() *SamplingParams { + if x != nil { + return x.SamplingParams + } + return nil +} + +func (x *GenerateRequest) GetReturnLogprob() bool { + if x != nil { + return x.ReturnLogprob + } + return false +} + +func (x *GenerateRequest) GetLogprobStartLen() int32 { + if x != nil { + return x.LogprobStartLen + } + return 0 +} + +func (x *GenerateRequest) GetTopLogprobsNum() int32 { + if x != nil { + return x.TopLogprobsNum + } + return 0 +} + +func (x *GenerateRequest) GetTokenIdsLogprob() []uint32 { + if x != nil { + return x.TokenIdsLogprob + } + return nil +} + +func (x *GenerateRequest) GetReturnHiddenStates() bool { + if x != nil { + return x.ReturnHiddenStates + } + return false +} + +func (x *GenerateRequest) GetDisaggregatedParams() *DisaggregatedParams { + if x != nil { + return x.DisaggregatedParams + } + return nil +} + +func (x *GenerateRequest) GetCustomLogitProcessor() string { + if x != nil { + return x.CustomLogitProcessor + } + return "" +} + +func (x *GenerateRequest) GetTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.Timestamp + } + return nil +} + +func (x *GenerateRequest) GetLogMetrics() bool { + if x != nil { + return x.LogMetrics + } + return false +} + +func (x *GenerateRequest) GetInputEmbeds() []float32 { + if x != nil { + return x.InputEmbeds + } + return nil +} + +func (x *GenerateRequest) GetLoraId() string { + if x != nil { + return x.LoraId + } + return "" +} + +func (x *GenerateRequest) GetDataParallelRank() int32 { + if x != nil { + return x.DataParallelRank + } + return 0 +} + +func (x *GenerateRequest) GetStream() bool { + if x != nil { + return x.Stream + } + return false +} + +type TokenizedInput struct { + state protoimpl.MessageState `protogen:"open.v1"` + OriginalText string `protobuf:"bytes,1,opt,name=original_text,json=originalText,proto3" json:"original_text,omitempty"` // For reference + InputIds []uint32 `protobuf:"varint,2,rep,packed,name=input_ids,json=inputIds,proto3" json:"input_ids,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TokenizedInput) Reset() { + *x = TokenizedInput{} + mi := &file_sglang_scheduler_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TokenizedInput) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TokenizedInput) ProtoMessage() {} + +func (x *TokenizedInput) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TokenizedInput.ProtoReflect.Descriptor instead. +func (*TokenizedInput) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{3} +} + +func (x *TokenizedInput) GetOriginalText() string { + if x != nil { + return x.OriginalText + } + return "" +} + +func (x *TokenizedInput) GetInputIds() []uint32 { + if x != nil { + return x.InputIds + } + return nil +} + +type MultimodalInputs struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Simplified multimodal handling - actual data processed by tokenizer + ImageUrls []string `protobuf:"bytes,1,rep,name=image_urls,json=imageUrls,proto3" json:"image_urls,omitempty"` + VideoUrls []string `protobuf:"bytes,2,rep,name=video_urls,json=videoUrls,proto3" json:"video_urls,omitempty"` + AudioUrls []string `protobuf:"bytes,3,rep,name=audio_urls,json=audioUrls,proto3" json:"audio_urls,omitempty"` + // Pre-processed multimodal features (if available) + ProcessedFeatures *structpb.Struct `protobuf:"bytes,4,opt,name=processed_features,json=processedFeatures,proto3" json:"processed_features,omitempty"` + // Raw data for direct processing + ImageData [][]byte `protobuf:"bytes,5,rep,name=image_data,json=imageData,proto3" json:"image_data,omitempty"` + VideoData [][]byte `protobuf:"bytes,6,rep,name=video_data,json=videoData,proto3" json:"video_data,omitempty"` + AudioData [][]byte `protobuf:"bytes,7,rep,name=audio_data,json=audioData,proto3" json:"audio_data,omitempty"` + // Modality metadata + Modalities []string `protobuf:"bytes,8,rep,name=modalities,proto3" json:"modalities,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MultimodalInputs) Reset() { + *x = MultimodalInputs{} + mi := &file_sglang_scheduler_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MultimodalInputs) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MultimodalInputs) ProtoMessage() {} + +func (x *MultimodalInputs) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MultimodalInputs.ProtoReflect.Descriptor instead. +func (*MultimodalInputs) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{4} +} + +func (x *MultimodalInputs) GetImageUrls() []string { + if x != nil { + return x.ImageUrls + } + return nil +} + +func (x *MultimodalInputs) GetVideoUrls() []string { + if x != nil { + return x.VideoUrls + } + return nil +} + +func (x *MultimodalInputs) GetAudioUrls() []string { + if x != nil { + return x.AudioUrls + } + return nil +} + +func (x *MultimodalInputs) GetProcessedFeatures() *structpb.Struct { + if x != nil { + return x.ProcessedFeatures + } + return nil +} + +func (x *MultimodalInputs) GetImageData() [][]byte { + if x != nil { + return x.ImageData + } + return nil +} + +func (x *MultimodalInputs) GetVideoData() [][]byte { + if x != nil { + return x.VideoData + } + return nil +} + +func (x *MultimodalInputs) GetAudioData() [][]byte { + if x != nil { + return x.AudioData + } + return nil +} + +func (x *MultimodalInputs) GetModalities() []string { + if x != nil { + return x.Modalities + } + return nil +} + +type GenerateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Response type + // + // Types that are valid to be assigned to Response: + // + // *GenerateResponse_Chunk + // *GenerateResponse_Complete + // *GenerateResponse_Error + Response isGenerateResponse_Response `protobuf_oneof:"response"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GenerateResponse) Reset() { + *x = GenerateResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GenerateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateResponse) ProtoMessage() {} + +func (x *GenerateResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateResponse.ProtoReflect.Descriptor instead. +func (*GenerateResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{5} +} + +func (x *GenerateResponse) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *GenerateResponse) GetResponse() isGenerateResponse_Response { + if x != nil { + return x.Response + } + return nil +} + +func (x *GenerateResponse) GetChunk() *GenerateStreamChunk { + if x != nil { + if x, ok := x.Response.(*GenerateResponse_Chunk); ok { + return x.Chunk + } + } + return nil +} + +func (x *GenerateResponse) GetComplete() *GenerateComplete { + if x != nil { + if x, ok := x.Response.(*GenerateResponse_Complete); ok { + return x.Complete + } + } + return nil +} + +func (x *GenerateResponse) GetError() *GenerateError { + if x != nil { + if x, ok := x.Response.(*GenerateResponse_Error); ok { + return x.Error + } + } + return nil +} + +type isGenerateResponse_Response interface { + isGenerateResponse_Response() +} + +type GenerateResponse_Chunk struct { + Chunk *GenerateStreamChunk `protobuf:"bytes,2,opt,name=chunk,proto3,oneof"` +} + +type GenerateResponse_Complete struct { + Complete *GenerateComplete `protobuf:"bytes,3,opt,name=complete,proto3,oneof"` +} + +type GenerateResponse_Error struct { + Error *GenerateError `protobuf:"bytes,4,opt,name=error,proto3,oneof"` +} + +func (*GenerateResponse_Chunk) isGenerateResponse_Response() {} + +func (*GenerateResponse_Complete) isGenerateResponse_Response() {} + +func (*GenerateResponse_Error) isGenerateResponse_Response() {} + +type GenerateStreamChunk struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Generated tokens (incremental chunk) + TokenIds []uint32 `protobuf:"varint,1,rep,packed,name=token_ids,json=tokenIds,proto3" json:"token_ids,omitempty"` + // Cumulative counts + PromptTokens int32 `protobuf:"varint,2,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CompletionTokens int32 `protobuf:"varint,3,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` + CachedTokens int32 `protobuf:"varint,4,opt,name=cached_tokens,json=cachedTokens,proto3" json:"cached_tokens,omitempty"` + // Output logprobs (if requested) - incremental for streaming + OutputLogprobs *OutputLogProbs `protobuf:"bytes,5,opt,name=output_logprobs,json=outputLogprobs,proto3" json:"output_logprobs,omitempty"` + // Hidden states (if requested) + HiddenStates []float32 `protobuf:"fixed32,6,rep,packed,name=hidden_states,json=hiddenStates,proto3" json:"hidden_states,omitempty"` + // Input logprobs (if requested) - only in first chunk + InputLogprobs *InputLogProbs `protobuf:"bytes,7,opt,name=input_logprobs,json=inputLogprobs,proto3" json:"input_logprobs,omitempty"` + // Index for ordering when n>1 (for parallel request multiplexing) + Index uint32 `protobuf:"varint,8,opt,name=index,proto3" json:"index,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GenerateStreamChunk) Reset() { + *x = GenerateStreamChunk{} + mi := &file_sglang_scheduler_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GenerateStreamChunk) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateStreamChunk) ProtoMessage() {} + +func (x *GenerateStreamChunk) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateStreamChunk.ProtoReflect.Descriptor instead. +func (*GenerateStreamChunk) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{6} +} + +func (x *GenerateStreamChunk) GetTokenIds() []uint32 { + if x != nil { + return x.TokenIds + } + return nil +} + +func (x *GenerateStreamChunk) GetPromptTokens() int32 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *GenerateStreamChunk) GetCompletionTokens() int32 { + if x != nil { + return x.CompletionTokens + } + return 0 +} + +func (x *GenerateStreamChunk) GetCachedTokens() int32 { + if x != nil { + return x.CachedTokens + } + return 0 +} + +func (x *GenerateStreamChunk) GetOutputLogprobs() *OutputLogProbs { + if x != nil { + return x.OutputLogprobs + } + return nil +} + +func (x *GenerateStreamChunk) GetHiddenStates() []float32 { + if x != nil { + return x.HiddenStates + } + return nil +} + +func (x *GenerateStreamChunk) GetInputLogprobs() *InputLogProbs { + if x != nil { + return x.InputLogprobs + } + return nil +} + +func (x *GenerateStreamChunk) GetIndex() uint32 { + if x != nil { + return x.Index + } + return 0 +} + +type GenerateComplete struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Final output + OutputIds []uint32 `protobuf:"varint,1,rep,packed,name=output_ids,json=outputIds,proto3" json:"output_ids,omitempty"` + // Finish reason as OpenAI-compatible string ("stop", "length", "abort") + FinishReason string `protobuf:"bytes,2,opt,name=finish_reason,json=finishReason,proto3" json:"finish_reason,omitempty"` + // Token usage counts + PromptTokens int32 `protobuf:"varint,3,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CompletionTokens int32 `protobuf:"varint,4,opt,name=completion_tokens,json=completionTokens,proto3" json:"completion_tokens,omitempty"` + CachedTokens int32 `protobuf:"varint,5,opt,name=cached_tokens,json=cachedTokens,proto3" json:"cached_tokens,omitempty"` + // Output logprobs if requested (cumulative) + OutputLogprobs *OutputLogProbs `protobuf:"bytes,6,opt,name=output_logprobs,json=outputLogprobs,proto3" json:"output_logprobs,omitempty"` + // All hidden states if requested + AllHiddenStates []*HiddenStates `protobuf:"bytes,7,rep,name=all_hidden_states,json=allHiddenStates,proto3" json:"all_hidden_states,omitempty"` + // Matched stop information (for stop sequences) + // + // Types that are valid to be assigned to MatchedStop: + // + // *GenerateComplete_MatchedTokenId + // *GenerateComplete_MatchedStopStr + MatchedStop isGenerateComplete_MatchedStop `protobuf_oneof:"matched_stop"` + // Input logprobs if requested (for prompt tokens) + InputLogprobs *InputLogProbs `protobuf:"bytes,10,opt,name=input_logprobs,json=inputLogprobs,proto3" json:"input_logprobs,omitempty"` + // Index for ordering when n>1 (for parallel request multiplexing) + Index uint32 `protobuf:"varint,11,opt,name=index,proto3" json:"index,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GenerateComplete) Reset() { + *x = GenerateComplete{} + mi := &file_sglang_scheduler_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GenerateComplete) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateComplete) ProtoMessage() {} + +func (x *GenerateComplete) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateComplete.ProtoReflect.Descriptor instead. +func (*GenerateComplete) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{7} +} + +func (x *GenerateComplete) GetOutputIds() []uint32 { + if x != nil { + return x.OutputIds + } + return nil +} + +func (x *GenerateComplete) GetFinishReason() string { + if x != nil { + return x.FinishReason + } + return "" +} + +func (x *GenerateComplete) GetPromptTokens() int32 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *GenerateComplete) GetCompletionTokens() int32 { + if x != nil { + return x.CompletionTokens + } + return 0 +} + +func (x *GenerateComplete) GetCachedTokens() int32 { + if x != nil { + return x.CachedTokens + } + return 0 +} + +func (x *GenerateComplete) GetOutputLogprobs() *OutputLogProbs { + if x != nil { + return x.OutputLogprobs + } + return nil +} + +func (x *GenerateComplete) GetAllHiddenStates() []*HiddenStates { + if x != nil { + return x.AllHiddenStates + } + return nil +} + +func (x *GenerateComplete) GetMatchedStop() isGenerateComplete_MatchedStop { + if x != nil { + return x.MatchedStop + } + return nil +} + +func (x *GenerateComplete) GetMatchedTokenId() uint32 { + if x != nil { + if x, ok := x.MatchedStop.(*GenerateComplete_MatchedTokenId); ok { + return x.MatchedTokenId + } + } + return 0 +} + +func (x *GenerateComplete) GetMatchedStopStr() string { + if x != nil { + if x, ok := x.MatchedStop.(*GenerateComplete_MatchedStopStr); ok { + return x.MatchedStopStr + } + } + return "" +} + +func (x *GenerateComplete) GetInputLogprobs() *InputLogProbs { + if x != nil { + return x.InputLogprobs + } + return nil +} + +func (x *GenerateComplete) GetIndex() uint32 { + if x != nil { + return x.Index + } + return 0 +} + +type isGenerateComplete_MatchedStop interface { + isGenerateComplete_MatchedStop() +} + +type GenerateComplete_MatchedTokenId struct { + MatchedTokenId uint32 `protobuf:"varint,8,opt,name=matched_token_id,json=matchedTokenId,proto3,oneof"` +} + +type GenerateComplete_MatchedStopStr struct { + MatchedStopStr string `protobuf:"bytes,9,opt,name=matched_stop_str,json=matchedStopStr,proto3,oneof"` +} + +func (*GenerateComplete_MatchedTokenId) isGenerateComplete_MatchedStop() {} + +func (*GenerateComplete_MatchedStopStr) isGenerateComplete_MatchedStop() {} + +type GenerateError struct { + state protoimpl.MessageState `protogen:"open.v1"` + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + HttpStatusCode string `protobuf:"bytes,2,opt,name=http_status_code,json=httpStatusCode,proto3" json:"http_status_code,omitempty"` + Details string `protobuf:"bytes,3,opt,name=details,proto3" json:"details,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GenerateError) Reset() { + *x = GenerateError{} + mi := &file_sglang_scheduler_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GenerateError) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateError) ProtoMessage() {} + +func (x *GenerateError) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateError.ProtoReflect.Descriptor instead. +func (*GenerateError) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{8} +} + +func (x *GenerateError) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *GenerateError) GetHttpStatusCode() string { + if x != nil { + return x.HttpStatusCode + } + return "" +} + +func (x *GenerateError) GetDetails() string { + if x != nil { + return x.Details + } + return "" +} + +// Output logprobs - all values are present (no None) +type OutputLogProbs struct { + state protoimpl.MessageState `protogen:"open.v1"` + TokenLogprobs []float32 `protobuf:"fixed32,1,rep,packed,name=token_logprobs,json=tokenLogprobs,proto3" json:"token_logprobs,omitempty"` + TokenIds []int32 `protobuf:"varint,2,rep,packed,name=token_ids,json=tokenIds,proto3" json:"token_ids,omitempty"` + // Top logprobs at each position + TopLogprobs []*TopLogProbs `protobuf:"bytes,3,rep,name=top_logprobs,json=topLogprobs,proto3" json:"top_logprobs,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *OutputLogProbs) Reset() { + *x = OutputLogProbs{} + mi := &file_sglang_scheduler_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *OutputLogProbs) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*OutputLogProbs) ProtoMessage() {} + +func (x *OutputLogProbs) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use OutputLogProbs.ProtoReflect.Descriptor instead. +func (*OutputLogProbs) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{9} +} + +func (x *OutputLogProbs) GetTokenLogprobs() []float32 { + if x != nil { + return x.TokenLogprobs + } + return nil +} + +func (x *OutputLogProbs) GetTokenIds() []int32 { + if x != nil { + return x.TokenIds + } + return nil +} + +func (x *OutputLogProbs) GetTopLogprobs() []*TopLogProbs { + if x != nil { + return x.TopLogprobs + } + return nil +} + +// Input logprobs - first token has no logprob (None) +type InputLogProbs struct { + state protoimpl.MessageState `protogen:"open.v1"` + TokenLogprobs []*InputTokenLogProb `protobuf:"bytes,1,rep,name=token_logprobs,json=tokenLogprobs,proto3" json:"token_logprobs,omitempty"` + TokenIds []int32 `protobuf:"varint,2,rep,packed,name=token_ids,json=tokenIds,proto3" json:"token_ids,omitempty"` + // Top logprobs at each position + TopLogprobs []*TopLogProbs `protobuf:"bytes,3,rep,name=top_logprobs,json=topLogprobs,proto3" json:"top_logprobs,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InputLogProbs) Reset() { + *x = InputLogProbs{} + mi := &file_sglang_scheduler_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InputLogProbs) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InputLogProbs) ProtoMessage() {} + +func (x *InputLogProbs) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InputLogProbs.ProtoReflect.Descriptor instead. +func (*InputLogProbs) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{10} +} + +func (x *InputLogProbs) GetTokenLogprobs() []*InputTokenLogProb { + if x != nil { + return x.TokenLogprobs + } + return nil +} + +func (x *InputLogProbs) GetTokenIds() []int32 { + if x != nil { + return x.TokenIds + } + return nil +} + +func (x *InputLogProbs) GetTopLogprobs() []*TopLogProbs { + if x != nil { + return x.TopLogprobs + } + return nil +} + +// Wrapper to represent optional logprob (first input token has no logprob) +type InputTokenLogProb struct { + state protoimpl.MessageState `protogen:"open.v1"` + Value *float32 `protobuf:"fixed32,1,opt,name=value,proto3,oneof" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InputTokenLogProb) Reset() { + *x = InputTokenLogProb{} + mi := &file_sglang_scheduler_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InputTokenLogProb) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InputTokenLogProb) ProtoMessage() {} + +func (x *InputTokenLogProb) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InputTokenLogProb.ProtoReflect.Descriptor instead. +func (*InputTokenLogProb) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{11} +} + +func (x *InputTokenLogProb) GetValue() float32 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +type TopLogProbs struct { + state protoimpl.MessageState `protogen:"open.v1"` + Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + TokenIds []int32 `protobuf:"varint,2,rep,packed,name=token_ids,json=tokenIds,proto3" json:"token_ids,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TopLogProbs) Reset() { + *x = TopLogProbs{} + mi := &file_sglang_scheduler_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TopLogProbs) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TopLogProbs) ProtoMessage() {} + +func (x *TopLogProbs) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TopLogProbs.ProtoReflect.Descriptor instead. +func (*TopLogProbs) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{12} +} + +func (x *TopLogProbs) GetValues() []float32 { + if x != nil { + return x.Values + } + return nil +} + +func (x *TopLogProbs) GetTokenIds() []int32 { + if x != nil { + return x.TokenIds + } + return nil +} + +type HiddenStates struct { + state protoimpl.MessageState `protogen:"open.v1"` + Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + Layer int32 `protobuf:"varint,2,opt,name=layer,proto3" json:"layer,omitempty"` + Position int32 `protobuf:"varint,3,opt,name=position,proto3" json:"position,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HiddenStates) Reset() { + *x = HiddenStates{} + mi := &file_sglang_scheduler_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HiddenStates) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HiddenStates) ProtoMessage() {} + +func (x *HiddenStates) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HiddenStates.ProtoReflect.Descriptor instead. +func (*HiddenStates) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{13} +} + +func (x *HiddenStates) GetValues() []float32 { + if x != nil { + return x.Values + } + return nil +} + +func (x *HiddenStates) GetLayer() int32 { + if x != nil { + return x.Layer + } + return 0 +} + +func (x *HiddenStates) GetPosition() int32 { + if x != nil { + return x.Position + } + return 0 +} + +type EmbedRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Input must be tokenized (no raw text) + Tokenized *TokenizedInput `protobuf:"bytes,2,opt,name=tokenized,proto3" json:"tokenized,omitempty"` + // Multimodal inputs + MmInputs *MultimodalInputs `protobuf:"bytes,4,opt,name=mm_inputs,json=mmInputs,proto3" json:"mm_inputs,omitempty"` + // Dummy sampling params for compatibility + // EmbedRequest doesn't use sampling_params + SamplingParams *SamplingParams `protobuf:"bytes,5,opt,name=sampling_params,json=samplingParams,proto3" json:"sampling_params,omitempty"` + LogMetrics bool `protobuf:"varint,6,opt,name=log_metrics,json=logMetrics,proto3" json:"log_metrics,omitempty"` + // Token type IDs for models that require them + TokenTypeIds []int32 `protobuf:"varint,7,rep,packed,name=token_type_ids,json=tokenTypeIds,proto3" json:"token_type_ids,omitempty"` + // Data parallel routing + DataParallelRank int32 `protobuf:"varint,8,opt,name=data_parallel_rank,json=dataParallelRank,proto3" json:"data_parallel_rank,omitempty"` + // For cross-encoder requests + IsCrossEncoder bool `protobuf:"varint,9,opt,name=is_cross_encoder,json=isCrossEncoder,proto3" json:"is_cross_encoder,omitempty"` + Texts []string `protobuf:"bytes,10,rep,name=texts,proto3" json:"texts,omitempty"` // For cross-encoder batch + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EmbedRequest) Reset() { + *x = EmbedRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EmbedRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbedRequest) ProtoMessage() {} + +func (x *EmbedRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbedRequest.ProtoReflect.Descriptor instead. +func (*EmbedRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{14} +} + +func (x *EmbedRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *EmbedRequest) GetTokenized() *TokenizedInput { + if x != nil { + return x.Tokenized + } + return nil +} + +func (x *EmbedRequest) GetMmInputs() *MultimodalInputs { + if x != nil { + return x.MmInputs + } + return nil +} + +func (x *EmbedRequest) GetSamplingParams() *SamplingParams { + if x != nil { + return x.SamplingParams + } + return nil +} + +func (x *EmbedRequest) GetLogMetrics() bool { + if x != nil { + return x.LogMetrics + } + return false +} + +func (x *EmbedRequest) GetTokenTypeIds() []int32 { + if x != nil { + return x.TokenTypeIds + } + return nil +} + +func (x *EmbedRequest) GetDataParallelRank() int32 { + if x != nil { + return x.DataParallelRank + } + return 0 +} + +func (x *EmbedRequest) GetIsCrossEncoder() bool { + if x != nil { + return x.IsCrossEncoder + } + return false +} + +func (x *EmbedRequest) GetTexts() []string { + if x != nil { + return x.Texts + } + return nil +} + +type EmbedResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + // Types that are valid to be assigned to Response: + // + // *EmbedResponse_Complete + // *EmbedResponse_Error + Response isEmbedResponse_Response `protobuf_oneof:"response"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EmbedResponse) Reset() { + *x = EmbedResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EmbedResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbedResponse) ProtoMessage() {} + +func (x *EmbedResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbedResponse.ProtoReflect.Descriptor instead. +func (*EmbedResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{15} +} + +func (x *EmbedResponse) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *EmbedResponse) GetResponse() isEmbedResponse_Response { + if x != nil { + return x.Response + } + return nil +} + +func (x *EmbedResponse) GetComplete() *EmbedComplete { + if x != nil { + if x, ok := x.Response.(*EmbedResponse_Complete); ok { + return x.Complete + } + } + return nil +} + +func (x *EmbedResponse) GetError() *EmbedError { + if x != nil { + if x, ok := x.Response.(*EmbedResponse_Error); ok { + return x.Error + } + } + return nil +} + +type isEmbedResponse_Response interface { + isEmbedResponse_Response() +} + +type EmbedResponse_Complete struct { + Complete *EmbedComplete `protobuf:"bytes,2,opt,name=complete,proto3,oneof"` +} + +type EmbedResponse_Error struct { + Error *EmbedError `protobuf:"bytes,3,opt,name=error,proto3,oneof"` +} + +func (*EmbedResponse_Complete) isEmbedResponse_Response() {} + +func (*EmbedResponse_Error) isEmbedResponse_Response() {} + +type EmbedComplete struct { + state protoimpl.MessageState `protogen:"open.v1"` + Embedding []float32 `protobuf:"fixed32,1,rep,packed,name=embedding,proto3" json:"embedding,omitempty"` + PromptTokens int32 `protobuf:"varint,2,opt,name=prompt_tokens,json=promptTokens,proto3" json:"prompt_tokens,omitempty"` + CachedTokens int32 `protobuf:"varint,3,opt,name=cached_tokens,json=cachedTokens,proto3" json:"cached_tokens,omitempty"` + // Additional metadata + EmbeddingDim int32 `protobuf:"varint,4,opt,name=embedding_dim,json=embeddingDim,proto3" json:"embedding_dim,omitempty"` + // For batch embeddings + BatchEmbeddings []*Embedding `protobuf:"bytes,5,rep,name=batch_embeddings,json=batchEmbeddings,proto3" json:"batch_embeddings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EmbedComplete) Reset() { + *x = EmbedComplete{} + mi := &file_sglang_scheduler_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EmbedComplete) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbedComplete) ProtoMessage() {} + +func (x *EmbedComplete) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbedComplete.ProtoReflect.Descriptor instead. +func (*EmbedComplete) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{16} +} + +func (x *EmbedComplete) GetEmbedding() []float32 { + if x != nil { + return x.Embedding + } + return nil +} + +func (x *EmbedComplete) GetPromptTokens() int32 { + if x != nil { + return x.PromptTokens + } + return 0 +} + +func (x *EmbedComplete) GetCachedTokens() int32 { + if x != nil { + return x.CachedTokens + } + return 0 +} + +func (x *EmbedComplete) GetEmbeddingDim() int32 { + if x != nil { + return x.EmbeddingDim + } + return 0 +} + +func (x *EmbedComplete) GetBatchEmbeddings() []*Embedding { + if x != nil { + return x.BatchEmbeddings + } + return nil +} + +type Embedding struct { + state protoimpl.MessageState `protogen:"open.v1"` + Values []float32 `protobuf:"fixed32,1,rep,packed,name=values,proto3" json:"values,omitempty"` + Index int32 `protobuf:"varint,2,opt,name=index,proto3" json:"index,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Embedding) Reset() { + *x = Embedding{} + mi := &file_sglang_scheduler_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Embedding) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Embedding) ProtoMessage() {} + +func (x *Embedding) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[17] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Embedding.ProtoReflect.Descriptor instead. +func (*Embedding) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{17} +} + +func (x *Embedding) GetValues() []float32 { + if x != nil { + return x.Values + } + return nil +} + +func (x *Embedding) GetIndex() int32 { + if x != nil { + return x.Index + } + return 0 +} + +type EmbedError struct { + state protoimpl.MessageState `protogen:"open.v1"` + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"` + Details string `protobuf:"bytes,3,opt,name=details,proto3" json:"details,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EmbedError) Reset() { + *x = EmbedError{} + mi := &file_sglang_scheduler_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EmbedError) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbedError) ProtoMessage() {} + +func (x *EmbedError) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[18] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbedError.ProtoReflect.Descriptor instead. +func (*EmbedError) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{18} +} + +func (x *EmbedError) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *EmbedError) GetCode() string { + if x != nil { + return x.Code + } + return "" +} + +func (x *EmbedError) GetDetails() string { + if x != nil { + return x.Details + } + return "" +} + +type HealthCheckRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HealthCheckRequest) Reset() { + *x = HealthCheckRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HealthCheckRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthCheckRequest) ProtoMessage() {} + +func (x *HealthCheckRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthCheckRequest.ProtoReflect.Descriptor instead. +func (*HealthCheckRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{19} +} + +type HealthCheckResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Healthy bool `protobuf:"varint,1,opt,name=healthy,proto3" json:"healthy,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HealthCheckResponse) Reset() { + *x = HealthCheckResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HealthCheckResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthCheckResponse) ProtoMessage() {} + +func (x *HealthCheckResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthCheckResponse.ProtoReflect.Descriptor instead. +func (*HealthCheckResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{20} +} + +func (x *HealthCheckResponse) GetHealthy() bool { + if x != nil { + return x.Healthy + } + return false +} + +func (x *HealthCheckResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type AbortRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RequestId string `protobuf:"bytes,1,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` + Reason string `protobuf:"bytes,2,opt,name=reason,proto3" json:"reason,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AbortRequest) Reset() { + *x = AbortRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[21] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AbortRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AbortRequest) ProtoMessage() {} + +func (x *AbortRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[21] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AbortRequest.ProtoReflect.Descriptor instead. +func (*AbortRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{21} +} + +func (x *AbortRequest) GetRequestId() string { + if x != nil { + return x.RequestId + } + return "" +} + +func (x *AbortRequest) GetReason() string { + if x != nil { + return x.Reason + } + return "" +} + +type AbortResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AbortResponse) Reset() { + *x = AbortResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[22] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AbortResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AbortResponse) ProtoMessage() {} + +func (x *AbortResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[22] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AbortResponse.ProtoReflect.Descriptor instead. +func (*AbortResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{22} +} + +func (x *AbortResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *AbortResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +// Load LoRA adapter +type LoadLoRARequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + AdapterId string `protobuf:"bytes,1,opt,name=adapter_id,json=adapterId,proto3" json:"adapter_id,omitempty"` + AdapterPath string `protobuf:"bytes,2,opt,name=adapter_path,json=adapterPath,proto3" json:"adapter_path,omitempty"` + Rank int32 `protobuf:"varint,3,opt,name=rank,proto3" json:"rank,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LoadLoRARequest) Reset() { + *x = LoadLoRARequest{} + mi := &file_sglang_scheduler_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LoadLoRARequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LoadLoRARequest) ProtoMessage() {} + +func (x *LoadLoRARequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[23] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LoadLoRARequest.ProtoReflect.Descriptor instead. +func (*LoadLoRARequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{23} +} + +func (x *LoadLoRARequest) GetAdapterId() string { + if x != nil { + return x.AdapterId + } + return "" +} + +func (x *LoadLoRARequest) GetAdapterPath() string { + if x != nil { + return x.AdapterPath + } + return "" +} + +func (x *LoadLoRARequest) GetRank() int32 { + if x != nil { + return x.Rank + } + return 0 +} + +type LoadLoRAResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + AdapterId string `protobuf:"bytes,2,opt,name=adapter_id,json=adapterId,proto3" json:"adapter_id,omitempty"` + Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LoadLoRAResponse) Reset() { + *x = LoadLoRAResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[24] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LoadLoRAResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LoadLoRAResponse) ProtoMessage() {} + +func (x *LoadLoRAResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[24] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LoadLoRAResponse.ProtoReflect.Descriptor instead. +func (*LoadLoRAResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{24} +} + +func (x *LoadLoRAResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *LoadLoRAResponse) GetAdapterId() string { + if x != nil { + return x.AdapterId + } + return "" +} + +func (x *LoadLoRAResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +// Unload LoRA adapter +type UnloadLoRARequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + AdapterId string `protobuf:"bytes,1,opt,name=adapter_id,json=adapterId,proto3" json:"adapter_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UnloadLoRARequest) Reset() { + *x = UnloadLoRARequest{} + mi := &file_sglang_scheduler_proto_msgTypes[25] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UnloadLoRARequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UnloadLoRARequest) ProtoMessage() {} + +func (x *UnloadLoRARequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[25] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UnloadLoRARequest.ProtoReflect.Descriptor instead. +func (*UnloadLoRARequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{25} +} + +func (x *UnloadLoRARequest) GetAdapterId() string { + if x != nil { + return x.AdapterId + } + return "" +} + +type UnloadLoRAResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UnloadLoRAResponse) Reset() { + *x = UnloadLoRAResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[26] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UnloadLoRAResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UnloadLoRAResponse) ProtoMessage() {} + +func (x *UnloadLoRAResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[26] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UnloadLoRAResponse.ProtoReflect.Descriptor instead. +func (*UnloadLoRAResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{26} +} + +func (x *UnloadLoRAResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *UnloadLoRAResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +// Update weights +type UpdateWeightsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Source: + // + // *UpdateWeightsRequest_DiskPath + // *UpdateWeightsRequest_TensorData + // *UpdateWeightsRequest_RemoteUrl + Source isUpdateWeightsRequest_Source `protobuf_oneof:"source"` + WeightName string `protobuf:"bytes,4,opt,name=weight_name,json=weightName,proto3" json:"weight_name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UpdateWeightsRequest) Reset() { + *x = UpdateWeightsRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[27] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UpdateWeightsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UpdateWeightsRequest) ProtoMessage() {} + +func (x *UpdateWeightsRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[27] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UpdateWeightsRequest.ProtoReflect.Descriptor instead. +func (*UpdateWeightsRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{27} +} + +func (x *UpdateWeightsRequest) GetSource() isUpdateWeightsRequest_Source { + if x != nil { + return x.Source + } + return nil +} + +func (x *UpdateWeightsRequest) GetDiskPath() string { + if x != nil { + if x, ok := x.Source.(*UpdateWeightsRequest_DiskPath); ok { + return x.DiskPath + } + } + return "" +} + +func (x *UpdateWeightsRequest) GetTensorData() []byte { + if x != nil { + if x, ok := x.Source.(*UpdateWeightsRequest_TensorData); ok { + return x.TensorData + } + } + return nil +} + +func (x *UpdateWeightsRequest) GetRemoteUrl() string { + if x != nil { + if x, ok := x.Source.(*UpdateWeightsRequest_RemoteUrl); ok { + return x.RemoteUrl + } + } + return "" +} + +func (x *UpdateWeightsRequest) GetWeightName() string { + if x != nil { + return x.WeightName + } + return "" +} + +type isUpdateWeightsRequest_Source interface { + isUpdateWeightsRequest_Source() +} + +type UpdateWeightsRequest_DiskPath struct { + DiskPath string `protobuf:"bytes,1,opt,name=disk_path,json=diskPath,proto3,oneof"` +} + +type UpdateWeightsRequest_TensorData struct { + TensorData []byte `protobuf:"bytes,2,opt,name=tensor_data,json=tensorData,proto3,oneof"` +} + +type UpdateWeightsRequest_RemoteUrl struct { + RemoteUrl string `protobuf:"bytes,3,opt,name=remote_url,json=remoteUrl,proto3,oneof"` +} + +func (*UpdateWeightsRequest_DiskPath) isUpdateWeightsRequest_Source() {} + +func (*UpdateWeightsRequest_TensorData) isUpdateWeightsRequest_Source() {} + +func (*UpdateWeightsRequest_RemoteUrl) isUpdateWeightsRequest_Source() {} + +type UpdateWeightsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *UpdateWeightsResponse) Reset() { + *x = UpdateWeightsResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[28] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *UpdateWeightsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UpdateWeightsResponse) ProtoMessage() {} + +func (x *UpdateWeightsResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[28] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UpdateWeightsResponse.ProtoReflect.Descriptor instead. +func (*UpdateWeightsResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{28} +} + +func (x *UpdateWeightsResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *UpdateWeightsResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +// Get internal state for debugging +type GetInternalStateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + StateKeys []string `protobuf:"bytes,1,rep,name=state_keys,json=stateKeys,proto3" json:"state_keys,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetInternalStateRequest) Reset() { + *x = GetInternalStateRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[29] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetInternalStateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetInternalStateRequest) ProtoMessage() {} + +func (x *GetInternalStateRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[29] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetInternalStateRequest.ProtoReflect.Descriptor instead. +func (*GetInternalStateRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{29} +} + +func (x *GetInternalStateRequest) GetStateKeys() []string { + if x != nil { + return x.StateKeys + } + return nil +} + +type GetInternalStateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + State *structpb.Struct `protobuf:"bytes,1,opt,name=state,proto3" json:"state,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetInternalStateResponse) Reset() { + *x = GetInternalStateResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[30] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetInternalStateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetInternalStateResponse) ProtoMessage() {} + +func (x *GetInternalStateResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[30] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetInternalStateResponse.ProtoReflect.Descriptor instead. +func (*GetInternalStateResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{30} +} + +func (x *GetInternalStateResponse) GetState() *structpb.Struct { + if x != nil { + return x.State + } + return nil +} + +// Set internal state for testing +type SetInternalStateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + State *structpb.Struct `protobuf:"bytes,1,opt,name=state,proto3" json:"state,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetInternalStateRequest) Reset() { + *x = SetInternalStateRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetInternalStateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetInternalStateRequest) ProtoMessage() {} + +func (x *SetInternalStateRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[31] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetInternalStateRequest.ProtoReflect.Descriptor instead. +func (*SetInternalStateRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{31} +} + +func (x *SetInternalStateRequest) GetState() *structpb.Struct { + if x != nil { + return x.State + } + return nil +} + +type SetInternalStateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetInternalStateResponse) Reset() { + *x = SetInternalStateResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[32] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetInternalStateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetInternalStateResponse) ProtoMessage() {} + +func (x *SetInternalStateResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[32] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetInternalStateResponse.ProtoReflect.Descriptor instead. +func (*SetInternalStateResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{32} +} + +func (x *SetInternalStateResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +func (x *SetInternalStateResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +// Get model information +type GetModelInfoRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetModelInfoRequest) Reset() { + *x = GetModelInfoRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetModelInfoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetModelInfoRequest) ProtoMessage() {} + +func (x *GetModelInfoRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[33] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetModelInfoRequest.ProtoReflect.Descriptor instead. +func (*GetModelInfoRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{33} +} + +type GetModelInfoResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelPath string `protobuf:"bytes,1,opt,name=model_path,json=modelPath,proto3" json:"model_path,omitempty"` + TokenizerPath string `protobuf:"bytes,2,opt,name=tokenizer_path,json=tokenizerPath,proto3" json:"tokenizer_path,omitempty"` + IsGeneration bool `protobuf:"varint,3,opt,name=is_generation,json=isGeneration,proto3" json:"is_generation,omitempty"` + PreferredSamplingParams string `protobuf:"bytes,4,opt,name=preferred_sampling_params,json=preferredSamplingParams,proto3" json:"preferred_sampling_params,omitempty"` // JSON string or empty + WeightVersion string `protobuf:"bytes,5,opt,name=weight_version,json=weightVersion,proto3" json:"weight_version,omitempty"` + ServedModelName string `protobuf:"bytes,6,opt,name=served_model_name,json=servedModelName,proto3" json:"served_model_name,omitempty"` + MaxContextLength int32 `protobuf:"varint,7,opt,name=max_context_length,json=maxContextLength,proto3" json:"max_context_length,omitempty"` + VocabSize int32 `protobuf:"varint,8,opt,name=vocab_size,json=vocabSize,proto3" json:"vocab_size,omitempty"` + SupportsVision bool `protobuf:"varint,9,opt,name=supports_vision,json=supportsVision,proto3" json:"supports_vision,omitempty"` + ModelType string `protobuf:"bytes,10,opt,name=model_type,json=modelType,proto3" json:"model_type,omitempty"` + EosTokenIds []int32 `protobuf:"varint,11,rep,packed,name=eos_token_ids,json=eosTokenIds,proto3" json:"eos_token_ids,omitempty"` + PadTokenId int32 `protobuf:"varint,12,opt,name=pad_token_id,json=padTokenId,proto3" json:"pad_token_id,omitempty"` + BosTokenId int32 `protobuf:"varint,13,opt,name=bos_token_id,json=bosTokenId,proto3" json:"bos_token_id,omitempty"` + MaxReqInputLen int32 `protobuf:"varint,14,opt,name=max_req_input_len,json=maxReqInputLen,proto3" json:"max_req_input_len,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetModelInfoResponse) Reset() { + *x = GetModelInfoResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetModelInfoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetModelInfoResponse) ProtoMessage() {} + +func (x *GetModelInfoResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[34] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetModelInfoResponse.ProtoReflect.Descriptor instead. +func (*GetModelInfoResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{34} +} + +func (x *GetModelInfoResponse) GetModelPath() string { + if x != nil { + return x.ModelPath + } + return "" +} + +func (x *GetModelInfoResponse) GetTokenizerPath() string { + if x != nil { + return x.TokenizerPath + } + return "" +} + +func (x *GetModelInfoResponse) GetIsGeneration() bool { + if x != nil { + return x.IsGeneration + } + return false +} + +func (x *GetModelInfoResponse) GetPreferredSamplingParams() string { + if x != nil { + return x.PreferredSamplingParams + } + return "" +} + +func (x *GetModelInfoResponse) GetWeightVersion() string { + if x != nil { + return x.WeightVersion + } + return "" +} + +func (x *GetModelInfoResponse) GetServedModelName() string { + if x != nil { + return x.ServedModelName + } + return "" +} + +func (x *GetModelInfoResponse) GetMaxContextLength() int32 { + if x != nil { + return x.MaxContextLength + } + return 0 +} + +func (x *GetModelInfoResponse) GetVocabSize() int32 { + if x != nil { + return x.VocabSize + } + return 0 +} + +func (x *GetModelInfoResponse) GetSupportsVision() bool { + if x != nil { + return x.SupportsVision + } + return false +} + +func (x *GetModelInfoResponse) GetModelType() string { + if x != nil { + return x.ModelType + } + return "" +} + +func (x *GetModelInfoResponse) GetEosTokenIds() []int32 { + if x != nil { + return x.EosTokenIds + } + return nil +} + +func (x *GetModelInfoResponse) GetPadTokenId() int32 { + if x != nil { + return x.PadTokenId + } + return 0 +} + +func (x *GetModelInfoResponse) GetBosTokenId() int32 { + if x != nil { + return x.BosTokenId + } + return 0 +} + +func (x *GetModelInfoResponse) GetMaxReqInputLen() int32 { + if x != nil { + return x.MaxReqInputLen + } + return 0 +} + +// Get server information +type GetServerInfoRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetServerInfoRequest) Reset() { + *x = GetServerInfoRequest{} + mi := &file_sglang_scheduler_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetServerInfoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetServerInfoRequest) ProtoMessage() {} + +func (x *GetServerInfoRequest) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[35] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetServerInfoRequest.ProtoReflect.Descriptor instead. +func (*GetServerInfoRequest) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{35} +} + +type GetServerInfoResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Server configuration (as structured data) + ServerArgs *structpb.Struct `protobuf:"bytes,1,opt,name=server_args,json=serverArgs,proto3" json:"server_args,omitempty"` + // Scheduler metrics (from scheduler initialization) + SchedulerInfo *structpb.Struct `protobuf:"bytes,2,opt,name=scheduler_info,json=schedulerInfo,proto3" json:"scheduler_info,omitempty"` + // Runtime state + ActiveRequests int32 `protobuf:"varint,3,opt,name=active_requests,json=activeRequests,proto3" json:"active_requests,omitempty"` + IsPaused bool `protobuf:"varint,4,opt,name=is_paused,json=isPaused,proto3" json:"is_paused,omitempty"` + LastReceiveTimestamp float64 `protobuf:"fixed64,5,opt,name=last_receive_timestamp,json=lastReceiveTimestamp,proto3" json:"last_receive_timestamp,omitempty"` + UptimeSeconds float64 `protobuf:"fixed64,6,opt,name=uptime_seconds,json=uptimeSeconds,proto3" json:"uptime_seconds,omitempty"` + // Version info + SglangVersion string `protobuf:"bytes,7,opt,name=sglang_version,json=sglangVersion,proto3" json:"sglang_version,omitempty"` + // Server metadata + ServerType string `protobuf:"bytes,8,opt,name=server_type,json=serverType,proto3" json:"server_type,omitempty"` // "grpc" + StartTime *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetServerInfoResponse) Reset() { + *x = GetServerInfoResponse{} + mi := &file_sglang_scheduler_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetServerInfoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetServerInfoResponse) ProtoMessage() {} + +func (x *GetServerInfoResponse) ProtoReflect() protoreflect.Message { + mi := &file_sglang_scheduler_proto_msgTypes[36] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetServerInfoResponse.ProtoReflect.Descriptor instead. +func (*GetServerInfoResponse) Descriptor() ([]byte, []int) { + return file_sglang_scheduler_proto_rawDescGZIP(), []int{36} +} + +func (x *GetServerInfoResponse) GetServerArgs() *structpb.Struct { + if x != nil { + return x.ServerArgs + } + return nil +} + +func (x *GetServerInfoResponse) GetSchedulerInfo() *structpb.Struct { + if x != nil { + return x.SchedulerInfo + } + return nil +} + +func (x *GetServerInfoResponse) GetActiveRequests() int32 { + if x != nil { + return x.ActiveRequests + } + return 0 +} + +func (x *GetServerInfoResponse) GetIsPaused() bool { + if x != nil { + return x.IsPaused + } + return false +} + +func (x *GetServerInfoResponse) GetLastReceiveTimestamp() float64 { + if x != nil { + return x.LastReceiveTimestamp + } + return 0 +} + +func (x *GetServerInfoResponse) GetUptimeSeconds() float64 { + if x != nil { + return x.UptimeSeconds + } + return 0 +} + +func (x *GetServerInfoResponse) GetSglangVersion() string { + if x != nil { + return x.SglangVersion + } + return "" +} + +func (x *GetServerInfoResponse) GetServerType() string { + if x != nil { + return x.ServerType + } + return "" +} + +func (x *GetServerInfoResponse) GetStartTime() *timestamppb.Timestamp { + if x != nil { + return x.StartTime + } + return nil +} + +var File_sglang_scheduler_proto protoreflect.FileDescriptor + +const file_sglang_scheduler_proto_rawDesc = "" + + "\n" + + "\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\x82\b\n" + + "\x0eSamplingParams\x12 \n" + + "\vtemperature\x18\x01 \x01(\x02R\vtemperature\x12\x13\n" + + "\x05top_p\x18\x02 \x01(\x02R\x04topP\x12\x13\n" + + "\x05top_k\x18\x03 \x01(\x05R\x04topK\x12\x13\n" + + "\x05min_p\x18\x04 \x01(\x02R\x04minP\x12+\n" + + "\x11frequency_penalty\x18\x05 \x01(\x02R\x10frequencyPenalty\x12)\n" + + "\x10presence_penalty\x18\x06 \x01(\x02R\x0fpresencePenalty\x12-\n" + + "\x12repetition_penalty\x18\a \x01(\x02R\x11repetitionPenalty\x12)\n" + + "\x0emax_new_tokens\x18\b \x01(\x05H\x01R\fmaxNewTokens\x88\x01\x01\x12\x12\n" + + "\x04stop\x18\t \x03(\tR\x04stop\x12$\n" + + "\x0estop_token_ids\x18\n" + + " \x03(\rR\fstopTokenIds\x12.\n" + + "\x13skip_special_tokens\x18\v \x01(\bR\x11skipSpecialTokens\x12A\n" + + "\x1dspaces_between_special_tokens\x18\f \x01(\bR\x1aspacesBetweenSpecialTokens\x12\x16\n" + + "\x05regex\x18\r \x01(\tH\x00R\x05regex\x12!\n" + + "\vjson_schema\x18\x0e \x01(\tH\x00R\n" + + "jsonSchema\x12#\n" + + "\febnf_grammar\x18\x0f \x01(\tH\x00R\vebnfGrammar\x12'\n" + + "\x0estructural_tag\x18\x10 \x01(\tH\x00R\rstructuralTag\x12\f\n" + + "\x01n\x18\x11 \x01(\x05R\x01n\x12$\n" + + "\x0emin_new_tokens\x18\x12 \x01(\x05R\fminNewTokens\x12\x1d\n" + + "\n" + + "ignore_eos\x18\x13 \x01(\bR\tignoreEos\x12 \n" + + "\fno_stop_trim\x18\x14 \x01(\bR\n" + + "noStopTrim\x12,\n" + + "\x0fstream_interval\x18\x15 \x01(\x05H\x02R\x0estreamInterval\x88\x01\x01\x12S\n" + + "\n" + + "logit_bias\x18\x16 \x03(\v24.sglang.grpc.scheduler.SamplingParams.LogitBiasEntryR\tlogitBias\x12<\n" + + "\rcustom_params\x18\x17 \x01(\v2\x17.google.protobuf.StructR\fcustomParams\x1a<\n" + + "\x0eLogitBiasEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\x02R\x05value:\x028\x01B\f\n" + + "\n" + + "constraintB\x11\n" + + "\x0f_max_new_tokensB\x12\n" + + "\x10_stream_interval\"\x8a\x01\n" + + "\x13DisaggregatedParams\x12%\n" + + "\x0ebootstrap_host\x18\x01 \x01(\tR\rbootstrapHost\x12%\n" + + "\x0ebootstrap_port\x18\x02 \x01(\x05R\rbootstrapPort\x12%\n" + + "\x0ebootstrap_room\x18\x03 \x01(\x05R\rbootstrapRoom\"\xd8\x06\n" + + "\x0fGenerateRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12C\n" + + "\ttokenized\x18\x02 \x01(\v2%.sglang.grpc.scheduler.TokenizedInputR\ttokenized\x12D\n" + + "\tmm_inputs\x18\x03 \x01(\v2'.sglang.grpc.scheduler.MultimodalInputsR\bmmInputs\x12N\n" + + "\x0fsampling_params\x18\x04 \x01(\v2%.sglang.grpc.scheduler.SamplingParamsR\x0esamplingParams\x12%\n" + + "\x0ereturn_logprob\x18\x05 \x01(\bR\rreturnLogprob\x12*\n" + + "\x11logprob_start_len\x18\x06 \x01(\x05R\x0flogprobStartLen\x12(\n" + + "\x10top_logprobs_num\x18\a \x01(\x05R\x0etopLogprobsNum\x12*\n" + + "\x11token_ids_logprob\x18\b \x03(\rR\x0ftokenIdsLogprob\x120\n" + + "\x14return_hidden_states\x18\t \x01(\bR\x12returnHiddenStates\x12]\n" + + "\x14disaggregated_params\x18\n" + + " \x01(\v2*.sglang.grpc.scheduler.DisaggregatedParamsR\x13disaggregatedParams\x124\n" + + "\x16custom_logit_processor\x18\v \x01(\tR\x14customLogitProcessor\x128\n" + + "\ttimestamp\x18\f \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1f\n" + + "\vlog_metrics\x18\r \x01(\bR\n" + + "logMetrics\x12!\n" + + "\finput_embeds\x18\x0e \x03(\x02R\vinputEmbeds\x12\x17\n" + + "\alora_id\x18\x0f \x01(\tR\x06loraId\x12,\n" + + "\x12data_parallel_rank\x18\x10 \x01(\x05R\x10dataParallelRank\x12\x16\n" + + "\x06stream\x18\x11 \x01(\bR\x06stream\"R\n" + + "\x0eTokenizedInput\x12#\n" + + "\roriginal_text\x18\x01 \x01(\tR\foriginalText\x12\x1b\n" + + "\tinput_ids\x18\x02 \x03(\rR\binputIds\"\xb4\x02\n" + + "\x10MultimodalInputs\x12\x1d\n" + + "\n" + + "image_urls\x18\x01 \x03(\tR\timageUrls\x12\x1d\n" + + "\n" + + "video_urls\x18\x02 \x03(\tR\tvideoUrls\x12\x1d\n" + + "\n" + + "audio_urls\x18\x03 \x03(\tR\taudioUrls\x12F\n" + + "\x12processed_features\x18\x04 \x01(\v2\x17.google.protobuf.StructR\x11processedFeatures\x12\x1d\n" + + "\n" + + "image_data\x18\x05 \x03(\fR\timageData\x12\x1d\n" + + "\n" + + "video_data\x18\x06 \x03(\fR\tvideoData\x12\x1d\n" + + "\n" + + "audio_data\x18\a \x03(\fR\taudioData\x12\x1e\n" + + "\n" + + "modalities\x18\b \x03(\tR\n" + + "modalities\"\x86\x02\n" + + "\x10GenerateResponse\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12B\n" + + "\x05chunk\x18\x02 \x01(\v2*.sglang.grpc.scheduler.GenerateStreamChunkH\x00R\x05chunk\x12E\n" + + "\bcomplete\x18\x03 \x01(\v2'.sglang.grpc.scheduler.GenerateCompleteH\x00R\bcomplete\x12<\n" + + "\x05error\x18\x04 \x01(\v2$.sglang.grpc.scheduler.GenerateErrorH\x00R\x05errorB\n" + + "\n" + + "\bresponse\"\x81\x03\n" + + "\x13GenerateStreamChunk\x12\x1b\n" + + "\ttoken_ids\x18\x01 \x03(\rR\btokenIds\x12#\n" + + "\rprompt_tokens\x18\x02 \x01(\x05R\fpromptTokens\x12+\n" + + "\x11completion_tokens\x18\x03 \x01(\x05R\x10completionTokens\x12#\n" + + "\rcached_tokens\x18\x04 \x01(\x05R\fcachedTokens\x12N\n" + + "\x0foutput_logprobs\x18\x05 \x01(\v2%.sglang.grpc.scheduler.OutputLogProbsR\x0eoutputLogprobs\x12#\n" + + "\rhidden_states\x18\x06 \x03(\x02R\fhiddenStates\x12K\n" + + "\x0einput_logprobs\x18\a \x01(\v2$.sglang.grpc.scheduler.InputLogProbsR\rinputLogprobs\x12\x14\n" + + "\x05index\x18\b \x01(\rR\x05index\"\xb9\x04\n" + + "\x10GenerateComplete\x12\x1d\n" + + "\n" + + "output_ids\x18\x01 \x03(\rR\toutputIds\x12#\n" + + "\rfinish_reason\x18\x02 \x01(\tR\ffinishReason\x12#\n" + + "\rprompt_tokens\x18\x03 \x01(\x05R\fpromptTokens\x12+\n" + + "\x11completion_tokens\x18\x04 \x01(\x05R\x10completionTokens\x12#\n" + + "\rcached_tokens\x18\x05 \x01(\x05R\fcachedTokens\x12N\n" + + "\x0foutput_logprobs\x18\x06 \x01(\v2%.sglang.grpc.scheduler.OutputLogProbsR\x0eoutputLogprobs\x12O\n" + + "\x11all_hidden_states\x18\a \x03(\v2#.sglang.grpc.scheduler.HiddenStatesR\x0fallHiddenStates\x12*\n" + + "\x10matched_token_id\x18\b \x01(\rH\x00R\x0ematchedTokenId\x12*\n" + + "\x10matched_stop_str\x18\t \x01(\tH\x00R\x0ematchedStopStr\x12K\n" + + "\x0einput_logprobs\x18\n" + + " \x01(\v2$.sglang.grpc.scheduler.InputLogProbsR\rinputLogprobs\x12\x14\n" + + "\x05index\x18\v \x01(\rR\x05indexB\x0e\n" + + "\fmatched_stop\"m\n" + + "\rGenerateError\x12\x18\n" + + "\amessage\x18\x01 \x01(\tR\amessage\x12(\n" + + "\x10http_status_code\x18\x02 \x01(\tR\x0ehttpStatusCode\x12\x18\n" + + "\adetails\x18\x03 \x01(\tR\adetails\"\x9b\x01\n" + + "\x0eOutputLogProbs\x12%\n" + + "\x0etoken_logprobs\x18\x01 \x03(\x02R\rtokenLogprobs\x12\x1b\n" + + "\ttoken_ids\x18\x02 \x03(\x05R\btokenIds\x12E\n" + + "\ftop_logprobs\x18\x03 \x03(\v2\".sglang.grpc.scheduler.TopLogProbsR\vtopLogprobs\"\xc4\x01\n" + + "\rInputLogProbs\x12O\n" + + "\x0etoken_logprobs\x18\x01 \x03(\v2(.sglang.grpc.scheduler.InputTokenLogProbR\rtokenLogprobs\x12\x1b\n" + + "\ttoken_ids\x18\x02 \x03(\x05R\btokenIds\x12E\n" + + "\ftop_logprobs\x18\x03 \x03(\v2\".sglang.grpc.scheduler.TopLogProbsR\vtopLogprobs\"8\n" + + "\x11InputTokenLogProb\x12\x19\n" + + "\x05value\x18\x01 \x01(\x02H\x00R\x05value\x88\x01\x01B\b\n" + + "\x06_value\"B\n" + + "\vTopLogProbs\x12\x16\n" + + "\x06values\x18\x01 \x03(\x02R\x06values\x12\x1b\n" + + "\ttoken_ids\x18\x02 \x03(\x05R\btokenIds\"X\n" + + "\fHiddenStates\x12\x16\n" + + "\x06values\x18\x01 \x03(\x02R\x06values\x12\x14\n" + + "\x05layer\x18\x02 \x01(\x05R\x05layer\x12\x1a\n" + + "\bposition\x18\x03 \x01(\x05R\bposition\"\xbd\x03\n" + + "\fEmbedRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12C\n" + + "\ttokenized\x18\x02 \x01(\v2%.sglang.grpc.scheduler.TokenizedInputR\ttokenized\x12D\n" + + "\tmm_inputs\x18\x04 \x01(\v2'.sglang.grpc.scheduler.MultimodalInputsR\bmmInputs\x12N\n" + + "\x0fsampling_params\x18\x05 \x01(\v2%.sglang.grpc.scheduler.SamplingParamsR\x0esamplingParams\x12\x1f\n" + + "\vlog_metrics\x18\x06 \x01(\bR\n" + + "logMetrics\x12$\n" + + "\x0etoken_type_ids\x18\a \x03(\x05R\ftokenTypeIds\x12,\n" + + "\x12data_parallel_rank\x18\b \x01(\x05R\x10dataParallelRank\x12(\n" + + "\x10is_cross_encoder\x18\t \x01(\bR\x0eisCrossEncoder\x12\x14\n" + + "\x05texts\x18\n" + + " \x03(\tR\x05texts\"\xb9\x01\n" + + "\rEmbedResponse\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12B\n" + + "\bcomplete\x18\x02 \x01(\v2$.sglang.grpc.scheduler.EmbedCompleteH\x00R\bcomplete\x129\n" + + "\x05error\x18\x03 \x01(\v2!.sglang.grpc.scheduler.EmbedErrorH\x00R\x05errorB\n" + + "\n" + + "\bresponse\"\xe9\x01\n" + + "\rEmbedComplete\x12\x1c\n" + + "\tembedding\x18\x01 \x03(\x02R\tembedding\x12#\n" + + "\rprompt_tokens\x18\x02 \x01(\x05R\fpromptTokens\x12#\n" + + "\rcached_tokens\x18\x03 \x01(\x05R\fcachedTokens\x12#\n" + + "\rembedding_dim\x18\x04 \x01(\x05R\fembeddingDim\x12K\n" + + "\x10batch_embeddings\x18\x05 \x03(\v2 .sglang.grpc.scheduler.EmbeddingR\x0fbatchEmbeddings\"9\n" + + "\tEmbedding\x12\x16\n" + + "\x06values\x18\x01 \x03(\x02R\x06values\x12\x14\n" + + "\x05index\x18\x02 \x01(\x05R\x05index\"T\n" + + "\n" + + "EmbedError\x12\x18\n" + + "\amessage\x18\x01 \x01(\tR\amessage\x12\x12\n" + + "\x04code\x18\x02 \x01(\tR\x04code\x12\x18\n" + + "\adetails\x18\x03 \x01(\tR\adetails\"\x14\n" + + "\x12HealthCheckRequest\"I\n" + + "\x13HealthCheckResponse\x12\x18\n" + + "\ahealthy\x18\x01 \x01(\bR\ahealthy\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"E\n" + + "\fAbortRequest\x12\x1d\n" + + "\n" + + "request_id\x18\x01 \x01(\tR\trequestId\x12\x16\n" + + "\x06reason\x18\x02 \x01(\tR\x06reason\"C\n" + + "\rAbortResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"g\n" + + "\x0fLoadLoRARequest\x12\x1d\n" + + "\n" + + "adapter_id\x18\x01 \x01(\tR\tadapterId\x12!\n" + + "\fadapter_path\x18\x02 \x01(\tR\vadapterPath\x12\x12\n" + + "\x04rank\x18\x03 \x01(\x05R\x04rank\"e\n" + + "\x10LoadLoRAResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1d\n" + + "\n" + + "adapter_id\x18\x02 \x01(\tR\tadapterId\x12\x18\n" + + "\amessage\x18\x03 \x01(\tR\amessage\"2\n" + + "\x11UnloadLoRARequest\x12\x1d\n" + + "\n" + + "adapter_id\x18\x01 \x01(\tR\tadapterId\"H\n" + + "\x12UnloadLoRAResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"\xa4\x01\n" + + "\x14UpdateWeightsRequest\x12\x1d\n" + + "\tdisk_path\x18\x01 \x01(\tH\x00R\bdiskPath\x12!\n" + + "\vtensor_data\x18\x02 \x01(\fH\x00R\n" + + "tensorData\x12\x1f\n" + + "\n" + + "remote_url\x18\x03 \x01(\tH\x00R\tremoteUrl\x12\x1f\n" + + "\vweight_name\x18\x04 \x01(\tR\n" + + "weightNameB\b\n" + + "\x06source\"K\n" + + "\x15UpdateWeightsResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"8\n" + + "\x17GetInternalStateRequest\x12\x1d\n" + + "\n" + + "state_keys\x18\x01 \x03(\tR\tstateKeys\"I\n" + + "\x18GetInternalStateResponse\x12-\n" + + "\x05state\x18\x01 \x01(\v2\x17.google.protobuf.StructR\x05state\"H\n" + + "\x17SetInternalStateRequest\x12-\n" + + "\x05state\x18\x01 \x01(\v2\x17.google.protobuf.StructR\x05state\"N\n" + + "\x18SetInternalStateResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + + "\amessage\x18\x02 \x01(\tR\amessage\"\x15\n" + + "\x13GetModelInfoRequest\"\xb8\x04\n" + + "\x14GetModelInfoResponse\x12\x1d\n" + + "\n" + + "model_path\x18\x01 \x01(\tR\tmodelPath\x12%\n" + + "\x0etokenizer_path\x18\x02 \x01(\tR\rtokenizerPath\x12#\n" + + "\ris_generation\x18\x03 \x01(\bR\fisGeneration\x12:\n" + + "\x19preferred_sampling_params\x18\x04 \x01(\tR\x17preferredSamplingParams\x12%\n" + + "\x0eweight_version\x18\x05 \x01(\tR\rweightVersion\x12*\n" + + "\x11served_model_name\x18\x06 \x01(\tR\x0fservedModelName\x12,\n" + + "\x12max_context_length\x18\a \x01(\x05R\x10maxContextLength\x12\x1d\n" + + "\n" + + "vocab_size\x18\b \x01(\x05R\tvocabSize\x12'\n" + + "\x0fsupports_vision\x18\t \x01(\bR\x0esupportsVision\x12\x1d\n" + + "\n" + + "model_type\x18\n" + + " \x01(\tR\tmodelType\x12\"\n" + + "\reos_token_ids\x18\v \x03(\x05R\veosTokenIds\x12 \n" + + "\fpad_token_id\x18\f \x01(\x05R\n" + + "padTokenId\x12 \n" + + "\fbos_token_id\x18\r \x01(\x05R\n" + + "bosTokenId\x12)\n" + + "\x11max_req_input_len\x18\x0e \x01(\x05R\x0emaxReqInputLen\"\x16\n" + + "\x14GetServerInfoRequest\"\xb7\x03\n" + + "\x15GetServerInfoResponse\x128\n" + + "\vserver_args\x18\x01 \x01(\v2\x17.google.protobuf.StructR\n" + + "serverArgs\x12>\n" + + "\x0escheduler_info\x18\x02 \x01(\v2\x17.google.protobuf.StructR\rschedulerInfo\x12'\n" + + "\x0factive_requests\x18\x03 \x01(\x05R\x0eactiveRequests\x12\x1b\n" + + "\tis_paused\x18\x04 \x01(\bR\bisPaused\x124\n" + + "\x16last_receive_timestamp\x18\x05 \x01(\x01R\x14lastReceiveTimestamp\x12%\n" + + "\x0euptime_seconds\x18\x06 \x01(\x01R\ruptimeSeconds\x12%\n" + + "\x0esglang_version\x18\a \x01(\tR\rsglangVersion\x12\x1f\n" + + "\vserver_type\x18\b \x01(\tR\n" + + "serverType\x129\n" + + "\n" + + "start_time\x18\t \x01(\v2\x1a.google.protobuf.TimestampR\tstartTime2\xd3\x04\n" + + "\x0fSglangScheduler\x12]\n" + + "\bGenerate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n" + + "\x05Embed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12d\n" + + "\vHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n" + + "\x05Abort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponse\x12g\n" + + "\fGetModelInfo\x12*.sglang.grpc.scheduler.GetModelInfoRequest\x1a+.sglang.grpc.scheduler.GetModelInfoResponse\x12j\n" + + "\rGetServerInfo\x12+.sglang.grpc.scheduler.GetServerInfoRequest\x1a,.sglang.grpc.scheduler.GetServerInfoResponseb\x06proto3" + +var ( + file_sglang_scheduler_proto_rawDescOnce sync.Once + file_sglang_scheduler_proto_rawDescData []byte +) + +func file_sglang_scheduler_proto_rawDescGZIP() []byte { + file_sglang_scheduler_proto_rawDescOnce.Do(func() { + file_sglang_scheduler_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_sglang_scheduler_proto_rawDesc), len(file_sglang_scheduler_proto_rawDesc))) + }) + return file_sglang_scheduler_proto_rawDescData +} + +var file_sglang_scheduler_proto_msgTypes = make([]protoimpl.MessageInfo, 38) +var file_sglang_scheduler_proto_goTypes = []any{ + (*SamplingParams)(nil), // 0: sglang.grpc.scheduler.SamplingParams + (*DisaggregatedParams)(nil), // 1: sglang.grpc.scheduler.DisaggregatedParams + (*GenerateRequest)(nil), // 2: sglang.grpc.scheduler.GenerateRequest + (*TokenizedInput)(nil), // 3: sglang.grpc.scheduler.TokenizedInput + (*MultimodalInputs)(nil), // 4: sglang.grpc.scheduler.MultimodalInputs + (*GenerateResponse)(nil), // 5: sglang.grpc.scheduler.GenerateResponse + (*GenerateStreamChunk)(nil), // 6: sglang.grpc.scheduler.GenerateStreamChunk + (*GenerateComplete)(nil), // 7: sglang.grpc.scheduler.GenerateComplete + (*GenerateError)(nil), // 8: sglang.grpc.scheduler.GenerateError + (*OutputLogProbs)(nil), // 9: sglang.grpc.scheduler.OutputLogProbs + (*InputLogProbs)(nil), // 10: sglang.grpc.scheduler.InputLogProbs + (*InputTokenLogProb)(nil), // 11: sglang.grpc.scheduler.InputTokenLogProb + (*TopLogProbs)(nil), // 12: sglang.grpc.scheduler.TopLogProbs + (*HiddenStates)(nil), // 13: sglang.grpc.scheduler.HiddenStates + (*EmbedRequest)(nil), // 14: sglang.grpc.scheduler.EmbedRequest + (*EmbedResponse)(nil), // 15: sglang.grpc.scheduler.EmbedResponse + (*EmbedComplete)(nil), // 16: sglang.grpc.scheduler.EmbedComplete + (*Embedding)(nil), // 17: sglang.grpc.scheduler.Embedding + (*EmbedError)(nil), // 18: sglang.grpc.scheduler.EmbedError + (*HealthCheckRequest)(nil), // 19: sglang.grpc.scheduler.HealthCheckRequest + (*HealthCheckResponse)(nil), // 20: sglang.grpc.scheduler.HealthCheckResponse + (*AbortRequest)(nil), // 21: sglang.grpc.scheduler.AbortRequest + (*AbortResponse)(nil), // 22: sglang.grpc.scheduler.AbortResponse + (*LoadLoRARequest)(nil), // 23: sglang.grpc.scheduler.LoadLoRARequest + (*LoadLoRAResponse)(nil), // 24: sglang.grpc.scheduler.LoadLoRAResponse + (*UnloadLoRARequest)(nil), // 25: sglang.grpc.scheduler.UnloadLoRARequest + (*UnloadLoRAResponse)(nil), // 26: sglang.grpc.scheduler.UnloadLoRAResponse + (*UpdateWeightsRequest)(nil), // 27: sglang.grpc.scheduler.UpdateWeightsRequest + (*UpdateWeightsResponse)(nil), // 28: sglang.grpc.scheduler.UpdateWeightsResponse + (*GetInternalStateRequest)(nil), // 29: sglang.grpc.scheduler.GetInternalStateRequest + (*GetInternalStateResponse)(nil), // 30: sglang.grpc.scheduler.GetInternalStateResponse + (*SetInternalStateRequest)(nil), // 31: sglang.grpc.scheduler.SetInternalStateRequest + (*SetInternalStateResponse)(nil), // 32: sglang.grpc.scheduler.SetInternalStateResponse + (*GetModelInfoRequest)(nil), // 33: sglang.grpc.scheduler.GetModelInfoRequest + (*GetModelInfoResponse)(nil), // 34: sglang.grpc.scheduler.GetModelInfoResponse + (*GetServerInfoRequest)(nil), // 35: sglang.grpc.scheduler.GetServerInfoRequest + (*GetServerInfoResponse)(nil), // 36: sglang.grpc.scheduler.GetServerInfoResponse + nil, // 37: sglang.grpc.scheduler.SamplingParams.LogitBiasEntry + (*structpb.Struct)(nil), // 38: google.protobuf.Struct + (*timestamppb.Timestamp)(nil), // 39: google.protobuf.Timestamp +} +var file_sglang_scheduler_proto_depIdxs = []int32{ + 37, // 0: sglang.grpc.scheduler.SamplingParams.logit_bias:type_name -> sglang.grpc.scheduler.SamplingParams.LogitBiasEntry + 38, // 1: sglang.grpc.scheduler.SamplingParams.custom_params:type_name -> google.protobuf.Struct + 3, // 2: sglang.grpc.scheduler.GenerateRequest.tokenized:type_name -> sglang.grpc.scheduler.TokenizedInput + 4, // 3: sglang.grpc.scheduler.GenerateRequest.mm_inputs:type_name -> sglang.grpc.scheduler.MultimodalInputs + 0, // 4: sglang.grpc.scheduler.GenerateRequest.sampling_params:type_name -> sglang.grpc.scheduler.SamplingParams + 1, // 5: sglang.grpc.scheduler.GenerateRequest.disaggregated_params:type_name -> sglang.grpc.scheduler.DisaggregatedParams + 39, // 6: sglang.grpc.scheduler.GenerateRequest.timestamp:type_name -> google.protobuf.Timestamp + 38, // 7: sglang.grpc.scheduler.MultimodalInputs.processed_features:type_name -> google.protobuf.Struct + 6, // 8: sglang.grpc.scheduler.GenerateResponse.chunk:type_name -> sglang.grpc.scheduler.GenerateStreamChunk + 7, // 9: sglang.grpc.scheduler.GenerateResponse.complete:type_name -> sglang.grpc.scheduler.GenerateComplete + 8, // 10: sglang.grpc.scheduler.GenerateResponse.error:type_name -> sglang.grpc.scheduler.GenerateError + 9, // 11: sglang.grpc.scheduler.GenerateStreamChunk.output_logprobs:type_name -> sglang.grpc.scheduler.OutputLogProbs + 10, // 12: sglang.grpc.scheduler.GenerateStreamChunk.input_logprobs:type_name -> sglang.grpc.scheduler.InputLogProbs + 9, // 13: sglang.grpc.scheduler.GenerateComplete.output_logprobs:type_name -> sglang.grpc.scheduler.OutputLogProbs + 13, // 14: sglang.grpc.scheduler.GenerateComplete.all_hidden_states:type_name -> sglang.grpc.scheduler.HiddenStates + 10, // 15: sglang.grpc.scheduler.GenerateComplete.input_logprobs:type_name -> sglang.grpc.scheduler.InputLogProbs + 12, // 16: sglang.grpc.scheduler.OutputLogProbs.top_logprobs:type_name -> sglang.grpc.scheduler.TopLogProbs + 11, // 17: sglang.grpc.scheduler.InputLogProbs.token_logprobs:type_name -> sglang.grpc.scheduler.InputTokenLogProb + 12, // 18: sglang.grpc.scheduler.InputLogProbs.top_logprobs:type_name -> sglang.grpc.scheduler.TopLogProbs + 3, // 19: sglang.grpc.scheduler.EmbedRequest.tokenized:type_name -> sglang.grpc.scheduler.TokenizedInput + 4, // 20: sglang.grpc.scheduler.EmbedRequest.mm_inputs:type_name -> sglang.grpc.scheduler.MultimodalInputs + 0, // 21: sglang.grpc.scheduler.EmbedRequest.sampling_params:type_name -> sglang.grpc.scheduler.SamplingParams + 16, // 22: sglang.grpc.scheduler.EmbedResponse.complete:type_name -> sglang.grpc.scheduler.EmbedComplete + 18, // 23: sglang.grpc.scheduler.EmbedResponse.error:type_name -> sglang.grpc.scheduler.EmbedError + 17, // 24: sglang.grpc.scheduler.EmbedComplete.batch_embeddings:type_name -> sglang.grpc.scheduler.Embedding + 38, // 25: sglang.grpc.scheduler.GetInternalStateResponse.state:type_name -> google.protobuf.Struct + 38, // 26: sglang.grpc.scheduler.SetInternalStateRequest.state:type_name -> google.protobuf.Struct + 38, // 27: sglang.grpc.scheduler.GetServerInfoResponse.server_args:type_name -> google.protobuf.Struct + 38, // 28: sglang.grpc.scheduler.GetServerInfoResponse.scheduler_info:type_name -> google.protobuf.Struct + 39, // 29: sglang.grpc.scheduler.GetServerInfoResponse.start_time:type_name -> google.protobuf.Timestamp + 2, // 30: sglang.grpc.scheduler.SglangScheduler.Generate:input_type -> sglang.grpc.scheduler.GenerateRequest + 14, // 31: sglang.grpc.scheduler.SglangScheduler.Embed:input_type -> sglang.grpc.scheduler.EmbedRequest + 19, // 32: sglang.grpc.scheduler.SglangScheduler.HealthCheck:input_type -> sglang.grpc.scheduler.HealthCheckRequest + 21, // 33: sglang.grpc.scheduler.SglangScheduler.Abort:input_type -> sglang.grpc.scheduler.AbortRequest + 33, // 34: sglang.grpc.scheduler.SglangScheduler.GetModelInfo:input_type -> sglang.grpc.scheduler.GetModelInfoRequest + 35, // 35: sglang.grpc.scheduler.SglangScheduler.GetServerInfo:input_type -> sglang.grpc.scheduler.GetServerInfoRequest + 5, // 36: sglang.grpc.scheduler.SglangScheduler.Generate:output_type -> sglang.grpc.scheduler.GenerateResponse + 15, // 37: sglang.grpc.scheduler.SglangScheduler.Embed:output_type -> sglang.grpc.scheduler.EmbedResponse + 20, // 38: sglang.grpc.scheduler.SglangScheduler.HealthCheck:output_type -> sglang.grpc.scheduler.HealthCheckResponse + 22, // 39: sglang.grpc.scheduler.SglangScheduler.Abort:output_type -> sglang.grpc.scheduler.AbortResponse + 34, // 40: sglang.grpc.scheduler.SglangScheduler.GetModelInfo:output_type -> sglang.grpc.scheduler.GetModelInfoResponse + 36, // 41: sglang.grpc.scheduler.SglangScheduler.GetServerInfo:output_type -> sglang.grpc.scheduler.GetServerInfoResponse + 36, // [36:42] is the sub-list for method output_type + 30, // [30:36] is the sub-list for method input_type + 30, // [30:30] is the sub-list for extension type_name + 30, // [30:30] is the sub-list for extension extendee + 0, // [0:30] is the sub-list for field type_name +} + +func init() { file_sglang_scheduler_proto_init() } +func file_sglang_scheduler_proto_init() { + if File_sglang_scheduler_proto != nil { + return + } + file_sglang_scheduler_proto_msgTypes[0].OneofWrappers = []any{ + (*SamplingParams_Regex)(nil), + (*SamplingParams_JsonSchema)(nil), + (*SamplingParams_EbnfGrammar)(nil), + (*SamplingParams_StructuralTag)(nil), + } + file_sglang_scheduler_proto_msgTypes[5].OneofWrappers = []any{ + (*GenerateResponse_Chunk)(nil), + (*GenerateResponse_Complete)(nil), + (*GenerateResponse_Error)(nil), + } + file_sglang_scheduler_proto_msgTypes[7].OneofWrappers = []any{ + (*GenerateComplete_MatchedTokenId)(nil), + (*GenerateComplete_MatchedStopStr)(nil), + } + file_sglang_scheduler_proto_msgTypes[11].OneofWrappers = []any{} + file_sglang_scheduler_proto_msgTypes[15].OneofWrappers = []any{ + (*EmbedResponse_Complete)(nil), + (*EmbedResponse_Error)(nil), + } + file_sglang_scheduler_proto_msgTypes[27].OneofWrappers = []any{ + (*UpdateWeightsRequest_DiskPath)(nil), + (*UpdateWeightsRequest_TensorData)(nil), + (*UpdateWeightsRequest_RemoteUrl)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_sglang_scheduler_proto_rawDesc), len(file_sglang_scheduler_proto_rawDesc)), + NumEnums: 0, + NumMessages: 38, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_sglang_scheduler_proto_goTypes, + DependencyIndexes: file_sglang_scheduler_proto_depIdxs, + MessageInfos: file_sglang_scheduler_proto_msgTypes, + }.Build() + File_sglang_scheduler_proto = out.File + file_sglang_scheduler_proto_goTypes = nil + file_sglang_scheduler_proto_depIdxs = nil +} diff --git a/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler_grpc.pb.go b/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler_grpc.pb.go new file mode 100644 index 000000000000..e8674cc25e26 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/internal/proto/sglang_scheduler_grpc.pb.go @@ -0,0 +1,333 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v3.21.12 +// source: sglang_scheduler.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + SglangScheduler_Generate_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Generate" + SglangScheduler_Embed_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Embed" + SglangScheduler_HealthCheck_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/HealthCheck" + SglangScheduler_Abort_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/Abort" + SglangScheduler_GetModelInfo_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/GetModelInfo" + SglangScheduler_GetServerInfo_FullMethodName = "/sglang.grpc.scheduler.SglangScheduler/GetServerInfo" +) + +// SglangSchedulerClient is the client API for SglangScheduler service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// Service definition for SGLang scheduler communication +// This protocol bridges the Rust router and Python scheduler +type SglangSchedulerClient interface { + // Submit a generation request (supports streaming) + Generate(ctx context.Context, in *GenerateRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GenerateResponse], error) + // Submit an embedding request + Embed(ctx context.Context, in *EmbedRequest, opts ...grpc.CallOption) (*EmbedResponse, error) + // Health check and metrics + HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) + // Abort a running request + Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*AbortResponse, error) + // Get model information + GetModelInfo(ctx context.Context, in *GetModelInfoRequest, opts ...grpc.CallOption) (*GetModelInfoResponse, error) + // Get server information + GetServerInfo(ctx context.Context, in *GetServerInfoRequest, opts ...grpc.CallOption) (*GetServerInfoResponse, error) +} + +type sglangSchedulerClient struct { + cc grpc.ClientConnInterface +} + +func NewSglangSchedulerClient(cc grpc.ClientConnInterface) SglangSchedulerClient { + return &sglangSchedulerClient{cc} +} + +func (c *sglangSchedulerClient) Generate(ctx context.Context, in *GenerateRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[GenerateResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &SglangScheduler_ServiceDesc.Streams[0], SglangScheduler_Generate_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[GenerateRequest, GenerateResponse]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type SglangScheduler_GenerateClient = grpc.ServerStreamingClient[GenerateResponse] + +func (c *sglangSchedulerClient) Embed(ctx context.Context, in *EmbedRequest, opts ...grpc.CallOption) (*EmbedResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(EmbedResponse) + err := c.cc.Invoke(ctx, SglangScheduler_Embed_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *sglangSchedulerClient) HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(HealthCheckResponse) + err := c.cc.Invoke(ctx, SglangScheduler_HealthCheck_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *sglangSchedulerClient) Abort(ctx context.Context, in *AbortRequest, opts ...grpc.CallOption) (*AbortResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AbortResponse) + err := c.cc.Invoke(ctx, SglangScheduler_Abort_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *sglangSchedulerClient) GetModelInfo(ctx context.Context, in *GetModelInfoRequest, opts ...grpc.CallOption) (*GetModelInfoResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetModelInfoResponse) + err := c.cc.Invoke(ctx, SglangScheduler_GetModelInfo_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *sglangSchedulerClient) GetServerInfo(ctx context.Context, in *GetServerInfoRequest, opts ...grpc.CallOption) (*GetServerInfoResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetServerInfoResponse) + err := c.cc.Invoke(ctx, SglangScheduler_GetServerInfo_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// SglangSchedulerServer is the server API for SglangScheduler service. +// All implementations must embed UnimplementedSglangSchedulerServer +// for forward compatibility. +// +// Service definition for SGLang scheduler communication +// This protocol bridges the Rust router and Python scheduler +type SglangSchedulerServer interface { + // Submit a generation request (supports streaming) + Generate(*GenerateRequest, grpc.ServerStreamingServer[GenerateResponse]) error + // Submit an embedding request + Embed(context.Context, *EmbedRequest) (*EmbedResponse, error) + // Health check and metrics + HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) + // Abort a running request + Abort(context.Context, *AbortRequest) (*AbortResponse, error) + // Get model information + GetModelInfo(context.Context, *GetModelInfoRequest) (*GetModelInfoResponse, error) + // Get server information + GetServerInfo(context.Context, *GetServerInfoRequest) (*GetServerInfoResponse, error) + mustEmbedUnimplementedSglangSchedulerServer() +} + +// UnimplementedSglangSchedulerServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedSglangSchedulerServer struct{} + +func (UnimplementedSglangSchedulerServer) Generate(*GenerateRequest, grpc.ServerStreamingServer[GenerateResponse]) error { + return status.Errorf(codes.Unimplemented, "method Generate not implemented") +} +func (UnimplementedSglangSchedulerServer) Embed(context.Context, *EmbedRequest) (*EmbedResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Embed not implemented") +} +func (UnimplementedSglangSchedulerServer) HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method HealthCheck not implemented") +} +func (UnimplementedSglangSchedulerServer) Abort(context.Context, *AbortRequest) (*AbortResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Abort not implemented") +} +func (UnimplementedSglangSchedulerServer) GetModelInfo(context.Context, *GetModelInfoRequest) (*GetModelInfoResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetModelInfo not implemented") +} +func (UnimplementedSglangSchedulerServer) GetServerInfo(context.Context, *GetServerInfoRequest) (*GetServerInfoResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetServerInfo not implemented") +} +func (UnimplementedSglangSchedulerServer) mustEmbedUnimplementedSglangSchedulerServer() {} +func (UnimplementedSglangSchedulerServer) testEmbeddedByValue() {} + +// UnsafeSglangSchedulerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to SglangSchedulerServer will +// result in compilation errors. +type UnsafeSglangSchedulerServer interface { + mustEmbedUnimplementedSglangSchedulerServer() +} + +func RegisterSglangSchedulerServer(s grpc.ServiceRegistrar, srv SglangSchedulerServer) { + // If the following call pancis, it indicates UnimplementedSglangSchedulerServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&SglangScheduler_ServiceDesc, srv) +} + +func _SglangScheduler_Generate_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(GenerateRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(SglangSchedulerServer).Generate(m, &grpc.GenericServerStream[GenerateRequest, GenerateResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type SglangScheduler_GenerateServer = grpc.ServerStreamingServer[GenerateResponse] + +func _SglangScheduler_Embed_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EmbedRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SglangSchedulerServer).Embed(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: SglangScheduler_Embed_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SglangSchedulerServer).Embed(ctx, req.(*EmbedRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SglangScheduler_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthCheckRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SglangSchedulerServer).HealthCheck(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: SglangScheduler_HealthCheck_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SglangSchedulerServer).HealthCheck(ctx, req.(*HealthCheckRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SglangScheduler_Abort_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AbortRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SglangSchedulerServer).Abort(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: SglangScheduler_Abort_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SglangSchedulerServer).Abort(ctx, req.(*AbortRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SglangScheduler_GetModelInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetModelInfoRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SglangSchedulerServer).GetModelInfo(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: SglangScheduler_GetModelInfo_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SglangSchedulerServer).GetModelInfo(ctx, req.(*GetModelInfoRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SglangScheduler_GetServerInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetServerInfoRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SglangSchedulerServer).GetServerInfo(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: SglangScheduler_GetServerInfo_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SglangSchedulerServer).GetServerInfo(ctx, req.(*GetServerInfoRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// SglangScheduler_ServiceDesc is the grpc.ServiceDesc for SglangScheduler service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var SglangScheduler_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "sglang.grpc.scheduler.SglangScheduler", + HandlerType: (*SglangSchedulerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Embed", + Handler: _SglangScheduler_Embed_Handler, + }, + { + MethodName: "HealthCheck", + Handler: _SglangScheduler_HealthCheck_Handler, + }, + { + MethodName: "Abort", + Handler: _SglangScheduler_Abort_Handler, + }, + { + MethodName: "GetModelInfo", + Handler: _SglangScheduler_GetModelInfo_Handler, + }, + { + MethodName: "GetServerInfo", + Handler: _SglangScheduler_GetServerInfo_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Generate", + Handler: _SglangScheduler_Generate_Handler, + ServerStreams: true, + }, + }, + Metadata: "sglang_scheduler.proto", +} diff --git a/sgl-model-gateway/bindings/golang/src/lib.rs b/sgl-model-gateway/bindings/golang/src/lib.rs index 82a37e6eb87d..c4671a5a2479 100644 --- a/sgl-model-gateway/bindings/golang/src/lib.rs +++ b/sgl-model-gateway/bindings/golang/src/lib.rs @@ -63,6 +63,19 @@ pub use stream::{ // Re-export client stream function (defined in client.rs but used by stream) pub use client::sgl_client_chat_completion_stream; +// Re-export preprocessor functions +pub use preprocessor::{ + sgl_preprocess_chat_request, + sgl_preprocess_chat_request_with_tokenizer, + sgl_preprocessed_request_free, +}; + +// Re-export postprocessor functions +pub use postprocessor::{ + sgl_postprocess_stream_chunk, + sgl_postprocess_stream_chunks_batch, +}; + // Re-export utility functions pub use utils::sgl_generate_tool_constraints; @@ -75,6 +88,8 @@ mod grpc_converter; mod client; mod stream; mod utils; +mod preprocessor; +mod postprocessor; #[cfg(test)] mod tests { diff --git a/sgl-model-gateway/bindings/golang/src/postprocessor.rs b/sgl-model-gateway/bindings/golang/src/postprocessor.rs new file mode 100644 index 000000000000..7cebe0f53477 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/src/postprocessor.rs @@ -0,0 +1,465 @@ +//! Postprocessing FFI functions for gRPC stream chunks +//! +//! This module provides C-compatible functions for postprocessing gRPC stream chunks: +//! - Parse tool calls from model output +//! - Convert proto format to OpenAI format +//! - Handle reasoning content parsing +//! +//! These functions are designed to be called for each stream chunk, but can be optimized +//! with batching in the future. + +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_int}; +use std::ptr; +use std::sync::Arc; +use serde_json::Value; + +use sgl_model_gateway::grpc_client::sglang_proto as proto; + +use super::error::{SglErrorCode, set_error_message}; +use super::grpc_converter::GrpcResponseConverterHandle; + +use tokio::runtime::Runtime; +use once_cell::sync::Lazy; + +/// Global tokio runtime for async operations +static RUNTIME: Lazy = Lazy::new(|| { + Runtime::new().expect("Failed to create tokio runtime for postprocessor FFI") +}); + +/// Postprocess a gRPC stream chunk to OpenAI format +/// +/// This function: +/// 1. Parses the proto chunk from JSON +/// 2. Converts it to OpenAI format using the converter handle +/// 3. Returns the OpenAI format JSON +/// +/// # Arguments +/// * `converter_handle` - Converter handle (created with sgl_grpc_response_converter_create) +/// * `proto_chunk_json` - JSON string of proto.GenerateResponse +/// * `openai_json_out` - Pointer to receive OpenAI format JSON (must be freed with sgl_free_string) +/// * `is_done_out` - Pointer to receive is_done flag (1 if stream is complete, 0 otherwise) +/// * `error_out` - Optional pointer to receive error message +/// +/// # Returns +/// * SglErrorCode::Success on success, error code on failure +#[no_mangle] +pub unsafe extern "C" fn sgl_postprocess_stream_chunk( + converter_handle: *mut GrpcResponseConverterHandle, + proto_chunk_json: *const c_char, + openai_json_out: *mut *mut c_char, + is_done_out: *mut c_int, + error_out: *mut *mut c_char, +) -> SglErrorCode { + if converter_handle.is_null() + || proto_chunk_json.is_null() + || openai_json_out.is_null() + || is_done_out.is_null() + { + set_error_message(error_out, "Invalid arguments: null pointer"); + return SglErrorCode::InvalidArgument; + } + + let proto_chunk_str = match CStr::from_ptr(proto_chunk_json).to_str() { + Ok(s) => s, + Err(_) => { + set_error_message(error_out, "Invalid UTF-8 in proto_chunk_json"); + return SglErrorCode::InvalidArgument; + } + }; + + // Parse proto.GenerateResponse from JSON + let json_value: Value = match serde_json::from_str(proto_chunk_str) { + Ok(v) => v, + Err(e) => { + set_error_message(error_out, &format!("Failed to parse proto chunk JSON: {}", e)); + return SglErrorCode::ParsingError; + } + }; + + // Build proto::GenerateResponse from JSON value + let mut proto_response = proto::GenerateResponse { + request_id: json_value + .get("request_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + response: None, + }; + + // Parse the response oneof field + let is_done = if let Some(chunk_json) = json_value.get("chunk") { + let chunk = proto::GenerateStreamChunk { + token_ids: chunk_json + .get("token_ids") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as u32)) + .collect() + }) + .unwrap_or_default(), + prompt_tokens: chunk_json + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + completion_tokens: chunk_json + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + cached_tokens: chunk_json + .get("cached_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + output_logprobs: None, + hidden_states: vec![], + input_logprobs: None, + index: 0, + }; + proto_response.response = Some(proto::generate_response::Response::Chunk(chunk)); + false + } else if let Some(complete_json) = json_value.get("complete") { + let complete = proto::GenerateComplete { + output_ids: complete_json + .get("output_ids") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as u32)) + .collect() + }) + .unwrap_or_default(), + finish_reason: complete_json + .get("finish_reason") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + prompt_tokens: complete_json + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + completion_tokens: complete_json + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + cached_tokens: complete_json + .get("cached_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + output_logprobs: None, + all_hidden_states: vec![], + input_logprobs: None, + matched_stop: None, + index: 0, + }; + proto_response.response = Some(proto::generate_response::Response::Complete(complete)); + true + } else if let Some(error_json) = json_value.get("error") { + let error = proto::GenerateError { + message: error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + http_status_code: error_json + .get("http_status_code") + .and_then(|v| v.as_str()) + .unwrap_or("500") + .to_string(), + details: error_json + .get("details") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }; + proto_response.response = Some(proto::generate_response::Response::Error(error)); + true + } else { + set_error_message( + error_out, + "Proto chunk JSON must contain 'chunk', 'complete', or 'error' field", + ); + return SglErrorCode::ParsingError; + }; + + // Convert proto chunk to OpenAI format using the converter's convert_chunk function + // We'll use the existing converter API instead of calling the internal function directly + let proto_chunk_json_cstr = match CString::new(proto_chunk_str) { + Ok(s) => s, + Err(e) => { + set_error_message(error_out, &format!("Failed to create C string: {}", e)); + return SglErrorCode::MemoryError; + } + }; + + // Use the existing converter API + let mut openai_json_ptr: *mut c_char = ptr::null_mut(); + let result = super::grpc_converter::sgl_grpc_response_converter_convert_chunk( + converter_handle, + proto_chunk_json_cstr.as_ptr(), + &mut openai_json_ptr, + error_out, + ); + + if result == SglErrorCode::Success { + *openai_json_out = openai_json_ptr; + *is_done_out = if is_done { 1 } else { 0 }; + SglErrorCode::Success + } else { + *openai_json_out = ptr::null_mut(); + *is_done_out = if is_done { 1 } else { 0 }; + result + } +} + +/// Postprocess multiple gRPC stream chunks in batch (reduces FFI overhead) +/// +/// This function processes multiple chunks in a single FFI call, significantly reducing +/// FFI overhead in streaming scenarios. +/// +/// # Arguments +/// * `converter_handle` - Converter handle (created with sgl_grpc_response_converter_create) +/// * `proto_chunks_json_array` - JSON array string of proto.GenerateResponse chunks +/// * `max_chunks` - Maximum number of chunks to process (for safety) +/// * `openai_chunks_json_array_out` - Pointer to receive JSON array of OpenAI format chunks (must be freed with sgl_free_string) +/// * `chunks_count_out` - Pointer to receive number of processed chunks +/// * `error_out` - Optional pointer to receive error message +/// +/// # Returns +/// * SglErrorCode::Success on success, error code on failure +#[no_mangle] +pub unsafe extern "C" fn sgl_postprocess_stream_chunks_batch( + converter_handle: *mut GrpcResponseConverterHandle, + proto_chunks_json_array: *const c_char, + max_chunks: c_int, + openai_chunks_json_array_out: *mut *mut c_char, + chunks_count_out: *mut c_int, + error_out: *mut *mut c_char, +) -> SglErrorCode { + if converter_handle.is_null() + || proto_chunks_json_array.is_null() + || openai_chunks_json_array_out.is_null() + || chunks_count_out.is_null() + { + set_error_message(error_out, "Invalid arguments: null pointer"); + return SglErrorCode::InvalidArgument; + } + + let chunks_array_str = match CStr::from_ptr(proto_chunks_json_array).to_str() { + Ok(s) => s, + Err(_) => { + set_error_message(error_out, "Invalid UTF-8 in proto_chunks_json_array"); + return SglErrorCode::InvalidArgument; + } + }; + + // Parse JSON array of chunks + let chunks_array: Vec = match serde_json::from_str(chunks_array_str) { + Ok(arr) => arr, + Err(e) => { + set_error_message( + error_out, + &format!("Failed to parse chunks JSON array: {}", e), + ); + return SglErrorCode::ParsingError; + } + }; + + // Limit batch size for safety + let max_chunks_usize = max_chunks as usize; + let chunks_to_process = if chunks_array.len() > max_chunks_usize { + &chunks_array[..max_chunks_usize] + } else { + &chunks_array + }; + + let handle_ref = &mut *converter_handle; + let tokenizer = Arc::clone(&handle_ref.tokenizer); + let model = handle_ref.model.clone(); + let request_id = handle_ref.request_id.clone(); + let created = handle_ref.created; + let system_fingerprint = handle_ref.system_fingerprint.clone(); + + // Process chunks in batch + let mut results = Vec::new(); + let mut has_error = false; + let mut error_msg = String::new(); + + for chunk_json in chunks_to_process { + // Parse proto.GenerateResponse from JSON + let mut proto_response = proto::GenerateResponse { + request_id: chunk_json + .get("request_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + response: None, + }; + + // Parse the response oneof field (same logic as single chunk processing) + let _is_done = if let Some(chunk_json) = chunk_json.get("chunk") { + let chunk = proto::GenerateStreamChunk { + token_ids: chunk_json + .get("token_ids") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as u32)) + .collect() + }) + .unwrap_or_default(), + prompt_tokens: chunk_json + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + completion_tokens: chunk_json + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + cached_tokens: chunk_json + .get("cached_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + output_logprobs: None, + hidden_states: vec![], + input_logprobs: None, + index: 0, + }; + proto_response.response = Some(proto::generate_response::Response::Chunk(chunk)); + false + } else if let Some(complete_json) = chunk_json.get("complete") { + let complete = proto::GenerateComplete { + output_ids: complete_json + .get("output_ids") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_u64().map(|n| n as u32)) + .collect() + }) + .unwrap_or_default(), + finish_reason: complete_json + .get("finish_reason") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + prompt_tokens: complete_json + .get("prompt_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + completion_tokens: complete_json + .get("completion_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + cached_tokens: complete_json + .get("cached_tokens") + .and_then(|v| v.as_i64()) + .map(|n| n as i32) + .unwrap_or(0), + output_logprobs: None, + all_hidden_states: vec![], + input_logprobs: None, + matched_stop: None, + index: 0, + }; + proto_response.response = Some(proto::generate_response::Response::Complete(complete)); + true + } else if let Some(error_json) = chunk_json.get("error") { + let error = proto::GenerateError { + message: error_json + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + http_status_code: error_json + .get("http_status_code") + .and_then(|v| v.as_str()) + .unwrap_or("500") + .to_string(), + details: error_json + .get("details") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }; + proto_response.response = Some(proto::generate_response::Response::Error(error)); + true + } else { + error_msg = format!( + "Chunk JSON must contain 'chunk', 'complete', or 'error' field: {}", + chunk_json + ); + has_error = true; + break; + }; + + // Convert proto chunk to OpenAI format + let result = RUNTIME.block_on(async { + super::grpc_converter::convert_proto_chunk_to_openai( + proto_response, + handle_ref, + &tokenizer, + &model, + &request_id, + created, + system_fingerprint.as_deref(), + ) + .await + }); + + match result { + Ok(Some(openai_response)) => { + results.push(openai_response); + } + Ok(None) => { + // Empty response, skip + } + Err(e) => { + error_msg = format!("Postprocessing failed for chunk: {}", e); + has_error = true; + break; + } + } + } + + if has_error { + set_error_message(error_out, &error_msg); + return SglErrorCode::ParsingError; + } + + // Serialize results to JSON array + let results_json = match serde_json::to_string(&results) { + Ok(s) => s, + Err(e) => { + set_error_message( + error_out, + &format!("Failed to serialize results JSON array: {}", e), + ); + return SglErrorCode::ParsingError; + } + }; + + let results_cstr = match CString::new(results_json) { + Ok(s) => s, + Err(e) => { + set_error_message(error_out, &format!("Failed to create C string: {}", e)); + return SglErrorCode::MemoryError; + } + }; + + *openai_chunks_json_array_out = results_cstr.into_raw(); + *chunks_count_out = results.len() as c_int; + + SglErrorCode::Success +} diff --git a/sgl-model-gateway/bindings/golang/src/preprocessor.rs b/sgl-model-gateway/bindings/golang/src/preprocessor.rs new file mode 100644 index 000000000000..1ec13de46af1 --- /dev/null +++ b/sgl-model-gateway/bindings/golang/src/preprocessor.rs @@ -0,0 +1,372 @@ +//! Preprocessing FFI functions for chat requests +//! +//! This module provides C-compatible functions for preprocessing chat completion requests: +//! - Apply chat_template to messages +//! - Tokenize the processed text +//! - Generate tool constraints +//! +//! These functions are designed to be called once per request, reducing FFI overhead. + +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_int}; +use std::ptr; +use std::os::raw::c_uint; + +use sgl_model_gateway::tokenizer::create_tokenizer_from_file; +use sgl_model_gateway::protocols::chat::ChatCompletionRequest; +use sgl_model_gateway::routers::grpc::utils::{process_chat_messages, generate_tool_constraints}; + +use super::error::{SglErrorCode, set_error_message}; +use super::memory::{sgl_free_string, sgl_free_token_ids}; +use super::tokenizer::TokenizerHandle; + +/// Handle for preprocessed request +#[repr(C)] +pub struct PreprocessedRequestHandle { + pub(crate) prompt_text: CString, + pub(crate) token_ids: Vec, + pub(crate) tool_constraints_json: Option, + pub(crate) prompt_tokens: i32, +} + +/// Preprocess a chat completion request +/// +/// This function: +/// 1. Applies chat_template to messages +/// 2. Tokenizes the processed text +/// 3. Generates tool constraints (if tools are present) +/// +/// # Arguments +/// * `request_json` - OpenAI ChatCompletionRequest as JSON string +/// * `tokenizer_path` - Path to tokenizer directory +/// * `prompt_text_out` - Pointer to receive prompt text (C string, must be freed with sgl_free_string) +/// * `token_ids_out` - Pointer to receive token IDs array (must be freed with sgl_free_token_ids) +/// * `token_ids_len_out` - Pointer to receive token IDs array length +/// * `tool_constraints_json_out` - Optional pointer to receive tool constraints JSON (must be freed with sgl_free_string) +/// * `prompt_tokens_out` - Pointer to receive prompt token count +/// * `error_out` - Optional pointer to receive error message +/// +/// # Returns +/// * SglErrorCode::Success on success, error code on failure +#[no_mangle] +pub unsafe extern "C" fn sgl_preprocess_chat_request( + request_json: *const c_char, + tokenizer_path: *const c_char, + prompt_text_out: *mut *mut c_char, + token_ids_out: *mut *mut c_uint, + token_ids_len_out: *mut usize, + tool_constraints_json_out: *mut *mut c_char, + prompt_tokens_out: *mut c_int, + error_out: *mut *mut c_char, +) -> SglErrorCode { + if request_json.is_null() + || tokenizer_path.is_null() + || prompt_text_out.is_null() + || token_ids_out.is_null() + || token_ids_len_out.is_null() + || prompt_tokens_out.is_null() + { + set_error_message(error_out, "Invalid arguments: null pointer"); + return SglErrorCode::InvalidArgument; + } + + // Parse input strings + let request_str = match CStr::from_ptr(request_json).to_str() { + Ok(s) => s, + Err(_) => { + set_error_message(error_out, "Invalid UTF-8 in request_json"); + return SglErrorCode::InvalidArgument; + } + }; + + let tokenizer_path_str = match CStr::from_ptr(tokenizer_path).to_str() { + Ok(s) => s, + Err(_) => { + set_error_message(error_out, "Invalid UTF-8 in tokenizer_path"); + return SglErrorCode::InvalidArgument; + } + }; + + // Parse ChatCompletionRequest + let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) { + Ok(req) => req, + Err(e) => { + set_error_message(error_out, &format!("Failed to parse request JSON: {}", e)); + return SglErrorCode::ParsingError; + } + }; + + // Create tokenizer + let tokenizer = match create_tokenizer_from_file(tokenizer_path_str) { + Ok(t) => t, + Err(e) => { + set_error_message(error_out, &format!("Failed to create tokenizer: {}", e)); + return SglErrorCode::TokenizationError; + } + }; + + // Process chat messages (apply chat_template) + let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) { + Ok(msgs) => msgs, + Err(e) => { + set_error_message(error_out, &format!("Failed to process chat messages: {}", e)); + return SglErrorCode::ParsingError; + } + }; + + // Tokenize the processed text + let encoding = match tokenizer.encode(&processed_messages.text) { + Ok(enc) => enc, + Err(e) => { + set_error_message(error_out, &format!("Tokenization failed: {}", e)); + return SglErrorCode::TokenizationError; + } + }; + + let token_ids_vec: Vec = encoding + .token_ids() + .iter() + .map(|&id| id as i32) + .collect(); + + let prompt_tokens = token_ids_vec.len() as i32; + + // Generate tool constraints if tools are present + let tool_constraints_json = if let Some(tools) = chat_request.tools.as_ref() { + match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) { + Ok(Some(constraints)) => { + match serde_json::to_string(&constraints) { + Ok(json_str) => Some(CString::new(json_str).unwrap()), + Err(e) => { + set_error_message( + error_out, + &format!("Failed to serialize tool constraints: {}", e), + ); + return SglErrorCode::ParsingError; + } + } + } + Ok(None) => None, + Err(e) => { + set_error_message( + error_out, + &format!("Failed to generate tool constraints: {}", e), + ); + return SglErrorCode::ParsingError; + } + } + } else { + None + }; + + // Allocate memory for outputs + let prompt_text_cstr = match CString::new(processed_messages.text) { + Ok(s) => s, + Err(e) => { + set_error_message(error_out, &format!("Failed to create C string: {}", e)); + return SglErrorCode::MemoryError; + } + }; + + let token_ids_len = token_ids_vec.len(); + // Convert i32 to u32 for token IDs (as expected by the memory management functions) + let token_ids_u32: Vec = token_ids_vec.iter().map(|&id| id as u32).collect(); + let token_ids_ptr = if token_ids_u32.is_empty() { + ptr::null_mut() + } else { + let boxed = token_ids_u32.into_boxed_slice(); + Box::into_raw(boxed) as *mut c_uint + }; + + // Set output values + *prompt_text_out = prompt_text_cstr.into_raw(); + *token_ids_out = token_ids_ptr; + *token_ids_len_out = token_ids_len; + *prompt_tokens_out = prompt_tokens; + + if !tool_constraints_json_out.is_null() { + if let Some(constraints) = tool_constraints_json { + *tool_constraints_json_out = constraints.into_raw(); + } else { + *tool_constraints_json_out = ptr::null_mut(); + } + } + + SglErrorCode::Success +} + +/// Preprocess a chat completion request using an existing tokenizer handle +/// +/// This function is similar to sgl_preprocess_chat_request, but accepts a TokenizerHandle +/// instead of creating a new tokenizer. This allows reusing a cached tokenizer instance, +/// significantly reducing initialization overhead in concurrent scenarios. +/// +/// # Arguments +/// * `request_json` - OpenAI ChatCompletionRequest as JSON string +/// * `tokenizer_handle` - Existing tokenizer handle (must be valid) +/// * `prompt_text_out` - Pointer to receive prompt text (C string, must be freed with sgl_free_string) +/// * `token_ids_out` - Pointer to receive token IDs array (must be freed with sgl_free_token_ids) +/// * `token_ids_len_out` - Pointer to receive token IDs array length +/// * `tool_constraints_json_out` - Optional pointer to receive tool constraints JSON (must be freed with sgl_free_string) +/// * `prompt_tokens_out` - Pointer to receive prompt token count +/// * `error_out` - Optional pointer to receive error message +/// +/// # Returns +/// * SglErrorCode::Success on success, error code on failure +#[no_mangle] +pub unsafe extern "C" fn sgl_preprocess_chat_request_with_tokenizer( + request_json: *const c_char, + tokenizer_handle: *mut TokenizerHandle, + prompt_text_out: *mut *mut c_char, + token_ids_out: *mut *mut c_uint, + token_ids_len_out: *mut usize, + tool_constraints_json_out: *mut *mut c_char, + prompt_tokens_out: *mut c_int, + error_out: *mut *mut c_char, +) -> SglErrorCode { + if request_json.is_null() + || tokenizer_handle.is_null() + || prompt_text_out.is_null() + || token_ids_out.is_null() + || token_ids_len_out.is_null() + || prompt_tokens_out.is_null() + { + set_error_message(error_out, "Invalid arguments: null pointer"); + return SglErrorCode::InvalidArgument; + } + + // Parse input string + let request_str = match CStr::from_ptr(request_json).to_str() { + Ok(s) => s, + Err(_) => { + set_error_message(error_out, "Invalid UTF-8 in request_json"); + return SglErrorCode::InvalidArgument; + } + }; + + // Parse ChatCompletionRequest + let chat_request: ChatCompletionRequest = match serde_json::from_str(request_str) { + Ok(req) => req, + Err(e) => { + set_error_message(error_out, &format!("Failed to parse request JSON: {}", e)); + return SglErrorCode::ParsingError; + } + }; + + // Use existing tokenizer from handle (no need to create new one!) + let handle_ref = &*tokenizer_handle; + let tokenizer = &handle_ref.tokenizer; + + // Process chat messages (apply chat_template) + let processed_messages = match process_chat_messages(&chat_request, tokenizer.as_ref()) { + Ok(msgs) => msgs, + Err(e) => { + set_error_message(error_out, &format!("Failed to process chat messages: {}", e)); + return SglErrorCode::ParsingError; + } + }; + + // Tokenize the processed text + let encoding = match tokenizer.encode(&processed_messages.text) { + Ok(enc) => enc, + Err(e) => { + set_error_message(error_out, &format!("Tokenization failed: {}", e)); + return SglErrorCode::TokenizationError; + } + }; + + let token_ids_vec: Vec = encoding + .token_ids() + .iter() + .map(|&id| id as i32) + .collect(); + + let prompt_tokens = token_ids_vec.len() as i32; + + // Generate tool constraints if tools are present + let tool_constraints_json = if let Some(tools) = chat_request.tools.as_ref() { + match generate_tool_constraints(tools, &chat_request.tool_choice, &chat_request.model) { + Ok(Some(constraints)) => { + match serde_json::to_string(&constraints) { + Ok(json_str) => Some(CString::new(json_str).unwrap()), + Err(e) => { + set_error_message( + error_out, + &format!("Failed to serialize tool constraints: {}", e), + ); + return SglErrorCode::ParsingError; + } + } + } + Ok(None) => None, + Err(e) => { + set_error_message( + error_out, + &format!("Failed to generate tool constraints: {}", e), + ); + return SglErrorCode::ParsingError; + } + } + } else { + None + }; + + // Allocate memory for outputs + let prompt_text_cstr = match CString::new(processed_messages.text) { + Ok(s) => s, + Err(e) => { + set_error_message(error_out, &format!("Failed to create C string: {}", e)); + return SglErrorCode::MemoryError; + } + }; + + let token_ids_len = token_ids_vec.len(); + // Convert i32 to u32 for token IDs (as expected by the memory management functions) + let token_ids_u32: Vec = token_ids_vec.iter().map(|&id| id as u32).collect(); + let token_ids_ptr = if token_ids_u32.is_empty() { + ptr::null_mut() + } else { + let boxed = token_ids_u32.into_boxed_slice(); + Box::into_raw(boxed) as *mut c_uint + }; + + // Set output values + *prompt_text_out = prompt_text_cstr.into_raw(); + *token_ids_out = token_ids_ptr; + *token_ids_len_out = token_ids_len; + *prompt_tokens_out = prompt_tokens; + + if !tool_constraints_json_out.is_null() { + if let Some(constraints) = tool_constraints_json { + *tool_constraints_json_out = constraints.into_raw(); + } else { + *tool_constraints_json_out = ptr::null_mut(); + } + } + + SglErrorCode::Success +} + +/// Free a preprocessed request handle (cleanup function) +/// +/// This function frees the memory allocated by sgl_preprocess_chat_request. +/// It should be called after the preprocessed data is no longer needed. +#[no_mangle] +pub unsafe extern "C" fn sgl_preprocessed_request_free( + prompt_text: *mut c_char, + token_ids: *mut c_uint, + token_ids_len: usize, + tool_constraints_json: *mut c_char, +) { + if !prompt_text.is_null() { + sgl_free_string(prompt_text); + } + + if !token_ids.is_null() && token_ids_len > 0 { + sgl_free_token_ids(token_ids, token_ids_len); + } + + if !tool_constraints_json.is_null() { + sgl_free_string(tool_constraints_json); + } +}