Skip to content

Commit 53500d6

Browse files
author
BrianPark314
committed
feat: add prefix aware picker
Signed-off-by: BrianPark314 <[email protected]>
1 parent e7edcbd commit 53500d6

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
Copyright 2025 The vLLM Production Stack Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
*/
10+
11+
package picker
12+
13+
import (
14+
"math/rand"
15+
"sync"
16+
"time"
17+
18+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
19+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
20+
)
21+
22+
var _ plugins.Picker = &PrefixMatchPicker{}
23+
24+
// PrefixMatchPicker selects the engine whose URL was returned by the
25+
// longest-prefix match against previously-seen prompts (same idea as the
26+
// Python `route_request`). Ties are broken at random.
27+
type PrefixMatchPicker struct {
28+
trie *hashTrie
29+
rnd *rand.Rand
30+
}
31+
32+
// NewPrefixMatchPicker returns a ready-to-use picker instance.
33+
func NewPrefixMatchPicker() *PrefixMatchPicker {
34+
return &PrefixMatchPicker{
35+
trie: newHashTrie(),
36+
rnd: rand.New(rand.NewSource(time.Now().UnixNano())),
37+
}
38+
}
39+
40+
func (p *PrefixMatchPicker) Name() string { return "prefixmatch" }
41+
42+
// Pick implements plugins.Picker.
43+
//
44+
// SchedulingContext is assumed to carry the inference request body in
45+
// ctx.RequestBody (map[string]any) with the prompt at key "prompt". Adjust
46+
// the accessor if your integration differs.
47+
func (p *PrefixMatchPicker) Pick(
48+
ctx *types.SchedulingContext,
49+
scoredPods []*types.ScoredPod,
50+
) *types.Result {
51+
if len(scoredPods) == 0 {
52+
return &types.Result{}
53+
}
54+
55+
prompt, _ := ctx.RequestBody["prompt"].(string)
56+
57+
// 1. Build the set of available endpoints.
58+
available := make(map[string]struct{}, len(scoredPods))
59+
for _, sp := range scoredPods {
60+
ep := sp.GetPod().EndpointURL // <-- adapt this accessor
61+
available[ep] = struct{}{}
62+
}
63+
64+
// 2. Longest-prefix match within the trie.
65+
matched := p.trie.longestPrefixMatch(prompt, available)
66+
67+
// 3. Fallback: no match --> all endpoints are candidates.
68+
if len(matched) == 0 {
69+
for ep := range available {
70+
matched[ep] = struct{}{}
71+
}
72+
}
73+
74+
// 4. Convert the matched set to a slice and pick randomly.
75+
endpoints := make([]string, 0, len(matched))
76+
for ep := range matched {
77+
endpoints = append(endpoints, ep)
78+
}
79+
selected := endpoints[p.rnd.Intn(len(endpoints))]
80+
81+
// 5. Cache the decision for future prefix look-ups.
82+
p.trie.insert(prompt, selected)
83+
84+
// 6. Return the pod whose URL matches `selected`.
85+
for _, sp := range scoredPods {
86+
if sp.GetPod().EndpointURL == selected { // same accessor as above
87+
return &types.Result{TargetPod: sp}
88+
}
89+
}
90+
// Should never hit; safe fallback.
91+
return &types.Result{TargetPod: scoredPods[0]}
92+
}
93+
94+
/*---------------------------- trie implementation ---------------------------*/
95+
96+
type hashTrie struct {
97+
mu sync.RWMutex
98+
children map[rune]*hashTrie
99+
endpoints map[string]struct{}
100+
}
101+
102+
func newHashTrie() *hashTrie {
103+
return &hashTrie{children: make(map[rune]*hashTrie)}
104+
}
105+
106+
func (t *hashTrie) insert(key, endpoint string) {
107+
t.mu.Lock()
108+
defer t.mu.Unlock()
109+
110+
node := t
111+
for _, r := range key {
112+
child, ok := node.children[r]
113+
if !ok {
114+
child = newHashTrie()
115+
node.children[r] = child
116+
}
117+
node = child
118+
}
119+
if node.endpoints == nil {
120+
node.endpoints = make(map[string]struct{})
121+
}
122+
node.endpoints[endpoint] = struct{}{}
123+
}
124+
125+
func (t *hashTrie) longestPrefixMatch(
126+
key string,
127+
available map[string]struct{},
128+
) map[string]struct{} {
129+
t.mu.RLock()
130+
defer t.mu.RUnlock()
131+
132+
var lastMatch map[string]struct{}
133+
node := t
134+
for _, r := range key {
135+
if node.endpoints != nil {
136+
lastMatch = node.endpoints
137+
}
138+
child, ok := node.children[r]
139+
if !ok {
140+
break
141+
}
142+
node = child
143+
}
144+
// Filter by `available`.
145+
res := make(map[string]struct{})
146+
for ep := range lastMatch {
147+
if _, ok := available[ep]; ok {
148+
res[ep] = struct{}{}
149+
}
150+
}
151+
return res
152+
}

0 commit comments

Comments
 (0)