Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,48 @@ import (
"github.com/vllm-project/aibrix/pkg/metrics"
routing "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/ratelimiter"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler/sessioninfo"
"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
gatewayv1 "sigs.k8s.io/gateway-api/apis/v1"
gatewayapi "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned"
)

// requestState represents the state of a single request processing flow
type requestState int

const (
stateAwaitingHeaders requestState = iota
stateAwaitingBody
stateAwaitingDecision
stateForwarding
stateDone
)

// perRequestState holds all the state for a single Process() invocation
type perRequestState struct {
currentState requestState
sessionID string
requestID string
user utils.User
rpm int64
model string
routerCtx *types.RoutingContext
stream bool
traceTerm int64
completed bool
isRespError bool
respErrorCode int

// For timing and scheduling
requestStartTime time.Time
submissionTime time.Time
dispatchTime time.Time // When scheduler granted permission
schedulingDecision *scheduler.Decision
}

const (
defaultAIBrixNamespace = "aibrix-system"
)
Expand All @@ -56,6 +91,13 @@ type Server struct {
requestCountTracker map[string]int
cache cache.Cache
metricsServer *metrics.Server

// Scheduler and session management
scheduler scheduler.Scheduler
sessionCache *sessioninfo.MutexSessionCache

// Cleanup function for session cache
sessionCleanupStop func()
}

func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayClient gatewayapi.Interface) *Server {
Expand All @@ -68,6 +110,13 @@ func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayCl
// Initialize the routers
routing.Init()

// Initialize session cache and scheduler
sessionCache := sessioninfo.NewMutexSessionCache()
sched := scheduler.NewScheduler(client, sessionCache, c)

// Start session cleanup routine (cleanup every 5 minutes, timeout after 30 minutes)
sessionCleanupStop := sessionCache.StartCleanupRoutine(5*time.Minute, 30*time.Minute)

return &Server{
redisClient: redisClient,
ratelimiter: r,
Expand All @@ -76,10 +125,19 @@ func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayCl
requestCountTracker: map[string]int{},
cache: c,
metricsServer: nil,
scheduler: sched,
sessionCache: sessionCache,
sessionCleanupStop: sessionCleanupStop,
}
}

// Process delegates to the state machine implementation
func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
return s.ProcessStateMachine(srv)
}

// ProcessLegacy is the original implementation kept for reference
func (s *Server) ProcessLegacy(srv extProcPb.ExternalProcessor_ProcessServer) error {
var user utils.User
var rpm, traceTerm int64
var respErrorCode int
Expand Down Expand Up @@ -227,11 +285,29 @@ func (s *Server) StartMetricsServer(addr string) error {
}

func (s *Server) Shutdown() {
klog.InfoS("Starting graceful shutdown of Gateway Server")

// Stop scheduler first to prevent new jobs
if s.scheduler != nil {
klog.InfoS("Stopping scheduler")
s.scheduler.Stop()
}

// Stop session cache cleanup routine
if s.sessionCleanupStop != nil {
klog.InfoS("Stopping session cache cleanup routine")
s.sessionCleanupStop()
}

// Stop metrics server
if s.metricsServer != nil {
klog.InfoS("Stopping metrics server")
if err := s.metricsServer.Stop(); err != nil {
klog.ErrorS(err, "Error stopping metrics server")
}
}

klog.InfoS("Gateway Server shutdown complete")
}

func (s *Server) responseErrorProcessing(ctx context.Context, resp *extProcPb.ProcessingResponse, respErrorCode int,
Expand Down
18 changes: 17 additions & 1 deletion pkg/plugins/gateway/gateway_req_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ import (
func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64) {
var term int64 // Identify the trace window

routingCtx, _ := ctx.(*types.RoutingContext)
routingCtx, ok := ctx.(*types.RoutingContext)
if !ok || routingCtx == nil {
klog.ErrorS(nil, "CRITICAL: context is not RoutingContext or is nil", "requestID", requestID, "contextType", fmt.Sprintf("%T", ctx))
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorRouting, RawValue: []byte("true")}}},
"internal routing context error"), "", nil, false, term
}
requestPath := routingCtx.ReqPath
routingAlgorithm := routingCtx.Algorithm

Expand Down Expand Up @@ -66,6 +73,15 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e
fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, term
}

