diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index ac0c36223..ecef78389 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -37,6 +37,8 @@ import ( "github.com/vllm-project/aibrix/pkg/metrics" routing "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms" "github.com/vllm-project/aibrix/pkg/plugins/gateway/ratelimiter" + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler" + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler/sessioninfo" "github.com/vllm-project/aibrix/pkg/types" "github.com/vllm-project/aibrix/pkg/utils" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -44,6 +46,39 @@ import ( gatewayapi "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned" ) +// requestState represents the state of a single request processing flow +type requestState int + +const ( + stateAwaitingHeaders requestState = iota + stateAwaitingBody + stateAwaitingDecision + stateForwarding + stateDone +) + +// perRequestState holds all the state for a single Process() invocation +type perRequestState struct { + currentState requestState + sessionID string + requestID string + user utils.User + rpm int64 + model string + routerCtx *types.RoutingContext + stream bool + traceTerm int64 + completed bool + isRespError bool + respErrorCode int + + // For timing and scheduling + requestStartTime time.Time + submissionTime time.Time + dispatchTime time.Time // When scheduler granted permission + schedulingDecision *scheduler.Decision +} + const ( defaultAIBrixNamespace = "aibrix-system" ) @@ -56,6 +91,13 @@ type Server struct { requestCountTracker map[string]int cache cache.Cache metricsServer *metrics.Server + + // Scheduler and session management + scheduler scheduler.Scheduler + sessionCache *sessioninfo.MutexSessionCache + + // Cleanup function for session cache + sessionCleanupStop func() } func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayClient gatewayapi.Interface) *Server { @@ -68,6 +110,13 @@ func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayCl // Initialize the routers routing.Init() + // Initialize session cache and scheduler + sessionCache := sessioninfo.NewMutexSessionCache() + sched := scheduler.NewScheduler(client, sessionCache, c) + + // Start session cleanup routine (cleanup every 5 minutes, timeout after 30 minutes) + sessionCleanupStop := sessionCache.StartCleanupRoutine(5*time.Minute, 30*time.Minute) + return &Server{ redisClient: redisClient, ratelimiter: r, @@ -76,10 +125,19 @@ func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayCl requestCountTracker: map[string]int{}, cache: c, metricsServer: nil, + scheduler: sched, + sessionCache: sessionCache, + sessionCleanupStop: sessionCleanupStop, } } +// Process delegates to the state machine implementation func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { + return s.ProcessStateMachine(srv) +} + +// ProcessLegacy is the original implementation kept for reference +func (s *Server) ProcessLegacy(srv extProcPb.ExternalProcessor_ProcessServer) error { var user utils.User var rpm, traceTerm int64 var respErrorCode int @@ -227,11 +285,29 @@ func (s *Server) StartMetricsServer(addr string) error { } func (s *Server) Shutdown() { + klog.InfoS("Starting graceful shutdown of Gateway Server") + + // Stop scheduler first to prevent new jobs + if s.scheduler != nil { + klog.InfoS("Stopping scheduler") + s.scheduler.Stop() + } + + // Stop session cache cleanup routine + if s.sessionCleanupStop != nil { + klog.InfoS("Stopping session cache cleanup routine") + s.sessionCleanupStop() + } + + // Stop metrics server if s.metricsServer != nil { + klog.InfoS("Stopping metrics server") if err := s.metricsServer.Stop(); err != nil { klog.ErrorS(err, "Error stopping metrics server") } } + + klog.InfoS("Gateway Server shutdown complete") } func (s *Server) responseErrorProcessing(ctx context.Context, resp *extProcPb.ProcessingResponse, respErrorCode int, diff --git a/pkg/plugins/gateway/gateway_req_body.go b/pkg/plugins/gateway/gateway_req_body.go index 413c1d0fc..8b9b47d04 100644 --- a/pkg/plugins/gateway/gateway_req_body.go +++ b/pkg/plugins/gateway/gateway_req_body.go @@ -34,7 +34,14 @@ import ( func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64) { var term int64 // Identify the trace window - routingCtx, _ := ctx.(*types.RoutingContext) + routingCtx, ok := ctx.(*types.RoutingContext) + if !ok || routingCtx == nil { + klog.ErrorS(nil, "CRITICAL: context is not RoutingContext or is nil", "requestID", requestID, "contextType", fmt.Sprintf("%T", ctx)) + return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorRouting, RawValue: []byte("true")}}}, + "internal routing context error"), "", nil, false, term + } requestPath := routingCtx.ReqPath routingAlgorithm := routingCtx.Algorithm @@ -66,6 +73,15 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, term } + // Check if scheduler is enabled - if so, defer routing to the scheduler + if s.scheduler != nil { + // With scheduler enabled, we don't perform routing here + // Just validate the model exists and return nil to let Process handle scheduling + klog.InfoS("request body processed, deferring to scheduler", "requestID", requestID, "requestPath", requestPath, "model", model, "stream", stream) + return nil, model, routingCtx, stream, term + } + + // Legacy routing logic (when scheduler is not enabled) headers := []*configPb.HeaderValueOption{} if routingAlgorithm == routing.RouterNotSet { if err := s.validateHTTPRouteStatus(ctx, model); err != nil { diff --git a/pkg/plugins/gateway/gateway_req_headers.go b/pkg/plugins/gateway/gateway_req_headers.go index 026c59f8b..38a07ba58 100644 --- a/pkg/plugins/gateway/gateway_req_headers.go +++ b/pkg/plugins/gateway/gateway_req_headers.go @@ -23,6 +23,7 @@ import ( "k8s.io/klog/v2" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extProcFilterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" routing "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms" @@ -86,23 +87,75 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req routingCtx.ReqPath = requestPath routingCtx.ReqHeaders = reqHeaders - return &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: HeaderWentIntoReqHeaders, - RawValue: []byte("true"), - }, - }, + // For legacy Process function (non-state machine), we need to handle routing here + // For state machine version, this function is only used to extract headers info + if s.scheduler == nil { + // Legacy mode: complete routing in RequestHeaders phase + klog.InfoS("legacy mode: completing processing in RequestHeaders phase", "requestID", requestID) + + // early reject the request if model doesn't exist. + model := "default" // TODO: Extract model from headers if available + if !s.cache.HasModel(model) { + klog.ErrorS(nil, "model doesn't exist in cache", "requestID", requestID, "model", model) + return generateErrorResponse(envoyTypePb.StatusCode_BadRequest, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorNoModelBackends, RawValue: []byte(model)}}}, + "model does not exist"), user, rpm, routingCtx + } + + // early reject if no pods are ready to accept request for a model + podsArr, err := s.cache.ListPodsByModel(model) + if err != nil || podsArr == nil || utils.CountRoutablePods(podsArr.All()) == 0 { + klog.ErrorS(err, "no ready pod available", "requestID", requestID, "model", model) + return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorNoModelBackends, RawValue: []byte("true")}}}, + "no ready pods available"), user, rpm, routingCtx + } + + headers := []*configPb.HeaderValueOption{} + if routingAlgorithm == routing.RouterNotSet { + if err := s.validateHTTPRouteStatus(ctx, model); err != nil { + return buildErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, err.Error(), HeaderErrorRouting, "true"), user, rpm, routingCtx + } + headers = buildEnvoyProxyHeaders(headers, HeaderModel, model) + klog.InfoS("request start", "requestID", requestID, "requestPath", routingCtx.ReqPath, "model", model) + } else { + targetPodIP, err := s.selectTargetPod(routingCtx, podsArr) + if targetPodIP == "" || err != nil { + klog.ErrorS(err, "failed to select target pod", "requestID", requestID, "routingStrategy", routingAlgorithm, "model", model, "routingDuration", routingCtx.GetRoutingDelay()) + return generateErrorResponse( + envoyTypePb.StatusCode_ServiceUnavailable, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorRouting, RawValue: []byte("true")}}}, + "error on selecting target pod"), user, rpm, routingCtx + } + headers = buildEnvoyProxyHeaders(headers, + HeaderRoutingStrategy, string(routingAlgorithm), + HeaderTargetPod, targetPodIP, + "X-Request-Id", routingCtx.RequestID) + klog.InfoS("request start", "requestID", requestID, "requestPath", routingCtx.ReqPath, "routingAlgorithm", routingAlgorithm, "targetPodIP", targetPodIP, "routingDuration", routingCtx.GetRoutingDelay()) + } + + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, }, + ClearRouteCache: true, }, - ClearRouteCache: true, }, }, - }, - }, user, rpm, routingCtx + // Don't request RequestBody in legacy mode + ModeOverride: &extProcFilterPb.ProcessingMode{ + RequestBodyMode: extProcFilterPb.ProcessingMode_NONE, + }, + }, user, rpm, routingCtx + } + + // State machine mode: just return basic info, don't do routing here + klog.InfoS("state machine mode: headers processed, waiting for body", "requestID", requestID) + return nil, user, rpm, routingCtx } diff --git a/pkg/plugins/gateway/gateway_scheduler_integration_test.go b/pkg/plugins/gateway/gateway_scheduler_integration_test.go new file mode 100644 index 000000000..0f06f4ae1 --- /dev/null +++ b/pkg/plugins/gateway/gateway_scheduler_integration_test.go @@ -0,0 +1,417 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gateway + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/vllm-project/aibrix/pkg/cache" + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler" + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler/sessioninfo" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestSchedulerComponents_Initialization(t *testing.T) { + // Test session cache initialization + sessionCache := sessioninfo.NewMutexSessionCache() + assert.NotNil(t, sessionCache, "Session cache should be initialized") + + // Test that we can create and use session cache + cst, waitTime := sessionCache.GetOrCreateForScheduler("test-session") + assert.GreaterOrEqual(t, cst.Nanoseconds(), int64(0), "CST should be non-negative") + assert.GreaterOrEqual(t, waitTime.Nanoseconds(), int64(0), "Wait time should be non-negative") +} + +func TestExtractSessionID(t *testing.T) { + tests := []struct { + name string + requestID string + requestPath string + requestBody []byte + headers map[string]string + expected string + }{ + { + name: "session ID from header (lowercase)", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{"x-session-id": "session-456"}, + expected: "session-456", + }, + { + name: "session ID from header (uppercase)", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{"X-Session-ID": "session-789"}, + expected: "session-789", + }, + // COMMENTED OUT: Body-based session ID extraction is disabled + // { + // name: "session ID from request body", + // requestID: "req-123", + // requestPath: "/v1/chat/completions", + // requestBody: []byte(`{"model":"test-model","session_id":"session-body-123","messages":[{"role":"user","content":"hello"}]}`), + // headers: map[string]string{}, + // expected: "session-body-123", + // }, + { + name: "session ID from request body - now falls back to request ID", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","session_id":"session-body-123","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{}, + expected: "req-123", // MODIFIED: Now falls back to requestID since body parsing is disabled + }, + { + name: "fallback to request ID", + requestID: "req-fallback", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{}, + expected: "req-fallback", + }, + { + name: "header takes precedence over body", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","session_id":"session-body-123","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{"x-session-id": "session-header-456"}, + expected: "session-header-456", + }, + { + name: "invalid JSON body falls back to request ID", + requestID: "req-invalid", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{invalid json`), + headers: map[string]string{}, + expected: "req-invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractSessionID(tt.requestID, tt.requestPath, tt.requestBody, tt.headers) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractSessionIDFromHeaders(t *testing.T) { + tests := []struct { + name string + requestID string + headers map[string]string + expected string + }{ + { + name: "session ID from header (lowercase)", + requestID: "req-123", + headers: map[string]string{"x-session-id": "session-456"}, + expected: "session-456", + }, + { + name: "session ID from header (uppercase)", + requestID: "req-123", + headers: map[string]string{"X-Session-ID": "session-789"}, + expected: "session-789", + }, + { + name: "no session ID in headers - fallback to requestID", + requestID: "req-fallback", + headers: map[string]string{}, + expected: "req-fallback", + }, + { + name: "empty session ID in headers - fallback to requestID", + requestID: "req-empty", + headers: map[string]string{"x-session-id": ""}, + expected: "req-empty", + }, + { + name: "lowercase takes precedence when both present", + requestID: "req-123", + headers: map[string]string{"x-session-id": "session-lower", "X-Session-ID": "session-upper"}, + expected: "session-lower", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractSessionIDFromHeaders(tt.requestID, tt.headers) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestExtractSessionIDFromBody tests the extractSessionID function behavior +// NOTE: Body parsing is now disabled, so these tests verify headers-only behavior +// and proper fallback to requestID when no headers are present +func TestExtractSessionIDFromBody(t *testing.T) { + tests := []struct { + name string + requestID string + requestPath string + requestBody []byte + headers map[string]string + expected string + }{ + { + name: "header session_id takes priority over body", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","session_id":"session-body-123","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{"x-session-id": "session-header-456"}, + expected: "session-header-456", // Headers are checked first in extractSessionID + }, + { + name: "body session_id when no header session_id - now falls back to request ID", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","session_id":"session-body-123","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{}, + expected: "req-123", // MODIFIED: Now falls back to requestID since body parsing is disabled + }, + { + name: "header session_id when no body session_id", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{"x-session-id": "session-header-456"}, + expected: "session-header-456", + }, + { + name: "fallback to requestID when no session_id anywhere", + requestID: "req-fallback", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{}, + expected: "req-fallback", + }, + { + name: "invalid JSON body falls back to header session_id", + requestID: "req-invalid", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{invalid json`), + headers: map[string]string{"x-session-id": "session-header-789"}, + expected: "session-header-789", + }, + { + name: "invalid JSON body with no header falls back to requestID", + requestID: "req-invalid", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{invalid json`), + headers: map[string]string{}, + expected: "req-invalid", + }, + { + name: "non-chat-completion path ignores body", + requestID: "req-123", + requestPath: "/v1/models", + requestBody: []byte(`{"session_id":"session-body-123"}`), + headers: map[string]string{"x-session-id": "session-header-456"}, + expected: "session-header-456", + }, + { + name: "empty body session_id falls back to header", + requestID: "req-123", + requestPath: "/v1/chat/completions", + requestBody: []byte(`{"model":"test-model","session_id":"","messages":[{"role":"user","content":"hello"}]}`), + headers: map[string]string{"x-session-id": "session-header-456"}, + expected: "session-header-456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractSessionID(tt.requestID, tt.requestPath, tt.requestBody, tt.headers) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestExtractSessionID_HeadersOnlyBehavior specifically tests the headers-only behavior +// after disabling body parsing to ensure the modification works correctly +func TestExtractSessionID_HeadersOnlyBehavior(t *testing.T) { + tests := []struct { + name string + requestID string + headers map[string]string + requestBody []byte // This should be ignored + expected string + }{ + { + name: "headers-only: lowercase header with body present (body ignored)", + requestID: "req-123", + headers: map[string]string{"x-session-id": "session-from-header"}, + requestBody: []byte(`{"session_id":"session-from-body"}`), // Should be ignored + expected: "session-from-header", + }, + { + name: "headers-only: uppercase header with body present (body ignored)", + requestID: "req-456", + headers: map[string]string{"X-Session-ID": "session-from-header-upper"}, + requestBody: []byte(`{"session_id":"session-from-body"}`), // Should be ignored + expected: "session-from-header-upper", + }, + { + name: "headers-only: no header, body present (fallback to requestID)", + requestID: "req-789", + headers: map[string]string{}, + requestBody: []byte(`{"session_id":"session-from-body"}`), // Should be ignored + expected: "req-789", // Falls back to requestID + }, + { + name: "headers-only: empty header, body present (fallback to requestID)", + requestID: "req-empty", + headers: map[string]string{"x-session-id": ""}, + requestBody: []byte(`{"session_id":"session-from-body"}`), // Should be ignored + expected: "req-empty", // Falls back to requestID + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the main extractSessionID function (which now ignores body) + result := extractSessionID(tt.requestID, "/v1/chat/completions", tt.requestBody, tt.headers) + assert.Equal(t, tt.expected, result, "extractSessionID should ignore body and use headers-only logic") + + // Also test the headers-only function for consistency + headerResult := extractSessionIDFromHeaders(tt.requestID, tt.headers) + assert.Equal(t, tt.expected, headerResult, "extractSessionIDFromHeaders should match extractSessionID behavior") + }) + } +} + +func TestScheduler_CacheIntegration(t *testing.T) { + // Test that scheduler correctly integrates with cache when available + k8sClient := fake.NewSimpleClientset() + sessionCache := sessioninfo.NewMutexSessionCache() + + // Test with nil cache (fallback mode) + schedulerWithoutCache := scheduler.NewScheduler(k8sClient, sessionCache, nil) + assert.NotNil(t, schedulerWithoutCache, "Scheduler should be created even without cache") + + // Test that scheduler can be stopped gracefully + schedulerWithoutCache.Stop() + + // Give it a moment to stop + time.Sleep(10 * time.Millisecond) + + t.Log("Scheduler cache integration test completed") +} + +func TestScheduler_LoadAwarenessWithRealCache(t *testing.T) { + // This test would work if cache was properly initialized + // For now, we test the fallback behavior + + k8sClient := fake.NewSimpleClientset() + sessionCache := sessioninfo.NewMutexSessionCache() + + // Try to get cache (will fail in test environment) + cacheInstance, err := cache.Get() + if err != nil { + t.Logf("Cache not available in test environment: %v", err) + cacheInstance = nil + } + + // Create scheduler with cache (or nil) + sched := scheduler.NewScheduler(k8sClient, sessionCache, cacheInstance) + defer sched.Stop() + + // Verify scheduler was created successfully + assert.NotNil(t, sched, "Scheduler should be created") + + t.Log("Load awareness with real cache test completed") +} + +func TestScheduler_PodCapacityEstimation(t *testing.T) { + // Test pod capacity estimation logic + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "high-capacity-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "200", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "1.1.1.1"}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "medium-capacity-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "100", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "2.2.2.2"}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "default-capacity-pod", + Namespace: "default", + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "3.3.3.3"}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "invalid-annotation-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "invalid", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "4.4.4.4"}, + }, + } + + // Test that we can create pods with different capacity annotations + for i, pod := range pods { + assert.NotNil(t, pod, "Pod %d should not be nil", i) + assert.NotEmpty(t, pod.Name, "Pod %d should have a name", i) + + // Check annotation parsing logic + if pod.Annotations != nil { + if maxConcurrency, exists := pod.Annotations["aibrix.io/max-concurrent-requests"]; exists { + t.Logf("Pod %s has max-concurrent-requests: %s", pod.Name, maxConcurrency) + } + } + } + + t.Log("Pod capacity estimation test completed") +} + +func TestScheduler_Integration(t *testing.T) { + // Test that scheduler components work together + k8sClient := fake.NewSimpleClientset() + sessionCache := sessioninfo.NewMutexSessionCache() + + // This should not panic + assert.NotPanics(t, func() { + // Note: We can't easily test NewScheduler here because it starts background goroutines + // and requires proper cleanup. The actual integration is tested in the main test suite. + _ = sessionCache + _ = k8sClient + }) +} diff --git a/pkg/plugins/gateway/gateway_state_machine.go b/pkg/plugins/gateway/gateway_state_machine.go new file mode 100644 index 000000000..8bb2c6191 --- /dev/null +++ b/pkg/plugins/gateway/gateway_state_machine.go @@ -0,0 +1,472 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gateway + +import ( + "fmt" + "io" + "strconv" + "time" + + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/klog/v2" + + configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extProcFilterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler" + "github.com/vllm-project/aibrix/pkg/types" +) + +// ProcessStateMachine implements the state machine-based request processing with scheduler integration +func (s *Server) ProcessStateMachine(srv extProcPb.ExternalProcessor_ProcessServer) error { + ctx := srv.Context() + + // Initialize the state for this specific request stream + state := &perRequestState{ + currentState: stateAwaitingHeaders, + requestID: uuid.New().String(), + } + + // This channel will receive the scheduler's decision without blocking the main loop + // Buffer size of 1 is sufficient since we only expect one decision per request + decisionChan := make(chan *scheduler.Decision, 1) + + // Channel to receive messages from Envoy in a non-blocking way + // Buffer size of 10 to handle burst of messages in high load scenarios + msgChan := make(chan *extProcPb.ProcessingRequest, 10) + errChan := make(chan error, 1) + + klog.InfoS("new request stream started", "requestID", state.requestID) + klog.InfoS("DEBUG: line 57 reached", "requestID", state.requestID) + klog.InfoS("about to start message receive goroutine", "requestID", state.requestID) + klog.InfoS("DEBUG: line 59 reached", "requestID", state.requestID) + defer func() { + if r := recover(); r != nil { + klog.ErrorS(nil, "panic in Process function", "requestID", state.requestID, "panic", r) + } + klog.InfoS("request stream finished", "requestID", state.requestID) + + // CRITICAL: Ensure FinalizeJob is always called if a decision was made. + if state.schedulingDecision != nil && !state.completed { + klog.InfoS("finalizing job due to unexpected stream termination", "requestID", state.requestID) + // Calculate approximate times + if !state.dispatchTime.IsZero() { + waitTime := state.dispatchTime.Sub(state.submissionTime) + executionTime := time.Since(state.dispatchTime) + inheritedCST := state.schedulingDecision.Job.InheritedCST + s.scheduler.FinalizeJob(state.sessionID, inheritedCST, executionTime, waitTime) + } + } + + s.cache.DoneRequestCount(state.routerCtx, state.requestID, state.model, state.traceTerm) + if state.routerCtx != nil { + state.routerCtx.Delete() + } + }() + + // Start goroutine to receive messages from Envoy + klog.InfoS("launching message receive goroutine", "requestID", state.requestID) + go func() { + klog.InfoS("starting message receive goroutine", "requestID", state.requestID) + for { + klog.InfoS("waiting for message from Envoy", "requestID", state.requestID) + req, err := srv.Recv() + klog.InfoS("received message from Envoy", "requestID", state.requestID, "hasError", err != nil) + + // IMPORTANT: Check for context cancellation *before* sending to channel + select { + case <-ctx.Done(): + klog.InfoS("context cancelled in receive goroutine", "requestID", state.requestID) + // The main loop will handle the error, just exit. + return + default: + } + + if err != nil { + klog.InfoS("sending error to error channel", "requestID", state.requestID, "error", err) + errChan <- err + return + } + klog.InfoS("sending message to message channel", "requestID", state.requestID, "messageType", fmt.Sprintf("%T", req.Request)) + msgChan <- req + } + }() + + klog.InfoS("entering main loop", "requestID", state.requestID, "initialState", state.currentState) + for state.currentState != stateDone { + klog.InfoS("main loop iteration", "requestID", state.requestID, "currentState", state.currentState) + select { + case <-ctx.Done(): + klog.InfoS("request context cancelled", "requestID", state.requestID) + return ctx.Err() + + case decision := <-decisionChan: + // Event: Scheduler has made a decision + if state.currentState != stateAwaitingDecision { + klog.ErrorS(nil, "received scheduler decision in unexpected state", "requestID", state.requestID, "state", state.currentState) + continue + } + + klog.InfoS("received scheduling decision", "requestID", state.requestID, "sessionID", state.sessionID) + state.schedulingDecision = decision + state.dispatchTime = time.Now() // Record when scheduler granted permission + + if decision.Err != nil { + // Handle scheduling failure + errResp := generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorRouting, RawValue: []byte("true")}}}, + "scheduler error") + if err := srv.Send(errResp); err != nil { + klog.ErrorS(err, "failed to send scheduler error response", "requestID", state.requestID) + } + state.currentState = stateDone + continue + } + + // We have permission to proceed. Handle the actual routing + resp := s.handleScheduledRequest(state.routerCtx, state) + if err := srv.Send(resp); err != nil { + klog.ErrorS(err, "failed to send routing response", "requestID", state.requestID) + state.currentState = stateDone + } else { + state.currentState = stateForwarding + } + + case req := <-msgChan: + // Event: Received a message from Envoy + klog.InfoS("received message in main loop", "requestID", state.requestID, "messageType", fmt.Sprintf("%T", req.Request)) + if err := s.handleEnvoyMessage(srv, req, state, decisionChan); err != nil { + klog.ErrorS(err, "failed to handle envoy message", "requestID", state.requestID) + return err + } + klog.InfoS("successfully handled envoy message", "requestID", state.requestID, "newState", state.currentState) + + case err := <-errChan: + // Event: Error receiving from Envoy + if err == io.EOF { + klog.InfoS("envoy stream closed", "requestID", state.requestID) + state.currentState = stateDone + } else { + klog.ErrorS(err, "stream recv error", "requestID", state.requestID) + return status.Errorf(codes.Unknown, "cannot receive stream request: %v", err) + } + } + } + + return nil +} + +// handleEnvoyMessage processes messages received from Envoy based on current state +func (s *Server) handleEnvoyMessage(srv extProcPb.ExternalProcessor_ProcessServer, req *extProcPb.ProcessingRequest, state *perRequestState, decisionChan chan<- *scheduler.Decision) error { + klog.InfoS("received message from Envoy", "requestID", state.requestID, "messageType", fmt.Sprintf("%T", req.Request)) + + switch v := req.Request.(type) { + case *extProcPb.ProcessingRequest_RequestHeaders: + klog.InfoS("processing request headers", "requestID", state.requestID) + return s.handleRequestHeaders(srv, req, state) + + case *extProcPb.ProcessingRequest_RequestBody: + klog.InfoS("processing request body", "requestID", state.requestID) + return s.handleRequestBody(srv, req, state, decisionChan) + + case *extProcPb.ProcessingRequest_ResponseHeaders: + return s.handleResponseHeaders(srv, req, state) + + case *extProcPb.ProcessingRequest_ResponseBody: + return s.handleResponseBody(srv, req, state) + + default: + klog.InfoS("unknown request type", "requestID", state.requestID, "type", v) + return nil + } +} + +// handleRequestHeaders processes RequestHeaders message +func (s *Server) handleRequestHeaders(srv extProcPb.ExternalProcessor_ProcessServer, req *extProcPb.ProcessingRequest, state *perRequestState) error { + if state.currentState != stateAwaitingHeaders { + klog.ErrorS(nil, "received RequestHeaders in unexpected state", "requestID", state.requestID, "state", state.currentState) + return nil + } + + state.requestStartTime = time.Now() + _, user, rpm, routerCtx := s.HandleRequestHeaders(srv.Context(), state.requestID, req) + state.user, state.rpm, state.routerCtx = user, rpm, routerCtx + + // DEBUG: Check if routerCtx is nil after HandleRequestHeaders + klog.InfoS("DEBUG: HandleRequestHeaders returned", "requestID", state.requestID, "routerCtxIsNil", routerCtx == nil) + + // Extract temporary Session ID from headers only (no SessionCache interaction yet) + if routerCtx != nil { + state.sessionID = extractSessionIDFromHeaders(state.requestID, routerCtx.ReqHeaders) + klog.InfoS("extracted temporary session ID from headers", "requestID", state.requestID, "tempSessionID", state.sessionID) + } + + // CRITICAL: Send response to Envoy telling it to send us the RequestBody + resp := s.buildRequestHeadersResponse() + if err := srv.Send(resp); err != nil { + klog.ErrorS(err, "failed to send request headers response", "requestID", state.requestID) + state.currentState = stateDone + return err + } + + // Now transition to wait for the body + state.currentState = stateAwaitingBody + klog.InfoS("sent headers response, waiting for body", "requestID", state.requestID) + return nil +} + +// handleRequestBody processes RequestBody message and initiates scheduling +func (s *Server) handleRequestBody(srv extProcPb.ExternalProcessor_ProcessServer, req *extProcPb.ProcessingRequest, state *perRequestState, decisionChan chan<- *scheduler.Decision) error { + klog.InfoS("handleRequestBody called", "requestID", state.requestID, "currentState", state.currentState) + + if state.currentState != stateAwaitingBody { + klog.ErrorS(nil, "received RequestBody in unexpected state", "requestID", state.requestID, "state", state.currentState) + return nil + } + + klog.InfoS("calling HandleRequestBody", "requestID", state.requestID, "schedulerEnabled", s.scheduler != nil) + // CRITICAL FIX: Pass state.routerCtx directly - it's already *types.RoutingContext which implements context.Context + // Don't cast to context.Context as that loses the original type information needed for type assertion + resp, model, routerCtx, stream, traceTerm := s.HandleRequestBody(state.routerCtx, state.requestID, req, state.user) + state.model, state.routerCtx, state.stream, state.traceTerm = model, routerCtx, stream, traceTerm + + klog.InfoS("HandleRequestBody returned", "requestID", state.requestID, "hasResponse", resp != nil, "model", model) + + if resp != nil { + // Handle cases where HandleRequestBody returns an error response + if err := srv.Send(resp); err != nil { + klog.ErrorS(err, "failed to send request body error response", "requestID", state.requestID) + } + state.currentState = stateDone + return nil + } + + // CRITICAL: state.routerCtx must not be nil at this point + // If it's nil, it indicates a serious bug in HandleRequestBody + if state.routerCtx == nil { + klog.ErrorS(nil, "CRITICAL BUG: state.routerCtx is nil after HandleRequestBody", + "requestID", state.requestID, "hasResponse", resp != nil) + return fmt.Errorf("critical error: routerCtx is nil after HandleRequestBody for request %s", state.requestID) + } + + // MODIFIED: Use only headers-based session ID (no body parsing) + // The session ID was already extracted from headers in the headers phase + finalSessionID := state.sessionID + + // CRITICAL: Session ID is required - client must provide it in headers + // If finalSessionID equals requestID, it means no real session ID was found (fallback was used) + if finalSessionID == state.requestID { + klog.ErrorS(nil, "session ID is required but not provided in headers", + "requestID", state.requestID, "path", state.routerCtx.ReqPath, + "headerSessionID", state.sessionID) + return fmt.Errorf("session ID is required but not found in headers for request %s", state.requestID) + } + + klog.InfoS("using session ID from headers", "requestID", state.requestID, "sessionID", finalSessionID) + // No need to reassign since finalSessionID is already state.sessionID + + // Check if scheduler is enabled and submit job + if s.scheduler != nil { + // Submit job to scheduler asynchronously with timeout protection + state.submissionTime = time.Now() + go func() { + klog.InfoS("submitting job to scheduler", "requestID", state.requestID, "sessionID", state.sessionID) + + // TODO: Add timeout protection when scheduler supports context + decision, err := s.scheduler.SubmitJob(state.routerCtx, state.sessionID) + if err != nil { + decision = &scheduler.Decision{Err: err} + } + + // Non-blocking send to avoid goroutine leak if main loop exits + select { + case decisionChan <- decision: + case <-srv.Context().Done(): + // Main context cancelled, don't send decision + } + }() + + state.currentState = stateAwaitingDecision + return nil + } + + // If we reach here, scheduler is disabled but HandleRequestBody returned nil + // This should not happen in normal operation + klog.ErrorS(nil, "HandleRequestBody returned nil but scheduler is disabled", "requestID", state.requestID) + state.currentState = stateDone + return nil +} + +// handleResponseHeaders processes ResponseHeaders message +func (s *Server) handleResponseHeaders(srv extProcPb.ExternalProcessor_ProcessServer, req *extProcPb.ProcessingRequest, state *perRequestState) error { + if state.currentState != stateForwarding { + klog.ErrorS(nil, "received ResponseHeaders in unexpected state", "requestID", state.requestID, "state", state.currentState) + return nil + } + + resp, isRespError, respErrorCode := s.HandleResponseHeaders(srv.Context(), state.requestID, state.model, req) + state.isRespError, state.respErrorCode = isRespError, respErrorCode + + if isRespError && respErrorCode == 500 { + // for error code 500, ProcessingRequest_ResponseBody is not invoked + resp = s.responseErrorProcessing(srv.Context(), resp, respErrorCode, state.model, state.requestID, "") + } + + if isRespError && respErrorCode == 401 { + // Early return due to unauthorized or canceled context + resp = s.responseErrorProcessing(srv.Context(), resp, respErrorCode, state.model, state.requestID, `{"error":"unauthorized"}`) + } + + if err := srv.Send(resp); err != nil { + klog.ErrorS(err, "failed to send response headers", "requestID", state.requestID) + state.currentState = stateDone + } + + return nil +} + +// handleResponseBody processes ResponseBody message and finalizes the job +func (s *Server) handleResponseBody(srv extProcPb.ExternalProcessor_ProcessServer, req *extProcPb.ProcessingRequest, state *perRequestState) error { + if state.currentState != stateForwarding { + klog.ErrorS(nil, "received ResponseBody in unexpected state", "requestID", state.requestID, "state", state.currentState) + return nil + } + + var resp *extProcPb.ProcessingResponse + if state.isRespError { + resp = s.responseErrorProcessing(srv.Context(), resp, state.respErrorCode, state.model, state.requestID, + string(req.Request.(*extProcPb.ProcessingRequest_ResponseBody).ResponseBody.GetBody())) + } else { + resp, state.completed = s.HandleResponseBody(srv.Context(), state.requestID, req, state.user, state.rpm, state.model, state.stream, state.traceTerm, state.completed) + } + + // Finalize the job upon completion + if state.completed && state.sessionID != "" && state.schedulingDecision != nil { + // WaitTime is the duration from submission to dispatch. + waitTime := state.dispatchTime.Sub(state.submissionTime) + // ExecutionTime is the duration from dispatch to completion. + executionTime := time.Since(state.dispatchTime) + inheritedCST := state.schedulingDecision.Job.InheritedCST + + klog.InfoS("finalizing job", "requestID", state.requestID, "sessionID", state.sessionID, + "executionTime", executionTime, "waitTime", waitTime, "inheritedCST", inheritedCST) + + s.scheduler.FinalizeJob(state.sessionID, inheritedCST, executionTime, waitTime) + state.currentState = stateDone + } + + if err := srv.Send(resp); err != nil { + klog.ErrorS(err, "failed to send response body", "requestID", state.requestID) + state.currentState = stateDone + } + + return nil +} + +// handleScheduledRequest performs routing after receiving scheduler permission +func (s *Server) handleScheduledRequest(routingCtx *types.RoutingContext, state *perRequestState) *extProcPb.ProcessingResponse { + // Get affinity hint from the session cache + if sessionState, exists := s.sessionCache.GetState(state.sessionID); exists && sessionState.PodAffinity != "" { + klog.InfoS("using pod affinity hint", "requestID", state.requestID, "sessionID", state.sessionID, "podAffinity", sessionState.PodAffinity) + // TODO: Use pod affinity hint in routing decision + } + + // Get pods for the model + podsArr, err := s.cache.ListPodsByModel(state.model) + if err != nil || podsArr == nil || len(podsArr.All()) == 0 { + klog.ErrorS(err, "no ready pod available", "requestID", state.requestID, "model", state.model) + return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorNoModelBackends, RawValue: []byte("true")}}}, + "no ready pods available") + } + + // Perform routing + targetPodIP, err := s.selectTargetPod(routingCtx, podsArr) + if targetPodIP == "" || err != nil { + klog.ErrorS(err, "failed to select target pod", "requestID", state.requestID, "model", state.model) + return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorRouting, RawValue: []byte("true")}}}, + "error on selecting target pod") + } + + // Update session cache with pod affinity + if routingCtx.HasRouted() { + podName := routingCtx.TargetPod().Name + s.sessionCache.UpdateAffinity(state.sessionID, podName) + } + + // Build response headers + headers := buildEnvoyProxyHeaders([]*configPb.HeaderValueOption{}, + HeaderRoutingStrategy, string(routingCtx.Algorithm), + HeaderTargetPod, targetPodIP, + "content-length", strconv.Itoa(len(routingCtx.ReqBody))) + + klog.InfoS("request start", "requestID", state.requestID, "requestPath", routingCtx.ReqPath, "model", state.model, "stream", state.stream, "routingAlgorithm", routingCtx.Algorithm, "targetPodIP", targetPodIP, "routingDuration", routingCtx.GetRoutingDelay()) + + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_Body{ + Body: routingCtx.ReqBody, + }, + }, + }, + }, + }, + } +} + +// buildRequestHeadersResponse creates a response that tells Envoy to send us the RequestBody +func (s *Server) buildRequestHeadersResponse() *extProcPb.ProcessingResponse { + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: HeaderWentIntoReqHeaders, + RawValue: []byte("true"), + }, + }, + }, + }, + ClearRouteCache: true, + }, + }, + }, + // This is the MOST IMPORTANT part. We are telling Envoy: + // "I'm interested in the body. Please stream it to me." + ModeOverride: &extProcFilterPb.ProcessingMode{ + RequestBodyMode: extProcFilterPb.ProcessingMode_BUFFERED, // Use BUFFERED for complete body + ResponseBodyMode: extProcFilterPb.ProcessingMode_STREAMED, + }, + } +} diff --git a/pkg/plugins/gateway/gateway_state_machine_test.go b/pkg/plugins/gateway/gateway_state_machine_test.go new file mode 100644 index 000000000..f19a4d58c --- /dev/null +++ b/pkg/plugins/gateway/gateway_state_machine_test.go @@ -0,0 +1,274 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gateway + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPerRequestState_StateTransitions(t *testing.T) { + tests := []struct { + name string + initialState requestState + expectedStates []requestState + }{ + { + name: "normal flow", + initialState: stateAwaitingHeaders, + expectedStates: []requestState{ + stateAwaitingHeaders, + stateAwaitingBody, + stateAwaitingDecision, + stateForwarding, + stateDone, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := &perRequestState{ + currentState: tt.initialState, + requestID: "test-request", + } + + assert.Equal(t, tt.expectedStates[0], state.currentState) + + // Test state transitions + for i := 1; i < len(tt.expectedStates); i++ { + state.currentState = tt.expectedStates[i] + assert.Equal(t, tt.expectedStates[i], state.currentState) + } + }) + } +} + +func TestSchedulerIntegration_StateTransitions(t *testing.T) { + // Test that scheduler integration works with state machine + state := &perRequestState{ + currentState: stateAwaitingDecision, + requestID: "test-scheduler-request", + submissionTime: time.Now(), + } + + // Simulate scheduler decision received + state.dispatchTime = time.Now().Add(10 * time.Millisecond) + state.currentState = stateForwarding + + assert.Equal(t, stateForwarding, state.currentState) + assert.False(t, state.dispatchTime.IsZero()) + + // Test timing calculations + waitTime := state.dispatchTime.Sub(state.submissionTime) + assert.Greater(t, waitTime, time.Duration(0)) + + t.Log("Scheduler integration state transitions test completed") +} + +func TestLoadAwareScheduling_StateTracking(t *testing.T) { + // Test that load-aware scheduling properly tracks state + states := []*perRequestState{ + { + currentState: stateAwaitingDecision, + requestID: "high-load-request-1", + submissionTime: time.Now(), + }, + { + currentState: stateAwaitingDecision, + requestID: "high-load-request-2", + submissionTime: time.Now().Add(1 * time.Millisecond), + }, + { + currentState: stateAwaitingDecision, + requestID: "high-load-request-3", + submissionTime: time.Now().Add(2 * time.Millisecond), + }, + } + + // Simulate batch scheduling decision + batchDispatchTime := time.Now().Add(5 * time.Millisecond) + for _, state := range states { + state.dispatchTime = batchDispatchTime + state.currentState = stateForwarding + } + + // Verify all states transitioned correctly + for i, state := range states { + assert.Equal(t, stateForwarding, state.currentState, "State %d should be forwarding", i) + assert.Equal(t, batchDispatchTime, state.dispatchTime, "State %d should have batch dispatch time", i) + + waitTime := state.dispatchTime.Sub(state.submissionTime) + assert.Greater(t, waitTime, time.Duration(0), "State %d should have positive wait time", i) + } + + t.Log("Load-aware scheduling state tracking test completed") +} + +func TestCapacityAwareScheduling_Metrics(t *testing.T) { + // Test capacity-aware scheduling metrics tracking + now := time.Now() + + // Simulate different capacity scenarios + scenarios := []struct { + name string + podCapacity int + requestCount int + expectedBatchSize int + }{ + { + name: "low capacity pod", + podCapacity: 1, + requestCount: 10, + expectedBatchSize: 1, + }, + { + name: "high capacity pod", + podCapacity: 100, + requestCount: 10, + expectedBatchSize: 10, + }, + { + name: "overloaded scenario", + podCapacity: 50, + requestCount: 100, + expectedBatchSize: 50, + }, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + // Create states for requests + states := make([]*perRequestState, scenario.requestCount) + for i := 0; i < scenario.requestCount; i++ { + states[i] = &perRequestState{ + currentState: stateAwaitingDecision, + requestID: fmt.Sprintf("capacity-test-request-%d", i), + submissionTime: now.Add(time.Duration(i) * time.Millisecond), + } + } + + // Simulate capacity-aware batch scheduling + batchSize := min(scenario.expectedBatchSize, len(states)) + batchDispatchTime := now.Add(10 * time.Millisecond) + + // Only dispatch up to capacity + for i := 0; i < batchSize; i++ { + states[i].dispatchTime = batchDispatchTime + states[i].currentState = stateForwarding + } + + // Verify correct number were dispatched + dispatchedCount := 0 + for _, state := range states { + if state.currentState == stateForwarding { + dispatchedCount++ + } + } + + assert.Equal(t, batchSize, dispatchedCount, "Should dispatch exactly batch size") + + // Verify remaining are still waiting + waitingCount := 0 + for _, state := range states { + if state.currentState == stateAwaitingDecision { + waitingCount++ + } + } + + expectedWaiting := scenario.requestCount - batchSize + assert.Equal(t, expectedWaiting, waitingCount, "Remaining should still be waiting") + }) + } + + t.Log("Capacity-aware scheduling metrics test completed") +} + +func TestServer_WithStateMachine(t *testing.T) { + // Test that state machine components are properly defined + assert.Equal(t, requestState(0), stateAwaitingHeaders) + assert.Equal(t, requestState(1), stateAwaitingBody) + assert.Equal(t, requestState(2), stateAwaitingDecision) + assert.Equal(t, requestState(3), stateForwarding) + assert.Equal(t, requestState(4), stateDone) + + // Test perRequestState initialization + state := &perRequestState{ + currentState: stateAwaitingHeaders, + requestID: "test-123", + sessionID: "session-456", + } + + assert.Equal(t, stateAwaitingHeaders, state.currentState) + assert.Equal(t, "test-123", state.requestID) + assert.Equal(t, "session-456", state.sessionID) +} + +func TestRequestState_Constants(t *testing.T) { + // Test that state constants are properly defined + assert.Equal(t, requestState(0), stateAwaitingHeaders) + assert.Equal(t, requestState(1), stateAwaitingBody) + assert.Equal(t, requestState(2), stateAwaitingDecision) + assert.Equal(t, requestState(3), stateForwarding) + assert.Equal(t, requestState(4), stateDone) +} + +func TestPerRequestState_Initialization(t *testing.T) { + state := &perRequestState{ + currentState: stateAwaitingHeaders, + requestID: "test-123", + sessionID: "session-456", + } + + assert.Equal(t, stateAwaitingHeaders, state.currentState) + assert.Equal(t, "test-123", state.requestID) + assert.Equal(t, "session-456", state.sessionID) + assert.False(t, state.completed) + assert.False(t, state.isRespError) + assert.Equal(t, int64(0), state.rpm) + assert.Equal(t, int64(0), state.traceTerm) + + // Test new timing fields + assert.True(t, state.requestStartTime.IsZero()) + assert.True(t, state.submissionTime.IsZero()) + assert.True(t, state.dispatchTime.IsZero()) + assert.Nil(t, state.schedulingDecision) +} + +func TestPerRequestState_TimingCalculations(t *testing.T) { + now := time.Now() + state := &perRequestState{ + requestStartTime: now, + submissionTime: now.Add(10 * time.Millisecond), + dispatchTime: now.Add(50 * time.Millisecond), + } + + // Test wait time calculation (submission to dispatch) + expectedWaitTime := 40 * time.Millisecond + actualWaitTime := state.dispatchTime.Sub(state.submissionTime) + assert.Equal(t, expectedWaitTime, actualWaitTime) + + // Test execution time would be calculated from dispatchTime to completion + // (this would be done with time.Since(state.dispatchTime) in real code) + completionTime := now.Add(100 * time.Millisecond) + expectedExecutionTime := 50 * time.Millisecond + actualExecutionTime := completionTime.Sub(state.dispatchTime) + assert.Equal(t, expectedExecutionTime, actualExecutionTime) +} diff --git a/pkg/plugins/gateway/scheduler/interface.go b/pkg/plugins/gateway/scheduler/interface.go new file mode 100644 index 000000000..cf166a56d --- /dev/null +++ b/pkg/plugins/gateway/scheduler/interface.go @@ -0,0 +1,49 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduler + +import ( + "time" + + "github.com/vllm-project/aibrix/pkg/types" +) + +// Decision is the output from the scheduler, +// containing the job to be processed. +type Decision struct { + Job *SchedulingJob // The job to be processed. + Err error // Any error encountered during scheduling. +} + +// Scheduler defines the interface for a session-aware request scheduler. +// It is designed to be pluggable into the AIBrix Gateway. +type Scheduler interface { + // SubmitJob adds a new job to the scheduler's queue and blocks until a + // decision is made. The decision contains the job itself, + // empowering the caller to proceed. + // Plus, one job is one request. + SubmitJob(ctx *types.RoutingContext, sessionID string) (*Decision, error) + + // FinalizeJob updates the session state after a job has been completed. + FinalizeJob(sessionID string, inheritedCST, executionTime, waitTime time.Duration) + + // Stop gracefully shuts down the scheduler's background processing loop. + Stop() + + // TODO: Add a GetStats() method for monitoring in the future. + // GetSchedulerStats() *SchedulerStats +} diff --git a/pkg/plugins/gateway/scheduler/job.go b/pkg/plugins/gateway/scheduler/job.go new file mode 100644 index 000000000..6c31f8aec --- /dev/null +++ b/pkg/plugins/gateway/scheduler/job.go @@ -0,0 +1,121 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduler + +import ( + "time" + + "github.com/vllm-project/aibrix/pkg/types" +) + +// SchedulingJob represents a request waiting in the priority queue. +// It holds all context needed for scheduling and final execution. +type SchedulingJob struct { + SessionID string + RequestContext *types.RoutingContext + + // Priority-related fields, populated on submission. + InheritedCST time.Duration + TotalWaitTime time.Duration + + // Timestamps for aging and starvation detection. + SubmissionTime time.Time + + // ResultChan is used by the scheduling loop to send the decision back. + ResultChan chan *Decision + + // for heap.Interface + index int +} + +const ( + STARVATION_THRESHOLD = 2.0 +) + +// PriorityQueue implements heap.Interface for SchedulingJob pointers. +// ----------------------------------------------------------------------------------- +// In Autellix paper, they use MLFQ instead of a single priority queue. +// ----------------------------------------------------------------------------------- +// - Avoiding thrashing: If priorities are sequential, +// a request with a priority of 100.0 might be preempted by a new request +// with a priority of 99.9. +// This extremely small difference in priority can lead to preemption, +// resulting in context switching overhead (KV cache swap) that +// far outweighs the benefits of preemption. +// This can cause system thrashing. +// - Avoiding worst-case scenarios: The paper notes that in some cases, +// the performance of preemptive scheduling with sequential priorities can degrade, +// even worse than simple FCFS. +// - Batch-friendly: Discretizing priorities into several queues +// makes it natural to batch requests within each queue. +// +// ----------------------------------------------------------------------------------- +// HOWEVER, I plan to use a single priority queue for the following reasons: +// - We perform request-level scheduling at the Gateway layer. +// We do not interrupt requests already executing on a Pod. +// Our "preemption" simply determines which Pod should be sent the next request. +// In this model, the context switch cost is zero. +// - The implementation of a single Heap is much simpler and easier to maintain. +// - MLFQ requires you to predefine the number of queues and the priority range +// for each queue (for example, queues with a CST of 0-1 seconds go to Q1...). +// This requires fine-tuning and is not very flexible. + +type PriorityQueue []*SchedulingJob + +func (pq PriorityQueue) Len() int { return len(pq) } + +// Less is the core of the priority logic, including anti-starvation. +func (pq PriorityQueue) Less(i, j int) bool { + + // Calculate effective priority for item i + priority_i := float64(pq[i].InheritedCST) + // Add a small epsilon to avoid division by zero + if priority_i > 1 && float64(pq[i].TotalWaitTime)/(priority_i+1e-9) > STARVATION_THRESHOLD { + priority_i *= 0.1 // Apply a significant priority boost + } + + // Calculate effective priority for item j + priority_j := float64(pq[j].InheritedCST) + if priority_j > 1 && float64(pq[j].TotalWaitTime)/(priority_j+1e-9) > STARVATION_THRESHOLD { + priority_j *= 0.1 + } + + // The smaller the effective priority, the higher the actual priority + return priority_i < priority_j +} + +func (pq PriorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +func (pq *PriorityQueue) Push(x any) { + job := x.(*SchedulingJob) + job.index = len(*pq) + *pq = append(*pq, job) +} + +func (pq *PriorityQueue) Pop() any { + old := *pq + n := len(old) + job := old[n-1] + old[n-1] = nil // avoid memory leak + job.index = -1 + *pq = old[0 : n-1] + return job +} diff --git a/pkg/plugins/gateway/scheduler/scheduler.go b/pkg/plugins/gateway/scheduler/scheduler.go new file mode 100644 index 000000000..e86b3a268 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/scheduler.go @@ -0,0 +1,529 @@ +package scheduler + +import ( + "container/heap" + "strconv" + "sync/atomic" + "time" + + "github.com/vllm-project/aibrix/pkg/cache" + "github.com/vllm-project/aibrix/pkg/metrics" + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler/sessioninfo" + "github.com/vllm-project/aibrix/pkg/types" + v1 "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/klog/v2" +) + +// inProcessScheduler implements a high-throughput, low-latency scheduler. +// It uses a lock-free channel for job submission and a single goroutine +// for all stateful processing to eliminate lock contention on the hot path. +type inProcessScheduler struct { + // A buffered channel for lock-free submission from multiple goroutines. + // This is the main ingress point for new requests. + submitChan chan *SchedulingJob + + // The session cache, which is thread-safe itself. + sessionCache *sessioninfo.MutexSessionCache + + // A channel to signal all background goroutines to stop gracefully. + stopChan chan struct{} + + // Kubernetes client to get information about the cluster state. + k8sClient kubernetes.Interface + + // --- Dynamic Batching Fields --- + // A thread-safe container for the list of currently healthy and routable pods. + // Updated periodically by the podWatcherLoop. + healthyPods atomic.Value + // A thread-safe counter for requests that have been scheduled but not yet finalized. + // Used to calculate available capacity. + inflightRequests atomic.Int64 + + // --- Load Awareness Integration --- + // Cache for accessing real-time load metrics from Router + cache cache.Cache + // Load provider for advanced capacity calculation (reuses Router's logic) + loadProvider cache.CappedLoadProvider + + // --- Batch Size Smoothing --- + // Last calculated batch size for smoothing (atomic for thread safety) + lastBatchSize atomic.Int32 +} + +const ( + // The size of the channel buffer for job submission. + // This should be large enough to absorb bursts of requests. + SUBMIT_CHAN_BUFFER_SIZE = 1024 + // The interval at which the processing loop checks for new jobs and tries to schedule them. + PROCESSING_LOOP_INTERVAL = 5 * time.Millisecond + + // Default estimated concurrent capacity per pod when no annotation is provided + DEFAULT_POD_CONCURRENT_CAPACITY = 1 + + // Multiplier for estimating capacity from current running requests + CAPACITY_ESTIMATION_MULTIPLIER = 2 + + // Oversubscription factor - Router doesn't allow this, so we set to 1.0 (no oversubscription) + SCHEDULER_OVERSUBSCRIPTION_FACTOR = 1.0 + + // Batch size smoothing factor (exponential moving average alpha) + BATCH_SIZE_SMOOTHING_ALPHA = 0.3 + + // Pass-through mode batch size (essentially disables batching) + PASS_THROUGH_BATCH_SIZE = 1000 +) + +// NewScheduler creates and starts the new high-throughput scheduler. +func NewScheduler(k8sClient kubernetes.Interface, sessionCache *sessioninfo.MutexSessionCache, cacheInstance cache.Cache) Scheduler { + // Use provided cache instance for load awareness + if cacheInstance == nil { + klog.InfoS("no cache instance provided, falling back to basic scheduling") + } + + // Initialize load provider for advanced capacity calculation + var loadProvider cache.CappedLoadProvider + if cacheInstance != nil { + pendingLoadProvider, err := cache.NewPendingLoadProvider() + if err != nil { + klog.ErrorS(err, "failed to create PendingLoadProvider, using basic capacity calculation") + } else { + loadProvider = pendingLoadProvider + // klog.InfoS("scheduler initialized with advanced load awareness") + } + } + + s := &inProcessScheduler{ + // A large buffer can absorb bursts of requests without blocking SubmitJob. + submitChan: make(chan *SchedulingJob, SUBMIT_CHAN_BUFFER_SIZE), + sessionCache: sessionCache, + stopChan: make(chan struct{}), + k8sClient: k8sClient, + cache: cacheInstance, + loadProvider: loadProvider, + } + + // Initialize atomic values with empty/zero state. + s.healthyPods.Store([]*v1.Pod{}) + s.inflightRequests.Store(0) + + // Start the single, powerful processing loop. + go s.processingLoop() + // Start background goroutine for pod watching. + go s.podWatcherLoop() + + return s +} + +// podWatcherLoop periodically aggregates healthy pods from cache across all models +// to determine the cluster's processing capacity. This avoids direct K8s API calls +// and leverages the existing cache infrastructure. +func (s *inProcessScheduler) podWatcherLoop() { + // A ticker to trigger the pod list refresh at regular intervals. + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + // klog.InfoS("Pod watcher loop started") + + for { + select { + case <-s.stopChan: + // klog.InfoS("Pod watcher loop stopping") + return + case <-ticker.C: + // Use cache to get all healthy pods across all models + // This is the correct approach - no hardcoded labels or direct K8s API calls + var allPods []*v1.Pod + // totalPods := 0 + + if s.cache != nil { + // Get all models from cache + models := s.cache.ListModels() + klog.V(4).InfoS("Scheduler found models in cache", "modelCount", len(models)) + + // Aggregate pods from all models + podMap := make(map[string]*v1.Pod) // Use map to deduplicate pods + for _, model := range models { + podList, err := s.cache.ListPodsByModel(model) + if err != nil { + klog.V(4).InfoS("Failed to get pods for model", "model", model, "error", err) + continue + } + + for _, pod := range podList.All() { + podMap[pod.Name] = pod + } + } + + // Convert map to slice + for _, pod := range podMap { + allPods = append(allPods, pod) + } + // totalPods = len(allPods) + } else { + // klog.V(3).InfoS("Cache not available, scheduler will use fallback capacity calculation") + } + + // klog.V(3).InfoS("Pod watcher found pods from cache", "totalPods", totalPods) + + readyPods := s.filterReadyPods(allPods) + s.healthyPods.Store(readyPods) + + // klog.InfoS("Pod watcher updated healthy pods", "readyPods", len(readyPods), "totalPods", totalPods) + + // Log pod details for debugging + // for _, pod := range readyPods { + // klog.V(4).InfoS("Ready pod found", + // "name", pod.Name, + // "ip", pod.Status.PodIP, + // "phase", pod.Status.Phase, + // "model", pod.Labels["model.aibrix.ai/name"]) + // } + } + } +} + +// filterReadyPods filters pods that are ready and available for scheduling +func (s *inProcessScheduler) filterReadyPods(pods []*v1.Pod) []*v1.Pod { + var readyPods []*v1.Pod + for _, pod := range pods { + if pod.Status.Phase == v1.PodRunning { + // Check if all containers are ready + allReady := true + for _, condition := range pod.Status.Conditions { + if condition.Type == v1.PodReady && condition.Status != v1.ConditionTrue { + allReady = false + break + } + } + if allReady { + readyPods = append(readyPods, pod) + } + } + } + return readyPods +} + +// SubmitJob is now extremely lightweight and lock-free. +// It simply sends the job to a channel and immediately returns a channel +// on which the caller can wait for the final decision. +func (s *inProcessScheduler) SubmitJob(ctx *types.RoutingContext, sessionID string) (*Decision, error) { + // The job no longer contains state from the cache. + // State enrichment happens inside the processingLoop. + job := &SchedulingJob{ + SessionID: sessionID, + RequestContext: ctx, + SubmissionTime: time.Now(), + ResultChan: make(chan *Decision, 1), + } + + // This send operation is thread-safe and highly optimized. + // It will only block if the channel buffer is full, which acts as a natural backpressure mechanism. + s.submitChan <- job + + // The goroutine now blocks waiting for the decision, NOT on a lock. + decision := <-job.ResultChan + if decision.Err == nil { + s.inflightRequests.Add(1) + } + return decision, decision.Err +} + +// FinalizeJob is also made asynchronous by sending a message. +// We can use a different type or a special field in SchedulingJob to signify finalization. +func (s *inProcessScheduler) FinalizeJob(sessionID string, inheritedCST, executionTime, waitTime time.Duration) { + s.inflightRequests.Add(-1) + // The state update is still directly on the cache, which is thread-safe. + // This is fine because it doesn't contend with the scheduling logic. + s.sessionCache.UpdateState(sessionID, inheritedCST, executionTime, waitTime) + + // A crucial step: a finished job might have freed up capacity. + // We need to trigger the scheduling loop to check again. + // We can send a nil job as a signal. + select { + case s.submitChan <- nil: // Send a special signal to wake up the loop + default: // Don't block if the channel is full; the loop is already busy. + } +} + +func (s *inProcessScheduler) Stop() { + close(s.stopChan) +} + +// processingLoop is the single-threaded owner of the priority queue. +// This eliminates ALL lock contention for scheduling logic. +func (s *inProcessScheduler) processingLoop() { + // The priority queue is now a local variable, not shared. + priorityQueue := make(PriorityQueue, 0) + heap.Init(&priorityQueue) + + // Ticker for periodic, throughput-oriented scheduling. + ticker := time.NewTicker(PROCESSING_LOOP_INTERVAL) + defer ticker.Stop() + + // klog.InfoS("Scheduler processing loop started") + + for { + select { + case <-s.stopChan: + // klog.InfoS("Scheduler processing loop stopping") + return + + case job := <-s.submitChan: + if job == nil { // This is a wakeup signal from FinalizeJob + klog.V(5).InfoS("Received wakeup signal from FinalizeJob") + // Just continue to the unified scheduling point + break + } + + // klog.V(4).InfoS("Received new job", + // "sessionID", job.SessionID, + // "queueLength", len(priorityQueue)) + + // --- State Enrichment --- + cst, waitTime := s.sessionCache.GetOrCreateForScheduler(job.SessionID) + job.InheritedCST = cst + job.TotalWaitTime = waitTime + + // --- Enqueue --- + heap.Push(&priorityQueue, job) + + // klog.V(4).InfoS("Job enqueued", + // "sessionID", job.SessionID, + // "newQueueLength", len(priorityQueue)) + + case <-ticker.C: + // --- Periodic Scheduling --- + // klog.V(5).InfoS("Periodic scheduling tick", "queueLength", len(priorityQueue)) + } + + // --- Unified Scheduling Point --- + // At the end of every loop iteration, regardless of what woke it up, + // try to schedule a batch. This eliminates code duplication and + // ensures consistent scheduling behavior. + s.scheduleBatch(&priorityQueue) + } +} + +// scheduleBatch is now a private method that operates on the local priority queue. +func (s *inProcessScheduler) scheduleBatch(pq *PriorityQueue) { + queueLength := len(*pq) + if queueLength == 0 { + return + } + + // The batching logic with smoothing: based on dynamic capacity with exponential moving average. + batchSize := s.calculateSmoothedBatchSize() + + // klog.V(4).InfoS("Scheduling batch", + // "queueLength", queueLength, + // "calculatedBatchSize", batchSize) + + if batchSize <= 0 { + // klog.V(3).InfoS("Batch size is 0, skipping scheduling", + // "queueLength", queueLength) + return + } + + numToPop := min(batchSize, queueLength) + if numToPop <= 0 { + return + } + + // klog.InfoS("Dispatching batch of requests", + // "batchSize", numToPop, + // "queueLength", queueLength) + + // Pop and dispatch concurrently. + for i := 0; i < numToPop; i++ { + job := heap.Pop(pq).(*SchedulingJob) + decision := &Decision{ + Job: job, + Err: nil, + } + + // klog.V(4).InfoS("Dispatching job", + // "sessionID", job.SessionID, + // "waitTime", time.Since(job.SubmissionTime)) + + job.ResultChan <- decision + } +} + +// calculateSmoothedBatchSize determines how many new requests can be scheduled in this cycle. +// Uses exponential moving average to smooth out capacity fluctuations. +func (s *inProcessScheduler) calculateSmoothedBatchSize() int { + rawBatchSize := s.calculateBatchSize() + lastBatchSize := int(s.lastBatchSize.Load()) + + var smoothed int + if lastBatchSize == 0 { + // First time or after reset, use raw value + smoothed = rawBatchSize + } else { + // Apply exponential moving average smoothing + // smoothed = (1-alpha) * last + alpha * current + smoothed = int((1.0-BATCH_SIZE_SMOOTHING_ALPHA)*float64(lastBatchSize) + + BATCH_SIZE_SMOOTHING_ALPHA*float64(rawBatchSize)) + } + + // Store the smoothed value for next iteration + s.lastBatchSize.Store(int32(smoothed)) + + // klog.V(5).InfoS("batch size smoothing", + // "rawBatchSize", rawBatchSize, + // "lastBatchSize", lastBatchSize, + // "smoothedBatchSize", smoothed, + // "alpha", BATCH_SIZE_SMOOTHING_ALPHA) + + return smoothed +} + +// calculateBatchSize determines how many new requests can be scheduled in this cycle. +// Now uses Router's load awareness for accurate capacity calculation. +func (s *inProcessScheduler) calculateBatchSize() int { + // Load the list of healthy pods atomically. + pods, _ := s.healthyPods.Load().([]*v1.Pod) + + // klog.V(4).InfoS("Calculating batch size", "healthyPods", len(pods)) + + if len(pods) == 0 { + klog.V(3).InfoS("No healthy pods available, batch size = 0") + return 0 + } + + var batchSize int + // Use advanced load awareness if available + if s.cache != nil && s.loadProvider != nil { + klog.V(4).InfoS("Using advanced batch size calculation with load provider") + batchSize = s.calculateAdvancedBatchSize(pods) + } else { + // klog.V(4).InfoS("Using fallback batch size calculation") + batchSize = s.calculateImprovedFallbackBatchSize(pods) + } + + // klog.V(3).InfoS("Calculated batch size", "batchSize", batchSize, "podCount", len(pods)) + return batchSize +} + +// calculateAdvancedBatchSize uses Router's load awareness for precise capacity calculation +func (s *inProcessScheduler) calculateAdvancedBatchSize(pods []*v1.Pod) int { + totalAvailableCapacity := 0 + + for _, pod := range pods { + // Get current utilization using Router's load provider + utilization, err := s.loadProvider.GetUtilization(nil, pod) + if err != nil { + // If we can't get utilization, assume pod is available + // klog.V(4).InfoS("failed to get pod utilization, assuming available", + // "pod", pod.Name, "error", err) + totalAvailableCapacity += 1 + continue + } + + // Get pod's capacity limit (typically 1.0 for normalized load) + podCapacity := s.loadProvider.Cap() + + // Calculate available capacity for this pod (following Router's strict capacity limits) + availableCapacity := podCapacity - utilization + if availableCapacity > 0 { + // Convert normalized capacity to request slots + // Use Router's approach: strict capacity limits without oversubscription + estimatedConcurrentCapacity := s.getEstimatedConcurrentCapacity(pod) + podAvailableSlots := int(availableCapacity * float64(estimatedConcurrentCapacity)) + totalAvailableCapacity += podAvailableSlots + } + + // klog.V(4).InfoS("pod capacity analysis", + // "pod", pod.Name, + // "utilization", utilization, + // "capacity", podCapacity, + // "availableCapacity", availableCapacity, + // "estimatedSlots", int(availableCapacity*float64(s.getEstimatedConcurrentCapacity(pod)))) + } + + // Apply scheduling strategy following Router's approach (no oversubscription) + // Router strictly enforces capacity limits, so we follow the same principle + finalCapacity := int(float64(totalAvailableCapacity) * SCHEDULER_OVERSUBSCRIPTION_FACTOR) + + // klog.V(4).InfoS("advanced batch size calculation", + // "totalPods", len(pods), + // "totalAvailableCapacity", totalAvailableCapacity, + // "finalCapacity", finalCapacity, + // "oversubscriptionFactor", SCHEDULER_OVERSUBSCRIPTION_FACTOR) + + return max(0, finalCapacity) +} + +// calculateImprovedFallbackBatchSize provides intelligent fallback without cache +func (s *inProcessScheduler) calculateImprovedFallbackBatchSize(pods []*v1.Pod) int { + totalEstimatedCapacity := 0 + + // Calculate total estimated capacity from pod annotations + for _, pod := range pods { + podCapacity := s.getEstimatedConcurrentCapacity(pod) + totalEstimatedCapacity += podCapacity + + // klog.V(5).InfoS("Pod capacity estimation", + // "podName", pod.Name, + // "estimatedCapacity", podCapacity) + } + + // Calculate available capacity + inflight := s.inflightRequests.Load() + availableSlots := totalEstimatedCapacity - int(inflight) + + // klog.V(4).InfoS("improved fallback batch size calculation", + // "totalPods", len(pods), + // "totalEstimatedCapacity", totalEstimatedCapacity, + // "inflight", inflight, + // "availableSlots", availableSlots) + + // If we get very low capacity (indicating no annotations or default values), + // switch to pass-through mode to behave like no scheduler exists + if totalEstimatedCapacity <= len(pods) { + // klog.V(3).InfoS("switching to pass-through mode - no pod capacity annotations found", + // "totalEstimatedCapacity", totalEstimatedCapacity, + // "podCount", len(pods)) + // Return a large number to essentially disable batching + // This allows requests to be routed immediately without scheduler bottleneck + return PASS_THROUGH_BATCH_SIZE + } + + return max(0, availableSlots) +} + +// getEstimatedConcurrentCapacity estimates how many concurrent requests a pod can handle +func (s *inProcessScheduler) getEstimatedConcurrentCapacity(pod *v1.Pod) int { + // Try to get from pod annotations first + if pod.Annotations != nil { + if maxConcurrency, exists := pod.Annotations["aibrix.io/max-concurrent-requests"]; exists { + if val, err := strconv.Atoi(maxConcurrency); err == nil && val > 0 { + return val + } + } + } + + // Use real-time metrics if available + if s.cache != nil { + // Try to get current running requests and estimate capacity + if runningReq, err := s.cache.GetMetricValueByPod(pod.Name, pod.Namespace, metrics.RealtimeNumRequestsRunning); err == nil { + currentRunning := int(runningReq.GetSimpleValue()) + // Use a multiple of current running as estimate + if currentRunning > 0 { + return max(DEFAULT_POD_CONCURRENT_CAPACITY, currentRunning*CAPACITY_ESTIMATION_MULTIPLIER) + } + } + } + + // Default estimate: conservative approach, similar to original scheduler logic + return DEFAULT_POD_CONCURRENT_CAPACITY +} + +// min is a simple utility function. +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/plugins/gateway/scheduler/scheduler_bench_test.go b/pkg/plugins/gateway/scheduler/scheduler_bench_test.go new file mode 100644 index 000000000..cd0974a94 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/scheduler_bench_test.go @@ -0,0 +1,459 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduler + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler/sessioninfo" + "github.com/vllm-project/aibrix/pkg/types" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +// BenchmarkScheduler_SubmitJob_HighConcurrency tests the performance of the new +// lock-free scheduler under high concurrency scenarios. +func BenchmarkScheduler_SubmitJob_HighConcurrency(b *testing.B) { + cache := sessioninfo.NewMutexSessionCache() + k8sClient := fake.NewSimpleClientset() + scheduler := NewScheduler(k8sClient, cache, nil) + defer scheduler.Stop() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + sessionID := fmt.Sprintf("session-%d", time.Now().UnixNano()) + for pb.Next() { + ctx := types.NewRoutingContext(context.Background(), "test-algorithm", "test-model", "test message", "test-request", "test-user") + decision, err := scheduler.SubmitJob(ctx, sessionID) + if err != nil { + b.Fatalf("SubmitJob failed: %v", err) + } + if decision == nil { + b.Fatal("Decision is nil") + } + // Simulate job completion + scheduler.FinalizeJob(sessionID, 100*time.Millisecond, 50*time.Millisecond, 10*time.Millisecond) + } + }) +} + +// BenchmarkScheduler_BurstLoad simulates the scenario you described: +// thousands of requests arriving within milliseconds. +func BenchmarkScheduler_BurstLoad(b *testing.B) { + cache := sessioninfo.NewMutexSessionCache() + k8sClient := fake.NewSimpleClientset() + scheduler := NewScheduler(k8sClient, cache, nil) + defer scheduler.Stop() + + // Simulate burst scenarios + burstSizes := []int{100, 1000, 5000, 10000} + + for _, burstSize := range burstSizes { + b.Run(fmt.Sprintf("Burst-%d", burstSize), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + start := time.Now() + + // Launch burst of requests + for j := 0; j < burstSize; j++ { + wg.Add(1) + go func(requestID int) { + defer wg.Done() + sessionID := fmt.Sprintf("session-%d-%d", i, requestID) + ctx := types.NewRoutingContext(context.Background(), "test-algorithm", "test-model", "test message", fmt.Sprintf("request-%d", requestID), "test-user") + + decision, err := scheduler.SubmitJob(ctx, sessionID) + if err != nil { + b.Errorf("SubmitJob failed: %v", err) + return + } + if decision == nil { + b.Error("Decision is nil") + return + } + + // Simulate job completion + scheduler.FinalizeJob(sessionID, 100*time.Millisecond, 50*time.Millisecond, 10*time.Millisecond) + }(j) + } + + wg.Wait() + elapsed := time.Since(start) + + // Report throughput + throughput := float64(burstSize) / elapsed.Seconds() + b.ReportMetric(throughput, "requests/sec") + b.ReportMetric(float64(elapsed.Nanoseconds())/float64(burstSize), "ns/request") + } + }) + } +} + +// TestScheduler_LoadAwareness tests the load awareness functionality +func TestScheduler_LoadAwareness(t *testing.T) { + // Create fake k8s client + k8sClient := fake.NewSimpleClientset() + + // Create session cache + sessionCache := sessioninfo.NewMutexSessionCache() + + // Create scheduler with load awareness (no cache in test environment) + scheduler := NewScheduler(k8sClient, sessionCache, nil).(*inProcessScheduler) + defer scheduler.Stop() + + // Test basic initialization + if scheduler.cache == nil { + t.Log("Cache not available in test environment - this is expected") + } + + if scheduler.loadProvider == nil { + t.Log("LoadProvider not available in test environment - this is expected") + } + + // Test capacity calculation with no pods + batchSize := scheduler.calculateBatchSize() + if batchSize != 0 { + t.Errorf("Expected batch size 0 with no pods, got %d", batchSize) + } + + // Test with mock pods + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod-1", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "100", + }, + }, + Status: v1.PodStatus{ + Phase: v1.PodRunning, + PodIP: "1.2.3.4", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod-2", + Namespace: "default", + }, + Status: v1.PodStatus{ + Phase: v1.PodRunning, + PodIP: "5.6.7.8", + }, + }, + } + + scheduler.healthyPods.Store(pods) + + // Test improved fallback batch size calculation + batchSize = scheduler.calculateBatchSize() + expectedBatchSize := 100 + DEFAULT_POD_CONCURRENT_CAPACITY // Pod1(100) + Pod2(1) + if batchSize != expectedBatchSize { + t.Errorf("Expected batch size %d in improved fallback mode, got %d", expectedBatchSize, batchSize) + } + + // Test estimated concurrent capacity + capacity1 := scheduler.getEstimatedConcurrentCapacity(pods[0]) + if capacity1 != 100 { + t.Errorf("Expected capacity 100 from annotation, got %d", capacity1) + } + + capacity2 := scheduler.getEstimatedConcurrentCapacity(pods[1]) + if capacity2 != DEFAULT_POD_CONCURRENT_CAPACITY { + t.Errorf("Expected default capacity %d, got %d", DEFAULT_POD_CONCURRENT_CAPACITY, capacity2) + } + + t.Logf("Load awareness test completed successfully") + t.Logf(" - Pod 1 capacity: %d (from annotation)", capacity1) + t.Logf(" - Pod 2 capacity: %d (default)", capacity2) + t.Logf(" - Batch size: %d", batchSize) +} + +func TestScheduler_BatchSizeSmoothing(t *testing.T) { + k8sClient := fake.NewSimpleClientset() + sessionCache := sessioninfo.NewMutexSessionCache() + scheduler := NewScheduler(k8sClient, sessionCache, nil).(*inProcessScheduler) + defer scheduler.Stop() + + // Create pods for testing + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod-1", + Namespace: "default", + }, + Status: v1.PodStatus{ + Phase: v1.PodRunning, + PodIP: "1.2.3.4", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod-2", + Namespace: "default", + }, + Status: v1.PodStatus{ + Phase: v1.PodRunning, + PodIP: "5.6.7.8", + }, + }, + } + + scheduler.healthyPods.Store(pods) + + // Test smoothing behavior + // First call should return raw value (no previous history) + smoothed1 := scheduler.calculateSmoothedBatchSize() + raw1 := scheduler.calculateBatchSize() + + // First smoothed value should be close to raw value + if smoothed1 != raw1 { + // Allow small difference due to smoothing initialization + diff := smoothed1 - raw1 + if diff < -1 || diff > 1 { + t.Errorf("First smoothed batch size should be close to raw value, got smoothed=%d, raw=%d", smoothed1, raw1) + } + } + + // Second call should show smoothing effect + smoothed2 := scheduler.calculateSmoothedBatchSize() + + // Verify that lastBatchSize is being updated + lastBatchSize := int(scheduler.lastBatchSize.Load()) + if lastBatchSize != smoothed2 { + t.Errorf("lastBatchSize should be updated to smoothed value, got %d, expected %d", lastBatchSize, smoothed2) + } + + t.Logf("Batch size smoothing test completed successfully") + t.Logf(" - Raw batch size: %d", raw1) + t.Logf(" - First smoothed: %d", smoothed1) + t.Logf(" - Second smoothed: %d", smoothed2) + t.Logf(" - Last batch size stored: %d", lastBatchSize) +} + +func TestScheduler_CornerCases(t *testing.T) { + k8sClient := fake.NewSimpleClientset() + sessionCache := sessioninfo.NewMutexSessionCache() + scheduler := NewScheduler(k8sClient, sessionCache, nil).(*inProcessScheduler) + defer scheduler.Stop() + + t.Run("no_pods", func(t *testing.T) { + // Test with no pods + scheduler.healthyPods.Store([]*v1.Pod{}) + batchSize := scheduler.calculateBatchSize() + if batchSize != 0 { + t.Errorf("Expected batch size 0 with no pods, got %d", batchSize) + } + }) + + t.Run("pods_without_annotations_pass_through", func(t *testing.T) { + // Test with pods that have no capacity annotations - should trigger pass-through mode + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "no-annotation-pod-1", + Namespace: "default", + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "1.1.1.1"}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "no-annotation-pod-2", + Namespace: "default", + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "2.2.2.2"}, + }, + } + + scheduler.healthyPods.Store(pods) + batchSize := scheduler.calculateBatchSize() + + // Should trigger pass-through mode (totalEstimatedCapacity = 2, podCount = 2) + if batchSize != PASS_THROUGH_BATCH_SIZE { + t.Errorf("Expected pass-through batch size %d, got %d", PASS_THROUGH_BATCH_SIZE, batchSize) + } + + t.Logf("Pass-through mode triggered: batch size = %d", batchSize) + }) + + t.Run("mixed_pods_with_and_without_annotations", func(t *testing.T) { + // Test with mixed pods - some with annotations, some without + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "annotated-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "50", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "1.1.1.1"}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "no-annotation-pod", + Namespace: "default", + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "2.2.2.2"}, + }, + } + + scheduler.healthyPods.Store(pods) + batchSize := scheduler.calculateBatchSize() + + // Should use improved fallback: 50 + 1 = 51 + expectedBatchSize := 50 + DEFAULT_POD_CONCURRENT_CAPACITY + if batchSize != expectedBatchSize { + t.Errorf("Expected batch size %d for mixed pods, got %d", expectedBatchSize, batchSize) + } + + t.Logf("Mixed pods batch size: %d", batchSize) + }) + + t.Run("high_inflight_requests", func(t *testing.T) { + // Test with high inflight requests + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "10", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "1.1.1.1"}, + }, + } + + scheduler.healthyPods.Store(pods) + + // Set high inflight requests + scheduler.inflightRequests.Store(15) // More than capacity + + batchSize := scheduler.calculateBatchSize() + + // Should return 0 (no available capacity) + if batchSize != 0 { + t.Errorf("Expected batch size 0 with overloaded pods, got %d", batchSize) + } + + // Reset inflight requests + scheduler.inflightRequests.Store(0) + + t.Logf("Overloaded scenario handled correctly: batch size = %d", batchSize) + }) + + t.Run("invalid_annotation_values", func(t *testing.T) { + // Test with invalid annotation values + pods := []*v1.Pod{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "invalid-annotation-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "invalid-number", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "1.1.1.1"}, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "negative-annotation-pod", + Namespace: "default", + Annotations: map[string]string{ + "aibrix.io/max-concurrent-requests": "-5", + }, + }, + Status: v1.PodStatus{Phase: v1.PodRunning, PodIP: "2.2.2.2"}, + }, + } + + scheduler.healthyPods.Store(pods) + batchSize := scheduler.calculateBatchSize() + + // Should trigger pass-through mode (both pods fall back to default capacity 1) + if batchSize != PASS_THROUGH_BATCH_SIZE { + t.Errorf("Expected pass-through batch size %d for invalid annotations, got %d", PASS_THROUGH_BATCH_SIZE, batchSize) + } + + t.Logf("Invalid annotations handled correctly: batch size = %d", batchSize) + }) + + t.Log("All corner cases tested successfully") +} + +// BenchmarkScheduler_ChannelVsMutex compares the channel-based approach +// with a hypothetical mutex-based approach for job submission. +func BenchmarkScheduler_ChannelVsMutex(b *testing.B) { + cache := sessioninfo.NewMutexSessionCache() + k8sClient := fake.NewSimpleClientset() + scheduler := NewScheduler(k8sClient, cache, nil) + defer scheduler.Stop() + + b.Run("ChannelBased", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + sessionID := fmt.Sprintf("session-%d", time.Now().UnixNano()) + for pb.Next() { + ctx := types.NewRoutingContext(context.Background(), "test-algorithm", "test-model", "test message", "test-request", "test-user") + decision, err := scheduler.SubmitJob(ctx, sessionID) + if err != nil { + b.Fatalf("SubmitJob failed: %v", err) + } + if decision == nil { + b.Fatal("Decision is nil") + } + scheduler.FinalizeJob(sessionID, 100*time.Millisecond, 50*time.Millisecond, 10*time.Millisecond) + } + }) + }) +} + +// BenchmarkScheduler_SessionCachePerformance tests the performance of session cache operations +func BenchmarkScheduler_SessionCachePerformance(b *testing.B) { + cache := sessioninfo.NewMutexSessionCache() + k8sClient := fake.NewSimpleClientset() + scheduler := NewScheduler(k8sClient, cache, nil) + defer scheduler.Stop() + + // Pre-populate some sessions + for i := 0; i < 1000; i++ { + sessionID := fmt.Sprintf("session-%d", i) + cache.GetOrCreateForScheduler(sessionID) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + sessionID := fmt.Sprintf("session-%d", time.Now().UnixNano()%1000) + for pb.Next() { + ctx := types.NewRoutingContext(context.Background(), "test-algorithm", "test-model", "test message", "test-request", "test-user") + decision, err := scheduler.SubmitJob(ctx, sessionID) + if err != nil { + b.Fatalf("SubmitJob failed: %v", err) + } + if decision == nil { + b.Fatal("Decision is nil") + } + scheduler.FinalizeJob(sessionID, 100*time.Millisecond, 50*time.Millisecond, 10*time.Millisecond) + } + }) +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/sessioncache.go b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncache.go new file mode 100644 index 000000000..739d02666 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncache.go @@ -0,0 +1,250 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "hash" + "hash/fnv" + "sync" + "time" +) + +var hasherPool = sync.Pool{ + New: func() interface{} { + return fnv.New64a() + }, +} + +// SessionState holds all the scheduling-relevant information for a single session +type SessionState struct { + SessionID string // The session ID + CriticalPathServiceTime time.Duration // The critical path service time + TotalWaitTime time.Duration // The total wait time (anti-starvation) + PodAffinity string // The pod affinity (later may needed) + LastActivityTimestamp time.Time // The last activity timestamp +} + +// --- Internal channel communication structs --- +type cacheOp int // operation code for cacheRequest + +const ( + opGetForScheduler cacheOp = iota + opGetFullState + opUpdateState + opUpdateAffinity + opCleanup +) + +// cacheRequest is the message format for all shard channels. +type cacheRequest struct { + op cacheOp + sessionID string + updatePayload updatePayload + affinityPayload string + cleanupPayload cleanupPayload + schedulerInfoRespChan chan schedulerInfoResponse + fullStateResponseChan chan fullStateResponse +} + +// updatePayload is the payload for opUpdateState +type updatePayload struct { + inheritedCST time.Duration + executionTime time.Duration + waitTime time.Duration +} + +// schedulerInfoResponse is the response for opGetForScheduler +type schedulerInfoResponse struct { + cst time.Duration + waitTime time.Duration +} + +// cleanupPayload is the payload for opCleanup +type cleanupPayload struct { + timeout time.Duration +} + +// fullStateResponse is the response for opGetFullState +type fullStateResponse struct { + state *SessionState +} + +// --- Shard and ShardedCache implementation --- +const shardCount = 256 // Must be a power of 2 for bitwise AND optimization + +// cacheShard is a single shard of the sharded cache. +type cacheShard struct { + sessions map[string]*SessionState // sessionID -> *SessionState + requests chan cacheRequest // Channel for requests to this shard + done chan struct{} // Channel for shutdown +} + +// run is the main loop for each shard goroutine. +func (s *cacheShard) run(wg *sync.WaitGroup) { + defer wg.Done() + for req := range s.requests { + state, exists := s.sessions[req.sessionID] + if !exists { + state = &SessionState{ + SessionID: req.sessionID, + LastActivityTimestamp: time.Now(), + } + s.sessions[req.sessionID] = state + } + switch req.op { + case opGetForScheduler: + // Return a struct with the required values. + req.schedulerInfoRespChan <- schedulerInfoResponse{ + cst: state.CriticalPathServiceTime, + waitTime: state.TotalWaitTime, + } + case opGetFullState: + state, exists := s.sessions[req.sessionID] + if !exists { + req.fullStateResponseChan <- fullStateResponse{state: nil} + continue + } + stateCopy := *state + req.fullStateResponseChan <- fullStateResponse{state: &stateCopy} + case opUpdateState: + payload := req.updatePayload + state.TotalWaitTime += payload.waitTime + newPathLength := payload.inheritedCST + payload.executionTime + if newPathLength > state.CriticalPathServiceTime { + state.CriticalPathServiceTime = newPathLength + } + state.LastActivityTimestamp = time.Now() + case opUpdateAffinity: + state.PodAffinity = req.affinityPayload + case opCleanup: + payload := req.cleanupPayload + now := time.Now() + for sessionID, state := range s.sessions { + if now.Sub(state.LastActivityTimestamp) > payload.timeout { + delete(s.sessions, sessionID) + } + } + } + } +} + +// ShardedSessionCache is a highly concurrent, channel-based session cache. +type ShardedSessionCache struct { + shards []*cacheShard + wg sync.WaitGroup +} + +// NewShardedSessionCache creates and starts all shard goroutines. +func NewShardedSessionCache() *ShardedSessionCache { + sc := &ShardedSessionCache{ + shards: make([]*cacheShard, shardCount), + } + for i := 0; i < shardCount; i++ { + shard := &cacheShard{ + sessions: make(map[string]*SessionState), + requests: make(chan cacheRequest, 128), // Buffered channel per shard + done: make(chan struct{}), + } + sc.shards[i] = shard + sc.wg.Add(1) + go shard.run(&sc.wg) + } + return sc +} + +// getShard returns the shard for a given sessionID. +func (sc *ShardedSessionCache) getShard(sessionID string) *cacheShard { + hasher := hasherPool.Get().(hash.Hash64) + defer hasherPool.Put(hasher) + hasher.Reset() + hasher.Write([]byte(sessionID)) + return sc.shards[hasher.Sum64()&uint64(shardCount-1)] +} + +// --- Public API --- + +// GetOrCreateForScheduler is the primary method for the scheduler +func (sc *ShardedSessionCache) GetOrCreateForScheduler(sessionID string) (time.Duration, time.Duration) { + shard := sc.getShard(sessionID) + respChan := make(chan schedulerInfoResponse, 1) + shard.requests <- cacheRequest{ + op: opGetForScheduler, + sessionID: sessionID, + schedulerInfoRespChan: respChan, + } + info := <-respChan + return info.cst, info.waitTime +} + +// UpdateState is the primary method for the executor +func (sc *ShardedSessionCache) UpdateState(sessionID string, inheritedCST, executionTime, waitTime time.Duration) { + shard := sc.getShard(sessionID) + shard.requests <- cacheRequest{ + op: opUpdateState, + sessionID: sessionID, + updatePayload: updatePayload{ + inheritedCST: inheritedCST, + executionTime: executionTime, + waitTime: waitTime, + }, + } +} + +// UpdateAffinity is the primary method for the executor +func (sc *ShardedSessionCache) UpdateAffinity(sessionID, podName string) { + shard := sc.getShard(sessionID) + shard.requests <- cacheRequest{ + op: opUpdateAffinity, + sessionID: sessionID, + affinityPayload: podName, + } +} + +// GetState is provided for testing and debugging. +func (sc *ShardedSessionCache) GetState(sessionID string) (*SessionState, bool) { + shard := sc.getShard(sessionID) + respChan := make(chan fullStateResponse, 1) + shard.requests <- cacheRequest{ + op: opGetFullState, + sessionID: sessionID, + fullStateResponseChan: respChan, + } + info := <-respChan + state, exists := info.state, info.state != nil + return state, exists +} + +// StartCleanupRoutine starts a background goroutine that periodically +func (sc *ShardedSessionCache) StartCleanupRoutine(timeout time.Duration) { + req := cacheRequest{ + op: opCleanup, + cleanupPayload: cleanupPayload{ + timeout: timeout, + }, + } + for _, shard := range sc.shards { + shard.requests <- req + } +} + +// Close shuts down all shard goroutines, not elegantly yet. +func (sc *ShardedSessionCache) Close() { + for _, shard := range sc.shards { + close(shard.requests) + } + sc.wg.Wait() +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/sessioncache_test.go b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncache_test.go new file mode 100644 index 000000000..a6ff4b731 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncache_test.go @@ -0,0 +1,157 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// The sharded tests are structurally similar to the mutex tests, +// as they are testing the same public API contract. + +func TestShardedCache_GetOrCreateForScheduler_NewSession(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + cst, waitTime := cache.GetOrCreateForScheduler("session1") + + assert.Equal(t, time.Duration(0), cst) + assert.Equal(t, time.Duration(0), waitTime) +} + +func TestShardedCache_UpdateState_Single(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + assert.Eventually(t, func() bool { + cst, wait := cache.GetOrCreateForScheduler("session1") + return cst == 5*time.Second && wait == 2*time.Second + }, 250*time.Millisecond, 5*time.Millisecond) + + cache.UpdateState("session1", 5*time.Second, 3*time.Second, 1*time.Second) + assert.Eventually(t, func() bool { + cst, wait := cache.GetOrCreateForScheduler("session1") + return cst == 8*time.Second && wait == 3*time.Second + }, 250*time.Millisecond, 5*time.Millisecond) +} + +func TestShardedCache_UpdateState_Concurrent(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + concurrency := 100 + var done atomic.Int32 + for i := 0; i < concurrency; i++ { + go func(execTimeMs int) { + cache.UpdateState("session1", 0, time.Duration(execTimeMs)*time.Millisecond, 10*time.Millisecond) + done.Add(1) + }(i + 1) + } + assert.Eventually(t, func() bool { + return done.Load() == int32(concurrency) + }, time.Second, 10*time.Millisecond) + assert.Eventually(t, func() bool { + cst, wait := cache.GetOrCreateForScheduler("session1") + return cst == 100*time.Millisecond && wait == 1000*time.Millisecond + }, 250*time.Millisecond, 5*time.Millisecond) +} + +func TestShardedCache_UpdateAffinity_Concurrent(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + // This test is harder to write for the channel version without a proper GetState that returns affinity. + // We'll skip the detailed check for now as the main purpose is to test the update mechanism. + concurrency := 10 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func(podNum int) { + defer wg.Done() + cache.UpdateAffinity("session1", fmt.Sprintf("pod%d", podNum)) + }(i) + } + wg.Wait() + time.Sleep(50 * time.Millisecond) + + // We can't easily verify the result without a full GetState op. + // This test mainly serves to ensure no deadlocks occur. + t.Log("Concurrent affinity update test completed without deadlock.") +} + +func TestShardedCache_GetFullState(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + state, exists := cache.GetState("session1") + assert.True(t, exists) + assert.Equal(t, "session1", state.SessionID) + assert.Equal(t, 5*time.Second, state.CriticalPathServiceTime) + assert.Equal(t, 2*time.Second, state.TotalWaitTime) +} + +func TestShardedCache_Cleanup(t *testing.T) { + cache := NewShardedSessionCache() + defer cache.Close() + + // Ensure session1 and session2 hash to different shards if possible, or test with one + // For simplicity, we assume they might hash to the same shard, which is a valid test case. + + // Create session1 + cache.UpdateState("session1", 0, 1*time.Second, 0) + time.Sleep(10 * time.Millisecond) // wait for channel to process + + // Wait to make session1 stale + time.Sleep(2 * time.Second) + + // Create session2, making it fresh + cache.UpdateState("session2", 0, 1*time.Second, 0) + time.Sleep(10 * time.Millisecond) + + // Trigger cleanup on all shards + cache.StartCleanupRoutine(1500 * time.Millisecond) + + // Wait for the cleanup commands to be processed by all shards + time.Sleep(100 * time.Millisecond) + + // To check existence, we need a GetState that returns a bool + // Let's assume we have implemented opGetFullState as discussed + + // opGetFullState for session1 + shard1 := cache.getShard("session1") + respChan1 := make(chan fullStateResponse, 1) + shard1.requests <- cacheRequest{op: opGetFullState, sessionID: "session1", fullStateResponseChan: respChan1} + response1 := <-respChan1 + // The manager should have created a new empty state, because the old one was deleted. + assert.Equal(t, time.Duration(0), response1.state.CriticalPathServiceTime, "session1 should have been cleaned and recreated as empty") + + // opGetFullState for session2 + shard2 := cache.getShard("session2") + respChan2 := make(chan fullStateResponse, 1) + shard2.requests <- cacheRequest{op: opGetFullState, sessionID: "session2", fullStateResponseChan: respChan2} + response2 := <-respChan2 + assert.NotEqual(t, time.Duration(0), response2.state.CriticalPathServiceTime, "session2 should be fresh and remain") +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/sessioncachemutex.go b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncachemutex.go new file mode 100644 index 000000000..1dce92cc1 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncachemutex.go @@ -0,0 +1,157 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "sync" + "time" +) + +// // SessionState holds all the scheduling-relevant information for a single session +// type SessionState struct { +// SessionID string // The session ID +// CriticalPathServiceTime time.Duration // The critical path service time +// TotalWaitTime time.Duration // The total wait time (anti-starvation) +// PodAffinity string // The pod affinity (later may needed) +// LastActivityTimestamp time.Time // The last activity timestamp +// } + +// MutexSessionCache is a thread-safe, in-memory store for session states +// using a sync.RWMutex. +type MutexSessionCache struct { + mu sync.RWMutex // Protects the sessions map + sessions map[string]*SessionState // sessionID -> *SessionState +} + +// NewMutexSessionCache creates a new in-memory session cache protected by a mutex. +func NewMutexSessionCache() *MutexSessionCache { + return &MutexSessionCache{ + sessions: make(map[string]*SessionState), + } +} + +// getState is a private helper that assumes a write lock is already held. +// It ensures a session state exists before any operation. +func (sc *MutexSessionCache) getState(sessionID string) *SessionState { + state, exists := sc.sessions[sessionID] + if !exists { + state = &SessionState{ + SessionID: sessionID, + LastActivityTimestamp: time.Now(), + } + sc.sessions[sessionID] = state + } + return state +} + +// GetState retrieves a copy of the state for a given sessionID +// for read-only purposes. +// It returns false if the session does not exist. +func (sc *MutexSessionCache) GetState(sessionID string) (SessionState, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + + state, exists := sc.sessions[sessionID] + if !exists { + return SessionState{}, false + } + + // Return a copy to ensure + // the caller cannot modify the internal state without a lock, + // which would cause a data race. + return *state, true +} + +// GetOrCreateForScheduler is the primary method for the scheduler +// to get the necessary info. +// It returns the inherited CST and total wait time for a new job. +func (sc *MutexSessionCache) GetOrCreateForScheduler(sessionID string) ( + time.Duration, time.Duration) { + sc.mu.Lock() // Use a write lock because we might create a session. + defer sc.mu.Unlock() + + state := sc.getState(sessionID) + return state.CriticalPathServiceTime, state.TotalWaitTime +} + +// UpdateState atomically updates the session state after a request is finished. +func (sc *MutexSessionCache) UpdateState(sessionID string, inheritedCST, + executionTime, waitTime time.Duration) { + sc.mu.Lock() + defer sc.mu.Unlock() + + state := sc.getState(sessionID) + + // Atomically update total wait time + state.TotalWaitTime += waitTime + + // Atomically update CriticalPathServiceTime (ATLAS logic) + newPathLength := inheritedCST + executionTime + if newPathLength > state.CriticalPathServiceTime { + state.CriticalPathServiceTime = newPathLength + } + + state.LastActivityTimestamp = time.Now() +} + +// UpdateAffinity updates the pod affinity for a session. +func (sc *MutexSessionCache) UpdateAffinity(sessionID, podName string) { + sc.mu.Lock() + defer sc.mu.Unlock() + + state := sc.getState(sessionID) + state.PodAffinity = podName +} + +// StartCleanupRoutine starts a background goroutine that periodically +// cleans up stale sessions. +// It returns a function that can be called to stop the routine. +func (sc *MutexSessionCache) StartCleanupRoutine(interval, + timeout time.Duration) (stop func()) { + ticker := time.NewTicker(interval) + done := make(chan struct{}) + + go func() { + for { + select { + case <-ticker.C: + sc.cleanup(timeout) + case <-done: + ticker.Stop() + return + } + } + }() + + return func() { + close(done) + } +} + +// cleanup removes sessions that have been inactive for longer than the timeout. +// This is a private method that assumes the caller handles locking. +func (sc *MutexSessionCache) cleanup(timeout time.Duration) { + sc.mu.Lock() + defer sc.mu.Unlock() + + now := time.Now() + for sessionID, state := range sc.sessions { + if now.Sub(state.LastActivityTimestamp) > timeout { + delete(sc.sessions, sessionID) + } + } +} diff --git a/pkg/plugins/gateway/scheduler/sessioninfo/sessioncachemutex_test.go b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncachemutex_test.go new file mode 100644 index 000000000..545cee695 --- /dev/null +++ b/pkg/plugins/gateway/scheduler/sessioninfo/sessioncachemutex_test.go @@ -0,0 +1,138 @@ +/* +Copyright 2025 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessioninfo + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestMutexCache_GetOrCreateForScheduler_NewSession tests the GetOrCreateForScheduler method. +func TestMutexCache_GetOrCreateForScheduler_NewSession(t *testing.T) { + cache := NewMutexSessionCache() + cst, waitTime := cache.GetOrCreateForScheduler("session1") + + assert.Equal(t, time.Duration(0), cst) + assert.Equal(t, time.Duration(0), waitTime) +} + +// TestMutexCache_UpdateState_Single tests the UpdateState method. +func TestMutexCache_UpdateState_Single(t *testing.T) { + cache := NewMutexSessionCache() + + // First update (like a serial request) + cache.UpdateState("session1", 0, 5*time.Second, 2*time.Second) + state, _ := cache.GetState("session1") + assert.Equal(t, 5*time.Second, state.CriticalPathServiceTime) + assert.Equal(t, 2*time.Second, state.TotalWaitTime) + + // Second update (another serial request) + // InheritedCST should be the CST from the previous state (5s) + cache.UpdateState("session1", 5*time.Second, 3*time.Second, 1*time.Second) + state, _ = cache.GetState("session1") + assert.Equal(t, 8*time.Second, state.CriticalPathServiceTime) // 5s + 3s + assert.Equal(t, 3*time.Second, state.TotalWaitTime) // 2s + 1s +} + +// TestMutexCache_UpdateState_Concurrent tests the UpdateState method. +func TestMutexCache_UpdateState_Concurrent(t *testing.T) { + cache := NewMutexSessionCache() + concurrency := 1000 + var wg sync.WaitGroup + wg.Add(concurrency) + + // Simulate 100 parallel requests for the same session finishing. + // All inherited CST=0, as they started when the session's CST was 0. + for i := 0; i < concurrency; i++ { + go func(execTimeMs int) { + defer wg.Done() + cache.UpdateState("session1", 0, + time.Duration(execTimeMs)*time.Millisecond, + 10*time.Millisecond) + }(i + 1) + } + wg.Wait() + + state, exists := cache.GetState("session1") + assert.True(t, exists) + + // Final CST should be the max of all new path lengths, + // which is max(0+1ms, 0+2ms, ... 0+100ms, ..., 0+1000ms) = 1000ms + assert.Equal(t, 1000*time.Millisecond, state.CriticalPathServiceTime) + + // Total wait time should be the sum of all wait times: + // 1000 * 10ms = 10000ms + assert.Equal(t, 10000*time.Millisecond, state.TotalWaitTime) +} + +// TestMutexCache_UpdateAffinity_Concurrent tests the UpdateAffinity method. +func TestMutexCache_UpdateAffinity_Concurrent(t *testing.T) { + cache := NewMutexSessionCache() + concurrency := 10 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func(podNum int) { + defer wg.Done() + cache.UpdateAffinity("session1", + fmt.Sprintf("pod%d", podNum)) + }(i) + } + wg.Wait() + + state, exists := cache.GetState("session1") + assert.True(t, exists) + // Due to the race, we can't know the final value, + // but it must be one of the values we set. + assert.Contains(t, []string{"pod0", "pod1", "pod2", "pod3", + "pod4", "pod5", "pod6", "pod7", "pod8", "pod9"}, + state.PodAffinity) +} + +// TestMutexCache_Cleanup tests the Cleanup method. +func TestMutexCache_Cleanup(t *testing.T) { + cache := NewMutexSessionCache() + + // Create session1 + cache.UpdateState("session1", 0, 1*time.Second, 0) + // state1, _ := cache.GetState("session1") + // t.Logf("Session 1 LastActivity: %v", state1.LastActivityTimestamp) + + // Wait for 2 seconds, making session1 stale relative to a 1.5s timeout + time.Sleep(2 * time.Second) + + // Create/update session2, making it fresh + cache.UpdateState("session2", 0, 1*time.Second, 0) + // state2, _ := cache.GetState("session2") + // t.Logf("Session 2 LastActivity: %v", state2.LastActivityTimestamp) + + // Now, cleanup sessions older than 1.5 seconds + cache.cleanup(1500 * time.Millisecond) + + // session1 should be gone because it's ~2 seconds old + _, exists := cache.GetState("session1") + assert.False(t, exists, "session1 should be stale and cleaned up") + + // session2 should still exist because it's very fresh + _, exists = cache.GetState("session2") + assert.True(t, exists, "session2 should be fresh and remain") +} diff --git a/pkg/plugins/gateway/util.go b/pkg/plugins/gateway/util.go index 8af38818b..fb347bace 100644 --- a/pkg/plugins/gateway/util.go +++ b/pkg/plugins/gateway/util.go @@ -28,6 +28,53 @@ import ( "k8s.io/klog/v2" ) +// extractSessionIDFromHeaders extracts session ID only from headers. +// This is used in handleRequestHeaders phase for temporary session ID extraction. +// If no session ID is found in headers, it returns the requestID as a temporary fallback. +func extractSessionIDFromHeaders(requestID string, headers map[string]string) string { + // Try to get session ID from headers + if sessionID, exists := headers["x-session-id"]; exists && sessionID != "" { + return sessionID + } + if sessionID, exists := headers["X-Session-ID"]; exists && sessionID != "" { + return sessionID + } + + // Return requestID as temporary fallback + return requestID +} + +// extractSessionID extracts session ID from request body or headers. +// If no session ID is found, it returns the requestID as a fallback. +// DEPRECATED: This function is kept for backward compatibility and existing tests. +// New code should use extractSessionIDFromHeaders and extractFinalSessionID instead. +func extractSessionID(requestID, requestPath string, requestBody []byte, headers map[string]string) string { + // First, try to get session ID from headers + if sessionID, exists := headers["x-session-id"]; exists && sessionID != "" { + return sessionID + } + if sessionID, exists := headers["X-Session-ID"]; exists && sessionID != "" { + return sessionID + } + + // COMMENTED OUT: Body-based session ID extraction (moved to headers-only approach) + // Then, try to extract from request body + // if requestPath == "/v1/chat/completions" || requestPath == "/v1/completions" { + // var jsonMap map[string]json.RawMessage + // if err := json.Unmarshal(requestBody, &jsonMap); err == nil { + // if sessionData, exists := jsonMap["session_id"]; exists { + // var sessionID string + // if err := json.Unmarshal(sessionData, &sessionID); err == nil && sessionID != "" { + // return sessionID + // } + // } + // } + // } + + // Fallback to requestID if no session ID found + return requestID +} + // validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestpath. // nolint:nakedret func validateRequestBody(requestID, requestPath string, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) {