diff --git a/conformance/testing-epp/plugins/filter/filter_test.go b/conformance/testing-epp/plugins/filter/filter_test.go index 2c0082189..f059c8f33 100644 --- a/conformance/testing-epp/plugins/filter/filter_test.go +++ b/conformance/testing-epp/plugins/filter/filter_test.go @@ -22,6 +22,7 @@ import ( "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -117,7 +118,7 @@ func TestFilter(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := test.filter.Filter(context.Background(), types.NewCycleState(), test.req, test.input) + got := test.filter.Filter(context.Background(), plugins.NewCycleState(), test.req, test.input) if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) diff --git a/conformance/testing-epp/plugins/filter/request_header_based_filter.go b/conformance/testing-epp/plugins/filter/request_header_based_filter.go index 41194b55d..7e20d79ab 100644 --- a/conformance/testing-epp/plugins/filter/request_header_based_filter.go +++ b/conformance/testing-epp/plugins/filter/request_header_based_filter.go @@ -20,6 +20,7 @@ import ( "context" "strings" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -50,7 +51,7 @@ func (f *HeaderBasedTestingFilter) Type() string { } // Filter selects pods that match the IP addresses specified in the request header. -func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { headerValue, ok := request.Headers[headerTestEppEndPointSelectionKey] if !ok || headerValue == "" { return []types.Pod{} diff --git a/conformance/testing-epp/sheduler_test.go b/conformance/testing-epp/sheduler_test.go index 4901e0380..fe00c431c 100644 --- a/conformance/testing-epp/sheduler_test.go +++ b/conformance/testing-epp/sheduler_test.go @@ -24,6 +24,7 @@ import ( "github.com/google/uuid" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -100,7 +101,7 @@ func TestSchedule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { scheduler := NewReqHeaderBasedScheduler() - got, err := scheduler.Schedule(context.Background(), test.req, types.ToSchedulerPodMetrics(test.input)) + got, err := scheduler.Schedule(context.Background(), plugins.NewCycleState(), test.req, types.ToSchedulerPodMetrics(test.input)) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } diff --git a/docs/proposals/0845-scheduler-architecture-proposal/interfaces/interface.go b/docs/proposals/0845-scheduler-architecture-proposal/interfaces/interface.go index 35b787b35..fe5b98d4a 100644 --- a/docs/proposals/0845-scheduler-architecture-proposal/interfaces/interface.go +++ b/docs/proposals/0845-scheduler-architecture-proposal/interfaces/interface.go @@ -19,7 +19,7 @@ package framework import ( "context" - scheduling "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) type Endpoint struct { @@ -114,7 +114,7 @@ type ProfileHandler interface { // The framework will return an error to the client if the endpoints are filtered to zero. type Filter interface { Plugin - Filter(ctx context.Context, request *Request, state *scheduling.CycleState, endpoints []*Endpoint) []*Endpoint + Filter(ctx context.Context, request *Request, state *plugins.CycleState, endpoints []*Endpoint) []*Endpoint } // Scorer applies a score to each remaining endpoint provided. @@ -122,7 +122,7 @@ type Filter interface { // Any weighting should be added at the SchedulerProfile configuration level. type Scorer interface { Plugin - Score(ctx context.Context, request *Request, state *scheduling.CycleState, endpoints []*Endpoint) []*ScoredEndpoint + Score(ctx context.Context, request *Request, state *plugins.CycleState, endpoints []*Endpoint) []*ScoredEndpoint } // WeightedScorer is a struct that encapsulates a scorer with its weight. @@ -138,5 +138,5 @@ type WeightedScorer struct { // Picker MUST return, one endpoint at minimum. type Picker interface { Plugin - Pick(ctx context.Context, state *scheduling.CycleState, endpoints []*ScoredEndpoint) []*ScoredEndpoint + Pick(ctx context.Context, state *plugins.CycleState, endpoints []*ScoredEndpoint) []*ScoredEndpoint } diff --git a/pkg/epp/common/config/loader/configloader_test.go b/pkg/epp/common/config/loader/configloader_test.go index 475a594a2..f675785f1 100644 --- a/pkg/epp/common/config/loader/configloader_test.go +++ b/pkg/epp/common/config/loader/configloader_test.go @@ -557,7 +557,7 @@ func (f *test1) Type() string { } // Filter filters out pods that doesn't meet the filter criteria. -func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *test1) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { return pods } @@ -571,11 +571,11 @@ func (f *test2) Type() string { return test2Type } -func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ []types.Pod) map[types.Pod]float64 { +func (m *test2) Score(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, _ []types.Pod) map[types.Pod]float64 { return map[types.Pod]float64{} } -func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {} +func (m *test2) PostCycle(_ context.Context, _ *plugins.CycleState, _ *types.ProfileRunResult) {} // compile-time type validation var _ framework.Picker = &testPicker{} @@ -586,7 +586,7 @@ func (p *testPicker) Type() string { return testPickerType } -func (p *testPicker) Pick(_ context.Context, _ *types.CycleState, _ []*types.ScoredPod) *types.ProfileRunResult { +func (p *testPicker) Pick(_ context.Context, _ *plugins.CycleState, _ []*types.ScoredPod) *types.ProfileRunResult { return nil } @@ -599,11 +599,11 @@ func (p *testProfileHandler) Type() string { return testProfileHandlerType } -func (p *testProfileHandler) Pick(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ map[string]*framework.SchedulerProfile, _ map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { +func (p *testProfileHandler) Pick(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, _ map[string]*framework.SchedulerProfile, _ map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { return nil } -func (p *testProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { +func (p *testProfileHandler) ProcessResults(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, _ map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { return nil, nil } diff --git a/pkg/epp/scheduling/types/cycle_state.go b/pkg/epp/plugins/cycle_state.go similarity index 99% rename from pkg/epp/scheduling/types/cycle_state.go rename to pkg/epp/plugins/cycle_state.go index 9f0a67f6e..49a5c5397 100644 --- a/pkg/epp/scheduling/types/cycle_state.go +++ b/pkg/epp/plugins/cycle_state.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package types +package plugins import ( "errors" diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 5bc5ca4cd..4b8ae7e30 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -41,7 +42,7 @@ import ( // Scheduler defines the interface required by the Director for scheduling. type Scheduler interface { - Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error) + Schedule(ctx context.Context, cycleState *plugins.CycleState, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error) } // SaturationDetector provides a signal indicating whether the backends are considered saturated. @@ -135,11 +136,14 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } // --- 3. Call Scheduler --- + // Initialize cycleState for the scheduling cycle + cycleState := plugins.NewCycleState() + // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles. candidatePods := schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll()) - results, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, candidatePods) + results, err := d.scheduler.Schedule(ctx, cycleState, reqCtx.SchedulingRequest, candidatePods) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } @@ -147,7 +151,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo // --- 4. Prepare Request (Populates RequestContext and call PreRequest plugins) --- // Insert target endpoint to instruct Envoy to route requests to the specified target pod and attach the port number. // Invoke PreRequest registered plugins. - reqCtx, err = d.prepareRequest(ctx, reqCtx, results) + reqCtx, err = d.prepareRequest(ctx, reqCtx, cycleState, results) if err != nil { return reqCtx, err } @@ -178,7 +182,7 @@ func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2 // prepareRequest populates the RequestContext and calls the registered PreRequest plugins // for allowing plugging customized logic based on the scheduling results. -func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) { +func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, cycleState *plugins.CycleState, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) { logger := log.FromContext(ctx) if result == nil || len(result.ProfileResults) == 0 { return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"} @@ -198,7 +202,7 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC reqCtx.TargetPod = targetPod reqCtx.TargetEndpoint = endpoint - d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort) + d.runPreRequestPlugins(ctx, cycleState, reqCtx.SchedulingRequest, result, targetPort) return reqCtx, nil } @@ -255,12 +259,12 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed return "" } -func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult, +func (d *Director) runPreRequestPlugins(ctx context.Context, cycleState *plugins.CycleState, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult, targetPort int) { for _, plugin := range d.preRequestPlugins { log.FromContext(ctx).V(logutil.DEBUG).Info("Running pre-request plugin", "plugin", plugin.Type()) before := time.Now() - plugin.PreRequest(ctx, request, schedulingResult, targetPort) + plugin.PreRequest(ctx, cycleState, request, schedulingResult, targetPort) metrics.RecordRequestControlPluginProcessingLatency(PreRequestPluginType, plugin.Type(), time.Since(before)) } } diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 9f16ed117..2437dd146 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -36,6 +36,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -58,7 +59,7 @@ type mockScheduler struct { scheduleErr error } -func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMRequest, _ []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) { +func (m *mockScheduler) Schedule(_ context.Context, _ *plugins.CycleState, _ *schedulingtypes.LLMRequest, _ []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) { return m.scheduleResults, m.scheduleErr } diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index ba51c2afb..794c631a7 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -33,7 +33,7 @@ const ( // before a request is sent to the selected model server. type PreRequest interface { plugins.Plugin - PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int) + PreRequest(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int) } // PostResponse is called by the director after a successful response was sent. diff --git a/pkg/epp/scheduling/framework/plugins.go b/pkg/epp/scheduling/framework/plugins.go index 7e22d8618..31698b0b6 100644 --- a/pkg/epp/scheduling/framework/plugins.go +++ b/pkg/epp/scheduling/framework/plugins.go @@ -38,39 +38,39 @@ type ProfileHandler interface { plugins.Plugin // Pick selects the SchedulingProfiles to run from a list of candidate profiles, while taking into consideration the request properties // and the previously executed SchedluderProfile cycles along with their results. - Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*SchedulerProfile, + Pick(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, profiles map[string]*SchedulerProfile, profileResults map[string]*types.ProfileRunResult) map[string]*SchedulerProfile // ProcessResults handles the outcome of the profile runs after all profiles ran. // It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the // key of the primary profile that should be used to get the request selected destination. // When a profile run fails, its result in the profileResults map is nil. - ProcessResults(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, + ProcessResults(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) } // Filter defines the interface for filtering a list of pods based on context. type Filter interface { plugins.Plugin - Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod + Filter(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod } // Scorer defines the interface for scoring a list of pods based on context. // Scorers must score pods with a value within the range of [0,1] where 1 is the highest score. type Scorer interface { plugins.Plugin - Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 + Score(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 } // Picker picks the final pod(s) to send the request to. type Picker interface { plugins.Plugin - Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult + Pick(ctx context.Context, cycleState *plugins.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult } // PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle. // DEPRECATED - do not use PostCycle. this is in the process of deprecation. type PostCycle interface { plugins.Plugin - PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) + PostCycle(ctx context.Context, cycleState *plugins.CycleState, res *types.ProfileRunResult) } diff --git a/pkg/epp/scheduling/framework/plugins/filter/decision_tree_filter.go b/pkg/epp/scheduling/framework/plugins/filter/decision_tree_filter.go index 3f7c88fcf..6799b8053 100644 --- a/pkg/epp/scheduling/framework/plugins/filter/decision_tree_filter.go +++ b/pkg/epp/scheduling/framework/plugins/filter/decision_tree_filter.go @@ -20,6 +20,7 @@ import ( "context" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -56,7 +57,7 @@ func (f *DecisionTreeFilter) Type() string { } // Filter filters out pods that doesn't meet the filter criteria. -func (f *DecisionTreeFilter) Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *DecisionTreeFilter) Filter(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) filteredPod := f.Current.Filter(ctx, cycleState, request, pods) diff --git a/pkg/epp/scheduling/framework/plugins/filter/filter_test.go b/pkg/epp/scheduling/framework/plugins/filter/filter_test.go index 1f07c1d37..000c674ae 100644 --- a/pkg/epp/scheduling/framework/plugins/filter/filter_test.go +++ b/pkg/epp/scheduling/framework/plugins/filter/filter_test.go @@ -25,6 +25,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -39,7 +40,7 @@ func (f *filterAll) Type() string { return "filter-all" } -func (f *filterAll) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *filterAll) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { return []types.Pod{} } @@ -138,7 +139,7 @@ func TestFilter(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := test.filter.Filter(context.Background(), types.NewCycleState(), test.req, test.input) + got := test.filter.Filter(context.Background(), plugins.NewCycleState(), test.req, test.input) if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) @@ -206,7 +207,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { LoraAffinityFilter := NewLoraAffinityFilter(config.Conf.LoraAffinityThreshold) for range numIterations { - result := LoraAffinityFilter.Filter(context.Background(), types.NewCycleState(), req, pods) + result := LoraAffinityFilter.Filter(context.Background(), plugins.NewCycleState(), req, pods) // Check which type of pod was returned if len(result) != 1 { diff --git a/pkg/epp/scheduling/framework/plugins/filter/least_kvcache_filter.go b/pkg/epp/scheduling/framework/plugins/filter/least_kvcache_filter.go index 0b4c1ebec..6ec6180af 100644 --- a/pkg/epp/scheduling/framework/plugins/filter/least_kvcache_filter.go +++ b/pkg/epp/scheduling/framework/plugins/filter/least_kvcache_filter.go @@ -56,7 +56,7 @@ func (f *LeastKVCacheFilter) Type() string { } // Filter filters out pods that doesn't meet the filter criteria. -func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { filteredPods := []types.Pod{} min := math.MaxFloat64 diff --git a/pkg/epp/scheduling/framework/plugins/filter/least_queue_filter.go b/pkg/epp/scheduling/framework/plugins/filter/least_queue_filter.go index c43d4b3df..218eda907 100644 --- a/pkg/epp/scheduling/framework/plugins/filter/least_queue_filter.go +++ b/pkg/epp/scheduling/framework/plugins/filter/least_queue_filter.go @@ -56,7 +56,7 @@ func (f *LeastQueueFilter) Type() string { } // Filter filters out pods that doesn't meet the filter criteria. -func (f *LeastQueueFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *LeastQueueFilter) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { filteredPods := []types.Pod{} min := math.MaxInt diff --git a/pkg/epp/scheduling/framework/plugins/filter/lora_affinity_filter.go b/pkg/epp/scheduling/framework/plugins/filter/lora_affinity_filter.go index 2150cdb08..9228bf813 100644 --- a/pkg/epp/scheduling/framework/plugins/filter/lora_affinity_filter.go +++ b/pkg/epp/scheduling/framework/plugins/filter/lora_affinity_filter.go @@ -73,7 +73,7 @@ func (f *LoraAffinityFilter) Type() string { } // Filter filters out pods that doesn't meet the filter criteria. -func (f *LoraAffinityFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *LoraAffinityFilter) Filter(_ context.Context, _ *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod { // Pre-allocate slices with estimated capacity filtered_affinity := make([]types.Pod, 0, len(pods)) filtered_available := make([]types.Pod, 0, len(pods)) diff --git a/pkg/epp/scheduling/framework/plugins/filter/low_queue_filter.go b/pkg/epp/scheduling/framework/plugins/filter/low_queue_filter.go index 2d8aa3d79..76223033a 100644 --- a/pkg/epp/scheduling/framework/plugins/filter/low_queue_filter.go +++ b/pkg/epp/scheduling/framework/plugins/filter/low_queue_filter.go @@ -67,7 +67,7 @@ func (f *LowQueueFilter) Type() string { } // Filter filters out pods that doesn't meet the filter criteria. -func (f *LowQueueFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (f *LowQueueFilter) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { filteredPods := []types.Pod{} for _, pod := range pods { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index f79122ad8..c06935898 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -89,7 +89,7 @@ func (s ServerID) String() string { } // compile-time type validation -var _ types.StateData = &schedulingContextState{} +var _ plugins.StateData = &schedulingContextState{} // This is the state of this plugin to be used during a scheduling cycle. type schedulingContextState struct { @@ -99,7 +99,7 @@ type schedulingContextState struct { PrefixCacheServers map[ServerID]int } -func (s *schedulingContextState) Clone() types.StateData { +func (s *schedulingContextState) Clone() plugins.StateData { prefixHashes := make([]BlockHash, len(s.PrefixHashes)) copy(prefixHashes, s.PrefixHashes) prefixCacheServers := make(map[ServerID]int, len(s.PrefixCacheServers)) @@ -158,7 +158,7 @@ func (m *Plugin) Type() string { } // Score returns the scoring result for the given list of pods based on context. -func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (m *Plugin) Score(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { loggerTrace := log.FromContext(ctx).V(logutil.TRACE) // pre score step, hashing prompt and find longest prefix match. hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) @@ -167,7 +167,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques PrefixCacheServers: m.matchLongestPrefix(ctx, hashes), } - cycleState.Write(types.StateKey(m.Type()), state) + cycleState.Write(plugins.StateKey(m.Type()), state) loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) // calculate the scores of pods scores := make(map[types.Pod]float64, len(pods)) @@ -188,7 +188,7 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques } // PostCycle records in the plugin cache the result of the scheduling selection. -func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) { +func (m *Plugin) PostCycle(ctx context.Context, cycleState *plugins.CycleState, res *types.ProfileRunResult) { targetPod := res.TargetPod.GetPod() state, err := m.getPrefixState(cycleState) if err != nil { @@ -227,11 +227,11 @@ func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map } // getPrefixState returns the cycle state as a schedulingContextState. -func (m *Plugin) getPrefixState(cycleState *types.CycleState) (*schedulingContextState, error) { - prefixStateKey := types.StateKey(m.Type()) +func (m *Plugin) getPrefixState(cycleState *plugins.CycleState) (*schedulingContextState, error) { + prefixStateKey := plugins.StateKey(m.Type()) state, err := cycleState.Read(prefixStateKey) if err != nil { - return nil, fmt.Errorf("failed reading %q from CycleState: %w", prefixStateKey, err) + return nil, fmt.Errorf("failed reading %q from cycleState: %w", prefixStateKey, err) } prefixSchedulingState, ok := state.(*schedulingContextState) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 73f02c373..e3f247fa3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -48,7 +49,7 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model1", Prompt: "aaaaaa", } - cycleState1 := types.NewCycleState() + cycleState1 := plugins.NewCycleState() scores := plugin.Score(context.Background(), cycleState1, req1, pods) state, err := plugin.getPrefixState(cycleState1) assert.NoError(t, err) @@ -69,7 +70,7 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model2", Prompt: "bbbbbb", } - cycleState2 := types.NewCycleState() + cycleState2 := plugins.NewCycleState() scores = plugin.Score(context.Background(), cycleState2, req2, pods) state, err = plugin.getPrefixState(cycleState2) assert.NoError(t, err) @@ -89,7 +90,7 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model1", Prompt: "aaaabbbb", } - cycleState3 := types.NewCycleState() + cycleState3 := plugins.NewCycleState() scores = plugin.Score(context.Background(), cycleState3, req3, pods) state, err = plugin.getPrefixState(cycleState3) assert.NoError(t, err) @@ -108,7 +109,7 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model-new", Prompt: "aaaabbbb", } - cycleState4 := types.NewCycleState() + cycleState4 := plugins.NewCycleState() scores = plugin.Score(context.Background(), cycleState4, req4, pods) state, err = plugin.getPrefixState(cycleState4) assert.NoError(t, err) @@ -127,7 +128,7 @@ func TestPrefixPlugin(t *testing.T) { TargetModel: "test-model1", Prompt: "aaaabbbbcccc", } - cycleState5 := types.NewCycleState() + cycleState5 := plugins.NewCycleState() scores = plugin.Score(context.Background(), cycleState5, req5, pods) state, err = plugin.getPrefixState(cycleState5) assert.NoError(t, err) @@ -153,7 +154,6 @@ func BenchmarkPrefixPluginStress(b *testing.B) { } plugin := New(config) - types.NewCycleState() var promptLen []int for i := 1; i <= 1024; i++ { promptLen = append(promptLen, i) @@ -178,7 +178,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { } // First cycle: simulate scheduling and insert prefix info into the cache - cycleState := types.NewCycleState() + cycleState := plugins.NewCycleState() plugin.Score(context.Background(), cycleState, req, pods) plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPod: pod}) diff --git a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go index 43c438de6..c3c7ea79c 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go @@ -58,7 +58,7 @@ func (p *MaxScorePicker) Type() string { } // Pick selects the pod with the maximum score from the list of candidates. -func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { +func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *plugins.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a pod with the max score from %d candidates: %+v", len(scoredPods), scoredPods)) highestScorePods := []*types.ScoredPod{} diff --git a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go index f12a05b84..ba201d880 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go @@ -55,7 +55,7 @@ func (p *RandomPicker) Type() string { } // Pick selects a random pod from the list of candidates. -func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { +func (p *RandomPicker) Pick(ctx context.Context, _ *plugins.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(scoredPods), scoredPods)) i := rand.Intn(len(scoredPods)) return &types.ProfileRunResult{TargetPod: scoredPods[i]} diff --git a/pkg/epp/scheduling/framework/plugins/profile/single_profile_handler.go b/pkg/epp/scheduling/framework/plugins/profile/single_profile_handler.go index ca87f5e8b..ef7b7f0ad 100644 --- a/pkg/epp/scheduling/framework/plugins/profile/single_profile_handler.go +++ b/pkg/epp/scheduling/framework/plugins/profile/single_profile_handler.go @@ -54,7 +54,7 @@ func (h *SingleProfileHandler) Type() string { // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. -func (h *SingleProfileHandler) Pick(_ context.Context, _ *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, +func (h *SingleProfileHandler) Pick(_ context.Context, _ *plugins.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call return map[string]*framework.SchedulerProfile{} @@ -67,7 +67,7 @@ func (h *SingleProfileHandler) Pick(_ context.Context, _ *types.CycleState, requ // It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the // key of the primary profile that should be used to get the request selected destination. // When a profile run fails, its result in the profileResults map is nil. -func (h *SingleProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, +func (h *SingleProfileHandler) ProcessResults(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { if len(profileResults) != 1 { return nil, errors.New("single profile handler is intended to be used with a single profile, failed to process multiple profiles") diff --git a/pkg/epp/scheduling/framework/plugins/scorer/kvcache.go b/pkg/epp/scheduling/framework/plugins/scorer/kvcache.go index 6bab369b5..f0022bdbd 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/kvcache.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/kvcache.go @@ -52,7 +52,7 @@ func (s *KVCacheScorer) Type() string { } // Score returns the scoring result for the given list of pods based on context. -func (s *KVCacheScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *KVCacheScorer) Score(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { scores := make(map[types.Pod]float64, len(pods)) for _, pod := range pods { scores[pod] = 1 - pod.GetMetrics().KVCacheUsagePercent diff --git a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_test.go b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_test.go index c0eeb5210..2049fdb29 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_test.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -83,7 +84,7 @@ func TestKvCacheScorer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { scorer := &KVCacheScorer{} - scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods) + scores := scorer.Score(context.Background(), plugins.NewCycleState(), &types.LLMRequest{}, test.pods) for i, pod := range test.pods { expectedScore := test.expectedScoresPod[i] diff --git a/pkg/epp/scheduling/framework/plugins/scorer/queue.go b/pkg/epp/scheduling/framework/plugins/scorer/queue.go index a3c960e0f..bdbc376a9 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/queue.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/queue.go @@ -54,7 +54,7 @@ func (s *QueueScorer) Type() string { } // Score returns the scoring result for the given list of pods based on context. -func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *QueueScorer) Score(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { minQueueSize := math.MaxInt maxQueueSize := math.MinInt diff --git a/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go b/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go index a9a8115b3..4a29c5305 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/queue_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -73,7 +74,7 @@ func TestQueueScorer(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scores := scorer.Score(context.Background(), types.NewCycleState(), &types.LLMRequest{}, test.pods) + scores := scorer.Score(context.Background(), plugins.NewCycleState(), &types.LLMRequest{}, test.pods) for i, pod := range test.pods { expectedScore := test.expectedScoresPod[i] diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go index f41a915f0..817c5d8cb 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile.go +++ b/pkg/epp/scheduling/framework/scheduler_profile.go @@ -106,7 +106,7 @@ func (p *SchedulerProfile) AddPlugins(pluginObjects ...plugins.Plugin) error { // RunCycle runs a SchedulerProfile cycle. In other words, it invokes all the SchedulerProfile plugins in this // order - Filters, Scorers, Picker, PostCyclePlugins. After completing all, it returns the result. -func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidatePods []types.Pod) (*types.ProfileRunResult, error) { +func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *plugins.CycleState, candidatePods []types.Pod) (*types.ProfileRunResult, error) { pods := p.runFilterPlugins(ctx, request, cycleState, candidatePods) if len(pods) == 0 { return nil, errutil.Error{Code: errutil.Internal, Msg: "no pods available for the given request"} @@ -121,7 +121,7 @@ func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, c return result, nil } -func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod { +func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types.LLMRequest, cycleState *plugins.CycleState, pods []types.Pod) []types.Pod { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) filteredPods := pods loggerDebug.Info("Before running filter plugins", "pods", filteredPods) @@ -141,7 +141,7 @@ func (p *SchedulerProfile) runFilterPlugins(ctx context.Context, request *types. return filteredPods } -func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 { +func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.LLMRequest, cycleState *plugins.CycleState, pods []types.Pod) map[types.Pod]float64 { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) loggerDebug.Info("Before running scorer plugins", "pods", pods) @@ -165,7 +165,7 @@ func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types. return weightedScorePerPod } -func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *types.CycleState, weightedScorePerPod map[types.Pod]float64) *types.ProfileRunResult { +func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *plugins.CycleState, weightedScorePerPod map[types.Pod]float64) *types.ProfileRunResult { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) scoredPods := make([]*types.ScoredPod, len(weightedScorePerPod)) i := 0 @@ -183,7 +183,7 @@ func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *type return result } -func (p *SchedulerProfile) runPostCyclePlugins(ctx context.Context, cycleState *types.CycleState, result *types.ProfileRunResult) { +func (p *SchedulerProfile) runPostCyclePlugins(ctx context.Context, cycleState *plugins.CycleState, result *types.ProfileRunResult) { for _, plugin := range p.postCyclePlugins { log.FromContext(ctx).V(logutil.DEBUG).Info("Running post-cycle plugin", "plugin", plugin.Type()) before := time.Now() diff --git a/pkg/epp/scheduling/framework/scheduler_profile_test.go b/pkg/epp/scheduling/framework/scheduler_profile_test.go index 7fbc9e893..605c1ac43 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile_test.go +++ b/pkg/epp/scheduling/framework/scheduler_profile_test.go @@ -25,6 +25,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -129,7 +130,7 @@ func TestSchedulePlugins(t *testing.T) { RequestId: uuid.NewString(), } // Run profile cycle - got, err := test.profile.Run(context.Background(), request, types.NewCycleState(), types.ToSchedulerPodMetrics(test.input)) + got, err := test.profile.Run(context.Background(), request, plugins.NewCycleState(), types.ToSchedulerPodMetrics(test.input)) // Validate error state if test.err != (err != nil) { @@ -210,13 +211,12 @@ type testPlugin struct { func (tp *testPlugin) Type() string { return tp.TypeRes } -func (tp *testPlugin) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { +func (tp *testPlugin) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { tp.FilterCallCount++ return findPods(pods, tp.FilterRes...) - } -func (tp *testPlugin) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (tp *testPlugin) Score(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { tp.ScoreCallCount++ scoredPods := make(map[types.Pod]float64, len(pods)) for _, pod := range pods { @@ -226,7 +226,7 @@ func (tp *testPlugin) Score(_ context.Context, _ *types.CycleState, _ *types.LLM return scoredPods } -func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { +func (tp *testPlugin) Pick(_ context.Context, _ *plugins.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { tp.PickCallCount++ tp.NumOfPickerCandidates = len(scoredPods) @@ -241,7 +241,7 @@ func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredPods [] return &types.ProfileRunResult{TargetPod: winnerPod} } -func (tp *testPlugin) PostCycle(_ context.Context, _ *types.CycleState, res *types.ProfileRunResult) { +func (tp *testPlugin) PostCycle(_ context.Context, _ *plugins.CycleState, res *types.ProfileRunResult) { tp.PostScheduleCallCount++ } diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index b848b26dc..849aa9fe8 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -25,6 +25,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/filter" @@ -92,7 +93,7 @@ type Scheduler struct { } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest, candidatePods []types.Pod) (*types.SchedulingResult, error) { +func (s *Scheduler) Schedule(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, candidatePods []types.Pod) (*types.SchedulingResult, error) { logger := log.FromContext(ctx).WithValues("request", request) loggerDebug := logger.V(logutil.DEBUG) @@ -102,7 +103,6 @@ func (s *Scheduler) Schedule(ctx context.Context, request *types.LLMRequest, can }() profileRunResults := map[string]*types.ProfileRunResult{} - cycleState := types.NewCycleState() for { // get the next set of profiles to run iteratively based on the request and the previous execution results before := time.Now() diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 720a6b4f5..855fde74c 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -25,6 +25,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -120,7 +121,7 @@ func TestSchedule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { scheduler := NewScheduler() - got, err := scheduler.Schedule(context.Background(), test.req, types.ToSchedulerPodMetrics(test.input)) + got, err := scheduler.Schedule(context.Background(), plugins.NewCycleState(), test.req, types.ToSchedulerPodMetrics(test.input)) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) }