Skip to content

feat(prefix-cache): use post response instead of postCycle #1176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pkg/epp/common/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMReque

// compile-time type validation
var _ framework.Scorer = &test2{}
var _ framework.PostCycle = &test2{}

type test2 struct {
typedName plugins.TypedName
Expand All @@ -570,8 +569,6 @@ func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMReques
return map[types.Pod]float64{}
}

func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}

// compile-time type validation
var _ framework.Picker = &testPicker{}

Expand Down
8 changes: 0 additions & 8 deletions pkg/epp/scheduling/framework/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ const (
FilterPluginType = "Filter"
ScorerPluginType = "Scorer"
PickerPluginType = "Picker"
PostCyclePluginType = "PostCycle"
ProcessProfilesResultsType = "ProcessProfilesResults"
)

Expand Down Expand Up @@ -67,10 +66,3 @@ type Picker interface {
plugins.Plugin
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
}

// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
// DEPRECATED - do not use PostCycle. this is in the process of deprecation.
type PostCycle interface {
plugins.Plugin
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult)
}
27 changes: 17 additions & 10 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
"github.com/cespare/xxhash/v2"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand Down Expand Up @@ -116,7 +118,7 @@ func (s *SchedulingContextState) Clone() types.StateData {

// compile-time type assertion
var _ framework.Scorer = &Plugin{}
var _ framework.PostCycle = &Plugin{}
var _ requestcontrol.PostResponse = &Plugin{}

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
Expand Down Expand Up @@ -194,19 +196,24 @@ func (m *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
return scores
}

// PostCycle records in the plugin cache the result of the scheduling selection.
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult) {
targetPod := res.TargetPods[0].GetPod()
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
// PostResponse records in the plugin cache the result of the request processing.
// This method recomputes the prefix hashes from the request since the scheduling cycle state
// is not available in the PostResponse phase.
func (m *Plugin) PostResponse(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
// Recompute the prefix hashes from the request
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
if len(hashes) == 0 {
// No hashes to cache, skip processing
return
}

m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
// Add the hashes to the indexer for the target pod
m.indexer.Add(hashes, ServerID(targetPod.NamespacedName))

total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
// Record metrics - we need to compute the match length for this pod
prefixCacheServers := m.matchLongestPrefix(ctx, hashes)
total := len(hashes)
matchLen := prefixCacheServers[ServerID(targetPod.NamespacedName)]
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
}

Expand Down
109 changes: 101 additions & 8 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/stretchr/testify/assert"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

Expand Down Expand Up @@ -60,8 +61,12 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod1 was picked.
plugin.PostCycle(context.Background(), cycleState1, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
// Simulate pod1 was picked - use PostResponse instead
response1 := &requestcontrol.Response{
RequestId: "req1",
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req1, response1, pod1.Pod)

// Second request doesn't share any prefix with first one. It should be added to the cache but
// the pod score should be 0.
Expand All @@ -81,8 +86,12 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

// Simulate pod2 was picked.
plugin.PostCycle(context.Background(), cycleState2, &types.ProfileRunResult{TargetPods: []types.Pod{pod2}})
// Simulate pod2 was picked - use PostResponse instead
response2 := &requestcontrol.Response{
RequestId: "req2",
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req2, response2, pod2.Pod)

// Third request shares partial prefix with first one.
req3 := &types.LLMRequest{
Expand All @@ -101,7 +110,12 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState3, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
// Simulate pod1 was picked - use PostResponse instead
response3 := &requestcontrol.Response{
RequestId: "req3",
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req3, response3, pod1.Pod)

// 4th request is same as req3 except the model is different, still no match.
req4 := &types.LLMRequest{
Expand All @@ -120,7 +134,12 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, float64(0), scores[pod1], "score for pod1")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState4, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
// Simulate pod1 was picked - use PostResponse instead
response4 := &requestcontrol.Response{
RequestId: "req4",
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req4, response4, pod1.Pod)

// 5th request shares partial prefix with 3rd one.
req5 := &types.LLMRequest{
Expand All @@ -139,7 +158,76 @@ func TestPrefixPlugin(t *testing.T) {
assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match")
assert.Equal(t, float64(0), scores[pod2], "score for pod2")

plugin.PostCycle(context.Background(), cycleState5, &types.ProfileRunResult{TargetPods: []types.Pod{pod1}})
// Simulate pod1 was picked - use PostResponse instead
response5 := &requestcontrol.Response{
RequestId: "req5",
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req5, response5, pod1.Pod)
}

// TestPrefixPluginPostResponse tests the PostResponse method functionality
func TestPrefixPluginPostResponse(t *testing.T) {
config := Config{
HashBlockSize: 4,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
plugin := New(config)

pod1 := &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}
pod2 := &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}

// First request - use PostResponse to cache prefix
req1 := &types.LLMRequest{
TargetModel: "test-model1",
Prompt: "aaaaaa",
}
response1 := &requestcontrol.Response{
RequestId: "req1",
Headers: map[string]string{"content-type": "application/json"},
}

// Call PostResponse to cache the prefix for pod1
plugin.PostResponse(context.Background(), req1, response1, pod1)

// Second request with same prefix - should get cache hit on pod1
req2 := &types.LLMRequest{
TargetModel: "test-model1",
Prompt: "aaaabbbb",
}

// Test scoring to verify cache hit
pods := []types.Pod{
&types.PodMetrics{Pod: pod1},
&types.PodMetrics{Pod: pod2},
}
cycleState := types.NewCycleState()
scores := plugin.Score(context.Background(), cycleState, req2, pods)

// pod1 should have a higher score due to prefix cache hit
assert.Greater(t, scores[pods[0]], scores[pods[1]], "pod1 should have higher score due to prefix cache")
assert.Greater(t, scores[pods[0]], float64(0), "pod1 should have non-zero score")
assert.Equal(t, float64(0), scores[pods[1]], "pod2 should have zero score")

// Use PostResponse for the second request on pod1
response2 := &requestcontrol.Response{
RequestId: "req2",
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req2, response2, pod1)

// Third request with different model - should not get cache hit
req3 := &types.LLMRequest{
TargetModel: "test-model2", // Different model
Prompt: "aaaaaa",
}
cycleState3 := types.NewCycleState()
scores3 := plugin.Score(context.Background(), cycleState3, req3, pods)

// Both pods should have zero score since model is different
assert.Equal(t, float64(0), scores3[pods[0]], "pod1 should have zero score for different model")
assert.Equal(t, float64(0), scores3[pods[1]], "pod2 should have zero score for different model")
}

// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
Expand Down Expand Up @@ -180,7 +268,12 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
// First cycle: simulate scheduling and insert prefix info into the cache
cycleState := types.NewCycleState()
plugin.Score(context.Background(), cycleState, req, pods)
plugin.PostCycle(context.Background(), cycleState, &types.ProfileRunResult{TargetPods: []types.Pod{pod}})
// Use PostResponse instead of PostCycle
response := &requestcontrol.Response{
RequestId: fmt.Sprintf("req-%d", i),
Headers: map[string]string{"content-type": "application/json"},
}
plugin.PostResponse(context.Background(), req, response, pod.Pod)

// Second cycle: validate internal state
state, err := types.ReadCycleStateKey[*SchedulingContextState](cycleState, PrefixCachePluginType)
Expand Down
36 changes: 7 additions & 29 deletions pkg/epp/scheduling/framework/scheduler_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,17 @@ import (
// NewSchedulerProfile creates a new SchedulerProfile object and returns its pointer.
func NewSchedulerProfile() *SchedulerProfile {
return &SchedulerProfile{
filters: []Filter{},
scorers: []*WeightedScorer{},
postCyclePlugins: []PostCycle{},
filters: []Filter{},
scorers: []*WeightedScorer{},
// picker remains nil since profile doesn't support multiple pickers
}
}

// SchedulerProfile provides a profile configuration for the scheduler which influence routing decisions.
type SchedulerProfile struct {
filters []Filter
scorers []*WeightedScorer
picker Picker
postCyclePlugins []PostCycle
filters []Filter
scorers []*WeightedScorer
picker Picker
}

// WithFilters sets the given filter plugins as the Filter plugins.
Expand All @@ -68,13 +66,6 @@ func (p *SchedulerProfile) WithPicker(picker Picker) *SchedulerProfile {
return p
}

// WithPostCyclePlugins sets the given plugins as the PostCycle plugins.
// If the SchedulerProfile has PostCycle plugins, this call replaces the existing plugins with the given ones.
func (p *SchedulerProfile) WithPostCyclePlugins(plugins ...PostCycle) *SchedulerProfile {
p.postCyclePlugins = plugins
return p
}

// AddPlugins adds the given plugins to all scheduler plugins according to the interfaces each plugin implements.
// A plugin may implement more than one scheduler plugin interface.
// Special Case: In order to add a scorer, one must use the scorer.NewWeightedScorer function in order to provide a weight.
Expand All @@ -97,15 +88,13 @@ func (p *SchedulerProfile) AddPlugins(pluginObjects ...plugins.Plugin) error {
}
p.picker = picker
}
if postCyclePlugin, ok := plugin.(PostCycle); ok {
p.postCyclePlugins = append(p.postCyclePlugins, postCyclePlugin)
}

}
return nil
}

// RunCycle runs a SchedulerProfile cycle. In other words, it invokes all the SchedulerProfile plugins in this
// order - Filters, Scorers, Picker, PostCyclePlugins. After completing all, it returns the result.
// order - Filters, Scorers, Picker. After completing all, it returns the result.
func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, candidatePods []types.Pod) (*types.ProfileRunResult, error) {
pods := p.runFilterPlugins(ctx, request, cycleState, candidatePods)
if len(pods) == 0 {
Expand All @@ -116,8 +105,6 @@ func (p *SchedulerProfile) Run(ctx context.Context, request *types.LLMRequest, c

result := p.runPickerPlugin(ctx, cycleState, weightedScorePerPod)

p.runPostCyclePlugins(ctx, cycleState, result)

return result, nil
}

Expand Down Expand Up @@ -182,12 +169,3 @@ func (p *SchedulerProfile) runPickerPlugin(ctx context.Context, cycleState *type

return result
}

func (p *SchedulerProfile) runPostCyclePlugins(ctx context.Context, cycleState *types.CycleState, result *types.ProfileRunResult) {
for _, plugin := range p.postCyclePlugins {
log.FromContext(ctx).V(logutil.DEBUG).Info("Running post-cycle plugin", "plugin", plugin.TypedName().Type)
before := time.Now()
plugin.PostCycle(ctx, cycleState, result)
metrics.RecordSchedulerPluginProcessingLatency(PostCyclePluginType, plugin.TypedName().Type, time.Since(before))
}
}
Loading