diff --git a/pkg/epp/flowcontrol/contracts/errors.go b/pkg/epp/flowcontrol/contracts/errors.go new file mode 100644 index 000000000..fd46ec710 --- /dev/null +++ b/pkg/epp/flowcontrol/contracts/errors.go @@ -0,0 +1,35 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package contracts + +import "errors" + +// Registry Errors +var ( + // ErrFlowInstanceNotFound indicates that a requested flow instance (a `ManagedQueue`) does not exist in the registry + // shard, either because the flow is not registered or the specific instance (e.g., a draining queue at a particular + // priority) is not present. + ErrFlowInstanceNotFound = errors.New("flow instance not found") + + // ErrPriorityBandNotFound indicates that a requested priority band does not exist in the registry because it was not + // part of the initial configuration. + ErrPriorityBandNotFound = errors.New("priority band not found") + + // ErrPolicyQueueIncompatible indicates that a selected policy is not compatible with the capabilities of the queue it + // is intended to operate on. For example, a policy requiring priority-based peeking is used with a simple FIFO queue. + ErrPolicyQueueIncompatible = errors.New("policy is not compatible with queue capabilities") +) diff --git a/pkg/epp/flowcontrol/contracts/registry.go b/pkg/epp/flowcontrol/contracts/registry.go index dbc855bef..843f501ff 100644 --- a/pkg/epp/flowcontrol/contracts/registry.go +++ b/pkg/epp/flowcontrol/contracts/registry.go @@ -18,22 +18,78 @@ limitations under the License. // primary dependencies. In alignment with a "Ports and Adapters" (or "Hexagonal") architectural style, these // interfaces represent the "ports" through which the engine communicates. // -// This package contains the primary service contracts for the Flow Registry and Saturation Detector. +// This package contains the primary service contracts for the Flow Registry, which acts as the control plane for all +// flow state and configuration. package contracts import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" ) +// RegistryShard defines the read-oriented interface that a `controller.FlowController` worker uses to access its +// specific slice (shard) of the `FlowRegistry`'s state. It provides the necessary methods for a worker to perform its +// dispatch operations by accessing queues and policies in a concurrent-safe manner. +// +// # Conformance +// +// All methods MUST be goroutine-safe. +type RegistryShard interface { + // ID returns a unique identifier for this shard, which must remain stable for the shard's lifetime. + ID() string + + // IsActive returns true if the shard should accept new requests for enqueueing. A false value indicates the shard is + // being gracefully drained and should not be given new work. + IsActive() bool + + // ActiveManagedQueue returns the currently active `ManagedQueue` for a given flow on this shard. This is the queue to + // which new requests for the flow should be enqueued. + // Returns an error wrapping `ErrFlowInstanceNotFound` if no active instance exists for the given `flowID`. + ActiveManagedQueue(flowID string) (ManagedQueue, error) + + // ManagedQueue retrieves a specific (potentially draining) `ManagedQueue` instance from this shard. This allows a + // worker to continue dispatching items from queues that are draining as part of a flow update. + // Returns an error wrapping `ErrFlowInstanceNotFound` if no instance for the given flowID and priority exists. + ManagedQueue(flowID string, priority uint) (ManagedQueue, error) + + // IntraFlowDispatchPolicy retrieves a flow's configured `framework.IntraFlowDispatchPolicy` for this shard. + // The registry guarantees that a non-nil default policy (as configured at the priority-band level) is returned if + // none is specified on the flow itself. + // Returns an error wrapping `ErrFlowInstanceNotFound` if the flow instance does not exist. + IntraFlowDispatchPolicy(flowID string, priority uint) (framework.IntraFlowDispatchPolicy, error) + + // InterFlowDispatchPolicy retrieves a priority band's configured `framework.InterFlowDispatchPolicy` for this shard. + // The registry guarantees that a non-nil default policy is returned if none is configured for the band. + // Returns an error wrapping `ErrPriorityBandNotFound` if the priority level is not configured. + InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error) + + // PriorityBandAccessor retrieves a read-only accessor for a given priority level, providing a view of the band's + // state as seen by this specific shard. This is the primary entry point for inter-flow dispatch policies that + // need to inspect and compare multiple flow queues within the same priority band. + // Returns an error wrapping `ErrPriorityBandNotFound` if the priority level is not configured. + PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error) + + // AllOrderedPriorityLevels returns all configured priority levels that this shard is aware of, sorted in ascending + // numerical order. This order corresponds to highest priority (lowest numeric value) to lowest priority (highest + // numeric value). + // The returned slice provides a definitive, ordered list of priority levels for iteration, for example, by a + // `controller.FlowController` worker's dispatch loop. + AllOrderedPriorityLevels() []uint + + // Stats returns a snapshot of the statistics for this specific shard. + Stats() ShardStats +} + // ManagedQueue defines the interface for a flow's queue instance on a specific shard. // It wraps an underlying `framework.SafeQueue`, augmenting it with lifecycle validation against the `FlowRegistry` and // integrating atomic statistics updates. // -// Conformance: +// # Conformance +// // - All methods (including those embedded from `framework.SafeQueue`) MUST be goroutine-safe. -// - Mutating methods (`Add()`, `Remove()`, `CleanupExpired()`, `Drain()`) MUST ensure the flow instance still exists -// and is valid within the `FlowRegistry` before proceeding. They MUST also atomically update relevant statistics -// (e.g., queue length, byte size) at both the queue and priority-band levels. +// - The `Add()` method MUST reject new items if the queue has been marked as "draining" by the `FlowRegistry`, +// ensuring that lifecycle changes are respected even by consumers holding a stale pointer to the queue. +// - All mutating methods (`Add()`, `Remove()`, `Cleanup()`, `Drain()`) MUST atomically update relevant statistics +// (e.g., queue length, byte size). type ManagedQueue interface { framework.SafeQueue @@ -43,3 +99,61 @@ type ManagedQueue interface { // Conformance: This method MUST NOT return nil. FlowQueueAccessor() framework.FlowQueueAccessor } + +// ShardStats holds statistics for a single internal shard within the `FlowRegistry`. +type ShardStats struct { + // TotalCapacityBytes is the optional, maximum total byte size limit aggregated across all priority bands within this + // shard. Its value represents the globally configured limit for the `FlowRegistry` partitioned for this shard. + // The `controller.FlowController` enforces this limit in addition to any per-band capacity limits. + // A value of 0 signifies that this global limit is ignored, and only per-band limits apply. + TotalCapacityBytes uint64 + // TotalByteSize is the total byte size of all items currently queued across all priority bands within this shard. + TotalByteSize uint64 + // TotalLen is the total number of items currently queued across all priority bands within this shard. + TotalLen uint64 + // PerPriorityBandStats maps each configured priority level to its statistics within this shard. + // The key is the numerical priority level. + // All configured priority levels are guaranteed to be represented. + PerPriorityBandStats map[uint]PriorityBandStats +} + +// DeepCopy returns a deep copy of the `ShardStats`. +func (s *ShardStats) DeepCopy() ShardStats { + if s == nil { + return ShardStats{} + } + newStats := *s + if s.PerPriorityBandStats != nil { + newStats.PerPriorityBandStats = make(map[uint]PriorityBandStats, len(s.PerPriorityBandStats)) + for k, v := range s.PerPriorityBandStats { + newStats.PerPriorityBandStats[k] = v.DeepCopy() + } + } + return newStats +} + +// PriorityBandStats holds aggregated statistics for a single priority band. +type PriorityBandStats struct { + // Priority is the numerical priority level this struct describes. + Priority uint + // PriorityName is an optional, human-readable name for the priority level (e.g., "Critical", "Sheddable"). + PriorityName string + // CapacityBytes is the configured maximum total byte size for this priority band, aggregated across all items in + // all flow queues within this band. If scoped to a shard, its value represents the configured band limit for the + // `FlowRegistry` partitioned for this shard. + // The `controller.FlowController` enforces this limit. + // A default non-zero value is guaranteed if not configured. + CapacityBytes uint64 + // ByteSize is the total byte size of items currently queued in this priority band. + ByteSize uint64 + // Len is the total number of items currently queued in this priority band. + Len uint64 +} + +// DeepCopy returns a deep copy of the `PriorityBandStats`. +func (s *PriorityBandStats) DeepCopy() PriorityBandStats { + if s == nil { + return PriorityBandStats{} + } + return *s +} diff --git a/pkg/epp/flowcontrol/registry/config.go b/pkg/epp/flowcontrol/registry/config.go new file mode 100644 index 000000000..ed98ab830 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/config.go @@ -0,0 +1,204 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "errors" + "fmt" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + inter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead" + intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue/listqueue" +) + +// Config holds the master configuration for the entire `FlowRegistry`. It serves as the top-level blueprint, defining +// global capacity limits and the structure of its priority bands. +// +// This master configuration is validated and defaulted once at startup. It is then partitioned and distributed to each +// internal `registryShard`, ensuring a consistent and predictable state across the system. +type Config struct { + // MaxBytes defines an optional, global maximum total byte size limit aggregated across all priority bands and shards. + // The `controller.FlowController` enforces this limit in addition to per-band capacity limits. + // + // Optional: Defaults to 0, which signifies that the global limit is ignored. + MaxBytes uint64 + + // PriorityBands defines the set of priority bands managed by the `FlowRegistry`. The configuration for each band, + // including its default policies and queue types, is specified here. + // + // Required: At least one `PriorityBandConfig` must be provided for a functional registry. + PriorityBands []PriorityBandConfig +} + +// partition calculates and returns a new `Config` with capacity values partitioned for a specific shard. +// This method ensures that the total capacity is distributed as evenly as possible across all shards. +func (c *Config) partition(shardIndex, totalShards int) (*Config, error) { + if totalShards <= 0 || shardIndex < 0 || shardIndex >= totalShards { + return nil, fmt.Errorf("invalid shard partitioning arguments: shardIndex=%d, totalShards=%d", + shardIndex, totalShards) + } + + partitionValue := func(total uint64) uint64 { + if total == 0 { + return 0 + } + base := total / uint64(totalShards) + remainder := total % uint64(totalShards) + if uint64(shardIndex) < remainder { + return base + 1 + } + return base + } + + newCfg := &Config{ + MaxBytes: partitionValue(c.MaxBytes), + PriorityBands: make([]PriorityBandConfig, len(c.PriorityBands)), + } + + for i, band := range c.PriorityBands { + newBand := band // Copy the original config + newBand.MaxBytes = partitionValue(band.MaxBytes) // Overwrite with the partitioned value + newCfg.PriorityBands[i] = newBand + } + + return newCfg, nil +} + +// validateAndApplyDefaults checks the configuration for validity and populates any empty fields with system defaults. +// This method should be called once by the registry before it initializes any shards. +func (c *Config) validateAndApplyDefaults() error { + if len(c.PriorityBands) == 0 { + return errors.New("config validation failed: at least one priority band must be defined") + } + + priorities := make(map[uint]struct{}) // Keep track of seen priorities + + for i := range c.PriorityBands { + band := &c.PriorityBands[i] + if _, exists := priorities[band.Priority]; exists { + return fmt.Errorf("config validation failed: duplicate priority level %d found", band.Priority) + } + priorities[band.Priority] = struct{}{} + + if band.PriorityName == "" { + return errors.New("config validation failed: PriorityName is required for all priority bands") + } + if band.IntraFlowDispatchPolicy == "" { + band.IntraFlowDispatchPolicy = fcfs.FCFSPolicyName + } + if band.InterFlowDispatchPolicy == "" { + band.InterFlowDispatchPolicy = besthead.BestHeadPolicyName + } + if band.Queue == "" { + band.Queue = listqueue.ListQueueName + } + + // After defaulting, validate that the chosen plugins are compatible. + if err := validateBandCompatibility(*band); err != nil { + return err + } + } + return nil +} + +// validateBandCompatibility verifies that a band's default policy is compatible with its default queue type. +func validateBandCompatibility(band PriorityBandConfig) error { + policy, err := intra.NewPolicyFromName(band.IntraFlowDispatchPolicy) + if err != nil { + return fmt.Errorf("failed to validate policy %q for priority band %d: %w", + band.IntraFlowDispatchPolicy, band.Priority, err) + } + + requiredCapabilities := policy.RequiredQueueCapabilities() + if len(requiredCapabilities) == 0 { + return nil // Policy has no specific requirements. + } + + // Create a temporary queue instance to inspect its capabilities. + tempQueue, err := queue.NewQueueFromName(band.Queue, nil) + if err != nil { + return fmt.Errorf("failed to inspect queue type %q for priority band %d: %w", band.Queue, band.Priority, err) + } + queueCapabilities := tempQueue.Capabilities() + + // Build a set of the queue's capabilities for efficient lookup. + capabilitySet := make(map[framework.QueueCapability]struct{}, len(queueCapabilities)) + for _, cap := range queueCapabilities { + capabilitySet[cap] = struct{}{} + } + + // Check if all required capabilities are present. + for _, req := range requiredCapabilities { + if _, ok := capabilitySet[req]; !ok { + return fmt.Errorf( + "policy %q is not compatible with queue %q for priority band %d (%s): missing capability %q: %w", + policy.Name(), + tempQueue.Name(), + band.Priority, + band.PriorityName, + req, + contracts.ErrPolicyQueueIncompatible, + ) + } + } + + return nil +} + +// PriorityBandConfig defines the configuration for a single priority band within the `FlowRegistry`. It establishes the +// default behaviors (such as queueing and dispatch policies) and capacity limits for all flows that operate at this +// priority level. +type PriorityBandConfig struct { + // Priority is the numerical priority level for this band. + // Convention: Lower numerical values indicate higher priority (e.g., 0 is highest). + // + // Required. + Priority uint + + // PriorityName is a human-readable name for this priority band (e.g., "Critical", "Standard", "Sheddable"). + // + // Required. + PriorityName string + + // IntraFlowDispatchPolicy specifies the default name of the registered policy used to select a specific request to + // dispatch next from within a single flow's queue in this band. This default can be overridden on a per-flow basis. + // + // Optional: If empty, a system default (e.g., "FCFS") is used. + IntraFlowDispatchPolicy intra.RegisteredPolicyName + + // InterFlowDispatchPolicy specifies the name of the registered policy used to select which flow's queue to service + // next from this band. + // + // Optional: If empty, a system default (e.g., "BestHead") is used. + InterFlowDispatchPolicy inter.RegisteredPolicyName + + // Queue specifies the default name of the registered SafeQueue implementation to be used for flow queues within this + // band. + // + // Optional: If empty, a system default (e.g., "ListQueue") is used. + Queue queue.RegisteredQueueName + + // MaxBytes defines the maximum total byte size for this specific priority band, aggregated across all shards. + // + // Optional: If not set, a system default (e.g., 1 GB) is applied. + MaxBytes uint64 +} diff --git a/pkg/epp/flowcontrol/registry/config_test.go b/pkg/epp/flowcontrol/registry/config_test.go new file mode 100644 index 000000000..0cf9f93d8 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/config_test.go @@ -0,0 +1,292 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/roundrobin" + intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue/listqueue" +) + +func TestConfig_ValidateAndApplyDefaults(t *testing.T) { + t.Parallel() + + // Setup for failure injection tests + failingPolicyName := intra.RegisteredPolicyName("failing-policy-for-config-test") + intra.MustRegisterPolicy(failingPolicyName, func() (framework.IntraFlowDispatchPolicy, error) { + return nil, errors.New("policy instantiation failed") + }) + failingQueueName := queue.RegisteredQueueName("failing-queue-for-config-test") + queue.MustRegisterQueue(failingQueueName, func(_ framework.ItemComparator) (framework.SafeQueue, error) { + return nil, errors.New("queue instantiation failed") + }) + + // Setup a mock policy with a specific capability requirement to test the compatibility check. + const mockCapability = framework.QueueCapability("TEST_CAPABILITY_FOR_CONFIG") + policyWithReqName := intra.RegisteredPolicyName("policy-with-req-for-config-test") + intra.MustRegisterPolicy(policyWithReqName, func() (framework.IntraFlowDispatchPolicy, error) { + return &mocks.MockIntraFlowDispatchPolicy{ + NameV: string(policyWithReqName), + RequiredQueueCapabilitiesV: []framework.QueueCapability{ + mockCapability, + }, + }, nil + }) + + testCases := []struct { + name string + input *Config + expectErr bool + expectedErrIs error + expectedCfg *Config + }{ + { + name: "Valid config with missing defaults", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: 1, PriorityName: "High"}, + {Priority: 2, PriorityName: "Low", InterFlowDispatchPolicy: roundrobin.RoundRobinPolicyName}, + }, + }, + expectErr: false, + expectedCfg: &Config{ + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + InterFlowDispatchPolicy: besthead.BestHeadPolicyName, + Queue: listqueue.ListQueueName, + }, + { + Priority: 2, + PriorityName: "Low", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + InterFlowDispatchPolicy: roundrobin.RoundRobinPolicyName, + Queue: listqueue.ListQueueName, + }, + }, + }, + }, + { + name: "Config with all fields specified and compatible", + input: &Config{ + MaxBytes: 1000, + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, // Compatible with ListQueue + InterFlowDispatchPolicy: besthead.BestHeadPolicyName, + Queue: listqueue.ListQueueName, + MaxBytes: 500, + }, + }, + }, + expectErr: false, + expectedCfg: &Config{ // Should be unchanged + MaxBytes: 1000, + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + InterFlowDispatchPolicy: besthead.BestHeadPolicyName, + Queue: listqueue.ListQueueName, + MaxBytes: 500, + }, + }, + }, + }, + { + name: "Error: No priority bands", + input: &Config{PriorityBands: []PriorityBandConfig{}}, + expectErr: true, + }, + { + name: "Error: Missing PriorityName", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: 1}, + }, + }, + expectErr: true, + }, + { + name: "Error: Duplicate priority level", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: 1, PriorityName: "High"}, + {Priority: 1, PriorityName: "Also High"}, + }, + }, + expectErr: true, + }, + { + name: "Error: Incompatible policy and queue", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + IntraFlowDispatchPolicy: policyWithReqName, // Requires mock capability + Queue: listqueue.ListQueueName, // Does not provide it + }, + }, + }, + expectErr: true, + expectedErrIs: contracts.ErrPolicyQueueIncompatible, + }, + { + name: "Error: Failing policy instantiation", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + IntraFlowDispatchPolicy: failingPolicyName, + Queue: listqueue.ListQueueName, + }, + }, + }, + expectErr: true, + }, + { + name: "Error: Failing queue instantiation", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + Queue: failingQueueName, + }, + }, + }, + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.input.validateAndApplyDefaults() + if tc.expectErr { + require.Error(t, err, "Expected an error for this test case") + if tc.expectedErrIs != nil { + assert.ErrorIs(t, err, tc.expectedErrIs, "Error should be of the expected type") + } + } else { + require.NoError(t, err, "Did not expect an error for this test case") + assert.Equal(t, tc.expectedCfg, tc.input, "Config after applying defaults does not match expected config") + } + }) + } +} + +func TestConfig_Partition(t *testing.T) { + t.Parallel() + + baseConfig := &Config{ + MaxBytes: 103, + PriorityBands: []PriorityBandConfig{ + {Priority: 1, PriorityName: "High", MaxBytes: 55}, + {Priority: 2, PriorityName: "Low", MaxBytes: 0}, // Should remain 0 + }, + } + + t.Run("EvenDistributionWithRemainder", func(t *testing.T) { + t.Parallel() + totalShards := 10 + // Global: 103 / 10 = 10 remainder 3. First 3 shards get 11, rest get 10. + // Band 1: 55 / 10 = 5 remainder 5. First 5 shards get 6, rest get 5. + expectedGlobalBytes := []uint64{11, 11, 11, 10, 10, 10, 10, 10, 10, 10} + expectedBand1Bytes := []uint64{6, 6, 6, 6, 6, 5, 5, 5, 5, 5} + + var totalGlobal, totalBand1 uint64 + for i := range totalShards { + partitioned, err := baseConfig.partition(i, totalShards) + require.NoError(t, err, "Partitioning should not fail for shard %d", i) + assert.Equal(t, expectedGlobalBytes[i], partitioned.MaxBytes, "Global MaxBytes for shard %d is incorrect", i) + require.Len(t, partitioned.PriorityBands, 2, "Partitioned config should have the same number of bands") + assert.Equal(t, expectedBand1Bytes[i], partitioned.PriorityBands[0].MaxBytes, + "Band 1 MaxBytes for shard %d is incorrect", i) + assert.Zero(t, partitioned.PriorityBands[1].MaxBytes, "Band 2 MaxBytes should remain zero for shard %d", i) + totalGlobal += partitioned.MaxBytes + totalBand1 += partitioned.PriorityBands[0].MaxBytes + } + assert.Equal(t, baseConfig.MaxBytes, totalGlobal, "Sum of partitioned global MaxBytes should equal original") + assert.Equal(t, baseConfig.PriorityBands[0].MaxBytes, totalBand1, + "Sum of partitioned band 1 MaxBytes should equal original") + }) + + t.Run("SingleShard", func(t *testing.T) { + t.Parallel() + partitioned, err := baseConfig.partition(0, 1) + require.NoError(t, err, "Partitioning for a single shard should not fail") + assert.Equal(t, baseConfig.MaxBytes, partitioned.MaxBytes, "Global MaxBytes should be unchanged for a single shard") + require.Len(t, partitioned.PriorityBands, 2, "Partitioned config should have the same number of bands") + assert.Equal(t, baseConfig.PriorityBands[0].MaxBytes, partitioned.PriorityBands[0].MaxBytes, + "Band 1 MaxBytes should be unchanged for a single shard") + }) + + t.Run("EmptyPriorityBands", func(t *testing.T) { + t.Parallel() + config := &Config{ + MaxBytes: 100, + PriorityBands: []PriorityBandConfig{}, + } + partitioned, err := config.partition(1, 3) + require.NoError(t, err, "Partitioning should not fail for empty priority bands") + assert.Equal(t, uint64(33), partitioned.MaxBytes, "Global MaxBytes should be partitioned correctly") + assert.Empty(t, partitioned.PriorityBands, "PriorityBands slice should be empty") + assert.NotNil(t, partitioned.PriorityBands, "PriorityBands slice should not be nil") + }) + + t.Run("ErrorHandling", func(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + shardIndex int + totalShards int + }{ + {"NegativeShardIndex", -1, 5}, + {"ShardIndexOutOfBounds", 5, 5}, + {"ZeroTotalShards", 0, 0}, + {"NegativeTotalShards", 0, -1}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := baseConfig.partition(tc.shardIndex, tc.totalShards) + assert.Error(t, err, "Expected an error for invalid partitioning arguments") + }) + } + }) +} diff --git a/pkg/epp/flowcontrol/registry/doc.go b/pkg/epp/flowcontrol/registry/doc.go new file mode 100644 index 000000000..521ad6e7b --- /dev/null +++ b/pkg/epp/flowcontrol/registry/doc.go @@ -0,0 +1,40 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package registry provides the concrete implementation of the Flow Registry. +// +// As the stateful control plane for the entire Flow Control system, this package is responsible for managing the +// lifecycle of all flows, queues, and policies. It serves as the "adapter" that implements the service "ports" +// (interfaces) defined in the `contracts` package. It provides a sharded, concurrent-safe view of its state to the +// `controller.FlowController` workers, enabling scalable, parallel request processing. +// +// # Key Components +// +// - `FlowRegistry`: The future top-level administrative object that will manage the entire system, including shard +// lifecycles and flow registration. (Not yet implemented). +// +// - `registryShard`: A concrete implementation of the `contracts.RegistryShard` interface. It represents a single, +// concurrent-safe slice of the registry's state, containing a set of priority bands and the flow queues within +// them. This is the primary object a `controller.FlowController` worker interacts with. +// +// - `managedQueue`: A concrete implementation of the `contracts.ManagedQueue` interface. It acts as a stateful +// decorator around a `framework.SafeQueue`, adding critical registry-level functionality such as atomic statistics +// tracking and lifecycle state enforcement (active vs. draining). +// +// - `Config`: The top-level configuration object that defines the structure and default behaviors of the registry, +// including the definition of priority bands and default policy selections. This configuration is partitioned and +// distributed to each `registryShard`. +package registry diff --git a/pkg/epp/flowcontrol/registry/managedqueue.go b/pkg/epp/flowcontrol/registry/managedqueue.go index b8bc1022f..bea7e7e4f 100644 --- a/pkg/epp/flowcontrol/registry/managedqueue.go +++ b/pkg/epp/flowcontrol/registry/managedqueue.go @@ -14,14 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package registry provides the concrete implementation of the Flow Registry, which is the stateful control plane for -// the Flow Control system. It implements the service interfaces defined in the `contracts` package. -// -// This initial version includes the implementation of the `contracts.ManagedQueue`, a stateful wrapper that adds atomic -// statistics tracking to a `framework.SafeQueue`. package registry import ( + "fmt" "sync/atomic" "github.com/go-logr/logr" @@ -29,18 +25,25 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +// parentStatsReconciler defines the callback function that a `managedQueue` uses to propagate its statistics changes up +// to its parent `registryShard`. type parentStatsReconciler func(lenDelta, byteSizeDelta int64) -// managedQueue implements `contracts.ManagedQueue`. It wraps a `framework.SafeQueue` and is responsible for maintaining -// accurate, atomically-updated statistics that are aggregated at the shard level. +// managedQueue implements `contracts.ManagedQueue`. It is a stateful decorator that wraps a `framework.SafeQueue`, +// augmenting it with two critical, registry-level responsibilities: +// 1. Atomic Statistics: It maintains its own `len` and `byteSize` counters, which are updated atomically. This allows +// the parent `registryShard` to aggregate statistics across many queues without locks. +// 2. Lifecycle Enforcement: It tracks the queue's lifecycle state (active vs. draining) via an `isActive` flag. This +// is crucial for graceful flow updates, as it allows the registry to stop new requests from being enqueued while +// allowing existing items to be drained. // // # Statistical Integrity // // For performance, `managedQueue` maintains its own `len` and `byteSize` fields using atomic operations. This provides -// O(1) access for the parent `contracts.RegistryShard`'s aggregated statistics without needing to lock the underlying -// queue. +// O(1) access for the parent `registryShard`'s aggregated statistics without needing to lock the underlying queue. // // This design is predicated on two critical assumptions: // 1. Exclusive Access: All mutating operations on the underlying `framework.SafeQueue` MUST be performed exclusively @@ -60,6 +63,7 @@ type managedQueue struct { flowSpec types.FlowSpecification byteSize atomic.Uint64 len atomic.Uint64 + isActive atomic.Bool reconcileShardStats parentStatsReconciler logger logr.Logger } @@ -70,45 +74,68 @@ func newManagedQueue( dispatchPolicy framework.IntraFlowDispatchPolicy, flowSpec types.FlowSpecification, logger logr.Logger, - reconcileShardStats func(lenDelta, byteSizeDelta int64), + reconcileShardStats parentStatsReconciler, ) *managedQueue { mqLogger := logger.WithName("managed-queue").WithValues( "flowID", flowSpec.ID, "priority", flowSpec.Priority, "queueType", queue.Name(), ) - return &managedQueue{ + mq := &managedQueue{ queue: queue, dispatchPolicy: dispatchPolicy, flowSpec: flowSpec, reconcileShardStats: reconcileShardStats, logger: mqLogger, } + mq.isActive.Store(true) + return mq +} + +// markAsDraining is an internal method called by the parent shard to transition this queue to a draining state. +// Once a queue is marked as draining, it will no longer accept new items via `Add`. +func (mq *managedQueue) markAsDraining() { + // Use CompareAndSwap to ensure we only log the transition once. + if mq.isActive.CompareAndSwap(true, false) { + mq.logger.V(logging.DEFAULT).Info("Queue marked as draining") + } } -// FlowQueueAccessor returns a new `flowQueueAccessor` instance. +// FlowQueueAccessor returns a new `flowQueueAccessor` instance, which provides a read-only, policy-facing view of the +// queue. func (mq *managedQueue) FlowQueueAccessor() framework.FlowQueueAccessor { return &flowQueueAccessor{mq: mq} } +// Add first checks if the queue is active. If it is, it wraps the underlying `framework.SafeQueue.Add` call and +// atomically updates the queue's and the parent shard's statistics. func (mq *managedQueue) Add(item types.QueueItemAccessor) error { + if !mq.isActive.Load() { + return fmt.Errorf("flow instance %q is not active and cannot accept new requests: %w", + mq.flowSpec.ID, contracts.ErrFlowInstanceNotFound) + } if err := mq.queue.Add(item); err != nil { return err } mq.reconcileStats(1, int64(item.OriginalRequest().ByteSize())) + mq.logger.V(logging.TRACE).Info("Request added to queue", "requestID", item.OriginalRequest().ID()) return nil } +// Remove wraps the underlying `framework.SafeQueue.Remove` call and atomically updates statistics upon successful +// removal. func (mq *managedQueue) Remove(handle types.QueueItemHandle) (types.QueueItemAccessor, error) { removedItem, err := mq.queue.Remove(handle) if err != nil { return nil, err } mq.reconcileStats(-1, -int64(removedItem.OriginalRequest().ByteSize())) - // TODO: If mq.len.Load() == 0, signal shard for optimistic instance cleanup. + mq.logger.V(logging.TRACE).Info("Request removed from queue", "requestID", removedItem.OriginalRequest().ID()) return removedItem, nil } +// Cleanup wraps the underlying `framework.SafeQueue.Cleanup` call and atomically updates statistics for all removed +// items. func (mq *managedQueue) Cleanup(predicate framework.PredicateFunc) (cleanedItems []types.QueueItemAccessor, err error) { cleanedItems, err = mq.queue.Cleanup(predicate) if err != nil || len(cleanedItems) == 0 { @@ -122,10 +149,12 @@ func (mq *managedQueue) Cleanup(predicate framework.PredicateFunc) (cleanedItems byteSizeDelta -= int64(item.OriginalRequest().ByteSize()) } mq.reconcileStats(lenDelta, byteSizeDelta) - // TODO: If mq.len.Load() == 0, signal shard for optimistic instance cleanup. + mq.logger.V(logging.DEBUG).Info("Cleaned up queue", "removedItemCount", len(cleanedItems), + "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) return cleanedItems, nil } +// Drain wraps the underlying `framework.SafeQueue.Drain` call and atomically updates statistics for all removed items. func (mq *managedQueue) Drain() ([]types.QueueItemAccessor, error) { drainedItems, err := mq.queue.Drain() if err != nil || len(drainedItems) == 0 { @@ -139,7 +168,8 @@ func (mq *managedQueue) Drain() ([]types.QueueItemAccessor, error) { byteSizeDelta -= int64(item.OriginalRequest().ByteSize()) } mq.reconcileStats(lenDelta, byteSizeDelta) - // TODO: If mq.len.Load() == 0, signal shard for optimistic instance cleanup. + mq.logger.V(logging.DEBUG).Info("Drained queue", "itemCount", len(drainedItems), + "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) return drainedItems, nil } @@ -150,18 +180,34 @@ func (mq *managedQueue) reconcileStats(lenDelta, byteSizeDelta int64) { // two's complement arithmetic. mq.len.Add(uint64(lenDelta)) mq.byteSize.Add(uint64(byteSizeDelta)) - mq.reconcileShardStats(lenDelta, byteSizeDelta) + if mq.reconcileShardStats != nil { + mq.reconcileShardStats(lenDelta, byteSizeDelta) + } } // --- Pass-through and accessor methods --- -func (mq *managedQueue) Name() string { return mq.queue.Name() } -func (mq *managedQueue) Capabilities() []framework.QueueCapability { return mq.queue.Capabilities() } -func (mq *managedQueue) Len() int { return int(mq.len.Load()) } -func (mq *managedQueue) ByteSize() uint64 { return mq.byteSize.Load() } +// Name returns the name of the underlying queue implementation. +func (mq *managedQueue) Name() string { return mq.queue.Name() } + +// Capabilities returns the capabilities of the underlying queue implementation. +func (mq *managedQueue) Capabilities() []framework.QueueCapability { return mq.queue.Capabilities() } + +// Len returns the number of items in the queue. +func (mq *managedQueue) Len() int { return int(mq.len.Load()) } + +// ByteSize returns the total byte size of all items in the queue. +func (mq *managedQueue) ByteSize() uint64 { return mq.byteSize.Load() } + +// PeekHead returns the item at the front of the queue without removing it. func (mq *managedQueue) PeekHead() (types.QueueItemAccessor, error) { return mq.queue.PeekHead() } + +// PeekTail returns the item at the back of the queue without removing it. func (mq *managedQueue) PeekTail() (types.QueueItemAccessor, error) { return mq.queue.PeekTail() } -func (mq *managedQueue) Comparator() framework.ItemComparator { return mq.dispatchPolicy.Comparator() } + +// Comparator returns the `framework.ItemComparator` that defines this queue's item ordering logic, as dictated by its +// configured dispatch policy. +func (mq *managedQueue) Comparator() framework.ItemComparator { return mq.dispatchPolicy.Comparator() } var _ contracts.ManagedQueue = &managedQueue{} @@ -173,13 +219,28 @@ type flowQueueAccessor struct { mq *managedQueue } -func (a *flowQueueAccessor) Name() string { return a.mq.Name() } -func (a *flowQueueAccessor) Capabilities() []framework.QueueCapability { return a.mq.Capabilities() } -func (a *flowQueueAccessor) Len() int { return a.mq.Len() } -func (a *flowQueueAccessor) ByteSize() uint64 { return a.mq.ByteSize() } +// Name returns the name of the queue. +func (a *flowQueueAccessor) Name() string { return a.mq.Name() } + +// Capabilities returns the capabilities of the queue. +func (a *flowQueueAccessor) Capabilities() []framework.QueueCapability { return a.mq.Capabilities() } + +// Len returns the number of items in the queue. +func (a *flowQueueAccessor) Len() int { return a.mq.Len() } + +// ByteSize returns the total byte size of all items in the queue. +func (a *flowQueueAccessor) ByteSize() uint64 { return a.mq.ByteSize() } + +// PeekHead returns the item at the front of the queue without removing it. func (a *flowQueueAccessor) PeekHead() (types.QueueItemAccessor, error) { return a.mq.PeekHead() } + +// PeekTail returns the item at the back of the queue without removing it. func (a *flowQueueAccessor) PeekTail() (types.QueueItemAccessor, error) { return a.mq.PeekTail() } -func (a *flowQueueAccessor) Comparator() framework.ItemComparator { return a.mq.Comparator() } -func (a *flowQueueAccessor) FlowSpec() types.FlowSpecification { return a.mq.flowSpec } + +// Comparator returns the `framework.ItemComparator` that defines this queue's item ordering logic. +func (a *flowQueueAccessor) Comparator() framework.ItemComparator { return a.mq.Comparator() } + +// FlowSpec returns the `types.FlowSpecification` of the flow this queue accessor is associated with. +func (a *flowQueueAccessor) FlowSpec() types.FlowSpecification { return a.mq.flowSpec } var _ framework.FlowQueueAccessor = &flowQueueAccessor{} diff --git a/pkg/epp/flowcontrol/registry/managedqueue_test.go b/pkg/epp/flowcontrol/registry/managedqueue_test.go index b94c055a3..6fc172feb 100644 --- a/pkg/epp/flowcontrol/registry/managedqueue_test.go +++ b/pkg/epp/flowcontrol/registry/managedqueue_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" @@ -104,6 +105,7 @@ func TestManagedQueue_New(t *testing.T) { assert.Zero(t, f.mq.Len(), "A new managedQueue should have a length of 0") assert.Zero(t, f.mq.ByteSize(), "A new managedQueue should have a byte size of 0") + assert.True(t, f.mq.isActive.Load(), "A new managedQueue should be active") } func TestManagedQueue_Add(t *testing.T) { @@ -113,7 +115,9 @@ func TestManagedQueue_Add(t *testing.T) { name string itemByteSize uint64 mockAddError error + markAsDraining bool expectError bool + expectedErrorIs error expectedLen int expectedByteSize uint64 expectedLenDelta int64 @@ -123,7 +127,6 @@ func TestManagedQueue_Add(t *testing.T) { { name: "Success", itemByteSize: 100, - mockAddError: nil, expectError: false, expectedLen: 1, expectedByteSize: 100, @@ -142,6 +145,18 @@ func TestManagedQueue_Add(t *testing.T) { expectedByteSizeDelta: 0, expectedReconcile: false, }, + { + name: "Error on inactive queue", + itemByteSize: 100, + markAsDraining: true, + expectError: true, + expectedErrorIs: contracts.ErrFlowInstanceNotFound, + expectedLen: 0, + expectedByteSize: 0, + expectedLenDelta: 0, + expectedByteSizeDelta: 0, + expectedReconcile: false, + }, } for _, tc := range testCases { @@ -154,11 +169,19 @@ func TestManagedQueue_Add(t *testing.T) { return tc.mockAddError } + if tc.markAsDraining { + f.mq.markAsDraining() + assert.False(t, f.mq.isActive.Load(), "Setup: queue should be marked as inactive") + } + item := typesmocks.NewMockQueueItemAccessor(tc.itemByteSize, "req-1", "test-flow") err := f.mq.Add(item) if tc.expectError { require.Error(t, err, "Add should have returned an error") + if tc.expectedErrorIs != nil { + assert.ErrorIs(t, err, tc.expectedErrorIs, "Error should wrap the expected sentinel error") + } } else { require.NoError(t, err, "Add should not have returned an error") } @@ -308,7 +331,7 @@ func TestManagedQueue_CleanupAndDrain(t *testing.T) { // --- Test Error Paths --- t.Run("ErrorPaths", func(t *testing.T) { f := setupTestManagedQueue(t) - require.NoError(t, f.mq.Add(item1)) + require.NoError(t, f.mq.Add(item1), "Setup: Adding an item should not fail") initialLen, initialByteSize := f.mq.Len(), f.mq.ByteSize() expectedErr := errors.New("internal error") diff --git a/pkg/epp/flowcontrol/registry/shard.go b/pkg/epp/flowcontrol/registry/shard.go new file mode 100644 index 000000000..1e894f07a --- /dev/null +++ b/pkg/epp/flowcontrol/registry/shard.go @@ -0,0 +1,340 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "fmt" + "slices" + "sync" + "sync/atomic" + + "github.com/go-logr/logr" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + inter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch" + intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// registryShard implements the `contracts.RegistryShard` interface. It represents a single, concurrent-safe slice of +// the `FlowRegistry`'s state, providing an operational view for a single `controller.FlowController` worker. +// +// # Responsibilities +// +// - Holding the partitioned configuration and state (queues, policies) for its assigned shard. +// - Providing read-only access to its state for the `controller.FlowController`'s dispatch loop. +// - Aggregating statistics from its `managedQueue` instances. +// +// # Concurrency +// +// The `registryShard` uses a combination of an `RWMutex` and atomic operations to manage concurrency. +// - The `mu` RWMutex protects the shard's internal maps (`priorityBands`, `activeFlows`) during administrative +// operations like flow registration or updates. This ensures that the set of active or draining queues appears +// atomic to a `controller.FlowController` worker. All read-oriented methods on the shard take a read lock. +// - All statistics (`totalByteSize`, `totalLen`, etc.) are implemented as `atomic.Uint64` to allow for lock-free, +// high-performance updates from many concurrent queue operations. +type registryShard struct { + id string + logger logr.Logger + config *Config // Holds the *partitioned* config for this shard. + isActive bool + reconcileFun parentStatsReconciler + + // mu protects the shard's internal maps (`priorityBands` and `activeFlows`). + mu sync.RWMutex + + // priorityBands is the primary lookup table for all managed queues on this shard, organized by `priority`, then by + // `flowID`. This map contains BOTH active and draining queues. + priorityBands map[uint]*priorityBand + + // activeFlows is a flattened map for O(1) access to the SINGLE active queue for a given logical flow ID. + // This is the critical lookup for the `Enqueue` path. If a `flowID` is not in this map, it has no active queue on + // this shard. + activeFlows map[string]*managedQueue + + // orderedPriorityLevels is a cached, sorted list of `priority` levels. + // It is populated at initialization to avoid repeated map key iteration and sorting during the dispatch loop, + // ensuring a deterministic, ordered traversal from highest to lowest priority. + orderedPriorityLevels []uint + + // Shard-level statistics, which are updated atomically to ensure they are safe for concurrent access without locks. + totalByteSize atomic.Uint64 + totalLen atomic.Uint64 +} + +// priorityBand holds all the `managedQueues` and configuration for a single priority level within a shard. +type priorityBand struct { + // config holds the partitioned config for this specific band. + config PriorityBandConfig + + // queues holds all `managedQueue` instances within this band, keyed by `flowID`. This includes both active and + // draining queues. + queues map[string]*managedQueue + + // Band-level statistics, which are updated atomically. + byteSize atomic.Uint64 + len atomic.Uint64 + + // Cached policy instances for this band, created at initialization. + interFlowDispatchPolicy framework.InterFlowDispatchPolicy + defaultIntraFlowDispatchPolicy framework.IntraFlowDispatchPolicy +} + +// newShard creates a new `registryShard` instance from a partitioned configuration. +func newShard( + id string, + partitionedConfig *Config, + logger logr.Logger, + reconcileFunc parentStatsReconciler, +) (*registryShard, error) { + shardLogger := logger.WithName("registry-shard").WithValues("shardID", id) + s := ®istryShard{ + id: id, + logger: shardLogger, + config: partitionedConfig, + isActive: true, + reconcileFun: reconcileFunc, + priorityBands: make(map[uint]*priorityBand, len(partitionedConfig.PriorityBands)), + activeFlows: make(map[string]*managedQueue), + } + + for _, bandConfig := range partitionedConfig.PriorityBands { + interPolicy, err := inter.NewPolicyFromName(bandConfig.InterFlowDispatchPolicy) + if err != nil { + return nil, fmt.Errorf("failed to create inter-flow policy %q for priority band %d: %w", + bandConfig.InterFlowDispatchPolicy, bandConfig.Priority, err) + } + + intraPolicy, err := intra.NewPolicyFromName(bandConfig.IntraFlowDispatchPolicy) + if err != nil { + return nil, fmt.Errorf("failed to create intra-flow policy %q for priority band %d: %w", + bandConfig.IntraFlowDispatchPolicy, bandConfig.Priority, err) + } + + s.priorityBands[bandConfig.Priority] = &priorityBand{ + config: bandConfig, + queues: make(map[string]*managedQueue), + interFlowDispatchPolicy: interPolicy, + defaultIntraFlowDispatchPolicy: intraPolicy, + } + s.orderedPriorityLevels = append(s.orderedPriorityLevels, bandConfig.Priority) + } + + // Sort the priority levels to ensure deterministic iteration order. + slices.Sort(s.orderedPriorityLevels) + s.logger.V(logging.DEFAULT).Info("Registry shard initialized successfully", + "priorityBandCount", len(s.priorityBands), "orderedPriorities", s.orderedPriorityLevels) + return s, nil +} + +// reconcileStats is the single point of entry for all statistics changes within the shard. It updates the relevant +// band's stats, the shard's total stats, and propagates the delta to the parent registry. +func (s *registryShard) reconcileStats(priority uint, lenDelta, byteSizeDelta int64) { + s.totalLen.Add(uint64(lenDelta)) + s.totalByteSize.Add(uint64(byteSizeDelta)) + + if band, ok := s.priorityBands[priority]; ok { + band.len.Add(uint64(lenDelta)) + band.byteSize.Add(uint64(byteSizeDelta)) + } + + s.logger.V(logging.TRACE).Info("Reconciled shard stats", "priority", priority, + "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) + + if s.reconcileFun != nil { + s.reconcileFun(lenDelta, byteSizeDelta) + } +} + +// ID returns the unique identifier for this shard. +func (s *registryShard) ID() string { return s.id } + +// IsActive returns true if the shard is active and accepting new requests. +func (s *registryShard) IsActive() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.isActive +} + +// ActiveManagedQueue returns the currently active `ManagedQueue` for a given flow. +func (s *registryShard) ActiveManagedQueue(flowID string) (contracts.ManagedQueue, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + mq, ok := s.activeFlows[flowID] + if !ok { + return nil, fmt.Errorf("failed to get active queue for flow %q: %w", flowID, contracts.ErrFlowInstanceNotFound) + } + return mq, nil +} + +// ManagedQueue retrieves a specific (potentially draining) `ManagedQueue` instance from this shard. +func (s *registryShard) ManagedQueue(flowID string, priority uint) (contracts.ManagedQueue, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + band, ok := s.priorityBands[priority] + if !ok { + return nil, fmt.Errorf("failed to get managed queue for flow %q: %w", flowID, contracts.ErrPriorityBandNotFound) + } + mq, ok := band.queues[flowID] + if !ok { + return nil, fmt.Errorf("failed to get managed queue for flow %q at priority %d: %w", + flowID, priority, contracts.ErrFlowInstanceNotFound) + } + return mq, nil +} + +// IntraFlowDispatchPolicy retrieves a flow's configured `framework.IntraFlowDispatchPolicy`. +func (s *registryShard) IntraFlowDispatchPolicy( + flowID string, + priority uint, +) (framework.IntraFlowDispatchPolicy, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + band, ok := s.priorityBands[priority] + if !ok { + return nil, fmt.Errorf("failed to get intra-flow policy for flow %q: %w", flowID, contracts.ErrPriorityBandNotFound) + } + mq, ok := band.queues[flowID] + if !ok { + return nil, fmt.Errorf("failed to get intra-flow policy for flow %q at priority %d: %w", + flowID, priority, contracts.ErrFlowInstanceNotFound) + } + // The policy is stored on the managed queue. + return mq.dispatchPolicy, nil +} + +// InterFlowDispatchPolicy retrieves a priority band's configured `framework.InterFlowDispatchPolicy`. +func (s *registryShard) InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + band, ok := s.priorityBands[priority] + if !ok { + return nil, fmt.Errorf("failed to get inter-flow policy for priority %d: %w", + priority, contracts.ErrPriorityBandNotFound) + } + return band.interFlowDispatchPolicy, nil +} + +// PriorityBandAccessor retrieves a read-only accessor for a given priority level. +func (s *registryShard) PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + band, ok := s.priorityBands[priority] + if !ok { + return nil, fmt.Errorf("failed to get priority band accessor for priority %d: %w", + priority, contracts.ErrPriorityBandNotFound) + } + return &priorityBandAccessor{shard: s, band: band}, nil +} + +// AllOrderedPriorityLevels returns all configured priority levels for this shard, sorted from highest to lowest +// priority (ascending numerical order). +func (s *registryShard) AllOrderedPriorityLevels() []uint { + // This is cached and read-only, so no lock is needed. + return s.orderedPriorityLevels +} + +// Stats returns a snapshot of the statistics for this specific shard. +func (s *registryShard) Stats() contracts.ShardStats { + s.mu.RLock() + defer s.mu.RUnlock() + + stats := contracts.ShardStats{ + TotalCapacityBytes: s.config.MaxBytes, + TotalByteSize: s.totalByteSize.Load(), + TotalLen: s.totalLen.Load(), + PerPriorityBandStats: make(map[uint]contracts.PriorityBandStats, len(s.priorityBands)), + } + + for priority, band := range s.priorityBands { + stats.PerPriorityBandStats[priority] = contracts.PriorityBandStats{ + Priority: priority, + PriorityName: band.config.PriorityName, + CapacityBytes: band.config.MaxBytes, // This is the partitioned capacity + ByteSize: band.byteSize.Load(), + Len: band.len.Load(), + } + } + return stats +} + +var _ contracts.RegistryShard = ®istryShard{} + +// --- priorityBandAccessor --- + +// priorityBandAccessor implements `framework.PriorityBandAccessor`. It provides a read-only, concurrent-safe view of a +// single priority band within a shard. +type priorityBandAccessor struct { + shard *registryShard + band *priorityBand +} + +// Priority returns the numerical priority level of this band. +func (a *priorityBandAccessor) Priority() uint { + return a.band.config.Priority +} + +// PriorityName returns the human-readable name of this priority band. +func (a *priorityBandAccessor) PriorityName() string { + return a.band.config.PriorityName +} + +// FlowIDs returns a slice of all flow IDs within this priority band. +func (a *priorityBandAccessor) FlowIDs() []string { + a.shard.mu.RLock() + defer a.shard.mu.RUnlock() + + flowIDs := make([]string, 0, len(a.band.queues)) + for id := range a.band.queues { + flowIDs = append(flowIDs, id) + } + return flowIDs +} + +// Queue returns a `framework.FlowQueueAccessor` for the specified `flowID` within this priority band. +func (a *priorityBandAccessor) Queue(flowID string) framework.FlowQueueAccessor { + a.shard.mu.RLock() + defer a.shard.mu.RUnlock() + + mq, ok := a.band.queues[flowID] + if !ok { + return nil + } + return mq.FlowQueueAccessor() +} + +// IterateQueues executes the given `callback` for each `framework.FlowQueueAccessor` in this priority band. +// The callback is executed under the shard's read lock, so it should be efficient and non-blocking. +// If the callback returns false, iteration stops. +func (a *priorityBandAccessor) IterateQueues(callback func(queue framework.FlowQueueAccessor) bool) { + a.shard.mu.RLock() + defer a.shard.mu.RUnlock() + + for _, mq := range a.band.queues { + if !callback(mq.FlowQueueAccessor()) { + return + } + } +} + +var _ framework.PriorityBandAccessor = &priorityBandAccessor{} diff --git a/pkg/epp/flowcontrol/registry/shard_test.go b/pkg/epp/flowcontrol/registry/shard_test.go new file mode 100644 index 000000000..fe7f96e77 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/shard_test.go @@ -0,0 +1,355 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "errors" + "sort" + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + inter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead" + intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue/listqueue" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" +) + +// shardTestFixture holds the components needed for a `registryShard` test. +type shardTestFixture struct { + config *Config + shard *registryShard +} + +// setupTestShard creates a new test fixture for testing the `registryShard`. +func setupTestShard(t *testing.T) *shardTestFixture { + t.Helper() + + config := &Config{ + PriorityBands: []PriorityBandConfig{ + { + Priority: 10, + PriorityName: "High", + }, + { + Priority: 20, + PriorityName: "Low", + }, + }, + } + // Apply defaults to the master config first, as the parent registry would. + err := config.validateAndApplyDefaults() + require.NoError(t, err, "Setup: validating and defaulting config should not fail") + + // The parent registry would partition the config. For a single shard test, we can use the defaulted one directly. + shard, err := newShard("test-shard-1", config, logr.Discard(), nil) + require.NoError(t, err, "Setup: newShard should not return an error") + require.NotNil(t, shard, "Setup: newShard should return a non-nil shard") + + return &shardTestFixture{ + config: config, + shard: shard, + } +} + +// _reconcileFlow_testOnly is a test helper that simulates the future admin logic for adding or updating a flow. +// It creates a `managedQueue` and correctly populates the `priorityBands` and `activeFlows` maps. +// This helper is intended to be replaced by the real `reconcileFlow` method in a future PR. +func (s *registryShard) _reconcileFlow_testOnly( + t *testing.T, + flowSpec types.FlowSpecification, + isActive bool, +) *managedQueue { + t.Helper() + + band, ok := s.priorityBands[flowSpec.Priority] + require.True(t, ok, "Setup: priority band %d should exist", flowSpec.Priority) + + lq, err := queue.NewQueueFromName(listqueue.ListQueueName, nil) + require.NoError(t, err, "Setup: creating a real listqueue should not fail") + + mq := newManagedQueue( + lq, + band.defaultIntraFlowDispatchPolicy, + flowSpec, + logr.Discard(), + func(lenDelta, byteSizeDelta int64) { s.reconcileStats(flowSpec.Priority, lenDelta, byteSizeDelta) }, + ) + require.NotNil(t, mq, "Setup: newManagedQueue should not return nil") + + band.queues[flowSpec.ID] = mq + if isActive { + s.activeFlows[flowSpec.ID] = mq + } + + return mq +} + +func TestNewShard(t *testing.T) { + t.Parallel() + f := setupTestShard(t) + + assert.Equal(t, "test-shard-1", f.shard.ID(), "ID should be set correctly") + assert.True(t, f.shard.IsActive(), "A new shard should be active") + require.Len(t, f.shard.priorityBands, 2, "Should have 2 priority bands") + + // Check that priority levels are sorted correctly + assert.Equal(t, []uint{10, 20}, f.shard.AllOrderedPriorityLevels(), "Priority levels should be ordered") + + // Check band 10 + band10, ok := f.shard.priorityBands[10] + require.True(t, ok, "Priority band 10 should exist") + assert.Equal(t, uint(10), band10.config.Priority, "Band 10 should have correct priority") + assert.Equal(t, "High", band10.config.PriorityName, "Band 10 should have correct name") + assert.NotNil(t, band10.interFlowDispatchPolicy, "Inter-flow policy for band 10 should be instantiated") + assert.NotNil(t, band10.defaultIntraFlowDispatchPolicy, + "Default intra-flow policy for band 10 should be instantiated") + assert.Equal(t, besthead.BestHeadPolicyName, band10.interFlowDispatchPolicy.Name(), + "Correct default inter-flow policy should be used") + assert.Equal(t, fcfs.FCFSPolicyName, band10.defaultIntraFlowDispatchPolicy.Name(), + "Correct default intra-flow policy should be used") + + // Check band 20 + band20, ok := f.shard.priorityBands[20] + require.True(t, ok, "Priority band 20 should exist") + assert.Equal(t, uint(20), band20.config.Priority, "Band 20 should have correct priority") + assert.Equal(t, "Low", band20.config.PriorityName, "Band 20 should have correct name") + assert.NotNil(t, band20.interFlowDispatchPolicy, "Inter-flow policy for band 20 should be instantiated") + assert.NotNil(t, band20.defaultIntraFlowDispatchPolicy, + "Default intra-flow policy for band 20 should be instantiated") +} + +// TestNewShard_ErrorPaths modifies global plugin registries, so it cannot be run in parallel with other tests. +func TestNewShard_ErrorPaths(t *testing.T) { + baseConfig := &Config{ + PriorityBands: []PriorityBandConfig{{ + Priority: 10, + PriorityName: "High", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + InterFlowDispatchPolicy: besthead.BestHeadPolicyName, + Queue: listqueue.ListQueueName, + }}, + } + require.NoError(t, baseConfig.validateAndApplyDefaults(), "Setup: base config should be valid") + + t.Run("Invalid InterFlow Policy", func(t *testing.T) { + // Register a mock policy that always fails to instantiate + failingPolicyName := inter.RegisteredPolicyName("failing-inter-policy") + inter.MustRegisterPolicy(failingPolicyName, func() (framework.InterFlowDispatchPolicy, error) { + return nil, errors.New("inter-flow instantiation failed") + }) + + badConfig := *baseConfig + badConfig.PriorityBands[0].InterFlowDispatchPolicy = failingPolicyName + + _, err := newShard("test", &badConfig, logr.Discard(), nil) + require.Error(t, err, "newShard should fail with an invalid inter-flow policy") + }) + + t.Run("Invalid IntraFlow Policy", func(t *testing.T) { + // Register a mock policy that always fails to instantiate + failingPolicyName := intra.RegisteredPolicyName("failing-intra-policy") + intra.MustRegisterPolicy(failingPolicyName, func() (framework.IntraFlowDispatchPolicy, error) { + return nil, errors.New("intra-flow instantiation failed") + }) + + badConfig := *baseConfig + badConfig.PriorityBands[0].IntraFlowDispatchPolicy = failingPolicyName + + _, err := newShard("test", &badConfig, logr.Discard(), nil) + require.Error(t, err, "newShard should fail with an invalid intra-flow policy") + }) +} + +func TestShard_Stats(t *testing.T) { + t.Parallel() + f := setupTestShard(t) + + // Add a queue and some items to test stats aggregation + mq := f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow1", Priority: 10}, true) + + // Add items + require.NoError(t, mq.Add(mocks.NewMockQueueItemAccessor(100, "req1", "flow1")), "Adding item should not fail") + require.NoError(t, mq.Add(mocks.NewMockQueueItemAccessor(50, "req2", "flow1")), "Adding item should not fail") + + stats := f.shard.Stats() + + // Check shard-level stats + assert.Equal(t, uint64(2), stats.TotalLen, "Total length should be 2") + assert.Equal(t, uint64(150), stats.TotalByteSize, "Total byte size should be 150") + + // Check per-band stats + require.Len(t, stats.PerPriorityBandStats, 2, "Should have stats for 2 bands") + band10Stats := stats.PerPriorityBandStats[10] + assert.Equal(t, uint(10), band10Stats.Priority, "Band 10 stats should have correct priority") + assert.Equal(t, uint64(2), band10Stats.Len, "Band 10 length should be 2") + assert.Equal(t, uint64(150), band10Stats.ByteSize, "Band 10 byte size should be 150") + + band20Stats := stats.PerPriorityBandStats[20] + assert.Equal(t, uint(20), band20Stats.Priority, "Band 20 stats should have correct priority") + assert.Zero(t, band20Stats.Len, "Band 20 length should be 0") + assert.Zero(t, band20Stats.ByteSize, "Band 20 byte size should be 0") +} + +func TestShard_Accessors(t *testing.T) { + t.Parallel() + f := setupTestShard(t) + + flowID := "test-flow" + activePriority := uint(10) + drainingPriority := uint(20) + + // Setup state with one active and one draining queue for the same flow + activeQueue := f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ + ID: flowID, + Priority: activePriority, + }, true) + drainingQueue := f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ + ID: flowID, + Priority: drainingPriority, + }, false) + + t.Run("ActiveManagedQueue", func(t *testing.T) { + t.Parallel() + retrievedActiveQueue, err := f.shard.ActiveManagedQueue(flowID) + require.NoError(t, err, "ActiveManagedQueue should not error for an existing flow") + assert.Same(t, activeQueue, retrievedActiveQueue, "Should return the correct active queue") + + _, err = f.shard.ActiveManagedQueue("non-existent-flow") + require.Error(t, err, "ActiveManagedQueue should error for a non-existent flow") + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") + }) + + t.Run("ManagedQueue", func(t *testing.T) { + t.Parallel() + retrievedDrainingQueue, err := f.shard.ManagedQueue(flowID, drainingPriority) + require.NoError(t, err, "ManagedQueue should not error for a draining queue") + assert.Same(t, drainingQueue, retrievedDrainingQueue, "Should return the correct draining queue") + + _, err = f.shard.ManagedQueue(flowID, 99) // Non-existent priority + require.Error(t, err, "ManagedQueue should error for a non-existent priority") + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + + _, err = f.shard.ManagedQueue("non-existent-flow", activePriority) + require.Error(t, err, "ManagedQueue should error for a non-existent flow in an existing priority") + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") + }) + + t.Run("IntraFlowDispatchPolicy", func(t *testing.T) { + t.Parallel() + retrievedActivePolicy, err := f.shard.IntraFlowDispatchPolicy(flowID, activePriority) + require.NoError(t, err, "IntraFlowDispatchPolicy should not error for an active instance") + assert.Same(t, activeQueue.dispatchPolicy, retrievedActivePolicy, + "Should return the policy from the active instance") + + _, err = f.shard.IntraFlowDispatchPolicy("non-existent-flow", activePriority) + require.Error(t, err, "IntraFlowDispatchPolicy should error for a non-existent flow") + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") + + _, err = f.shard.IntraFlowDispatchPolicy(flowID, 99) // Non-existent priority + require.Error(t, err, "IntraFlowDispatchPolicy should error for a non-existent priority") + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + }) + + t.Run("InterFlowDispatchPolicy", func(t *testing.T) { + t.Parallel() + retrievedInterPolicy, err := f.shard.InterFlowDispatchPolicy(activePriority) + require.NoError(t, err, "InterFlowDispatchPolicy should not error for an existing priority") + assert.Same(t, f.shard.priorityBands[activePriority].interFlowDispatchPolicy, retrievedInterPolicy, + "Should return the correct inter-flow policy") + + _, err = f.shard.InterFlowDispatchPolicy(99) // Non-existent priority + require.Error(t, err, "InterFlowDispatchPolicy should error for a non-existent priority") + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + }) +} + +func TestShard_PriorityBandAccessor(t *testing.T) { + t.Parallel() + f := setupTestShard(t) + + // Setup shard state for the tests + p1, p2 := uint(10), uint(20) + f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow1", Priority: p1}, true) + f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow1", Priority: p2}, false) + f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow2", Priority: p1}, true) + + t.Run("Accessor for existing priority", func(t *testing.T) { + t.Parallel() + accessor, err := f.shard.PriorityBandAccessor(p1) + require.NoError(t, err, "PriorityBandAccessor should not fail for existing priority") + require.NotNil(t, accessor, "Accessor should not be nil") + + t.Run("Properties", func(t *testing.T) { + t.Parallel() + assert.Equal(t, p1, accessor.Priority(), "Accessor should have correct priority") + assert.Equal(t, "High", accessor.PriorityName(), "Accessor should have correct priority name") + }) + + t.Run("FlowIDs", func(t *testing.T) { + t.Parallel() + flowIDs := accessor.FlowIDs() + sort.Strings(flowIDs) + assert.Equal(t, []string{"flow1", "flow2"}, flowIDs, + "Accessor should return correct flow IDs for the priority band") + }) + + t.Run("Queue", func(t *testing.T) { + t.Parallel() + q := accessor.Queue("flow1") + require.NotNil(t, q, "Accessor should return queue for flow1") + assert.Equal(t, p1, q.FlowSpec().Priority, "Queue should have the correct priority") + assert.Nil(t, accessor.Queue("non-existent"), "Accessor should return nil for non-existent flow") + }) + + t.Run("IterateQueues", func(t *testing.T) { + t.Parallel() + var iteratedFlows []string + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iteratedFlows = append(iteratedFlows, queue.FlowSpec().ID) + return true + }) + sort.Strings(iteratedFlows) + assert.Equal(t, []string{"flow1", "flow2"}, iteratedFlows, "IterateQueues should visit all flows in the band") + }) + + t.Run("IterateQueues with early exit", func(t *testing.T) { + t.Parallel() + var iteratedFlows []string + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iteratedFlows = append(iteratedFlows, queue.FlowSpec().ID) + return false // Exit after first item + }) + assert.Len(t, iteratedFlows, 1, "IterateQueues should exit early if callback returns false") + }) + }) + + t.Run("Error on non-existent priority", func(t *testing.T) { + t.Parallel() + _, err := f.shard.PriorityBandAccessor(99) + require.Error(t, err, "PriorityBandAccessor should fail for non-existent priority") + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + }) +}