diff --git a/pkg/epp/common/config/loader/configloader_test.go b/pkg/epp/common/config/loader/configloader_test.go index 233f76aa7..4eddb7324 100644 --- a/pkg/epp/common/config/loader/configloader_test.go +++ b/pkg/epp/common/config/loader/configloader_test.go @@ -550,7 +550,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque // compile-time type validation var _ framework.Scorer = &test2{} -var _ framework.PostCycle = &test2{} type test2 struct { typedName plugins.TypedName @@ -570,8 +569,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques return map[types.Pod]float64{} } -func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {} - // compile-time type validation var _ framework.Picker = &testPicker{} diff --git a/pkg/epp/scheduling/framework/plugins.go b/pkg/epp/scheduling/framework/plugins.go index 7e22d8618..4e59fe7c5 100644 --- a/pkg/epp/scheduling/framework/plugins.go +++ b/pkg/epp/scheduling/framework/plugins.go @@ -28,7 +28,6 @@ const ( FilterPluginType = "Filter" ScorerPluginType = "Scorer" PickerPluginType = "Picker" - PostCyclePluginType = "PostCycle" ProcessProfilesResultsType = "ProcessProfilesResults" ) @@ -67,10 +66,3 @@ type Picker interface { plugins.Plugin Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult } - -// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle. -// DEPRECATED - do not use PostCycle. this is in the process of deprecation. -type PostCycle interface { - plugins.Plugin - PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) -} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 42e630354..191bcd3aa 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -25,8 +25,10 @@ import ( "github.com/cespare/xxhash/v2" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -116,7 +118,7 @@ func (s *SchedulingContextState) Clone() types.StateData { // compile-time type assertion var _ framework.Scorer = &Plugin{} -var _ framework.PostCycle = &Plugin{} +var _ requestcontrol.PostResponse = &Plugin{} // PrefixCachePluginFactory defines the factory function for Prefix plugin. func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { @@ -194,19 +196,24 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques return scores } -// PostCycle records in the plugin cache the result of the scheduling selection. -func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) { - targetPod := res.TargetPods[0].GetPod() - state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType) - if err != nil { - log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state") +// PostResponse records in the plugin cache the result of the request processing. +// This method recomputes the prefix hashes from the request since the scheduling cycle state +// is not available in the PostResponse phase. +func (m *Plugin) PostResponse(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + // Recompute the prefix hashes from the request + hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + if len(hashes) == 0 { + // No hashes to cache, skip processing return } - m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + // Add the hashes to the indexer for the target pod + m.indexer.Add(hashes, ServerID(targetPod.NamespacedName)) - total := len(state.PrefixHashes) - matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] + // Record metrics - we need to compute the match length for this pod + prefixCacheServers := m.matchLongestPrefix(ctx, hashes) + total := len(hashes) + matchLen := prefixCacheServers[ServerID(targetPod.NamespacedName)] metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 27e13d685..494ccfdef 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -60,8 +61,12 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - // Simulate pod1 was picked. - plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + // Simulate pod1 was picked - use PostResponse instead + response1 := &requestcontrol.Response{ + RequestId: "req1", + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req1, response1, pod1.Pod) // Second request doesn't share any prefix with first one. It should be added to the cache but // the pod score should be 0. @@ -81,8 +86,12 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - // Simulate pod2 was picked. - plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}}) + // Simulate pod2 was picked - use PostResponse instead + response2 := &requestcontrol.Response{ + RequestId: "req2", + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req2, response2, pod2.Pod) // Third request shares partial prefix with first one. req3 := &types.LLMRequest{ @@ -101,7 +110,12 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + // Simulate pod1 was picked - use PostResponse instead + response3 := &requestcontrol.Response{ + RequestId: "req3", + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req3, response3, pod1.Pod) // 4th request is same as req3 except the model is different, still no match. req4 := &types.LLMRequest{ @@ -120,7 +134,12 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, float64(0), scores[pod1], "score for pod1") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + // Simulate pod1 was picked - use PostResponse instead + response4 := &requestcontrol.Response{ + RequestId: "req4", + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req4, response4, pod1.Pod) // 5th request shares partial prefix with 3rd one. req5 := &types.LLMRequest{ @@ -139,7 +158,76 @@ func TestPrefixPlugin(t *testing.T) { assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match") assert.Equal(t, float64(0), scores[pod2], "score for pod2") - plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}}) + // Simulate pod1 was picked - use PostResponse instead + response5 := &requestcontrol.Response{ + RequestId: "req5", + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req5, response5, pod1.Pod) +} + +// TestPrefixPluginPostResponse tests the PostResponse method functionality +func TestPrefixPluginPostResponse(t *testing.T) { + config := Config{ + HashBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(config) + + pod1 := &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}} + pod2 := &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}} + + // First request - use PostResponse to cache prefix + req1 := &types.LLMRequest{ + TargetModel: "test-model1", + Prompt: "aaaaaa", + } + response1 := &requestcontrol.Response{ + RequestId: "req1", + Headers: map[string]string{"content-type": "application/json"}, + } + + // Call PostResponse to cache the prefix for pod1 + plugin.PostResponse(context.Background(), req1, response1, pod1) + + // Second request with same prefix - should get cache hit on pod1 + req2 := &types.LLMRequest{ + TargetModel: "test-model1", + Prompt: "aaaabbbb", + } + + // Test scoring to verify cache hit + pods := []types.Pod{ + &types.PodMetrics{Pod: pod1}, + &types.PodMetrics{Pod: pod2}, + } + cycleState := types.NewCycleState() + scores := plugin.Score(context.Background(), cycleState, req2, pods) + + // pod1 should have a higher score due to prefix cache hit + assert.Greater(t, scores[pods[0]], scores[pods[1]], "pod1 should have higher score due to prefix cache") + assert.Greater(t, scores[pods[0]], float64(0), "pod1 should have non-zero score") + assert.Equal(t, float64(0), scores[pods[1]], "pod2 should have zero score") + + // Use PostResponse for the second request on pod1 + response2 := &requestcontrol.Response{ + RequestId: "req2", + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req2, response2, pod1) + + // Third request with different model - should not get cache hit + req3 := &types.LLMRequest{ + TargetModel: "test-model2", // Different model + Prompt: "aaaaaa", + } + cycleState3 := types.NewCycleState() + scores3 := plugin.Score(context.Background(), cycleState3, req3, pods) + + // Both pods should have zero score since model is different + assert.Equal(t, float64(0), scores3[pods[0]], "pod1 should have zero score for different model") + assert.Equal(t, float64(0), scores3[pods[1]], "pod2 should have zero score for different model") } // TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. @@ -180,7 +268,12 @@ func BenchmarkPrefixPluginStress(b *testing.B) { // First cycle: simulate scheduling and insert prefix info into the cache cycleState := types.NewCycleState() plugin.Score(context.Background(), cycleState, req, pods) - plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}}) + // Use PostResponse instead of PostCycle + response := &requestcontrol.Response{ + RequestId: fmt.Sprintf("req-%d", i), + Headers: map[string]string{"content-type": "application/json"}, + } + plugin.PostResponse(context.Background(), req, response, pod.Pod) // Second cycle: validate internal state state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType) diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go index c65d20156..d9c92d9ae 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile.go +++ b/pkg/epp/scheduling/framework/scheduler_profile.go @@ -32,19 +32,17 @@ import ( // NewSchedulerProfile creates a new SchedulerProfile object and returns its pointer. func NewSchedulerProfile() *SchedulerProfile { return &SchedulerProfile{ - filters: []Filter{}, - scorers: []*WeightedScorer{}, - postCyclePlugins: []PostCycle{}, + filters: []Filter{}, + scorers: []*WeightedScorer{}, // picker remains nil since profile doesn't support multiple pickers } } // SchedulerProfile provides a profile configuration for the scheduler which influence routing decisions. type SchedulerProfile struct { - filters []Filter - scorers []*WeightedScorer - picker Picker - postCyclePlugins []PostCycle + filters []Filter + scorers []*WeightedScorer + picker Picker } // WithFilters sets the given filter plugins as the Filter plugins. @@ -68,13 +66,6 @@ func (p *SchedulerProfile) WithPicker(picker Picker) *SchedulerProfile { return p } -// WithPostCyclePlugins sets the given plugins as the PostCycle plugins. -// If the SchedulerProfile has PostCycle plugins, this call replaces the existing plugins with the given ones. -func (p *SchedulerProfile) WithPostCyclePlugins(plugins ...PostCycle) *SchedulerProfile { - p.postCyclePlugins = plugins - return p -} - // AddPlugins adds the given plugins to all scheduler plugins according to the interfaces each plugin implements. // A plugin may implement more than one scheduler plugin interface. // Special Case: In order to add a scorer, one must use the scorer.NewWeightedScorer function in order to provide a weight. @@ -97,15 +88,13 @@ func (p *SchedulerProfile) AddPlugins(pluginObjects ...plugins.Plugin) error { } p.picker = picker } - if postCyclePlugin, ok := plugin.(PostCycle); ok { - p.postCyclePlugins = append(p.postCyclePlugins, postCyclePlugin) - } + } return nil } // RunCycle runs a SchedulerProfile cycle. In other words, it invokes all the SchedulerProfile plugins in this -// order - Filters, Scorers, Picker, PostCyclePlugins. After completing all, it returns the result. +// order - Filters, Scorers, Picker. After completing all, it returns the result. func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidatePods []types.Pod) (*types.ProfileRunResult, error) { pods := p.runFilterPlugins(ctx, request, cycleState, candidatePods) if len(pods) == 0 { @@ -116,8 +105,6 @@ func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, c result := p.runPickerPlugin(ctx, cycleState, weightedScorePerPod) - p.runPostCyclePlugins(ctx, cycleState, result) - return result, nil } @@ -182,12 +169,3 @@ func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *type return result } - -func (p *SchedulerProfile) runPostCyclePlugins(ctx context.Context, cycleState *types.CycleState, result *types.ProfileRunResult) { - for _, plugin := range p.postCyclePlugins { - log.FromContext(ctx).V(logutil.DEBUG).Info("Running post-cycle plugin", "plugin", plugin.TypedName().Type) - before := time.Now() - plugin.PostCycle(ctx, cycleState, result) - metrics.RecordSchedulerPluginProcessingLatency(PostCyclePluginType, plugin.TypedName().Type, time.Since(before)) - } -} diff --git a/pkg/epp/scheduling/framework/scheduler_profile_test.go b/pkg/epp/scheduling/framework/scheduler_profile_test.go index 020223c00..bf7a21fb4 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile_test.go +++ b/pkg/epp/scheduling/framework/scheduler_profile_test.go @@ -64,8 +64,7 @@ func TestSchedulePlugins(t *testing.T) { profile: NewSchedulerProfile(). WithFilters(tp1, tp2). WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)). - WithPicker(pickerPlugin). - WithPostCyclePlugins(tp1, tp2), + WithPicker(pickerPlugin), input: []types.Pod{ &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -81,8 +80,7 @@ func TestSchedulePlugins(t *testing.T) { profile: NewSchedulerProfile(). WithFilters(tp1, tp2). WithScorers(NewWeightedScorer(tp1, 60), NewWeightedScorer(tp2, 40)). - WithPicker(pickerPlugin). - WithPostCyclePlugins(tp1, tp2), + WithPicker(pickerPlugin), input: []types.Pod{ &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -98,8 +96,7 @@ func TestSchedulePlugins(t *testing.T) { profile: NewSchedulerProfile(). WithFilters(tp1, tp_filterAll). WithScorers(NewWeightedScorer(tp1, 1), NewWeightedScorer(tp2, 1)). - WithPicker(pickerPlugin). - WithPostCyclePlugins(tp1, tp2), + WithPicker(pickerPlugin), input: []types.Pod{ &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -120,9 +117,6 @@ func TestSchedulePlugins(t *testing.T) { plugin.Scorer.(*testPlugin).reset() } test.profile.picker.(*testPlugin).reset() - for _, plugin := range test.profile.postCyclePlugins { - plugin.(*testPlugin).reset() - } // Initialize the scheduling context request := &types.LLMRequest{ @@ -179,12 +173,7 @@ func TestSchedulePlugins(t *testing.T) { if tp.WinnerPodScore != test.targetPodScore { t.Errorf("winner pod score %v, expected %v", tp.WinnerPodScore, test.targetPodScore) } - for _, plugin := range test.profile.postCyclePlugins { - tp, _ := plugin.(*testPlugin) - if tp.PostCycleCallCount != 1 { - t.Errorf("Plugin '%s' PostCycle() called %d times, expected 1", plugin.TypedName(), tp.PostCycleCallCount) - } - } + }) } } @@ -193,7 +182,6 @@ func TestSchedulePlugins(t *testing.T) { var _ Filter = &testPlugin{} var _ Scorer = &testPlugin{} var _ Picker = &testPlugin{} -var _ PostCycle = &testPlugin{} // testPlugin is an implementation useful in unit tests. type testPlugin struct { @@ -204,7 +192,6 @@ type testPlugin struct { ScoreRes float64 FilterCallCount int FilterRes []k8stypes.NamespacedName - PostCycleCallCount int PickCallCount int NumOfPickerCandidates int PickRes k8stypes.NamespacedName @@ -246,15 +233,10 @@ func (tp *testPlugin) Pick(_ context.Context, _ *types.CycleState, scoredPods [] return &types.ProfileRunResult{TargetPods: winnerPods} } -func (tp *testPlugin) PostCycle(_ context.Context, _ *types.CycleState, res *types.ProfileRunResult) { - tp.PostCycleCallCount++ -} - func (tp *testPlugin) reset() { tp.FilterCallCount = 0 tp.ScoreCallCount = 0 tp.NumOfScoredPods = 0 - tp.PostCycleCallCount = 0 tp.PickCallCount = 0 tp.NumOfPickerCandidates = 0 }