Skip to content

Expand schduling.CycleState to the request lifecycle #1062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion conformance/testing-epp/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/google/go-cmp/cmp"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
Expand Down Expand Up @@ -117,7 +118,7 @@ func TestFilter(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := test.filter.Filter(context.Background(), types.NewCycleState(), test.req, test.input)
got := test.filter.Filter(context.Background(), plugins.NewCycleState(), test.req, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"strings"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
Expand Down Expand Up @@ -50,7 +51,7 @@ func (f *HeaderBasedTestingFilter) Type() string {
}

// Filter selects pods that match the IP addresses specified in the request header.
func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
headerValue, ok := request.Headers[headerTestEppEndPointSelectionKey]
if !ok || headerValue == "" {
return []types.Pod{}
Expand Down
3 changes: 2 additions & 1 deletion conformance/testing-epp/sheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/google/uuid"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

Expand Down Expand Up @@ -100,7 +101,7 @@ func TestSchedule(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scheduler := NewReqHeaderBasedScheduler()
got, err := scheduler.Schedule(context.Background(), test.req, types.ToSchedulerPodMetrics(test.input))
got, err := scheduler.Schedule(context.Background(), plugins.NewCycleState(), test.req, types.ToSchedulerPodMetrics(test.input))
if test.err != (err != nil) {
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package framework
import (
"context"

scheduling "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
)

type Endpoint struct {
Expand Down Expand Up @@ -114,15 +114,15 @@ type ProfileHandler interface {
// The framework will return an error to the client if the endpoints are filtered to zero.
type Filter interface {
Plugin
Filter(ctx context.Context, request *Request, state *scheduling.CycleState, endpoints []*Endpoint) []*Endpoint
Filter(ctx context.Context, request *Request, state *plugins.CycleState, endpoints []*Endpoint) []*Endpoint
}

// Scorer applies a score to each remaining endpoint provided.
// Scorers SHOULD keep their score values in a normalized range: [0-1].
// Any weighting should be added at the SchedulerProfile configuration level.
type Scorer interface {
Plugin
Score(ctx context.Context, request *Request, state *scheduling.CycleState, endpoints []*Endpoint) []*ScoredEndpoint
Score(ctx context.Context, request *Request, state *plugins.CycleState, endpoints []*Endpoint) []*ScoredEndpoint
}

// WeightedScorer is a struct that encapsulates a scorer with its weight.
Expand All @@ -138,5 +138,5 @@ type WeightedScorer struct {
// Picker MUST return, one endpoint at minimum.
type Picker interface {
Plugin
Pick(ctx context.Context, state *scheduling.CycleState, endpoints []*ScoredEndpoint) []*ScoredEndpoint
Pick(ctx context.Context, state *plugins.CycleState, endpoints []*ScoredEndpoint) []*ScoredEndpoint
}
12 changes: 6 additions & 6 deletions pkg/epp/common/config/loader/configloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ func (f *test1) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *test1) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *test1) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
return pods
}

Expand All @@ -571,11 +571,11 @@ func (f *test2) Type() string {
return test2Type
}

func (m *test2) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ []types.Pod) map[types.Pod]float64 {
func (m *test2) Score(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, _ []types.Pod) map[types.Pod]float64 {
return map[types.Pod]float64{}
}

func (m *test2) PostCycle(_ context.Context, _ *types.CycleState, _ *types.ProfileRunResult) {}
func (m *test2) PostCycle(_ context.Context, _ *plugins.CycleState, _ *types.ProfileRunResult) {}

// compile-time type validation
var _ framework.Picker = &testPicker{}
Expand All @@ -586,7 +586,7 @@ func (p *testPicker) Type() string {
return testPickerType
}

func (p *testPicker) Pick(_ context.Context, _ *types.CycleState, _ []*types.ScoredPod) *types.ProfileRunResult {
func (p *testPicker) Pick(_ context.Context, _ *plugins.CycleState, _ []*types.ScoredPod) *types.ProfileRunResult {
return nil
}

Expand All @@ -599,11 +599,11 @@ func (p *testProfileHandler) Type() string {
return testProfileHandlerType
}

func (p *testProfileHandler) Pick(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ map[string]*framework.SchedulerProfile, _ map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
func (p *testProfileHandler) Pick(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, _ map[string]*framework.SchedulerProfile, _ map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
return nil
}

func (p *testProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, _ map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
func (p *testProfileHandler) ProcessResults(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, _ map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
return nil, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package types
package plugins

import (
"errors"
Expand Down
18 changes: 11 additions & 7 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand All @@ -41,7 +42,7 @@ import (

// Scheduler defines the interface required by the Director for scheduling.
type Scheduler interface {
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
Schedule(ctx context.Context, cycleState *plugins.CycleState, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
}

// SaturationDetector provides a signal indicating whether the backends are considered saturated.
Expand Down Expand Up @@ -135,19 +136,22 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
}

// --- 3. Call Scheduler ---
// Initialize cycleState for the scheduling cycle
cycleState := plugins.NewCycleState()

// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request between all scheduling cycles.
candidatePods := schedulingtypes.ToSchedulerPodMetrics(d.datastore.PodGetAll())
results, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, candidatePods)
results, err := d.scheduler.Schedule(ctx, cycleState, reqCtx.SchedulingRequest, candidatePods)
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}

// --- 4. Prepare Request (Populates RequestContext and call PreRequest plugins) ---
// Insert target endpoint to instruct Envoy to route requests to the specified target pod and attach the port number.
// Invoke PreRequest registered plugins.
reqCtx, err = d.prepareRequest(ctx, reqCtx, results)
reqCtx, err = d.prepareRequest(ctx, reqCtx, cycleState, results)
if err != nil {
return reqCtx, err
}
Expand Down Expand Up @@ -178,7 +182,7 @@ func (d *Director) admitRequest(ctx context.Context, requestCriticality v1alpha2

// prepareRequest populates the RequestContext and calls the registered PreRequest plugins
// for allowing plugging customized logic based on the scheduling results.
func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) {
func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestContext, cycleState *plugins.CycleState, result *schedulingtypes.SchedulingResult) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)
if result == nil || len(result.ProfileResults) == 0 {
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
Expand All @@ -198,7 +202,7 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
reqCtx.TargetPod = targetPod
reqCtx.TargetEndpoint = endpoint

d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort)
d.runPreRequestPlugins(ctx, cycleState, reqCtx.SchedulingRequest, result, targetPort)

return reqCtx, nil
}
Expand Down Expand Up @@ -255,12 +259,12 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed
return ""
}

func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult,
func (d *Director) runPreRequestPlugins(ctx context.Context, cycleState *plugins.CycleState, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult,
targetPort int) {
for _, plugin := range d.preRequestPlugins {
log.FromContext(ctx).V(logutil.DEBUG).Info("Running pre-request plugin", "plugin", plugin.Type())
before := time.Now()
plugin.PreRequest(ctx, request, schedulingResult, targetPort)
plugin.PreRequest(ctx, cycleState, request, schedulingResult, targetPort)
metrics.RecordRequestControlPluginProcessingLatency(PreRequestPluginType, plugin.Type(), time.Since(before))
}
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand All @@ -58,7 +59,7 @@ type mockScheduler struct {
scheduleErr error
}

func (m *mockScheduler) Schedule(_ context.Context, _ *schedulingtypes.LLMRequest, _ []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) {
func (m *mockScheduler) Schedule(_ context.Context, _ *plugins.CycleState, _ *schedulingtypes.LLMRequest, _ []schedulingtypes.Pod) (*schedulingtypes.SchedulingResult, error) {
return m.scheduleResults, m.scheduleErr
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/requestcontrol/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const (
// before a request is sent to the selected model server.
type PreRequest interface {
plugins.Plugin
PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int)
PreRequest(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int)
}

// PostResponse is called by the director after a successful response was sent.
Expand Down
12 changes: 6 additions & 6 deletions pkg/epp/scheduling/framework/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,39 @@ type ProfileHandler interface {
plugins.Plugin
// Pick selects the SchedulingProfiles to run from a list of candidate profiles, while taking into consideration the request properties
// and the previously executed SchedluderProfile cycles along with their results.
Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*SchedulerProfile,
Pick(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, profiles map[string]*SchedulerProfile,
profileResults map[string]*types.ProfileRunResult) map[string]*SchedulerProfile

// ProcessResults handles the outcome of the profile runs after all profiles ran.
// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the
// key of the primary profile that should be used to get the request selected destination.
// When a profile run fails, its result in the profileResults map is nil.
ProcessResults(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest,
ProcessResults(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest,
profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error)
}

// Filter defines the interface for filtering a list of pods based on context.
type Filter interface {
plugins.Plugin
Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod
Filter(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod
}

// Scorer defines the interface for scoring a list of pods based on context.
// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score.
type Scorer interface {
plugins.Plugin
Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64
Score(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64
}

// Picker picks the final pod(s) to send the request to.
type Picker interface {
plugins.Plugin
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
Pick(ctx context.Context, cycleState *plugins.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult
}

// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
// DEPRECATED - do not use PostCycle. this is in the process of deprecation.
type PostCycle interface {
plugins.Plugin
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.ProfileRunResult)
PostCycle(ctx context.Context, cycleState *plugins.CycleState, res *types.ProfileRunResult)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
Expand Down Expand Up @@ -56,7 +57,7 @@ func (f *DecisionTreeFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *DecisionTreeFilter) Filter(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *DecisionTreeFilter) Filter(ctx context.Context, cycleState *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
filteredPod := f.Current.Filter(ctx, cycleState, request, pods)

Expand Down
7 changes: 4 additions & 3 deletions pkg/epp/scheduling/framework/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
Expand All @@ -39,7 +40,7 @@ func (f *filterAll) Type() string {
return "filter-all"
}

func (f *filterAll) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *filterAll) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
return []types.Pod{}
}

Expand Down Expand Up @@ -138,7 +139,7 @@ func TestFilter(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := test.filter.Filter(context.Background(), types.NewCycleState(), test.req, test.input)
got := test.filter.Filter(context.Background(), plugins.NewCycleState(), test.req, test.input)

if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
Expand Down Expand Up @@ -206,7 +207,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
LoraAffinityFilter := NewLoraAffinityFilter(config.Conf.LoraAffinityThreshold)

for range numIterations {
result := LoraAffinityFilter.Filter(context.Background(), types.NewCycleState(), req, pods)
result := LoraAffinityFilter.Filter(context.Background(), plugins.NewCycleState(), req, pods)

// Check which type of pod was returned
if len(result) != 1 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (f *LeastKVCacheFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filteredPods := []types.Pod{}

min := math.MaxFloat64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (f *LeastQueueFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LeastQueueFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *LeastQueueFilter) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filteredPods := []types.Pod{}

min := math.MaxInt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (f *LoraAffinityFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LoraAffinityFilter) Filter(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *LoraAffinityFilter) Filter(_ context.Context, _ *plugins.CycleState, request *types.LLMRequest, pods []types.Pod) []types.Pod {
// Pre-allocate slices with estimated capacity
filtered_affinity := make([]types.Pod, 0, len(pods))
filtered_available := make([]types.Pod, 0, len(pods))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (f *LowQueueFilter) Type() string {
}

// Filter filters out pods that doesn't meet the filter criteria.
func (f *LowQueueFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
func (f *LowQueueFilter) Filter(_ context.Context, _ *plugins.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
filteredPods := []types.Pod{}

for _, pod := range pods {
Expand Down
Loading