diff --git a/pkg/epp/flowcontrol/contracts/mocks/mocks.go b/pkg/epp/flowcontrol/contracts/mocks/mocks.go new file mode 100644 index 000000000..5431e2c83 --- /dev/null +++ b/pkg/epp/flowcontrol/contracts/mocks/mocks.go @@ -0,0 +1,291 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package mocks provides mocks for the interfaces defined in the `contracts` package. +// +// # Testing Philosophy: High-Fidelity Mocks +// +// The components that consume these contracts, particularly the `controller.ShardProcessor`, are complex, concurrent +// orchestrators. Testing them reliably requires more than simple stubs. It requires high-fidelity mocks that allow for +// the deterministic simulation of race conditions and specific failure modes. +// +// For this reason, mocks like `MockManagedQueue` are deliberately stateful and thread-safe. They provide a reliable, +// in-memory simulation of the real component's behavior, while also providing function-based overrides +// (e.g., `AddFunc`) that allow tests to inject specific errors or pause execution at critical moments. This strategy is +// essential for creating the robust, non-flaky tests needed to verify the correctness of the system's concurrent logic. +// For a more detailed defense of this strategy, see the comment at the top of `controller/internal/processor_test.go`. +package mocks + +import ( + "context" + "fmt" + "sync" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" +) + +// MockRegistryShard is a simple "stub-style" mock for testing. +// Its methods are implemented as function fields (e.g., `IDFunc`). A test can inject behavior by setting the desired +// function field in the test setup. If a func is nil, the method will return a zero value. +type MockRegistryShard struct { + IDFunc func() string + IsActiveFunc func() bool + ActiveManagedQueueFunc func(flowID string) (contracts.ManagedQueue, error) + ManagedQueueFunc func(flowID string, priority uint) (contracts.ManagedQueue, error) + IntraFlowDispatchPolicyFunc func(flowID string, priority uint) (framework.IntraFlowDispatchPolicy, error) + InterFlowDispatchPolicyFunc func(priority uint) (framework.InterFlowDispatchPolicy, error) + PriorityBandAccessorFunc func(priority uint) (framework.PriorityBandAccessor, error) + AllOrderedPriorityLevelsFunc func() []uint + StatsFunc func() contracts.ShardStats +} + +func (m *MockRegistryShard) ID() string { + if m.IDFunc != nil { + return m.IDFunc() + } + return "" +} + +func (m *MockRegistryShard) IsActive() bool { + if m.IsActiveFunc != nil { + return m.IsActiveFunc() + } + return false +} + +func (m *MockRegistryShard) ActiveManagedQueue(flowID string) (contracts.ManagedQueue, error) { + if m.ActiveManagedQueueFunc != nil { + return m.ActiveManagedQueueFunc(flowID) + } + return nil, nil +} + +func (m *MockRegistryShard) ManagedQueue(flowID string, priority uint) (contracts.ManagedQueue, error) { + if m.ManagedQueueFunc != nil { + return m.ManagedQueueFunc(flowID, priority) + } + return nil, nil +} + +func (m *MockRegistryShard) IntraFlowDispatchPolicy(flowID string, priority uint) (framework.IntraFlowDispatchPolicy, error) { + if m.IntraFlowDispatchPolicyFunc != nil { + return m.IntraFlowDispatchPolicyFunc(flowID, priority) + } + return nil, nil +} + +func (m *MockRegistryShard) InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error) { + if m.InterFlowDispatchPolicyFunc != nil { + return m.InterFlowDispatchPolicyFunc(priority) + } + return nil, nil +} + +func (m *MockRegistryShard) PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error) { + if m.PriorityBandAccessorFunc != nil { + return m.PriorityBandAccessorFunc(priority) + } + return nil, nil +} + +func (m *MockRegistryShard) AllOrderedPriorityLevels() []uint { + if m.AllOrderedPriorityLevelsFunc != nil { + return m.AllOrderedPriorityLevelsFunc() + } + return nil +} + +func (m *MockRegistryShard) Stats() contracts.ShardStats { + if m.StatsFunc != nil { + return m.StatsFunc() + } + return contracts.ShardStats{} +} + +// MockSaturationDetector is a simple "stub-style" mock for testing. +type MockSaturationDetector struct { + IsSaturatedFunc func(ctx context.Context) bool +} + +func (m *MockSaturationDetector) IsSaturated(ctx context.Context) bool { + if m.IsSaturatedFunc != nil { + return m.IsSaturatedFunc(ctx) + } + return false +} + +// MockManagedQueue is a high-fidelity, thread-safe mock of the `contracts.ManagedQueue` interface, designed +// specifically for testing the concurrent `controller/internal.ShardProcessor`. +// +// This mock is essential for creating deterministic and focused unit tests. It allows for precise control over queue +// behavior and enables the testing of critical edge cases (e.g., empty queues, dispatch failures) in complete +// isolation, which would be difficult and unreliable to achieve with the concrete `registry.managedQueue` +// implementation. +// +// ### Design Philosophy +// +// 1. **Stateful**: The mock maintains an internal map of items to accurately reflect a real queue's state. Its `Len()` +// and `ByteSize()` methods are derived directly from this state. +// 2. **Deadlock-Safe Overrides**: Test-specific logic (e.g., `AddFunc`) is executed instead of the default +// implementation. The override function is fully responsible for its own logic and synchronization, as the mock's +// internal mutex will *not* be held during its execution. +// 3. **Self-Wiring**: The `FlowQueueAccessor()` method returns the mock itself, ensuring the accessor is always +// correctly connected to the queue's state without manual wiring in tests. +type MockManagedQueue struct { + // FlowSpecV defines the flow specification for this mock queue. It should be set by the test. + FlowSpecV types.FlowSpecification + + // AddFunc allows a test to completely override the default Add behavior. + AddFunc func(item types.QueueItemAccessor) error + // RemoveFunc allows a test to completely override the default Remove behavior. + RemoveFunc func(handle types.QueueItemHandle) (types.QueueItemAccessor, error) + // CleanupFunc allows a test to completely override the default Cleanup behavior. + CleanupFunc func(predicate framework.PredicateFunc) ([]types.QueueItemAccessor, error) + // DrainFunc allows a test to completely override the default Drain behavior. + DrainFunc func() ([]types.QueueItemAccessor, error) + + // mu protects access to the internal `items` map. + mu sync.Mutex + initOnce sync.Once + items map[types.QueueItemHandle]types.QueueItemAccessor +} + +func (m *MockManagedQueue) init() { + m.initOnce.Do(func() { + m.items = make(map[types.QueueItemHandle]types.QueueItemAccessor) + }) +} + +// FlowQueueAccessor returns the mock itself, as it fully implements the `framework.FlowQueueAccessor` interface. +func (m *MockManagedQueue) FlowQueueAccessor() framework.FlowQueueAccessor { + return m +} + +// Add adds an item to the queue. +// It checks for a test override before locking. If no override is present, it executes the default stateful logic, +// which includes fulfilling the `SafeQueue.Add` contract. +func (m *MockManagedQueue) Add(item types.QueueItemAccessor) error { + // If an override is provided, it is responsible for the full contract, including setting the handle. + if m.AddFunc != nil { + return m.AddFunc(item) + } + + m.mu.Lock() + defer m.mu.Unlock() + m.init() + + // Fulfill the `SafeQueue.Add` contract: the queue is responsible for setting the handle. + if item.Handle() == nil { + item.SetHandle(&typesmocks.MockQueueItemHandle{}) + } + + m.items[item.Handle()] = item + return nil +} + +// Remove removes an item from the queue. It checks for a test override before locking. +func (m *MockManagedQueue) Remove(handle types.QueueItemHandle) (types.QueueItemAccessor, error) { + if m.RemoveFunc != nil { + return m.RemoveFunc(handle) + } + m.mu.Lock() + defer m.mu.Unlock() + m.init() + item, ok := m.items[handle] + if !ok { + return nil, fmt.Errorf("item with handle %v not found", handle) + } + delete(m.items, handle) + return item, nil +} + +// Cleanup removes items matching a predicate. It checks for a test override before locking. +func (m *MockManagedQueue) Cleanup(predicate framework.PredicateFunc) ([]types.QueueItemAccessor, error) { + if m.CleanupFunc != nil { + return m.CleanupFunc(predicate) + } + m.mu.Lock() + defer m.mu.Unlock() + m.init() + var removed []types.QueueItemAccessor + for handle, item := range m.items { + if predicate(item) { + removed = append(removed, item) + delete(m.items, handle) + } + } + return removed, nil +} + +// Drain removes all items from the queue. It checks for a test override before locking. +func (m *MockManagedQueue) Drain() ([]types.QueueItemAccessor, error) { + if m.DrainFunc != nil { + return m.DrainFunc() + } + m.mu.Lock() + defer m.mu.Unlock() + m.init() + drained := make([]types.QueueItemAccessor, 0, len(m.items)) + for _, item := range m.items { + drained = append(drained, item) + } + m.items = make(map[types.QueueItemHandle]types.QueueItemAccessor) + return drained, nil +} + +func (m *MockManagedQueue) FlowSpec() types.FlowSpecification { return m.FlowSpecV } +func (m *MockManagedQueue) Name() string { return "" } +func (m *MockManagedQueue) Capabilities() []framework.QueueCapability { return nil } +func (m *MockManagedQueue) Comparator() framework.ItemComparator { return nil } + +// Len returns the actual number of items currently in the mock queue. +func (m *MockManagedQueue) Len() int { + m.mu.Lock() + defer m.mu.Unlock() + m.init() + return len(m.items) +} + +// ByteSize returns the actual total byte size of all items in the mock queue. +func (m *MockManagedQueue) ByteSize() uint64 { + m.mu.Lock() + defer m.mu.Unlock() + m.init() + var size uint64 + for _, item := range m.items { + size += item.OriginalRequest().ByteSize() + } + return size +} + +// PeekHead returns the first item found in the mock queue. Note: map iteration order is not guaranteed. +func (m *MockManagedQueue) PeekHead() (types.QueueItemAccessor, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.init() + for _, item := range m.items { + return item, nil // Return first item found + } + return nil, nil // Queue is empty +} + +// PeekTail is not implemented for this mock. +func (m *MockManagedQueue) PeekTail() (types.QueueItemAccessor, error) { + return nil, nil +} diff --git a/pkg/epp/flowcontrol/contracts/saturationdetector.go b/pkg/epp/flowcontrol/contracts/saturationdetector.go new file mode 100644 index 000000000..91d2406c5 --- /dev/null +++ b/pkg/epp/flowcontrol/contracts/saturationdetector.go @@ -0,0 +1,39 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package contracts + +import "context" + +// SaturationDetector defines the contract for a component that provides real-time load signals to the +// `controller.FlowController`. +// +// This interface abstracts away the complexity of determining system load. An implementation would consume various +// backend metrics (e.g., queue depths, KV cache utilization, observed latencies) and translate them into a simple +// boolean signal. +// +// This decoupling is important because it allows the saturation detection logic to evolve independently of the core +// `controller.FlowController` engine, which is only concerned with the final true/false signal. +// +// # Conformance +// +// Implementations MUST be goroutine-safe. +type SaturationDetector interface { + // IsSaturated returns true if the system's backend resources are considered saturated. + // `controller.FlowController`'s dispatch workers call this method to decide whether to pause or throttle dispatch + // operations to prevent overwhelming the backends. + IsSaturated(ctx context.Context) bool +} diff --git a/pkg/epp/flowcontrol/controller/doc.go b/pkg/epp/flowcontrol/controller/doc.go new file mode 100644 index 000000000..8c96bbc18 --- /dev/null +++ b/pkg/epp/flowcontrol/controller/doc.go @@ -0,0 +1,122 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package controller contains the implementation of the `FlowController` engine. +// +// # Overview +// +// The `FlowController` is the central processing engine of the flow control system. It is a sharded, high-throughput +// component responsible for managing the lifecycle of all incoming requests—from initial submission via the synchronous +// `EnqueueAndWait` method to a terminal outcome (dispatch, rejection, or eviction). It achieves this by orchestrating +// its dependencies—the `contracts.FlowRegistry`, the pluggable `Policy` framework, and the +// `contracts.SaturationDetector`—to make continuous, state-aware decisions. +// +// # Architecture: The Processor-Shard Relationship +// +// The `FlowController` engine is designed around a clear separation of state and execution. This "control plane vs. +// data plane" separation is key to enabling dynamic, concurrent-safe configuration updates. +// +// - The `contracts.FlowRegistry` is the **control plane**. It is the single source of truth for all configuration. +// When an administrative action occurs (e.g., `RegisterOrUpdateFlow`), the `contracts.FlowRegistry` is responsible +// for safely applying that change to each of its managed `contracts.RegistryShard` instances. +// +// - The `contracts.RegistryShard` is the **concurrent-safe state port**. It defines the contract for a state store +// that holds the `contracts.ManagedQueue` and framework `Policy` instances for a single shard. +// +// - The `internal.ShardProcessor` is the **data plane worker**. It is given a single `contracts.RegistryShard` to +// operate on. Its main `dispatchCycle` continuously acquires a read lock on the shard to get a consistent view of +// the active queues and policies, and then executes its dispatch logic. +// +// This separation is what enables dynamic updates. The `internal.ShardProcessor` is stateless; it simply executes +// against the state presented by its `contracts.RegistryShard` on each cycle. This allows the control plane +// (`contracts.FlowRegistry`) to safely change that state in the background. +// +// # Architectural Deep Dive: The `EnqueueAndWait` Model +// +// A fundamental design choice is the synchronous, blocking `EnqueueAndWait` method. In the context of the Gateway API +// Inference Extension's Endpoint Picker (EPP), which operates as an Envoy External Processing (`ext_proc`) server, this +// model is deliberately chosen for its simplicity and robustness. +// +// - Alignment with `ext_proc`: The `ext_proc` protocol is stream-based. A single goroutine within the EPP manages the +// stream for a given HTTP request. `EnqueueAndWait` fits this perfectly: the request-handling goroutine calls it, +// blocks, and upon return, has the definitive outcome. It can then immediately act on that outcome, maintaining +// clear request-goroutine affinity. +// +// - Simplified State Management: The state of a "waiting" request is implicitly managed by the blocked goroutine's +// stack and its `context.Context`. The `FlowController` only needs to signal this specific goroutine to unblock it. +// +// - Direct Backpressure: If queues are full, `EnqueueAndWait` returns `types.ErrQueueAtCapacity`. This provides +// immediate, direct backpressure to the earliest point of contact. +// +// # Architectural Deep Dive: The Sharded Model & JSQ-Bytes +// +// The `FlowController` is built on a sharded architecture to enable parallel processing and prevent a central dispatch +// loop from becoming a bottleneck. The `FlowController` consists of a top-level manager and a pool of independent +// `internal.ShardProcessor` workers. The `contracts.FlowRegistry` guarantees that every logical flow is represented by +// a distinct queue instance on every active shard. +// +// This architecture trades deterministic global state for high throughput and scalability. The key challenge, and the +// system's most critical assumption, revolves around ensuring this distributed model can still achieve global fairness +// objectives. +// +// ## The Critical Assumption: Homogeneity Within Flows +// +// The effectiveness of the sharded model hinges on a critical assumption: while the system as a whole manages a +// heterogeneous set of flows, the traffic *within a single logical flow* is assumed to be roughly homogeneous in its +// characteristics. A logical flow is intended to represent a single workload or tenant; therefore, the most +// unpredictable variables (effecting decode behavior) are expected to be statistically similar *within* that flow. +// +// ## The Hedge: Join the Shortest Queue by Bytes (JSQ-Bytes) +// +// To make this assumption as robust as possible, the `FlowController` uses a "Join the Shortest Queue by Bytes +// (JSQ-Bytes)" algorithm. `ByteSize` is an excellent proxy for the resources the `FlowController` explicitly manages +// (host memory pressure and queuing capacity) and is also a reasonable proxy for prefill compute time. +// +// Crucially, the goal of the distributor is not to perfectly predict backend compute time, but to intelligently balance +// the load at the controller level. JSQ-Bytes achieves this by: +// +// 1. Reflecting True Load: It distributes work based on each shard's current queue size in bytes—a direct measure of +// its memory and capacity congestion. +// +// 2. Adapting to Congestion: The byte-size of a queue is a real-time signal of a shard's overall congestion. If a +// shard is slow (e.g., due to long-decoding downstream requests), its queues will remain full, and JSQ-Bytes will +// adaptively steer new work away. +// +// 3. Hedging Against Assumption Violations: This adaptive, self-correcting nature makes it a powerful hedge. It +// doesn't just distribute; it actively *load balances* based on the most relevant feedback available. +// +// # Architectural Deep Dive: Pre-Policy Gating +// +// Before policies are invoked, the `internal.ShardProcessor` applies an `internal.BandFilter`. This function determines +// which flows within a priority band are eligible for a given operation (e.g., dispatch). This pattern is a deliberate +// architectural choice to decouple the logic of *viability* from the logic of *selection*. +// +// - An `internal.BandFilter` (e.g., the `internal.NewSaturationFilter`) determines if a flow is viable based on +// external signals like backend load. +// - The `framework.InterFlowDispatchPolicy` then selects from among the viable candidates based on its own fairness +// logic. +// +// This abstraction provides two major benefits: +// +// 1. Low Contributor Burden: It makes the mental model for policy contributors significantly simpler. An author of a +// new fairness policy does not need to be concerned with the complexities of saturation detection or other gating +// concerns. They are given a simple, pre-filtered view of the world and can focus solely on their selection logic. +// +// 2. Correctness by Construction: The `internal.subsetPriorityBandAccessor` wrapper guarantees that a policy operates +// on a consistent, filtered view, regardless of which accessor method it calls (`FlowIDs`, `Queue`, etc.). This +// prevents an entire class of subtle bugs where a policy might otherwise see a stale or unfiltered view of the +// system state. +package controller diff --git a/pkg/epp/flowcontrol/controller/internal/doc.go b/pkg/epp/flowcontrol/controller/internal/doc.go new file mode 100644 index 000000000..3f39b5791 --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/doc.go @@ -0,0 +1,47 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package internal provides the core worker implementation for the `controller.FlowController`. +// +// The components in this package are the private, internal building blocks of the `controller` package. This separation +// enforces a clean public API at the `controller` level and allows the internal mechanics of the engine to evolve +// independently. +// +// # Design Philosophy: A Single-Writer Actor Model +// +// The concurrency model for this package is deliberately built around a single-writer, channel-based actor pattern, as +// implemented in the `ShardProcessor`. While a simple lock-based approach might seem easier, it is insufficient for the +// system's requirements. The "enqueue" operation is a complex, stateful transaction that requires a **hierarchical +// capacity check** against both the overall shard and a request's specific priority band. +// +// A coarse, shard-wide lock would be required to make this transaction atomic, creating a major performance bottleneck +// and causing head-of-line blocking at the top-level `controller.FlowController`. The single-writer model, where all +// state mutations are funneled through a single goroutine, makes this transaction atomic *without locks*. +// +// This design provides two critical benefits: +// 1. **Decoupling:** The `controller.FlowController` is decoupled via a non-blocking channel send, allowing for much +// higher throughput. +// 2. **Backpressure:** The state of the channel buffer serves as a high-fidelity, real-time backpressure signal, +// enabling more intelligent load balancing. +// +// # Future-Proofing for Complex Transactions +// +// This model's true power is that it provides a robust foundation for future features like **displacement** (a +// high-priority item evicting lower-priority ones). This is an "all-or-nothing" atomic transaction that is +// exceptionally difficult to implement correctly in a lock-free or coarse-grained locking model without significant +// performance penalties. The single-writer model contains the performance cost of such a potentially long transaction +// to the single `ShardProcessor`, preventing it from blocking the entire `controller.FlowController`. +package internal diff --git a/pkg/epp/flowcontrol/controller/internal/filter.go b/pkg/epp/flowcontrol/controller/internal/filter.go new file mode 100644 index 000000000..f7e731379 --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/filter.go @@ -0,0 +1,143 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "context" + + "github.com/go-logr/logr" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// BandFilter is a function that acts as a pre-policy gate. It takes a complete view of a priority band and returns a +// subset of flow IDs that are considered viable candidates for a subsequent policy decision. It can also return a +// boolean signal to pause the entire operation for the band. +// +// This abstraction decouples the logic of determining *viability* (e.g., is a flow subject to backpressure?) from the +// logic of *selection* (e.g., which of the viable flows is the fairest to pick next?). This separation simplifies the +// mental model for policy authors, who can focus solely on selection logic without needing to account for external +// gating signals. +// +// Because filters are applied before the band is passed to a policy, they are inherently composable. Multiple filters +// can be chained to apply different viability criteria. For example, a future filter could be developed to temporarily +// exclude a "misbehaving" flow that is causing repeated errors, quarantining it from policy consideration. +// +// A nil `allowedFlows` map indicates that no filtering is necessary and all flows in the band are visible. +// This provides a zero-allocation fast path for the common case where no flows are being filtered. +type BandFilter func( + ctx context.Context, + band framework.PriorityBandAccessor, + logger logr.Logger, +) (allowedFlows map[string]struct{}, shouldPause bool) + +// NewSaturationFilter creates a `BandFilter` that uses the provided `contracts.SaturationDetector` to determine which +// flows are dispatchable. This is the standard filter used in the production `FlowController` for the dispatch +// operation. +func NewSaturationFilter(sd contracts.SaturationDetector) BandFilter { + return func( + ctx context.Context, + band framework.PriorityBandAccessor, + logger logr.Logger, + ) (map[string]struct{}, bool) { + // Phase 1: Implement the current global saturation check. + if sd.IsSaturated(ctx) { + logger.V(logutil.VERBOSE).Info("System saturated, pausing dispatch for this shard.") + return nil, true // Pause dispatching for all bands. + } + + // Phase 2 (Future): This is where per-flow saturation logic would go. + // It would iterate `band`, call `IsSaturated(ctx, flowID)`, and build a filtered map of allowed flows. + // For now, no per-flow filtering is done. Return nil to signal the fast path. + return nil, false // Do not pause, and do not filter any flows. + } +} + +// subsetPriorityBandAccessor provides a view of a priority band that is restricted to a specific subset of flows. +// It implements the `framework.PriorityBandAccessor` interface, ensuring that any policy operating on it will only +// see the allowed flows, regardless of which accessor method is used. This provides correctness by construction. +// +// For performance, it pre-computes a slice of the allowed flow IDs at creation time, making subsequent calls to +// `FlowIDs()` an O(1) operation with zero allocations. +type subsetPriorityBandAccessor struct { + originalAccessor framework.PriorityBandAccessor + allowedFlows map[string]struct{} + allowedFlowsSlice []string +} + +var _ framework.PriorityBandAccessor = &subsetPriorityBandAccessor{} + +// newSubsetPriorityBandAccessor creates a new filtered view of a priority band. +func newSubsetPriorityBandAccessor( + original framework.PriorityBandAccessor, + allowed map[string]struct{}, +) *subsetPriorityBandAccessor { + // Pre-compute the slice of flow IDs for performance. + ids := make([]string, 0, len(allowed)) + for id := range allowed { + ids = append(ids, id) + } + + return &subsetPriorityBandAccessor{ + originalAccessor: original, + allowedFlows: allowed, + allowedFlowsSlice: ids, + } +} + +// Priority returns the numerical priority level of this band. +func (s *subsetPriorityBandAccessor) Priority() uint { + return s.originalAccessor.Priority() +} + +// PriorityName returns the human-readable name of this priority band. +func (s *subsetPriorityBandAccessor) PriorityName() string { + return s.originalAccessor.PriorityName() +} + +// FlowIDs returns a slice of all flow IDs within this priority band that are in the allowed subset. +// This is an O(1) operation because the slice is pre-computed at creation. +func (s *subsetPriorityBandAccessor) FlowIDs() []string { + return s.allowedFlowsSlice +} + +// Queue returns a `framework.FlowQueueAccessor` for the specified `flowID` within this priority band, but only if it is +// in the allowed subset. This is an O(1) map lookup. If the flow is not in the allowed subset, it returns nil. +func (s *subsetPriorityBandAccessor) Queue(flowID string) framework.FlowQueueAccessor { + if _, ok := s.allowedFlows[flowID]; !ok { + return nil + } + return s.originalAccessor.Queue(flowID) +} + +// IterateQueues executes the given `callback` for each `framework.FlowQueueAccessor` in the allowed subset of this +// priority band. The iteration stops if the callback returns false. +// This implementation delegates to the original accessor's iterator and applies the filter, which is more robust and +// efficient than iterating over a pre-computed slice of IDs. +func (s *subsetPriorityBandAccessor) IterateQueues(callback func(queue framework.FlowQueueAccessor) bool) { + s.originalAccessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + if _, ok := s.allowedFlows[queue.FlowSpec().ID]; ok { + // This queue is in the allowed set, so execute the callback. + if !callback(queue) { + return false // The callback requested to stop, so we stop the outer iteration too. + } + } + return true // Continue iterating over the original set. + }) +} diff --git a/pkg/epp/flowcontrol/controller/internal/filter_test.go b/pkg/epp/flowcontrol/controller/internal/filter_test.go new file mode 100644 index 000000000..f51581eaf --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/filter_test.go @@ -0,0 +1,171 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "context" + "sort" + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts/mocks" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" +) + +func TestNewSaturationFilter(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + isSaturated bool + expectShouldPause bool + expectAllowed map[string]struct{} + }{ + { + name: "should not pause or filter when system is not saturated", + isSaturated: false, + expectShouldPause: false, + expectAllowed: nil, // nil map signals the fast path + }, + { + name: "should pause when system is saturated", + isSaturated: true, + expectShouldPause: true, + expectAllowed: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // --- ARRANGE --- + mockSD := &mocks.MockSaturationDetector{IsSaturatedFunc: func(ctx context.Context) bool { return tc.isSaturated }} + filter := NewSaturationFilter(mockSD) + require.NotNil(t, filter, "NewSaturationFilter should not return nil") + + mockBand := &frameworkmocks.MockPriorityBandAccessor{} + + // --- ACT --- + allowed, shouldPause := filter(context.Background(), mockBand, logr.Discard()) + + // --- ASSERT --- + assert.Equal(t, tc.expectShouldPause, shouldPause, "The filter's pause signal should match the expected value") + + if tc.expectAllowed == nil { + assert.Nil(t, allowed, "Expected allowed map to be nil for the fast path") + } else { + assert.Equal(t, tc.expectAllowed, allowed, "The set of allowed flows should match the expected value") + } + }) + } +} + +func TestSubsetPriorityBandAccessor(t *testing.T) { + t.Parallel() + + // --- ARRANGE --- + // Setup a mock original accessor that knows about three flows. + mockQueueA := &frameworkmocks.MockFlowQueueAccessor{FlowSpecV: types.FlowSpecification{ID: "flow-a"}} + mockQueueB := &frameworkmocks.MockFlowQueueAccessor{FlowSpecV: types.FlowSpecification{ID: "flow-b"}} + mockQueueC := &frameworkmocks.MockFlowQueueAccessor{FlowSpecV: types.FlowSpecification{ID: "flow-c"}} + + originalAccessor := &frameworkmocks.MockPriorityBandAccessor{ + PriorityV: 10, + PriorityNameV: "High", + FlowIDsFunc: func() []string { + return []string{"flow-a", "flow-b", "flow-c"} + }, + QueueFunc: func(id string) framework.FlowQueueAccessor { + switch id { + case "flow-a": + return mockQueueA + case "flow-b": + return mockQueueB + case "flow-c": + return mockQueueC + } + return nil + }, + IterateQueuesFunc: func(callback func(queue framework.FlowQueueAccessor) bool) { + if !callback(mockQueueA) { + return + } + if !callback(mockQueueB) { + return + } + callback(mockQueueC) + }, + } + + // Create a subset view that only allows two of the flows. + allowedFlows := map[string]struct{}{ + "flow-a": {}, + "flow-c": {}, + } + subsetAccessor := newSubsetPriorityBandAccessor(originalAccessor, allowedFlows) + require.NotNil(t, subsetAccessor, "newSubsetPriorityBandAccessor should not return nil") + + t.Run("should pass through priority and name", func(t *testing.T) { + t.Parallel() + assert.Equal(t, uint(10), subsetAccessor.Priority(), "Priority() should pass through from the original accessor") + assert.Equal(t, "High", subsetAccessor.PriorityName(), + "PriorityName() should pass through from the original accessor") + }) + + t.Run("should only return allowed flow IDs", func(t *testing.T) { + t.Parallel() + flowIDs := subsetAccessor.FlowIDs() + // Sort for consistent comparison, as the pre-computed slice order is not guaranteed. + sort.Strings(flowIDs) + assert.Equal(t, []string{"flow-a", "flow-c"}, flowIDs, "FlowIDs() should only return the allowed subset") + }) + + t.Run("should only return queues for allowed flows", func(t *testing.T) { + t.Parallel() + assert.Same(t, mockQueueA, subsetAccessor.Queue("flow-a"), "Should return queue for allowed flow 'a'") + assert.Nil(t, subsetAccessor.Queue("flow-b"), "Should not return queue for disallowed flow 'b'") + assert.Same(t, mockQueueC, subsetAccessor.Queue("flow-c"), "Should return queue for allowed flow 'c'") + }) + + t.Run("should only iterate over allowed queues", func(t *testing.T) { + t.Parallel() + var iterated []string + subsetAccessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iterated = append(iterated, queue.FlowSpec().ID) + return true + }) + // Sort for consistent comparison, as iteration order is not guaranteed. + sort.Strings(iterated) + assert.Equal(t, []string{"flow-a", "flow-c"}, iterated, "IterateQueues() should only visit allowed flows") + }) + + t.Run("should stop iteration if callback returns false", func(t *testing.T) { + t.Parallel() + var iterated []string + subsetAccessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iterated = append(iterated, queue.FlowSpec().ID) + return false // Exit after the first item. + }) + assert.Len(t, iterated, 1, "Iteration should have stopped after one item") + }) +} diff --git a/pkg/epp/flowcontrol/controller/internal/item.go b/pkg/epp/flowcontrol/controller/internal/item.go new file mode 100644 index 000000000..86aeb8a0c --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/item.go @@ -0,0 +1,157 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "sync" + "sync/atomic" + "time" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" +) + +// flowItem is the internal representation of a request managed by the `FlowController`. It implements the +// `types.QueueItemAccessor` interface, which is the primary view of the item used by queue and policy implementations. +// It wraps the original `types.FlowControlRequest` and adds metadata for queuing, lifecycle management, and policy +// interaction. +// +// # Concurrency +// +// The `finalize` method is the primary point of concurrency concern. It is designed to be atomic and idempotent through +// the use of `sync.Once`. This guarantees that an item's final outcome can be set exactly once, even if multiple +// goroutines (e.g., the main dispatch loop and the expiry cleanup loop) race to finalize it. All other fields are set +// at creation time and are not modified thereafter, making them safe for concurrent access. +type flowItem struct { + // enqueueTime is the timestamp when the item was logically accepted by the `FlowController`. + enqueueTime time.Time + // effectiveTTL is the actual time-to-live assigned to this item. + effectiveTTL time.Duration + // originalRequest is the underlying request object. + originalRequest types.FlowControlRequest + // handle is the unique identifier for this item within a specific queue instance. + handle types.QueueItemHandle + + // done is closed exactly once when the item is finalized (dispatched or evicted/rejected). + done chan struct{} + // err stores the final error state if the item was not successfully dispatched. + // It is written to exactly once, protected by `onceFinalize`. + err atomic.Value // Stores error + // outcome stores the final `types.QueueOutcome` of the item's lifecycle. + // It is written to exactly once, protected by `onceFinalize`. + outcome atomic.Value // Stores `types.QueueOutcome` + // onceFinalize ensures the `finalize()` logic is idempotent. + onceFinalize sync.Once +} + +// ensure flowItem implements the interface. +var _ types.QueueItemAccessor = &flowItem{} + +// NewItem creates a new `flowItem`, which is the internal representation of a request inside the `FlowController`. +// This constructor is exported so that the parent `controller` package can create items to be passed into the +// `internal` package's processors. It initializes the item with a "NotYetFinalized" outcome and an open `done` channel. +func NewItem(req types.FlowControlRequest, effectiveTTL time.Duration, enqueueTime time.Time) *flowItem { + fi := &flowItem{ + enqueueTime: enqueueTime, + effectiveTTL: effectiveTTL, + originalRequest: req, + done: make(chan struct{}), + } + // Initialize the outcome to its zero state. + fi.outcome.Store(types.QueueOutcomeNotYetFinalized) + return fi +} + +// EnqueueTime returns the time the item was logically accepted by the `FlowController` for queuing. This is used as the +// basis for TTL calculations. +func (fi *flowItem) EnqueueTime() time.Time { return fi.enqueueTime } + +// EffectiveTTL returns the actual time-to-live assigned to this item by the `FlowController`. +func (fi *flowItem) EffectiveTTL() time.Duration { return fi.effectiveTTL } + +// OriginalRequest returns the original, underlying `types.FlowControlRequest` object. +func (fi *flowItem) OriginalRequest() types.FlowControlRequest { return fi.originalRequest } + +// Handle returns the `types.QueueItemHandle` that uniquely identifies this item within a specific queue instance. It +// returns nil if the item has not yet been added to a queue. +func (fi *flowItem) Handle() types.QueueItemHandle { return fi.handle } + +// SetHandle associates a `types.QueueItemHandle` with this item. This method is called by a `framework.SafeQueue` +// implementation immediately after the item is added to the queue. +func (fi *flowItem) SetHandle(handle types.QueueItemHandle) { fi.handle = handle } + +// Done returns a channel that is closed when the item has been finalized (e.g., dispatched or evicted). +// This is the primary mechanism for consumers to wait for an item's outcome. It is designed to be used in a `select` +// statement, allowing the caller to simultaneously wait for other events, such as context cancellation. +// +// # Example Usage +// +// select { +// case <-item.Done(): +// outcome, err := item.FinalState() +// // ... handle outcome +// case <-ctx.Done(): +// // ... handle cancellation +// } +func (fi *flowItem) Done() <-chan struct{} { + return fi.done +} + +// FinalState returns the terminal outcome and error for the item. +// +// CRITICAL: This method must only be called after the channel returned by `Done()` has been closed. Calling it before +// the item is finalized may result in a race condition where the final state has not yet been written. +func (fi *flowItem) FinalState() (types.QueueOutcome, error) { + outcomeVal := fi.outcome.Load() + errVal := fi.err.Load() + + var finalOutcome types.QueueOutcome + if oc, ok := outcomeVal.(types.QueueOutcome); ok { + finalOutcome = oc + } else { + // This case should not happen if finalize is always called correctly, but we default to a safe value. + finalOutcome = types.QueueOutcomeNotYetFinalized + } + + var finalErr error + if e, ok := errVal.(error); ok { + finalErr = e + } + return finalOutcome, finalErr +} + +// finalize sets the item's terminal state (`outcome`, `error`) and closes its `done` channel idempotently using +// `sync.Once`. This is the single, internal point where an item's lifecycle within the `FlowController` concludes. +func (fi *flowItem) finalize(outcome types.QueueOutcome, err error) { + fi.onceFinalize.Do(func() { + if err != nil { + fi.err.Store(err) + } + fi.outcome.Store(outcome) + close(fi.done) + }) +} + +// isFinalized checks if the item has been finalized without blocking. It is used internally by the `ShardProcessor` as +// a defensive check to avoid operating on items that have already been completed. +func (fi *flowItem) isFinalized() bool { + select { + case <-fi.done: + return true + default: + return false + } +} diff --git a/pkg/epp/flowcontrol/controller/internal/item_test.go b/pkg/epp/flowcontrol/controller/internal/item_test.go new file mode 100644 index 000000000..11e08ae56 --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/item_test.go @@ -0,0 +1,51 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" +) + +func TestItem(t *testing.T) { + t.Parallel() + + t.Run("should correctly set and get handle", func(t *testing.T) { + t.Parallel() + item := &flowItem{} + handle := &typesmocks.MockQueueItemHandle{} + item.SetHandle(handle) + assert.Same(t, handle, item.Handle(), "Handle() should retrieve the same handle instance set by SetHandle()") + }) + + t.Run("should have a non-finalized state upon creation", func(t *testing.T) { + t.Parallel() + req := typesmocks.NewMockFlowControlRequest(100, "req-1", "flow-a", context.Background()) + item := NewItem(req, time.Minute, time.Now()) + require.NotNil(t, item, "NewItem should not return nil") + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeNotYetFinalized, outcome, "A new item's outcome should be NotYetFinalized") + assert.NoError(t, err, "A new item should have a nil error") + }) +} diff --git a/pkg/epp/flowcontrol/controller/internal/processor.go b/pkg/epp/flowcontrol/controller/internal/processor.go new file mode 100644 index 000000000..800ca640b --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/processor.go @@ -0,0 +1,677 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/go-logr/logr" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + // enqueueChannelBufferSize sets the size of the buffered channel that accepts incoming requests for the shard + // processor. This buffer acts as a "shock absorber," decoupling the upstream distributor from the processor's main + // loop and allowing the system to handle short, intense bursts of traffic without blocking the distributor. + enqueueChannelBufferSize = 100 + + // maxCleanupWorkers caps the number of concurrent workers for background cleanup tasks. This prevents a single shard + // from overwhelming the Go scheduler with too many goroutines. + maxCleanupWorkers = 4 +) + +var ( + // errInterFlow is a sentinel error for failures during the inter-flow dispatch phase (e.g., a + // `framework.InterFlowDispatchPolicy` fails to select a queue). + // + // Strategy: When this error is encountered, the dispatch cycle aborts processing for the current priority band and + // immediately moves to the next, promoting work conservation. A failure in one band should not halt progress in + // others. + errInterFlow = errors.New("inter-flow policy failure") + + // errIntraFlow is a sentinel error for failures *after* a specific flow's queue has been selected (e.g., a + // `framework.IntraFlowDispatchPolicy` fails or a queue `Remove` fails). + // + // Strategy: When this error is encountered, the dispatch cycle aborts processing for the entire priority band for the + // current cycle. This acts as a critical circuit breaker. A stateless inter-flow policy could otherwise repeatedly + // select the same problematic queue in a tight loop of failures. Halting the band for one cycle prevents this. + errIntraFlow = errors.New("intra-flow operation failure") +) + +// clock defines an interface for getting the current time, allowing for dependency injection in tests. +type clock interface { + Now() time.Time +} + +// ShardProcessor is the core worker of the `controller.FlowController`. It is paired one-to-one with a +// `contracts.RegistryShard` instance and is responsible for all request lifecycle operations on that shard, including +// enqueueing, dispatching, and expiry cleanup. It acts as the "data plane" worker that executes against the +// concurrent-safe state provided by its shard. +// +// For a full rationale on the single-writer concurrency model, see the package-level documentation in `doc.go`. +// +// # Concurrency Guarantees and Race Conditions +// +// This model provides two key guarantees: +// +// 1. **Safe Enqueueing**: The `Run` method's goroutine has exclusive ownership of all operations that *add* items to +// queues. This makes the "check-then-act" sequence in `enqueue` (calling `hasCapacity` then `managedQ.Add`) +// inherently atomic from a writer's perspective, preventing capacity breaches. While the background +// `runExpiryCleanup` goroutine can concurrently *remove* items, this is a benign race; a concurrent removal only +// creates more available capacity, ensuring the `hasCapacity` check remains valid. +// +// 2. **Idempotent Finalization**: The primary internal race is between the main `dispatchCycle` and the background +// `runExpiryCleanup` goroutine, which might try to finalize the same `flowItem` simultaneously. This race is +// resolved by the `flowItem.finalize` method, which uses `sync.Once` to guarantee that only one of these goroutines +// can set the item's final state. +type ShardProcessor struct { + shard contracts.RegistryShard + dispatchFilter BandFilter + clock clock + expiryCleanupInterval time.Duration + logger logr.Logger + + // enqueueChan is the entry point for new requests to be processed by this shard's `Run` loop. + enqueueChan chan *flowItem + // wg is used to wait for background tasks like expiry cleanup to complete on shutdown. + wg sync.WaitGroup + isShuttingDown atomic.Bool + shutdownOnce sync.Once +} + +// NewShardProcessor creates a new `ShardProcessor` instance. +func NewShardProcessor( + shard contracts.RegistryShard, + dispatchFilter BandFilter, + clock clock, + expiryCleanupInterval time.Duration, + logger logr.Logger, +) *ShardProcessor { + return &ShardProcessor{ + shard: shard, + dispatchFilter: dispatchFilter, + clock: clock, + expiryCleanupInterval: expiryCleanupInterval, + logger: logger, + // A buffered channel decouples the processor from the distributor, allowing for a fast, asynchronous handoff of new + // requests. + enqueueChan: make(chan *flowItem, enqueueChannelBufferSize), + } +} + +// Run is the main operational loop for the shard processor. It must be run as a goroutine. +// +// # Loop Strategy: Interleaving Enqueue and Dispatch +// +// The loop uses a `select` statement to interleave two primary tasks: +// 1. Accepting new requests from the `enqueueChan`. +// 2. Attempting to dispatch existing requests from queues via `dispatchCycle`. +// +// This strategy is crucial for balancing responsiveness and throughput. When a new item arrives, it is immediately +// enqueued, and a dispatch cycle is triggered. This gives high-priority new arrivals a chance to be dispatched quickly. +// When no new items are arriving, the loop's `default` case continuously calls `dispatchCycle` to drain the existing +// backlog, ensuring work continues. +func (sp *ShardProcessor) Run(ctx context.Context) { + sp.logger.V(logutil.DEFAULT).Info("Shard processor run loop starting.") + defer sp.logger.V(logutil.DEFAULT).Info("Shard processor run loop stopped.") + + sp.wg.Add(1) + go sp.runExpiryCleanup(ctx) + + // This is the main worker loop. It continuously processes incoming requests and dispatches queued requests until the + // context is cancelled. The `select` statement has three cases: + // + // 1. Context Cancellation: The highest priority is shutting down. If the context's `Done` channel is closed, the + // loop will drain all queues and exit. This is the primary exit condition. + // + // 2. New Item Arrival: If an item is available on `enqueueChan`, it will be processed. This ensures that the + // processor is responsive to new work. + // + // 3. Default (Dispatch): If neither of the above cases is ready, the `default` case executes, ensuring the loop is + // non-blocking. It continuously attempts to dispatch items from the existing backlog, preventing starvation and + // ensuring queues are drained. + for { + select { + case <-ctx.Done(): + sp.shutdown() + sp.wg.Wait() + return + case item, ok := <-sp.enqueueChan: + if !ok { // Should not happen in practice, but is a clean shutdown signal. + sp.shutdown() + sp.wg.Wait() + return + } + // This is a safeguard against logic errors in the distributor. + if item == nil { + sp.logger.Error(nil, "Logic error: nil item received on shard processor enqueue channel, ignoring.") + continue + } + sp.enqueue(item) + sp.dispatchCycle(ctx) + default: + if !sp.dispatchCycle(ctx) { + // If no work was done, yield to other goroutines to prevent a tight, busy-loop when idle, but allow for + // immediate rescheduling. + runtime.Gosched() + } + } + } +} + +// Enqueue sends a new flow item to the processor's internal channel for asynchronous processing by its main `Run` loop. +// If the processor is shutting down, it immediately finalizes the item with a shutdown error. +func (sp *ShardProcessor) Enqueue(item *flowItem) { + if sp.isShuttingDown.Load() { + item.finalize(types.QueueOutcomeRejectedOther, + fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerShutdown)) + return + } + sp.enqueueChan <- item +} + +// enqueue is the internal implementation for adding a new item to a managed queue. It is always run from the single +// main `Run` goroutine, making its "check-then-act" logic for capacity safe. +func (sp *ShardProcessor) enqueue(item *flowItem) { + req := item.OriginalRequest() + logger := log.FromContext(req.Context()).WithName("enqueue").WithValues( + "flowID", req.FlowID(), + "reqID", req.ID(), + "reqByteSize", req.ByteSize(), + ) + + managedQ, err := sp.shard.ActiveManagedQueue(req.FlowID()) + if err != nil { + // This is a significant configuration error; an active queue should exist for a valid flow. + finalErr := fmt.Errorf("configuration error: failed to get active queue for flow %q: %w", req.FlowID(), err) + logger.Error(finalErr, "Rejecting item.") + item.finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) + return + } + priority := managedQ.FlowQueueAccessor().FlowSpec().Priority + logger = logger.WithValues("priority", priority) + + band, err := sp.shard.PriorityBandAccessor(priority) + if err != nil { + finalErr := fmt.Errorf("configuration error: failed to get priority band for priority %d: %w", priority, err) + logger.Error(finalErr, "Rejecting item.") + item.finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) + return + } + logger = logger.WithValues("priorityName", band.PriorityName()) + + if !sp.hasCapacity(priority, req.ByteSize()) { + // This is an expected outcome, not a system error. Log at the default level with rich context. + stats := sp.shard.Stats() + bandStats := stats.PerPriorityBandStats[priority] + logger.V(logutil.DEFAULT).Info("Rejecting request, queue at capacity", + "outcome", types.QueueOutcomeRejectedCapacity, + "shardTotalBytes", stats.TotalByteSize, + "shardCapacityBytes", stats.TotalCapacityBytes, + "bandTotalBytes", bandStats.ByteSize, + "bandCapacityBytes", bandStats.CapacityBytes, + ) + item.finalize(types.QueueOutcomeRejectedCapacity, fmt.Errorf("%w: %w", types.ErrRejected, types.ErrQueueAtCapacity)) + return + } + + // This is an optimistic check to prevent a needless add/remove cycle for an item that was finalized (e.g., context + // cancelled) during the handoff to this processor. A race condition still exists where an item can be finalized + // after this check but before the `Add` call completes. + // + // This is considered acceptable because: + // 1. The race window is extremely small. + // 2. The background `runExpiryCleanup` goroutine acts as the ultimate guarantor of correctness, as it will + // eventually find and evict any finalized item that slips through this check and is added to a queue. + if item.isFinalized() { + outcome, err := item.FinalState() + logger.V(logutil.VERBOSE).Info("Item finalized before adding to queue, ignoring.", + "outcome", outcome, "err", err) + return + } + + // This is the point of commitment. After this call, the item is officially in the queue and is the responsibility of + // the dispatch or cleanup loops to finalize. + if err := managedQ.Add(item); err != nil { + finalErr := fmt.Errorf("failed to add item to queue for flow %q: %w", req.FlowID(), err) + logger.Error(finalErr, "Rejecting item.") + item.finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) + return + } + logger.V(logutil.TRACE).Info("Item enqueued.") +} + +// hasCapacity checks if the shard and the specific priority band have enough capacity to accommodate an item of a given +// size. +func (sp *ShardProcessor) hasCapacity(priority uint, itemByteSize uint64) bool { + if itemByteSize == 0 { + return true + } + stats := sp.shard.Stats() + if stats.TotalCapacityBytes > 0 && stats.TotalByteSize+itemByteSize > stats.TotalCapacityBytes { + return false + } + bandStats, ok := stats.PerPriorityBandStats[priority] + if !ok { + // This should not happen if the registry is consistent, but we fail closed just in case. + return false + } + return bandStats.ByteSize+itemByteSize <= bandStats.CapacityBytes +} + +// dispatchCycle attempts to dispatch a single item by iterating through all priority bands from highest to lowest. +// It applies the configured policies for each band to select an item and then attempts to dispatch it. +// It returns true if an item was successfully dispatched, and false otherwise. +// +// # Error Handling Philosophy +// +// The engine employs a robust, two-tiered error handling strategy to isolate failures and maximize system availability. +// This is managed via the `errInterFlow` and `errIntraFlow` sentinel errors. +// +// - Inter-Flow Failures: If a failure occurs while selecting a flow (e.g., the `InterFlowDispatchPolicy` fails), the +// processor aborts the *current priority band* and immediately moves to the next one. This promotes work +// conservation, ensuring a single misconfigured band does not halt progress for the entire system. +// +// - Intra-Flow Failures: If a failure occurs *after* a flow has been selected (e.g., the `IntraFlowDispatchPolicy` +// fails), the processor aborts the *entire priority band* for the current cycle. This is a critical circuit +// breaker. An inter-flow policy that is not stateful with respect to past failures could otherwise repeatedly +// select the same problematic queue, causing a tight loop of failures. Halting the band for one cycle prevents +// this. +func (sp *ShardProcessor) dispatchCycle(ctx context.Context) bool { + baseLogger := sp.logger.WithName("dispatchCycle") + + // FUTURE EXTENSION POINT: The iteration over priority bands is currently a simple, strict-priority loop. + // This could be abstracted into a third policy tier (e.g., an `InterBandDispatchPolicy`) if more complex scheduling + // between bands, such as Weighted Fair Queuing (WFQ), is ever required. For now, strict priority is sufficient. + for _, priority := range sp.shard.AllOrderedPriorityLevels() { + band, err := sp.shard.PriorityBandAccessor(priority) + if err != nil { + baseLogger.Error(err, "Failed to get PriorityBandAccessor, skipping band", "priority", priority) + continue + } + logger := baseLogger.WithValues("priority", priority, "priorityName", band.PriorityName()) + + // Apply the configured filter to get a view of only the dispatchable flows. + allowedFlows, shouldPause := sp.dispatchFilter(ctx, band, logger) + if shouldPause { + return false // A global gate told us to stop the entire cycle. + } + + dispatchableBand := band + if allowedFlows != nil { + // An explicit subset of flows is allowed; create a filtered view. + dispatchableBand = newSubsetPriorityBandAccessor(band, allowedFlows) + } + + // Pass the (potentially filtered) band to the policies. + item, dispatchPriority, err := sp.selectItem(dispatchableBand, logger) + if err != nil { + // The error handling strategy depends on the type of failure (inter- vs. intra-flow). + if errors.Is(err, errIntraFlow) { + logger.Error(err, "Intra-flow policy failure, skipping priority band for this cycle") + } else { + logger.Error(err, "Inter-flow policy or configuration failure, skipping priority band for this cycle") + } + continue + } + if item == nil { + // This is the common case where a priority band has no items to dispatch. + logger.V(logutil.TRACE).Info("No item selected by dispatch policies, skipping band") + continue + } + logger = logger.WithValues("flowID", item.OriginalRequest().FlowID(), "reqID", item.OriginalRequest().ID()) + + if err := sp.dispatchItem(item, dispatchPriority, logger); err != nil { + // All errors from dispatchItem are considered intra-flow and unrecoverable for this band in this cycle. + logger.Error(err, "Failed to dispatch item, skipping priority band for this cycle") + continue + } + // A successful dispatch occurred, so we return true to signal that work was done. + return true + } + // No items were dispatched in this cycle across all priority bands. + return false +} + +// selectItem applies the configured inter- and intra-flow dispatch policies to select a single item from a priority +// band. +func (sp *ShardProcessor) selectItem( + band framework.PriorityBandAccessor, + logger logr.Logger, +) (types.QueueItemAccessor, uint, error) { + interP, err := sp.shard.InterFlowDispatchPolicy(band.Priority()) + if err != nil { + return nil, 0, fmt.Errorf("%w: could not get InterFlowDispatchPolicy: %w", errInterFlow, err) + } + queue, err := interP.SelectQueue(band) + if err != nil { + return nil, 0, fmt.Errorf("%w: InterFlowDispatchPolicy %q failed to select queue: %w", + errInterFlow, interP.Name(), err) + } + if queue == nil { + logger.V(logutil.TRACE).Info("No queue selected by InterFlowDispatchPolicy") + return nil, 0, nil + } + logger = logger.WithValues("selectedFlowID", queue.FlowSpec().ID) + + priority := queue.FlowSpec().Priority + intraP, err := sp.shard.IntraFlowDispatchPolicy(queue.FlowSpec().ID, priority) + if err != nil { + // This is an intra-flow failure because we have already successfully selected a queue. + return nil, 0, fmt.Errorf("%w: could not get IntraFlowDispatchPolicy for flow %q: %w", + errIntraFlow, queue.FlowSpec().ID, err) + } + item, err := intraP.SelectItem(queue) + if err != nil { + return nil, 0, fmt.Errorf("%w: IntraFlowDispatchPolicy %q failed to select item for flow %q: %w", + errIntraFlow, intraP.Name(), queue.FlowSpec().ID, err) + } + if item == nil { + logger.V(logutil.TRACE).Info("No item selected by IntraFlowDispatchPolicy") + return nil, 0, nil + } + return item, priority, nil +} + +// dispatchItem handles the final steps of dispatching an item after it has been selected by policies. This includes +// removing it from its queue, checking for last-minute expiry, and finalizing its outcome. +func (sp *ShardProcessor) dispatchItem(itemAcc types.QueueItemAccessor, priority uint, logger logr.Logger) error { + logger = logger.WithName("dispatchItem") + + req := itemAcc.OriginalRequest() + // We must look up the queue by its specific priority, as a flow might have draining queues at other levels. + managedQ, err := sp.shard.ManagedQueue(req.FlowID(), priority) + if err != nil { + return fmt.Errorf("%w: failed to get ManagedQueue for flow %q at priority %d: %w", + errIntraFlow, req.FlowID(), priority, err) + } + + // The core mutation: remove the item from the queue. + removedItemAcc, err := managedQ.Remove(itemAcc.Handle()) + if err != nil { + // This can happen benignly if the item was already removed by the expiry cleanup loop between the time it was + // selected by the policy and the time this function is called. + logger.V(logutil.VERBOSE).Info("Item already removed from queue, likely by expiry cleanup", "err", err) + return fmt.Errorf("%w: failed to remove item %q from queue for flow %q: %w", + errIntraFlow, req.ID(), req.FlowID(), err) + } + + removedItem, ok := removedItemAcc.(*flowItem) + if !ok { + // This indicates a severe logic error where a queue returns an item of an unexpected type. This violates a + // core system invariant: all items managed by the processor must be of type *flowItem. This is an unrecoverable + // state for this shard. + unexpectedItemErr := fmt.Errorf("%w: internal error: item %q of type %T is not a *flowItem", + errIntraFlow, removedItemAcc.OriginalRequest().ID(), removedItemAcc) + panic(unexpectedItemErr) + } + + // Final check for expiry/cancellation right before dispatch. + isExpired, outcome, expiryErr := checkItemExpiry(removedItem, sp.clock.Now()) + if isExpired { + // Ensure we always have a non-nil error to wrap for consistent logging and error handling. + finalErr := expiryErr + if finalErr == nil { + finalErr = errors.New("item finalized before dispatch") + } + logger.V(logutil.VERBOSE).Info("Item expired at time of dispatch, evicting", "outcome", outcome, + "err", finalErr) + removedItem.finalize(outcome, fmt.Errorf("%w: %w", types.ErrEvicted, finalErr)) + // Return an error to signal that the dispatch did not succeed. + return fmt.Errorf("%w: item %q expired before dispatch: %w", errIntraFlow, req.ID(), finalErr) + } + + // Finalize the item as dispatched. + removedItem.finalize(types.QueueOutcomeDispatched, nil) + logger.V(logutil.TRACE).Info("Item dispatched.") + return nil +} + +// checkItemExpiry checks if an item has been cancelled (via its context) or has exceeded its TTL. It returns true if +// the item is expired, along with the corresponding outcome and error. +// +// This function provides "defense in depth" against race conditions. It is the authoritative check that is called from +// multiple locations (the dispatch loop and the cleanup loop) to determine if an item should be evicted. Its first +// action is to check if the item has *already* been finalized by a competing goroutine, ensuring that the final outcome +// is decided exactly once. +func checkItemExpiry( + itemAcc types.QueueItemAccessor, + now time.Time, +) (isExpired bool, outcome types.QueueOutcome, err error) { + item, ok := itemAcc.(*flowItem) + if !ok { + // This indicates a severe logic error where a queue returns an item of an unexpected type. This violates a + // core system invariant: all items managed by the processor must be of type *flowItem. This is an unrecoverable + // state for this shard. + unexpectedItemErr := fmt.Errorf("internal error: item %q of type %T is not a *flowItem", + itemAcc.OriginalRequest().ID(), itemAcc) + panic(unexpectedItemErr) + } + + // This check is a critical defense against race conditions. If another goroutine (e.g., the cleanup loop) has + // already finalized this item, we must respect that outcome. + if item.isFinalized() { + outcome, err := item.FinalState() + return true, outcome, err + } + + // Check if the request's context has been cancelled. + if ctxErr := item.OriginalRequest().Context().Err(); ctxErr != nil { + return true, types.QueueOutcomeEvictedContextCancelled, fmt.Errorf("%w: %w", types.ErrContextCancelled, ctxErr) + } + + // Check if the item has outlived its TTL. + if item.EffectiveTTL() > 0 && now.Sub(item.EnqueueTime()) > item.EffectiveTTL() { + return true, types.QueueOutcomeEvictedTTL, types.ErrTTLExpired + } + + return false, types.QueueOutcomeNotYetFinalized, nil +} + +// runExpiryCleanup starts a background goroutine that periodically scans all queues on the shard for expired items. +func (sp *ShardProcessor) runExpiryCleanup(ctx context.Context) { + defer sp.wg.Done() + logger := sp.logger.WithName("runExpiryCleanup") + logger.V(logutil.DEFAULT).Info("Shard expiry cleanup goroutine starting.") + defer logger.V(logutil.DEFAULT).Info("Shard expiry cleanup goroutine stopped.") + + ticker := time.NewTicker(sp.expiryCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + sp.cleanupExpired(now) + } + } +} + +// cleanupExpired performs a single scan of all queues on the shard, removing and finalizing any items that have +// expired due to TTL or context cancellation. +func (sp *ShardProcessor) cleanupExpired(now time.Time) { + processFn := func(managedQ contracts.ManagedQueue, queueLogger logr.Logger) { + // This predicate identifies items to be removed by the Cleanup call. + predicate := func(item types.QueueItemAccessor) bool { + isExpired, _, _ := checkItemExpiry(item, now) + return isExpired + } + + removedItems, err := managedQ.Cleanup(predicate) + if err != nil { + queueLogger.Error(err, "Error during ManagedQueue Cleanup") + } + + // Finalize all the items that were removed. + sp.finalizeExpiredItems(removedItems, now, queueLogger) + } + sp.processAllQueuesConcurrently("cleanupExpired", processFn) +} + +// shutdown handles the graceful termination of the processor. It uses sync.Once to guarantee that the shutdown logic is +// executed exactly once, regardless of whether it's triggered by context cancellation or the closing of the enqueue +// channel. +func (sp *ShardProcessor) shutdown() { + sp.shutdownOnce.Do(func() { + // Set the atomic bool so that any new calls to Enqueue will fail fast. + sp.isShuttingDown.Store(true) + sp.logger.V(logutil.DEFAULT).Info("Shard processor shutting down.") + + // Drain the channel BEFORE closing it. This prevents a panic from any goroutine that is currently blocked trying to + // send to the channel. We read until it's empty. + DrainLoop: + for { + select { + case item := <-sp.enqueueChan: + item.finalize(types.QueueOutcomeRejectedOther, + fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerShutdown)) + default: + // The channel is empty, we can now safely close it. + break DrainLoop + } + } + close(sp.enqueueChan) + + // Evict all remaining items from the queues. + sp.evictAll() + }) +} + +// evictAll drains all queues on the shard and finalizes every item with a shutdown error. This is called when the +// processor is shutting down to ensure no requests are left in a pending state. +func (sp *ShardProcessor) evictAll() { + processFn := func(managedQ contracts.ManagedQueue, queueLogger logr.Logger) { + removedItems, err := managedQ.Drain() + if err != nil { + queueLogger.Error(err, "Error during ManagedQueue Drain") + } + + // Finalize all the items that were removed. + getOutcome := func(_ types.QueueItemAccessor) (types.QueueOutcome, error) { + return types.QueueOutcomeEvictedOther, fmt.Errorf("%w: %w", types.ErrEvicted, types.ErrFlowControllerShutdown) + } + sp.finalizeItems(removedItems, queueLogger, getOutcome) + } + sp.processAllQueuesConcurrently("evictAll", processFn) +} + +// processAllQueuesConcurrently iterates over all queues in all priority bands on the shard and executes the given +// `processFn` for each queue using a dynamically sized worker pool. +func (sp *ShardProcessor) processAllQueuesConcurrently( + ctxName string, + processFn func(mq contracts.ManagedQueue, logger logr.Logger), +) { + logger := sp.logger.WithName(ctxName) + + // Phase 1: Collect all queues to be processed into a single slice. + // This avoids holding locks on the shard while processing, and allows us to determine the optimal number of workers. + var queuesToProcess []framework.FlowQueueAccessor + for _, priority := range sp.shard.AllOrderedPriorityLevels() { + band, err := sp.shard.PriorityBandAccessor(priority) + if err != nil { + logger.Error(err, "Failed to get PriorityBandAccessor", "priority", priority) + continue + } + band.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + queuesToProcess = append(queuesToProcess, queue) + return true // Continue iterating. + }) + } + + if len(queuesToProcess) == 0 { + return + } + + // Phase 2: Determine the optimal number of workers. + // We cap the number of workers to a reasonable fixed number to avoid overwhelming the scheduler when many shards are + // running. We also don't need more workers than there are queues. + numWorkers := min(maxCleanupWorkers, len(queuesToProcess)) + + // Phase 3: Create a worker pool to process the queues. + tasks := make(chan framework.FlowQueueAccessor) + + var wg sync.WaitGroup + for range numWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for q := range tasks { + queueLogger := logger.WithValues("flowID", q.FlowSpec().ID, "priority", q.FlowSpec().Priority) + managedQ, err := sp.shard.ManagedQueue(q.FlowSpec().ID, q.FlowSpec().Priority) + if err != nil { + queueLogger.Error(err, "Failed to get ManagedQueue") + continue + } + processFn(managedQ, queueLogger) + } + }() + } + + // Feed the channel with all the queues to be processed. + for _, q := range queuesToProcess { + tasks <- q + } + close(tasks) // Close the channel to signal workers to exit. + wg.Wait() // Wait for all workers to finish. +} + +// finalizeItems is a helper to iterate over a slice of items, safely cast them, and finalize them with an outcome +// determined by the `getOutcome` function. +func (sp *ShardProcessor) finalizeItems( + items []types.QueueItemAccessor, + logger logr.Logger, + getOutcome func(item types.QueueItemAccessor) (types.QueueOutcome, error), +) { + for _, i := range items { + item, ok := i.(*flowItem) + if !ok { + unexpectedItemErr := fmt.Errorf("internal error: item %q of type %T is not a *flowItem", + i.OriginalRequest().ID(), i) + logger.Error(unexpectedItemErr, "Panic condition detected during finalization", "item", i) + continue + } + + outcome, err := getOutcome(i) + item.finalize(outcome, err) + logger.V(logutil.TRACE).Info("Item finalized", "reqID", item.OriginalRequest().ID(), + "outcome", outcome, "err", err) + } +} + +// finalizeExpiredItems is a specialized version of finalizeItems for items that are known to be expired. It determines +// the precise reason for expiry and finalizes the item accordingly. +func (sp *ShardProcessor) finalizeExpiredItems(items []types.QueueItemAccessor, now time.Time, logger logr.Logger) { + getOutcome := func(item types.QueueItemAccessor) (types.QueueOutcome, error) { + // We don't need the `isExpired` boolean here because we know it's true, but this function conveniently returns the + // precise outcome and error. + _, outcome, expiryErr := checkItemExpiry(item, now) + return outcome, fmt.Errorf("%w: %w", types.ErrEvicted, expiryErr) + } + sp.finalizeItems(items, logger, getOutcome) +} diff --git a/pkg/epp/flowcontrol/controller/internal/processor_test.go b/pkg/epp/flowcontrol/controller/internal/processor_test.go new file mode 100644 index 000000000..aeb46b636 --- /dev/null +++ b/pkg/epp/flowcontrol/controller/internal/processor_test.go @@ -0,0 +1,1390 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// +// A Note on the Testing Strategy for `ShardProcessor` +// +// The `ShardProcessor` is a complex concurrent orchestrator. Testing it with concrete implementations would lead to +// flaky, non-deterministic tests. Therefore, we use a high-fidelity `testHarness` with stateful mocks to enable +// reliable and deterministic testing. This is a deliberate and necessary choice for several key reasons: +// +// 1. Deterministic Race Simulation: The harness allows us to pause mock execution at critical moments, making it +// possible to deterministically simulate and verify the processor's behavior during race conditions (e.g., the +// dispatch vs. expiry race). This is impossible with concrete implementations without resorting to unreliable +// sleeps. +// +// 2. Failure Mode Simulation: We can trigger specific, on-demand errors from dependencies to verify the processor's +// resilience and complex error-handling logic (e.g., the `errIntraFlow` circuit breaker). +// +// 3. Interaction and Isolation Testing: Mocks allow us to isolate the `ShardProcessor` from its dependencies. This +// ensures that tests are verifying the processor's orchestration logic (i.e., that it calls its dependencies +// correctly) and are not affected by confounding bugs in those dependencies. +// +// In summary, this strategy is a prerequisite for reliably testing a concurrent engine, not just a simple data +// structure. +// + +package internal + +import ( + "context" + "errors" + "fmt" + "os" + "slices" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts/mocks" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" +) + +const ( + testTTL = 1 * time.Minute + testShortTTL = 20 * time.Millisecond + testCleanupTick = 10 * time.Millisecond + testWaitTimeout = 1 * time.Second +) + +var testFlow = types.FlowSpecification{ID: "flow-a", Priority: 10} + +// TestMain sets up the logger for all tests in the package. +func TestMain(m *testing.M) { + log.SetLogger(zap.New(zap.WriteTo(os.Stderr), zap.UseDevMode(true))) + os.Exit(m.Run()) +} + +// mockClock allows for controlling time in tests. +type mockClock struct { + mu sync.Mutex + currentTime time.Time +} + +func newMockClock() *mockClock { + return &mockClock{currentTime: time.Now()} +} + +func (c *mockClock) Now() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.currentTime +} + +func (c *mockClock) Advance(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.currentTime = c.currentTime.Add(d) +} + +// testHarness provides a unified, mock-based testing environment for the ShardProcessor. It centralizes all mock state +// and provides helper methods for setting up tests and managing the processor's lifecycle. +type testHarness struct { + t *testing.T + *mocks.MockRegistryShard + + // Concurrency and Lifecycle + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + startSignal chan struct{} + + // Core components under test + processor *ShardProcessor + mockClock *mockClock + logger logr.Logger + + // --- Centralized Mock State --- + // The harness's mutex protects the single source of truth for all mock state. + mu sync.Mutex + queues map[string]*mocks.MockManagedQueue // Key: `flowID` + priorityFlows map[uint][]string // Key: `priority`, Val: slice of `flowIDs` + + // Customizable policy logic for tests to override. + interFlowPolicySelectQueue func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) + intraFlowPolicySelectItem func(fqa framework.FlowQueueAccessor) (types.QueueItemAccessor, error) +} + +// newTestHarness creates and wires up a complete testing harness. +func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarness { + t.Helper() + h := &testHarness{ + t: t, + MockRegistryShard: &mocks.MockRegistryShard{}, + mockClock: newMockClock(), + logger: logr.Discard(), + startSignal: make(chan struct{}), + queues: make(map[string]*mocks.MockManagedQueue), + priorityFlows: make(map[uint][]string), + } + + // Wire up the harness to provide the mock implementations for the shard's dependencies. + h.ActiveManagedQueueFunc = h.activeManagedQueue + h.ManagedQueueFunc = h.managedQueue + h.AllOrderedPriorityLevelsFunc = h.allOrderedPriorityLevels + h.PriorityBandAccessorFunc = h.priorityBandAccessor + h.InterFlowDispatchPolicyFunc = h.interFlowDispatchPolicy + h.IntraFlowDispatchPolicyFunc = h.intraFlowDispatchPolicy + + // Provide a default stats implementation that is effectively infinite. + h.StatsFunc = func() contracts.ShardStats { + return contracts.ShardStats{ + TotalCapacityBytes: 1e9, + PerPriorityBandStats: map[uint]contracts.PriorityBandStats{ + testFlow.Priority: {CapacityBytes: 1e9}, + }, + } + } + + // Use a default pass-through filter. + filter := func( + ctx context.Context, + band framework.PriorityBandAccessor, + logger logr.Logger, + ) (map[string]struct{}, bool) { + return nil, false + } + h.processor = NewShardProcessor(h, filter, h.mockClock, expiryCleanupInterval, h.logger) + require.NotNil(t, h.processor, "NewShardProcessor should not return nil") + return h +} + +// --- Test Lifecycle and Helpers --- + +// Start prepares the processor to run in a background goroutine but pauses it until Go() is called. +func (h *testHarness) Start() { + h.t.Helper() + h.ctx, h.cancel = context.WithCancel(context.Background()) + h.wg.Add(1) + go func() { + defer h.wg.Done() + <-h.startSignal // Wait for the signal to begin execution. + h.processor.Run(h.ctx) + }() +} + +// Go unpauses the processor's main Run loop. +func (h *testHarness) Go() { + h.t.Helper() + close(h.startSignal) +} + +// Stop gracefully shuts down the processor and waits for it to terminate. +func (h *testHarness) Stop() { + h.t.Helper() + if h.cancel != nil { + h.cancel() + } + h.wg.Wait() +} + +// waitForFinalization blocks until an item is finalized or a timeout is reached. +func (h *testHarness) waitForFinalization(item *flowItem) (types.QueueOutcome, error) { + h.t.Helper() + select { + case <-item.Done(): + return item.FinalState() + case <-time.After(testWaitTimeout): + h.t.Fatalf("Timed out waiting for item %q to be finalized", item.OriginalRequest().ID()) + return types.QueueOutcomeNotYetFinalized, nil + } +} + +// newTestItem creates a new flowItem for testing purposes. +func (h *testHarness) newTestItem(id, flowID string, ttl time.Duration) *flowItem { + h.t.Helper() + ctx := log.IntoContext(context.Background(), h.logger) + req := typesmocks.NewMockFlowControlRequest(100, id, flowID, ctx) + return NewItem(req, ttl, h.mockClock.Now()) +} + +// addQueue centrally registers a new mock queue for a given flow, ensuring all harness components are aware of it. +func (h *testHarness) addQueue(spec types.FlowSpecification) *mocks.MockManagedQueue { + h.t.Helper() + h.mu.Lock() + defer h.mu.Unlock() + + mockQueue := &mocks.MockManagedQueue{FlowSpecV: spec} + h.queues[spec.ID] = mockQueue + + // Add the `flowID` to the correct priority band, creating the band if needed. + h.priorityFlows[spec.Priority] = append(h.priorityFlows[spec.Priority], spec.ID) + + return mockQueue +} + +// --- Mock Interface Implementations --- + +// activeManagedQueue provides the mock implementation for the `RegistryShard` interface. +func (h *testHarness) activeManagedQueue(flowID string) (contracts.ManagedQueue, error) { + h.mu.Lock() + defer h.mu.Unlock() + if q, ok := h.queues[flowID]; ok { + return q, nil + } + return nil, fmt.Errorf("test setup error: no active queue for flow %q", flowID) +} + +// managedQueue provides the mock implementation for the `RegistryShard` interface. +func (h *testHarness) managedQueue(flowID string, priority uint) (contracts.ManagedQueue, error) { + h.mu.Lock() + defer h.mu.Unlock() + if q, ok := h.queues[flowID]; ok && q.FlowSpec().Priority == priority { + return q, nil + } + return nil, fmt.Errorf("test setup error: no queue for %q at priority %d", flowID, priority) +} + +// allOrderedPriorityLevels provides the mock implementation for the `RegistryShard` interface. +func (h *testHarness) allOrderedPriorityLevels() []uint { + h.mu.Lock() + defer h.mu.Unlock() + prios := make([]uint, 0, len(h.priorityFlows)) + for p := range h.priorityFlows { + prios = append(prios, p) + } + slices.Sort(prios) + return prios +} + +// priorityBandAccessor provides the mock implementation for the `RegistryShard` interface. It acts as a factory for a +// fully-configured, stateless mock that is safe for concurrent use. +func (h *testHarness) priorityBandAccessor(p uint) (framework.PriorityBandAccessor, error) { + band := &frameworkmocks.MockPriorityBandAccessor{PriorityV: p} + + // Safely get a snapshot of the flow IDs under a lock. + h.mu.Lock() + flowIDsForPriority := h.priorityFlows[p] + h.mu.Unlock() + + // Configure the mock's behavior with a closure that reads from the harness's centralized, thread-safe state. + band.IterateQueuesFunc = func(cb func(fqa framework.FlowQueueAccessor) bool) { + // This closure safely iterates over the snapshot of flow IDs. + for _, id := range flowIDsForPriority { + // Get the queue using the thread-safe `managedQueue` method. + q, err := h.managedQueue(id, p) + if err == nil && q != nil { + mq := q.(*mocks.MockManagedQueue) + if !cb(mq.FlowQueueAccessor()) { + break + } + } + } + } + return band, nil +} + +// interFlowDispatchPolicy provides the mock implementation for the `contracts.RegistryShard` interface. +func (h *testHarness) interFlowDispatchPolicy(p uint) (framework.InterFlowDispatchPolicy, error) { + policy := &frameworkmocks.MockInterFlowDispatchPolicy{} + // If the test provided a custom implementation, use it. + if h.interFlowPolicySelectQueue != nil { + policy.SelectQueueFunc = h.interFlowPolicySelectQueue + return policy, nil + } + + // Otherwise, use a default implementation that selects the first non-empty queue. + policy.SelectQueueFunc = func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) { + var selectedQueue framework.FlowQueueAccessor + band.IterateQueues(func(fqa framework.FlowQueueAccessor) bool { + if fqa.Len() > 0 { + selectedQueue = fqa + return false // stop iterating + } + return true // continue + }) + return selectedQueue, nil + } + return policy, nil +} + +// intraFlowDispatchPolicy provides the mock implementation for the `contracts.RegistryShard` interface. +func (h *testHarness) intraFlowDispatchPolicy(flowID string, priority uint) (framework.IntraFlowDispatchPolicy, error) { + policy := &frameworkmocks.MockIntraFlowDispatchPolicy{} + // If the test provided a custom implementation, use it. + if h.intraFlowPolicySelectItem != nil { + policy.SelectItemFunc = h.intraFlowPolicySelectItem + return policy, nil + } + + // Otherwise, use a default implementation that selects the head of the queue. + policy.SelectItemFunc = func(fqa framework.FlowQueueAccessor) (types.QueueItemAccessor, error) { + return fqa.PeekHead() + } + return policy, nil +} + +// TestShardProcessor contains all tests for the `ShardProcessor`. +func TestShardProcessor(t *testing.T) { + t.Parallel() + + // Lifecycle tests use the processor's main `Run` loop to verify the complete end-to-end lifecycle of a request, from + // `Enqueue` to its final outcome. + t.Run("Lifecycle", func(t *testing.T) { + t.Parallel() + + t.Run("should dispatch item successfully", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-dispatch-success", testFlow.ID, testTTL) + h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + + // --- ACT --- + h.Start() + defer h.Stop() + h.processor.Enqueue(item) + h.Go() + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeDispatched, outcome, "The final outcome should be Dispatched") + require.NoError(t, err, "A successful dispatch should not produce an error") + }) + + t.Run("should reject item when at capacity", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-capacity-reject", testFlow.ID, testTTL) + h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + h.StatsFunc = func() contracts.ShardStats { + return contracts.ShardStats{PerPriorityBandStats: map[uint]contracts.PriorityBandStats{ + testFlow.Priority: {CapacityBytes: 50}, // 50 is less than item size of 100 + }} + } + + // --- ACT --- + h.Start() + defer h.Stop() + h.processor.Enqueue(item) + h.Go() + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeRejectedCapacity, outcome, "The final outcome should be RejectedCapacity") + require.Error(t, err, "A capacity rejection should produce an error") + assert.ErrorIs(t, err, types.ErrQueueAtCapacity, "The error should be of type ErrQueueAtCapacity") + }) + + t.Run("should reject item on registry lookup failure", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-lookup-fail-reject", testFlow.ID, testTTL) + registryErr := errors.New("test registry lookup error") + h.ActiveManagedQueueFunc = func(flowID string) (contracts.ManagedQueue, error) { + return nil, registryErr + } + + // --- ACT --- + h.Start() + defer h.Stop() + h.processor.Enqueue(item) + h.Go() + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "The final outcome should be RejectedOther") + require.Error(t, err, "A rejection from a registry failure should produce an error") + assert.ErrorIs(t, err, registryErr, "The underlying registry error should be preserved") + }) + + t.Run("should reject item if enqueued during shutdown", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-shutdown-reject", testFlow.ID, testTTL) + h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + + // --- ACT --- + h.Start() + h.Go() + // Stop the processor, then immediately try to enqueue. + h.Stop() + h.processor.Enqueue(item) + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "The outcome should be RejectedOther") + require.Error(t, err, "An eviction on shutdown should produce an error") + assert.ErrorIs(t, err, types.ErrFlowControllerShutdown, "The error should be of type ErrFlowControllerShutdown") + }) + + t.Run("should evict item on TTL expiry via background cleanup", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-expired-evict", testFlow.ID, testShortTTL) + h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + + // --- ACT --- + h.Start() + defer h.Stop() + h.processor.Enqueue(item) + h.Go() + + // Let time pass for the item to expire and for the background cleanup to run. + h.mockClock.Advance(testShortTTL * 2) + time.Sleep(testCleanupTick * 3) // Allow the cleanup goroutine time to run. + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The final outcome should be EvictedTTL") + require.Error(t, err, "A TTL eviction should produce an error") + assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") + }) + + t.Run("should evict item at moment of dispatch if TTL has expired", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, 1*time.Hour) // Disable background cleanup to isolate dispatch logic. + item := h.newTestItem("req-expired-dispatch-evict", testFlow.ID, testShortTTL) + mockQueue := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + require.NoError(t, mockQueue.Add(item), "Adding item to mock queue should not fail") + + // Have the policy select the item, but then advance time so it's expired by the time dispatchItem actually runs. + h.interFlowPolicySelectQueue = func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) { + h.mockClock.Advance(testShortTTL * 2) + return mockQueue.FlowQueueAccessor(), nil + } + + // --- ACT --- + h.Start() + defer h.Stop() + h.Go() + + // The run loop will pick up the item and attempt dispatch, which will fail internally. + // We need a small sleep to allow the non-blocking run loop to process. + time.Sleep(50 * time.Millisecond) + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The final outcome should be EvictedTTL") + require.Error(t, err, "An eviction at dispatch time should produce an error") + assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") + }) + + t.Run("should evict item on context cancellation", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + ctx, cancel := context.WithCancel(context.Background()) + req := typesmocks.NewMockFlowControlRequest(100, "req-ctx-cancel", testFlow.ID, ctx) + item := NewItem(req, testTTL, h.mockClock.Now()) + h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + + // --- ACT --- + h.Start() + defer h.Stop() + h.processor.Enqueue(item) + h.Go() + cancel() // Cancel the context *after* the item is enqueued. + time.Sleep(testCleanupTick * 3) // Allow the cleanup goroutine time to run. + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, outcome, + "The outcome should be EvictedContextCancelled") + require.Error(t, err, "A context cancellation eviction should produce an error") + assert.ErrorIs(t, err, types.ErrContextCancelled, "The error should be of type ErrContextCancelled") + }) + + t.Run("should evict a queued item on shutdown", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-shutdown-evict", testFlow.ID, testTTL) + mockQueue := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + require.NoError(t, mockQueue.Add(item), "Adding item to mock queue should not fail") + + // Prevent dispatch to ensure we test shutdown eviction, not a successful dispatch. + h.interFlowPolicySelectQueue = func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) { + return nil, nil + } + + // --- ACT --- + h.Start() + h.Go() + h.Stop() // Stop immediately to trigger eviction. + + // --- ASSERT --- + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeEvictedOther, outcome, "The outcome should be EvictedOther") + require.Error(t, err, "An eviction on shutdown should produce an error") + assert.ErrorIs(t, err, types.ErrFlowControllerShutdown, "The error should be of type ErrFlowControllerShutdown") + }) + + t.Run("should handle concurrent enqueues and dispatch all items", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + const numConcurrentItems = 20 + q := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + itemsToTest := make([]*flowItem, 0, numConcurrentItems) + for i := 0; i < numConcurrentItems; i++ { + item := h.newTestItem(fmt.Sprintf("req-concurrent-%d", i), testFlow.ID, testTTL) + itemsToTest = append(itemsToTest, item) + } + + // --- ACT --- + h.Start() + defer h.Stop() + var wg sync.WaitGroup + for _, item := range itemsToTest { + wg.Add(1) + go func(fi *flowItem) { + defer wg.Done() + h.processor.Enqueue(fi) + }(item) + } + h.Go() + wg.Wait() // Wait for all enqueues to finish. + + // --- ASSERT --- + for _, item := range itemsToTest { + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeDispatched, outcome, + "Item %q should have been dispatched", item.OriginalRequest().ID()) + assert.NoError(t, err, + "A successful dispatch of item %q should not produce an error", item.OriginalRequest().ID()) + } + assert.Equal(t, 0, q.Len(), "The mock queue should be empty at the end of the test") + }) + + t.Run("should guarantee exactly-once finalization during dispatch vs. expiry race", func(t *testing.T) { + t.Parallel() + + // --- ARRANGE --- + h := newTestHarness(t, 1*time.Hour) // Disable background cleanup to isolate the race. + item := h.newTestItem("req-race", testFlow.ID, testShortTTL) + q := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + + // Use channels to pause the dispatch cycle right before it would remove the item. + policyCanProceed := make(chan struct{}) + itemIsBeingDispatched := make(chan struct{}) + + require.NoError(t, q.Add(item)) // Add the item directly to the queue. + + // Override the queue's `RemoveFunc` to pause the dispatch goroutine at a critical moment. + q.RemoveFunc = func(h types.QueueItemHandle) (types.QueueItemAccessor, error) { + close(itemIsBeingDispatched) // 1. Signal that dispatch is happening. + <-policyCanProceed // 2. Wait for the test to tell us to continue. + // 4. After we unblock, the item will have already been finalized by the cleanup logic, so we simulate the + // real-world outcome of a failed remove. + return nil, fmt.Errorf("item with handle %v not found", h) + } + + // --- ACT --- + h.Start() + defer h.Stop() + h.Go() + + // Wait for the dispatch cycle to select our item and pause inside our mock `RemoveFunc`. + <-itemIsBeingDispatched + + // 3. The dispatch goroutine is now paused. We can now safely win the "race" by running cleanup logic. + h.mockClock.Advance(testShortTTL * 2) + h.processor.cleanupExpired(h.mockClock.Now()) // This will remove and finalize the item. + + // 5. Un-pause the dispatch goroutine. It will now fail to remove the item and the `dispatchCycle` will + // correctly conclude without finalizing the item a second time. + close(policyCanProceed) + + // --- ASSERT --- + // The item's final state should be from the cleanup logic (EvictedTTL), not the dispatch logic. + outcome, err := h.waitForFinalization(item) + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The outcome should be EvictedTTL from the cleanup routine") + require.Error(t, err, "A TTL eviction should produce an error") + assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") + }) + + t.Run("should shut down cleanly on context cancellation", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + stopped := make(chan struct{}) + + // --- ACT --- + h.Start() + h.Go() + + // Use a separate goroutine to wait for the processor to fully stop. + go func() { + h.Stop() // This cancels the context and waits on the WaitGroup. + close(stopped) + }() + + // --- ASSERT --- + select { + case <-stopped: + // Success: The Stop() call completed without a deadlock. + case <-time.After(testWaitTimeout): + t.Fatal("Test timed out waiting for processor to stop") + } + }) + + t.Run("should not panic on nil item from enqueue channel", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + // This test is primarily checking that the processor doesn't panic or error on a nil input. + + // --- ACT --- + h.Start() + defer h.Stop() + h.Go() + h.processor.Enqueue(nil) + + // --- ASSERT --- + // Allow a moment for the processor to potentially process the nil item. + // A successful test is one that completes without panicking. + time.Sleep(50 * time.Millisecond) + }) + }) + + t.Run("Unit", func(t *testing.T) { + t.Parallel() + + t.Run("enqueue", func(t *testing.T) { + t.Parallel() + testErr := errors.New("something went wrong") + + testCases := []struct { + name string + setupHarness func(h *testHarness) + item *flowItem + assert func(t *testing.T, h *testHarness, item *flowItem) + }{ + { + name: "should reject item on registry queue lookup failure", + setupHarness: func(h *testHarness) { + h.ActiveManagedQueueFunc = func(string) (contracts.ManagedQueue, error) { return nil, testErr } + }, + assert: func(t *testing.T, h *testHarness, item *flowItem) { + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "Outcome should be RejectedOther") + require.Error(t, err, "An error should be returned") + assert.ErrorIs(t, err, testErr, "The underlying error should be preserved") + }, + }, + { + name: "should reject item on registry priority band lookup failure", + setupHarness: func(h *testHarness) { + h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + h.PriorityBandAccessorFunc = func(uint) (framework.PriorityBandAccessor, error) { return nil, testErr } + }, + assert: func(t *testing.T, h *testHarness, item *flowItem) { + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "Outcome should be RejectedOther") + require.Error(t, err, "An error should be returned") + assert.ErrorIs(t, err, testErr, "The underlying error should be preserved") + }, + }, + { + name: "should reject item on queue add failure", + setupHarness: func(h *testHarness) { + mockQueue := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + mockQueue.AddFunc = func(types.QueueItemAccessor) error { return testErr } + }, + assert: func(t *testing.T, h *testHarness, item *flowItem) { + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "Outcome should be RejectedOther") + require.Error(t, err, "An error should be returned") + assert.ErrorIs(t, err, testErr, "The underlying error should be preserved") + }, + }, + { + name: "should ignore an already-finalized item", + setupHarness: func(h *testHarness) { + mockQueue := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + var addCallCount int + mockQueue.AddFunc = func(item types.QueueItemAccessor) error { + addCallCount++ + return nil + } + // Use Cleanup to assert after the test logic has run. + t.Cleanup(func() { + assert.Equal(t, 0, addCallCount, "Queue.Add should not have been called for a finalized item") + }) + }, + item: func() *flowItem { + // Create a pre-finalized item. + item := newTestHarness(t, 0).newTestItem("req-finalized", testFlow.ID, testTTL) + item.finalize(types.QueueOutcomeDispatched, nil) + return item + }(), + assert: func(t *testing.T, h *testHarness, item *flowItem) { + // The item was already finalized, so its state should not change. + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeDispatched, outcome, "Outcome should remain unchanged") + assert.NoError(t, err, "Error should remain unchanged") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newTestHarness(t, testCleanupTick) + tc.setupHarness(h) + item := tc.item + if item == nil { + item = h.newTestItem("req-enqueue-test", testFlow.ID, testTTL) + } + h.processor.enqueue(item) + tc.assert(t, h, item) + }) + } + }) + + t.Run("hasCapacity", func(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + itemByteSize uint64 + stats contracts.ShardStats + expectHasCap bool + }{ + { + name: "should allow zero-size item even if full", + itemByteSize: 0, + stats: contracts.ShardStats{TotalByteSize: 100, TotalCapacityBytes: 100}, + expectHasCap: true, + }, + { + name: "should deny item if shard capacity exceeded", + itemByteSize: 1, + stats: contracts.ShardStats{TotalByteSize: 100, TotalCapacityBytes: 100}, + expectHasCap: false, + }, + { + name: "should deny item if band capacity exceeded", + itemByteSize: 1, + stats: contracts.ShardStats{ + TotalCapacityBytes: 200, TotalByteSize: 100, + PerPriorityBandStats: map[uint]contracts.PriorityBandStats{ + testFlow.Priority: {ByteSize: 50, CapacityBytes: 50}, + }, + }, + expectHasCap: false, + }, + { + name: "should deny item if band stats are missing", + itemByteSize: 1, + stats: contracts.ShardStats{ + TotalCapacityBytes: 200, TotalByteSize: 100, + PerPriorityBandStats: map[uint]contracts.PriorityBandStats{}, // Missing stats for priority 10 + }, + expectHasCap: false, + }, + { + name: "should allow item if both shard and band have capacity", + itemByteSize: 10, + stats: contracts.ShardStats{ + TotalCapacityBytes: 200, TotalByteSize: 100, + PerPriorityBandStats: map[uint]contracts.PriorityBandStats{ + testFlow.Priority: {ByteSize: 50, CapacityBytes: 100}, + }, + }, + expectHasCap: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newTestHarness(t, testCleanupTick) + h.StatsFunc = func() contracts.ShardStats { return tc.stats } + hasCap := h.processor.hasCapacity(testFlow.Priority, tc.itemByteSize) + assert.Equal(t, tc.expectHasCap, hasCap, "Capacity check result should match expected value") + }) + } + }) + + t.Run("dispatchCycle", func(t *testing.T) { + t.Parallel() + + t.Run("should handle various policy and registry scenarios", func(t *testing.T) { + t.Parallel() + policyErr := errors.New("policy failure") + registryErr := errors.New("registry error") + specA := types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority} + + testCases := []struct { + name string + setupHarness func(h *testHarness) + expectDidDispatch bool + }{ + { + name: "should do nothing if no items are queued", + setupHarness: func(h *testHarness) { + h.addQueue(specA) // Add a queue, but no items. + }, + expectDidDispatch: false, + }, + { + name: "should stop dispatching when filter signals pause", + setupHarness: func(h *testHarness) { + // Add an item that *could* be dispatched to prove the pause is effective. + q := h.addQueue(types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority}) + require.NoError(t, q.Add(h.newTestItem("item", testFlow.ID, testTTL))) + h.processor.dispatchFilter = func( + _ context.Context, + _ framework.PriorityBandAccessor, + _ logr.Logger, + ) (map[string]struct{}, bool) { + return nil, true // Signal pause. + } + }, + expectDidDispatch: false, + }, + { + name: "should skip band on priority band accessor error", + setupHarness: func(h *testHarness) { + h.PriorityBandAccessorFunc = func(uint) (framework.PriorityBandAccessor, error) { + return nil, registryErr + } + }, + expectDidDispatch: false, + }, + { + name: "should skip band on inter-flow policy error", + setupHarness: func(h *testHarness) { + h.addQueue(specA) + h.interFlowPolicySelectQueue = func( + _ framework.PriorityBandAccessor, + ) (framework.FlowQueueAccessor, error) { + return nil, policyErr + } + }, + expectDidDispatch: false, + }, + { + name: "should skip band if inter-flow policy returns no queue", + setupHarness: func(h *testHarness) { + q := h.addQueue(specA) + require.NoError(t, q.Add(h.newTestItem("item", testFlow.ID, testTTL))) + h.interFlowPolicySelectQueue = func( + _ framework.PriorityBandAccessor, + ) (framework.FlowQueueAccessor, error) { + return nil, nil // Simulate band being empty or policy choosing to pause. + } + }, + expectDidDispatch: false, + }, + { + name: "should skip band on intra-flow policy error", + setupHarness: func(h *testHarness) { + q := h.addQueue(specA) + require.NoError(t, q.Add(h.newTestItem("item", testFlow.ID, testTTL))) + h.interFlowPolicySelectQueue = func( + _ framework.PriorityBandAccessor, + ) (framework.FlowQueueAccessor, error) { + return q.FlowQueueAccessor(), nil + } + h.intraFlowPolicySelectItem = func(_ framework.FlowQueueAccessor) (types.QueueItemAccessor, error) { + return nil, policyErr + } + }, + expectDidDispatch: false, + }, + { + name: "should skip band if intra-flow policy returns no item", + setupHarness: func(h *testHarness) { + q := h.addQueue(specA) + require.NoError(t, q.Add(h.newTestItem("item", testFlow.ID, testTTL))) + h.interFlowPolicySelectQueue = func( + _ framework.PriorityBandAccessor, + ) (framework.FlowQueueAccessor, error) { + return q.FlowQueueAccessor(), nil + } + h.intraFlowPolicySelectItem = func(_ framework.FlowQueueAccessor) (types.QueueItemAccessor, error) { + return nil, nil // Simulate queue being empty or policy choosing to pause. + } + }, + expectDidDispatch: false, + }, + { + name: "should continue to lower priority band on inter-flow policy error", + setupHarness: func(h *testHarness) { + // Create a failing high-priority queue and a working low-priority queue. + specHigh := types.FlowSpecification{ID: "flow-high", Priority: testFlow.Priority} + specLow := types.FlowSpecification{ID: "flow-low", Priority: 20} + h.addQueue(specHigh) + qLow := h.addQueue(specLow) + + itemLow := h.newTestItem("item-low", specLow.ID, testTTL) + require.NoError(t, qLow.Add(itemLow)) + + h.interFlowPolicySelectQueue = func( + band framework.PriorityBandAccessor, + ) (framework.FlowQueueAccessor, error) { + if band.Priority() == testFlow.Priority { + return nil, errors.New("policy failure") // Fail high-priority. + } + // Succeed for low-priority. + q, _ := h.managedQueue(specLow.ID, specLow.Priority) + return q.FlowQueueAccessor(), nil + } + }, + expectDidDispatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newTestHarness(t, testCleanupTick) + tc.setupHarness(h) + dispatched := h.processor.dispatchCycle(context.Background()) + assert.Equal(t, tc.expectDidDispatch, dispatched, "Dispatch result should match expected value") + }) + } + }) + + t.Run("should use filtered view of queues when filter is active", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + specA := types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority} + specB := types.FlowSpecification{ID: "flow-b", Priority: testFlow.Priority} + h.addQueue(specA) + qB := h.addQueue(specB) + itemB := h.newTestItem("item-b", specB.ID, testTTL) + require.NoError(t, qB.Add(itemB)) + + // This filter only allows flow-b. + h.processor.dispatchFilter = func( + _ context.Context, + _ framework.PriorityBandAccessor, + _ logr.Logger, + ) (map[string]struct{}, bool) { + return map[string]struct{}{specB.ID: {}}, false + } + + // This policy will be given the filtered view, so it should only see flow-b. + h.interFlowPolicySelectQueue = func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) { + var flowIDs []string + band.IterateQueues(func(fqa framework.FlowQueueAccessor) bool { + flowIDs = append(flowIDs, fqa.FlowSpec().ID) + return true + }) + // This is the core assertion of the test. + require.ElementsMatch(t, []string{specB.ID}, flowIDs, "Policy should only see the filtered flow") + + // Select flow-b to prove the chain works. + q, _ := h.managedQueue(specB.ID, specB.Priority) + return q.FlowQueueAccessor(), nil + } + + // --- ACT --- + dispatched := h.processor.dispatchCycle(context.Background()) + + // --- ASSERT --- + assert.True(t, dispatched, "An item should have been dispatched from the filtered flow") + outcome, err := itemB.FinalState() + assert.Equal(t, types.QueueOutcomeDispatched, outcome, "The dispatched item's outcome should be correct") + assert.NoError(t, err, "The dispatched item should not have an error") + }) + + t.Run("should guarantee strict priority by starving lower priority items", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + specHigh := types.FlowSpecification{ID: "flow-high", Priority: 10} + specLow := types.FlowSpecification{ID: "flow-low", Priority: 20} + qHigh := h.addQueue(specHigh) + qLow := h.addQueue(specLow) + + const numItems = 3 + highPrioItems := make([]*flowItem, numItems) + lowPrioItems := make([]*flowItem, numItems) + for i := range numItems { + // Add high priority items. + itemH := h.newTestItem(fmt.Sprintf("req-high-%d", i), specHigh.ID, testTTL) + require.NoError(t, qHigh.Add(itemH)) + highPrioItems[i] = itemH + + // Add low priority items. + itemL := h.newTestItem(fmt.Sprintf("req-low-%d", i), specLow.ID, testTTL) + require.NoError(t, qLow.Add(itemL)) + lowPrioItems[i] = itemL + } + + // --- ACT & ASSERT --- + // First, dispatch all high-priority items. + for i := range numItems { + dispatched := h.processor.dispatchCycle(context.Background()) + require.True(t, dispatched, "Expected a high-priority dispatch on cycle %d", i+1) + } + + // Verify all high-priority items are gone and low-priority items remain. + for _, item := range highPrioItems { + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeDispatched, outcome, "High-priority item should be dispatched") + assert.NoError(t, err, "Dispatched high-priority item should not have an error") + } + assert.Equal(t, numItems, qLow.Len(), "Low-priority queue should still be full") + + // Next, dispatch all low-priority items. + for i := range numItems { + dispatched := h.processor.dispatchCycle(context.Background()) + require.True(t, dispatched, "Expected a low-priority dispatch on cycle %d", i+1) + } + assert.Equal(t, 0, qLow.Len(), "Low-priority queue should be empty") + }) + }) + + t.Run("dispatchItem", func(t *testing.T) { + t.Parallel() + + t.Run("should fail on registry errors", func(t *testing.T) { + t.Parallel() + registryErr := errors.New("registry error") + + testCases := []struct { + name string + setupMocks func(h *testHarness) + expectedErr error + }{ + { + name: "on ManagedQueue lookup failure", + setupMocks: func(h *testHarness) { + h.ManagedQueueFunc = func(string, uint) (contracts.ManagedQueue, error) { return nil, registryErr } + }, + expectedErr: registryErr, + }, + { + name: "on queue remove failure", + setupMocks: func(h *testHarness) { + h.ManagedQueueFunc = func(string, uint) (contracts.ManagedQueue, error) { + return &mocks.MockManagedQueue{ + RemoveFunc: func(types.QueueItemHandle) (types.QueueItemAccessor, error) { + return nil, registryErr + }, + }, nil + } + }, + expectedErr: registryErr, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newTestHarness(t, testCleanupTick) + tc.setupMocks(h) + item := h.newTestItem("req-dispatch-fail", testFlow.ID, testTTL) + err := h.processor.dispatchItem(item, testFlow.Priority, h.logger) + require.Error(t, err, "dispatchItem should return an error") + assert.ErrorIs(t, err, tc.expectedErr, "The underlying registry error should be preserved") + }) + } + }) + + t.Run("should evict item that expires at moment of dispatch", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-expired-dispatch", testFlow.ID, testShortTTL) + + h.ManagedQueueFunc = func(string, uint) (contracts.ManagedQueue, error) { + return &mocks.MockManagedQueue{ + RemoveFunc: func(types.QueueItemHandle) (types.QueueItemAccessor, error) { + return item, nil + }, + }, nil + } + + // --- ACT --- + h.mockClock.Advance(testShortTTL * 2) // Make the item expire. + err := h.processor.dispatchItem(item, testFlow.Priority, h.logger) + + // --- ASSERT --- + // First, check the error returned by `dispatchItem`. + require.Error(t, err, "dispatchItem should return an error for an expired item") + assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") + + // Second, check the final state of the item itself. + outcome, finalErr := item.FinalState() + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The item's final outcome should be EvictedTTL") + require.Error(t, finalErr, "The item's final state should contain an error") + assert.ErrorIs(t, finalErr, types.ErrTTLExpired, "The item's final error should be of type ErrTTLExpired") + }) + + t.Run("should panic if queue returns item of wrong type", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + badItem := &typesmocks.MockQueueItemAccessor{ + OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "bad-item", "", context.Background()), + } + + h.ManagedQueueFunc = func(string, uint) (contracts.ManagedQueue, error) { + return &mocks.MockManagedQueue{ + RemoveFunc: func(types.QueueItemHandle) (types.QueueItemAccessor, error) { + return badItem, nil + }, + }, nil + } + + itemToDispatch := h.newTestItem("req-dispatch-panic", testFlow.ID, testTTL) + expectedPanicMsg := fmt.Sprintf("%s: internal error: item %q of type %T is not a *flowItem", + errIntraFlow, "bad-item", badItem) + + // --- ACT & ASSERT --- + assert.PanicsWithError(t, expectedPanicMsg, func() { + _ = h.processor.dispatchItem(itemToDispatch, testFlow.Priority, h.logger) + }, "A type mismatch from a queue should cause a panic") + }) + }) + + t.Run("cleanup and utility methods", func(t *testing.T) { + t.Parallel() + + t.Run("should remove and finalize expired items", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + spec := types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority} + // Create an item that is already expired relative to the cleanup time. + item := h.newTestItem("req-expired", testFlow.ID, 1*time.Millisecond) + q := h.addQueue(spec) + require.NoError(t, q.Add(item)) + cleanupTime := h.mockClock.Now().Add(10 * time.Millisecond) + + // --- ACT --- + h.processor.cleanupExpired(cleanupTime) + + // --- ASSERT --- + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "Item outcome should be EvictedTTL") + require.Error(t, err, "Item should have an error") + assert.ErrorIs(t, err, types.ErrTTLExpired, "Item error should be ErrTTLExpired") + }) + + t.Run("should evict all items on shutdown", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + spec := types.FlowSpecification{ID: testFlow.ID, Priority: testFlow.Priority} + item := h.newTestItem("req-pending", testFlow.ID, testTTL) + q := h.addQueue(spec) + require.NoError(t, q.Add(item)) + + // --- ACT --- + h.processor.evictAll() + + // --- ASSERT --- + outcome, err := item.FinalState() + assert.Equal(t, types.QueueOutcomeEvictedOther, outcome, "Item outcome should be EvictedOther") + require.Error(t, err, "Item should have an error") + assert.ErrorIs(t, err, types.ErrFlowControllerShutdown, "Item error should be ErrFlowControllerShutdown") + }) + + t.Run("should handle registry errors gracefully during concurrent processing", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + h.AllOrderedPriorityLevelsFunc = func() []uint { return []uint{testFlow.Priority} } + h.PriorityBandAccessorFunc = func(p uint) (framework.PriorityBandAccessor, error) { + return nil, errors.New("registry error") + } + + // --- ACT & ASSERT --- + // The test passes if this call completes without panicking. + assert.NotPanics(t, func() { + h.processor.processAllQueuesConcurrently("test", func(mq contracts.ManagedQueue, logger logr.Logger) {}) + }, "processAllQueuesConcurrently should not panic on registry errors") + }) + + t.Run("should handle items of an unexpected type gracefully during finalization", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := &typesmocks.MockQueueItemAccessor{ + OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "bad-item", testFlow.ID, context.Background()), + } + items := []types.QueueItemAccessor{item} + + // --- ACT & ASSERT --- + // The test passes if this call completes without panicking. + assert.NotPanics(t, func() { + getOutcome := func(types.QueueItemAccessor) (types.QueueOutcome, error) { + return types.QueueOutcomeEvictedOther, nil + } + h.processor.finalizeItems(items, h.logger, getOutcome) + }, "finalizeItems should not panic on unexpected item types") + }) + + t.Run("should process all queues with a worker pool", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + + // Create more queues than the fixed number of cleanup workers to ensure the pooling logic is exercised. + const numQueues = maxCleanupWorkers + 5 + var processedCount atomic.Int32 + + for i := range numQueues { + spec := types.FlowSpecification{ + ID: fmt.Sprintf("flow-%d", i), + Priority: testFlow.Priority, + } + h.addQueue(spec) + } + + processFn := func(mq contracts.ManagedQueue, logger logr.Logger) { + processedCount.Add(1) + } + + // --- ACT --- + h.processor.processAllQueuesConcurrently("test-worker-pool", processFn) + + // --- ASSERT --- + assert.Equal(t, int32(numQueues), processedCount.Load(), + "The number of processed queues should match the number created") + }) + }) + }) +} + +func TestCheckItemExpiry(t *testing.T) { + t.Parallel() + + // --- ARRANGE --- + now := time.Now() + ctxCancelled, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately. + + testCases := []struct { + name string + item types.QueueItemAccessor + now time.Time + expectExpired bool + expectOutcome types.QueueOutcome + expectErr error + }{ + { + name: "should not be expired if TTL is not reached and context is active", + item: NewItem( + typesmocks.NewMockFlowControlRequest(100, "req-not-expired", "", context.Background()), + testTTL, + now), + now: now.Add(30 * time.Second), + expectExpired: false, + expectOutcome: types.QueueOutcomeNotYetFinalized, + expectErr: nil, + }, + { + name: "should not be expired if TTL is disabled (0)", + item: NewItem( + typesmocks.NewMockFlowControlRequest(100, "req-not-expired-no-ttl", "", context.Background()), + 0, + now), + now: now.Add(30 * time.Second), + expectExpired: false, + expectOutcome: types.QueueOutcomeNotYetFinalized, + expectErr: nil, + }, + { + name: "should be expired if TTL is exceeded", + item: NewItem( + typesmocks.NewMockFlowControlRequest(100, "req-ttl-expired", "", context.Background()), + time.Second, + now), + now: now.Add(2 * time.Second), + expectExpired: true, + expectOutcome: types.QueueOutcomeEvictedTTL, + expectErr: types.ErrTTLExpired, + }, + { + name: "should be expired if context is cancelled", + item: NewItem( + typesmocks.NewMockFlowControlRequest(100, "req-ctx-cancelled", "", ctxCancelled), + testTTL, + now), + now: now, + expectExpired: true, + expectOutcome: types.QueueOutcomeEvictedContextCancelled, + expectErr: types.ErrContextCancelled, + }, + { + name: "should be expired if already finalized", + item: func() types.QueueItemAccessor { + i := NewItem(typesmocks.NewMockFlowControlRequest(100, "req-finalized", "", context.Background()), testTTL, now) + i.finalize(types.QueueOutcomeDispatched, nil) + return i + }(), + now: now, + expectExpired: true, + expectOutcome: types.QueueOutcomeDispatched, + expectErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + // --- ACT --- + isExpired, outcome, err := checkItemExpiry(tc.item, tc.now) + + // --- ASSERT --- + assert.Equal(t, tc.expectExpired, isExpired, "Expired status should match expected value") + assert.Equal(t, tc.expectOutcome, outcome, "Outcome should match expected value") + + if tc.expectErr != nil { + require.Error(t, err, "An error was expected") + // Use ErrorIs for sentinel errors, ErrorContains for general messages. + if errors.Is(tc.expectErr, types.ErrTTLExpired) || errors.Is(tc.expectErr, types.ErrContextCancelled) { + assert.ErrorIs(t, err, tc.expectErr, "The specific error type should be correct") + } else { + assert.ErrorContains(t, err, tc.expectErr.Error(), "The error message should contain the expected text") + } + } else { + assert.NoError(t, err, "No error was expected") + } + }) + } + + t.Run("should panic on item of an unexpected type", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + badItem := &typesmocks.MockQueueItemAccessor{ + OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "item-bad-type", "", context.Background()), + } + + expectedPanicMsg := fmt.Sprintf("internal error: item %q of type %T is not a *flowItem", + badItem.OriginalRequestV.ID(), badItem) + + // --- ACT & ASSERT --- + assert.PanicsWithError(t, expectedPanicMsg, func() { + _, _, _ = checkItemExpiry(badItem, time.Now()) + }, "A type mismatch from a queue should cause a panic") + }) +} diff --git a/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead/besthead_test.go b/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead/besthead_test.go index eb3077b6e..4905a2157 100644 --- a/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead/besthead_test.go +++ b/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch/besthead/besthead_test.go @@ -48,19 +48,21 @@ func newTestComparator() *frameworkmocks.MockItemComparator { } } -func newTestBand(queues map[string]framework.FlowQueueAccessor) *frameworkmocks.MockPriorityBandAccessor { +func newTestBand(queues ...framework.FlowQueueAccessor) *frameworkmocks.MockPriorityBandAccessor { flowIDs := make([]string, 0, len(queues)) - for id := range queues { - flowIDs = append(flowIDs, id) + queuesByID := make(map[string]framework.FlowQueueAccessor, len(queues)) + for _, q := range queues { + flowIDs = append(flowIDs, q.FlowSpec().ID) + queuesByID[q.FlowSpec().ID] = q } return &frameworkmocks.MockPriorityBandAccessor{ FlowIDsFunc: func() []string { return flowIDs }, QueueFunc: func(id string) framework.FlowQueueAccessor { - return queues[id] + return queuesByID[id] }, IterateQueuesFunc: func(iterator func(queue framework.FlowQueueAccessor) bool) { for _, id := range flowIDs { - if !iterator(queues[id]) { + if !iterator(queuesByID[id]) { break } } @@ -111,90 +113,88 @@ func TestBestHead_SelectQueue(t *testing.T) { shouldPanic bool }{ { - name: "BasicSelection_TwoQueues", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - flow1: queue1, - flow2: queue2, - }), + name: "BasicSelection_TwoQueues", + band: newTestBand(queue1, queue2), expectedQueueID: flow1, }, { - name: "IgnoresEmptyQueues", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - flow1: queue1, - "flowEmpty": queueEmpty, - flow2: queue2, - }), + name: "IgnoresEmptyQueues", + band: newTestBand(queue1, queueEmpty, queue2), expectedQueueID: flow1, }, { name: "SingleNonEmptyQueue", - band: newTestBand(map[string]framework.FlowQueueAccessor{flow1: queue1}), + band: newTestBand(queue1), expectedQueueID: flow1, }, { name: "ComparatorCompatibility", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - flow1: &frameworkmocks.MockFlowQueueAccessor{ + band: newTestBand( + &frameworkmocks.MockFlowQueueAccessor{ LenV: 1, PeekHeadV: itemBetter, FlowSpecV: types.FlowSpecification{ID: flow1}, ComparatorV: &frameworkmocks.MockItemComparator{ScoreTypeV: "typeA", FuncV: enqueueTimeComparatorFunc}, }, - flow2: &frameworkmocks.MockFlowQueueAccessor{ + &frameworkmocks.MockFlowQueueAccessor{ LenV: 1, PeekHeadV: itemWorse, FlowSpecV: types.FlowSpecification{ID: flow2}, ComparatorV: &frameworkmocks.MockItemComparator{ScoreTypeV: "typeB", FuncV: enqueueTimeComparatorFunc}, }, - }), + ), expectedErr: framework.ErrIncompatiblePriorityType, }, { name: "QueuePeekHeadErrors", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - flow1: &frameworkmocks.MockFlowQueueAccessor{ + band: newTestBand( + &frameworkmocks.MockFlowQueueAccessor{ LenV: 1, PeekHeadErrV: errors.New("peek error"), FlowSpecV: types.FlowSpecification{ID: flow1}, ComparatorV: newTestComparator(), }, - flow2: queue2, - }), + queue2, + ), expectedQueueID: flow2, }, { name: "QueueComparatorIsNil", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - flow1: &frameworkmocks.MockFlowQueueAccessor{ + band: newTestBand( + &frameworkmocks.MockFlowQueueAccessor{ LenV: 1, PeekHeadV: itemBetter, FlowSpecV: types.FlowSpecification{ID: flow1}, ComparatorV: nil, }, - flow2: queue2, - }), + queue2, + ), shouldPanic: true, }, { name: "ComparatorFuncIsNil", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - flow1: &frameworkmocks.MockFlowQueueAccessor{ + band: newTestBand( + &frameworkmocks.MockFlowQueueAccessor{ LenV: 1, PeekHeadV: itemBetter, FlowSpecV: types.FlowSpecification{ID: flow1}, ComparatorV: &frameworkmocks.MockItemComparator{ScoreTypeV: commonScoreType, FuncV: nil}, }, - flow2: queue2, - }), + queue2, + ), shouldPanic: true, }, { name: "AllQueuesEmpty", - band: newTestBand(map[string]framework.FlowQueueAccessor{ - "empty1": queueEmpty, - "empty2": queueEmpty, - }), + band: newTestBand( + queueEmpty, + &frameworkmocks.MockFlowQueueAccessor{ + LenV: 0, + PeekHeadErrV: framework.ErrQueueEmpty, + FlowSpecV: types.FlowSpecification{ID: "flowEmpty2"}, + ComparatorV: newTestComparator(), + }, + ), }, { name: "NilBand",