// Check if scheduler is enabled - if so, defer routing to the scheduler
if s.scheduler != nil {
// With scheduler enabled, we don't perform routing here
// Just validate the model exists and return nil to let Process handle scheduling
klog.InfoS("request body processed, deferring to scheduler", "requestID", requestID, "requestPath", requestPath, "model", model, "stream", stream)
return nil, model, routingCtx, stream, term
}

// Legacy routing logic (when scheduler is not enabled)
headers := []*configPb.HeaderValueOption{}
if routingAlgorithm == routing.RouterNotSet {
if err := s.validateHTTPRouteStatus(ctx, model); err != nil {
Expand Down
83 changes: 68 additions & 15 deletions pkg/plugins/gateway/gateway_req_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"k8s.io/klog/v2"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcFilterPb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
routing "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms"
Expand Down Expand Up @@ -86,23 +87,75 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, requestID string, req
routingCtx.ReqPath = requestPath
routingCtx.ReqHeaders = reqHeaders

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: HeaderWentIntoReqHeaders,
RawValue: []byte("true"),
},
},
// For legacy Process function (non-state machine), we need to handle routing here
// For state machine version, this function is only used to extract headers info
if s.scheduler == nil {
// Legacy mode: complete routing in RequestHeaders phase
klog.InfoS("legacy mode: completing processing in RequestHeaders phase", "requestID", requestID)

// early reject the request if model doesn't exist.
model := "default" // TODO: Extract model from headers if available
if !s.cache.HasModel(model) {
klog.ErrorS(nil, "model doesn't exist in cache", "requestID", requestID, "model", model)
return generateErrorResponse(envoyTypePb.StatusCode_BadRequest,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoModelBackends, RawValue: []byte(model)}}},
"model does not exist"), user, rpm, routingCtx
}

// early reject if no pods are ready to accept request for a model
podsArr, err := s.cache.ListPodsByModel(model)
if err != nil || podsArr == nil || utils.CountRoutablePods(podsArr.All()) == 0 {
klog.ErrorS(err, "no ready pod available", "requestID", requestID, "model", model)
return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorNoModelBackends, RawValue: []byte("true")}}},
"no ready pods available"), user, rpm, routingCtx
}

headers := []*configPb.HeaderValueOption{}
if routingAlgorithm == routing.RouterNotSet {
if err := s.validateHTTPRouteStatus(ctx, model); err != nil {
return buildErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, err.Error(), HeaderErrorRouting, "true"), user, rpm, routingCtx
}
headers = buildEnvoyProxyHeaders(headers, HeaderModel, model)
klog.InfoS("request start", "requestID", requestID, "requestPath", routingCtx.ReqPath, "model", model)
} else {
targetPodIP, err := s.selectTargetPod(routingCtx, podsArr)
if targetPodIP == "" || err != nil {
klog.ErrorS(err, "failed to select target pod", "requestID", requestID, "routingStrategy", routingAlgorithm, "model", model, "routingDuration", routingCtx.GetRoutingDelay())
return generateErrorResponse(
envoyTypePb.StatusCode_ServiceUnavailable,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorRouting, RawValue: []byte("true")}}},
"error on selecting target pod"), user, rpm, routingCtx
}
headers = buildEnvoyProxyHeaders(headers,
HeaderRoutingStrategy, string(routingAlgorithm),
HeaderTargetPod, targetPodIP,
"X-Request-Id", routingCtx.RequestID)
klog.InfoS("request start", "requestID", requestID, "requestPath", routingCtx.ReqPath, "routingAlgorithm", routingAlgorithm, "targetPodIP", targetPodIP, "routingDuration", routingCtx.GetRoutingDelay())
}

return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
ClearRouteCache: true,
},
ClearRouteCache: true,
},
},
},
}, user, rpm, routingCtx
// Don't request RequestBody in legacy mode
ModeOverride: &extProcFilterPb.ProcessingMode{
RequestBodyMode: extProcFilterPb.ProcessingMode_NONE,
},
}, user, rpm, routingCtx
}

// State machine mode: just return basic info, don't do routing here
klog.InfoS("state machine mode: headers processed, waiting for body", "requestID", requestID)
return nil, user, rpm, routingCtx
}
Loading