Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions pkg/epp/flowcontrol/registry/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,10 @@ var _ contracts.ActiveFlowConnection = &connection{}

// Shards returns a stable snapshot of accessors for all internal state shards.
func (c *connection) ActiveShards() []contracts.RegistryShard {
c.registry.mu.RLock()
defer c.registry.mu.RUnlock()

// Return a copy to ensure the caller cannot modify the registry's internal slice.
shardsCopy := make([]contracts.RegistryShard, len(c.registry.activeShards))
for i, s := range c.registry.activeShards {
shardsCopy[i] = s
}
shardsCopy := make([]contracts.RegistryShard, 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to remove the concept of a shard in a later PR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there will be no "data-parallel" concept anymore. See the attached issue for more context.

For these PRs though, as long as the FC layer has 0 regressions between revisions, I want to get these in even if there are some minor stylistic/semantic improvements that could be made. This code should look significantly different after these are all in, so I think it is most expedient to focus on polish at the end of the refactoring effort rather than at each intermediary step.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's totally fine with me, just making sure we are all pointed in the same direction

shardsCopy[0] = c.registry.shard

return shardsCopy
}

Expand Down
169 changes: 25 additions & 144 deletions pkg/epp/flowcontrol/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@ limitations under the License.
package registry

import (
"cmp"
"context"
"fmt"
"slices"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -105,11 +103,8 @@ type FlowRegistry struct {

// --- Administrative state (protected by `mu`) ---

mu sync.RWMutex
activeShards []*registryShard
drainingShards map[string]*registryShard
allShards []*registryShard // Cached, sorted combination of Active and Draining shards
nextShardID uint64
mu sync.RWMutex
shard *registryShard
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is currently read in ActiveShards() and ShardStats() without acquiring fr.mu.RLock(). With dynamic sharding removed, fr.shard is initialized once and never mutated, making these lock-free reads completely safe from data races.

Could we move the shard *registryShard field of the "Administrative state (protected by mu)" block and up to the "Immutable dependencies" block?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this in the next PR

}

var _ contracts.FlowRegistry = &FlowRegistry{}
Expand All @@ -131,10 +126,8 @@ func withClock(clk clock.WithTickerAndDelayedExecution) RegistryOption {
func NewFlowRegistry(config *Config, logger logr.Logger, opts ...RegistryOption) (*FlowRegistry, error) {
cfg := config.Clone()
fr := &FlowRegistry{
config: cfg,
logger: logger.WithName("flow-registry"),
activeShards: []*registryShard{},
drainingShards: make(map[string]*registryShard),
config: cfg,
logger: logger.WithName("flow-registry"),
}

for _, opt := range opts {
Expand All @@ -148,8 +141,8 @@ func NewFlowRegistry(config *Config, logger logr.Logger, opts ...RegistryOption)
fr.perPriorityBandStats.Store(prio, &bandStats{})
}

if err := fr.updateShardCount(cfg.InitialShardCount); err != nil {
return nil, fmt.Errorf("failed to initialize shards: %w", err)
if err := fr.createShard(); err != nil {
return nil, fmt.Errorf("failed to initialize shard: %w", err)
}
fr.logger.V(logging.DEFAULT).Info("FlowRegistry initialized successfully")
return fr, nil
Expand Down Expand Up @@ -261,14 +254,12 @@ func (fr *FlowRegistry) ensureFlowInfrastructure(key flowcontrol.FlowKey) error
fr.mu.RLock()
defer fr.mu.RUnlock()

components, err := fr.buildFlowComponents(key, len(fr.allShards))
components, err := fr.buildFlowComponents(key, 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider updating buildFlowComponents to drop the numInstances arg and just return a single (flowComponents, error) tuple. Fine if we want to defer this to a different PR though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do this in the next PR

if err != nil {
return err
}

for i, shard := range fr.allShards {
shard.synchronizeFlow(key, components[i].policy, components[i].queue)
}
fr.shard.synchronizeFlow(key, components[0].policy, components[0].queue)

fr.logger.V(logging.DEBUG).Info("JIT provisioned flow infrastructure", "flowKey", key)
return nil
Expand Down Expand Up @@ -298,9 +289,7 @@ func (fr *FlowRegistry) ensurePriorityBand(priority int) error {

fr.repartitionShardConfigsLocked()

for _, shard := range fr.activeShards {
shard.addPriorityBand(priority)
}
fr.shard.addPriorityBand(priority)

return nil
}
Expand Down Expand Up @@ -344,14 +333,8 @@ func (fr *FlowRegistry) Stats() contracts.AggregateStats {

// ShardStats returns a slice of statistics, one for each internal shard.
func (fr *FlowRegistry) ShardStats() []contracts.ShardStats {
fr.mu.RLock()
allShards := fr.allShards
fr.mu.RUnlock()

shardStats := make([]contracts.ShardStats, len(allShards))
for i, s := range allShards {
shardStats[i] = s.Stats()
}
shardStats := make([]contracts.ShardStats, 1)
shardStats[0] = fr.shard.Stats()
return shardStats
}

Expand All @@ -362,7 +345,6 @@ func (fr *FlowRegistry) executeGCCycle() {
fr.logger.V(logging.DEBUG).Info("Starting periodic GC scan")
fr.gcFlows()
fr.gcPriorityBands()
fr.sweepDrainingShards()
}

// gcFlows removes idle flows.
Expand Down Expand Up @@ -398,9 +380,7 @@ func (fr *FlowRegistry) cleanupFlowResources(keys []flowcontrol.FlowKey) {
if _, exists := fr.flowStates.Load(key); exists {
continue // 'Zombie' flow
}
for _, shard := range fr.allShards {
shard.deleteFlow(key)
}
fr.shard.deleteFlow(key)
}
}

Expand Down Expand Up @@ -440,78 +420,24 @@ func (fr *FlowRegistry) cleanupPriorityBandResources(priorities []int) {
// Delete from stats tracking
fr.perPriorityBandStats.Delete(priority)

// Delete from all shards (both active and draining)
for _, shard := range fr.allShards {
shard.deletePriorityBand(priority)
}
// Delete from the shard
fr.shard.deletePriorityBand(priority)

fr.logger.Info("Successfully deleted priority band", "priority", priority)
}
}

// sweepDrainingShards finalizes the removal of drained shards.
func (fr *FlowRegistry) sweepDrainingShards() {
// Acquire a full write lock on the registry as we may be modifying the shard topology.
fr.mu.Lock()
defer fr.mu.Unlock()

var shardsToDelete []string
for id, shard := range fr.drainingShards {
// A Draining shard is ready for GC once it is completely empty.
// Draining shards do not accept new work (enforced at `managedQueue.Add`), so `shard.totalLen.Load()` can only
// monotonically decrease.
if shard.totalLen.Load() == 0 {
shardsToDelete = append(shardsToDelete, id)
}
}

if len(shardsToDelete) > 0 {
fr.logger.V(logging.DEBUG).Info("Garbage collecting drained shards", "shardIDs", shardsToDelete)
for _, id := range shardsToDelete {
delete(fr.drainingShards, id)
}
fr.updateAllShardsCacheLocked()
}
}

// --- Shard Management (Scaling) ---

// updateShardCount dynamically adjusts the number of internal state shards.
func (fr *FlowRegistry) updateShardCount(n int) error {
if n <= 0 {
return fmt.Errorf("%w: shard count must be a positive integer, but got %d", contracts.ErrInvalidShardCount, n)
}

// createShard creates the shard.
func (fr *FlowRegistry) createShard() error {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire block of code in here that iterates over fr.flowStates.Range(...) to build allComponents and synchronizeFlow is actually dead code.

Because createShard() is exclusively called during NewFlowRegistry() before the EPP accepts any connections, fr.flowStates is guaranteed to be empty.

You can simplify this initialization method to just:

func (fr *FlowRegistry) createShard() error {
	fr.mu.Lock()
	defer fr.mu.Unlock()
	partitionedConfig := fr.config.partition(0, 1)
	fr.shard = newShard("shard-0", partitionedConfig, fr.logger, fr.propagateStatsDelta)
	return nil
}

// Use a full write lock as this is a major structural change to the shard topology.
fr.mu.Lock()
defer fr.mu.Unlock()

currentActiveShards := len(fr.activeShards)
if n == currentActiveShards {
return nil
}

if n > currentActiveShards {
return fr.executeScaleUpLocked(n)
}
fr.executeScaleDownLocked(n)
return nil
}

// executeScaleUpLocked handles adding new shards.
// It pre-provisions all existing active flows onto the new shards to ensure continuity.
func (fr *FlowRegistry) executeScaleUpLocked(newTotalActive int) error {
currentActive := len(fr.activeShards)
numToAdd := newTotalActive - currentActive
fr.logger.Info("Scaling up shards", "currentActive", currentActive, "newTotalActive", newTotalActive)

// Prepare All New Shard Objects (Infallible):
newShards := make([]*registryShard, numToAdd)
for i := range numToAdd {
shardID := fmt.Sprintf("shard-%04d", fr.nextShardID+uint64(i))
partitionedConfig := fr.config.partition(currentActive+i, newTotalActive)
newShards[i] = newShard(shardID, partitionedConfig, fr.logger, fr.propagateStatsDelta)
}
// Prepare Shard Object (Infallible)
partitionedConfig := fr.config.partition(0, 1)
shard := newShard("shard-0", partitionedConfig, fr.logger, fr.propagateStatsDelta)

// Prepare All Components for All New Shards (Fallible):
// Pre-build every component for every existing flow on every new shard.
Expand All @@ -521,7 +447,7 @@ func (fr *FlowRegistry) executeScaleUpLocked(newTotalActive int) error {
var rangeErr error
fr.flowStates.Range(func(key, _ interface{}) bool {
flowKey := key.(flowcontrol.FlowKey)
components, err := fr.buildFlowComponents(flowKey, len(newShards))
components, err := fr.buildFlowComponents(flowKey, 1)
if err != nil {
rangeErr = fmt.Errorf("failed to prepare components for flow %s on new shards: %w", flowKey, err)
return false
Expand All @@ -534,43 +460,19 @@ func (fr *FlowRegistry) executeScaleUpLocked(newTotalActive int) error {
}

// Commit (Infallible):
for i, shard := range newShards {
for key, components := range allComponents {
shard.synchronizeFlow(key, components[i].policy, components[i].queue)
}
for key, components := range allComponents {
shard.synchronizeFlow(key, components[0].policy, components[0].queue)
}
fr.activeShards = append(fr.activeShards, newShards...)
fr.nextShardID += uint64(numToAdd)
fr.shard = shard
fr.repartitionShardConfigsLocked()
fr.updateAllShardsCacheLocked()
return nil
}

// executeScaleDownLocked handles marking shards for graceful draining.
// Expects the registry's write lock to be held.
func (fr *FlowRegistry) executeScaleDownLocked(newTotalActive int) {
currentActive := len(fr.activeShards)
fr.logger.Info("Scaling down shards", "currentActive", currentActive, "newTotalActive", newTotalActive)

shardsToDrain := fr.activeShards[newTotalActive:]
fr.activeShards = fr.activeShards[:newTotalActive]
for _, shard := range shardsToDrain {
fr.drainingShards[shard.id] = shard
shard.markAsDraining()
}

fr.repartitionShardConfigsLocked()
fr.updateAllShardsCacheLocked()
}

// repartitionShardConfigsLocked updates the configuration for all active shards.
// Expects the registry's write lock to be held.
func (fr *FlowRegistry) repartitionShardConfigsLocked() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for other reviewers... This looks weird considering a single-shard view, but we must preserve it for now. When ensurePriorityBand dynamically creates a new band, this partition path acts as a deep-copy mechanism to push the mutated registry config down to the isolated shard state.

In a follow-up PR (if/when we eliminate the boundary between registryShard and FlowRegistry entirely), we can have everything reference a single unified Config, allowing us to drop this path entirely.

numActive := len(fr.activeShards)
for i, shard := range fr.activeShards {
newPartitionedConfig := fr.config.partition(i, numActive)
shard.updateConfig(newPartitionedConfig)
}
newPartitionedConfig := fr.config.partition(0, 1)
fr.shard.updateConfig(newPartitionedConfig)
}

// --- Internal Helpers ---
Expand Down Expand Up @@ -601,27 +503,6 @@ func (fr *FlowRegistry) buildFlowComponents(key flowcontrol.FlowKey, numInstance
return allComponents, nil
}

// updateAllShardsCacheLocked recalculates the cached `allShards` slice.
// It ensures the slice is sorted by shard ID to maintain a deterministic order.
// Expects the registry's write lock to be held.
func (fr *FlowRegistry) updateAllShardsCacheLocked() {
allShards := make([]*registryShard, 0, len(fr.activeShards)+len(fr.drainingShards))
allShards = append(allShards, fr.activeShards...)
for _, shard := range fr.drainingShards {
allShards = append(allShards, shard)
}

// Sort the combined slice by shard ID.
// This provides a stable, deterministic order for all consumers of the shard list, which is critical because map
// iteration for `drainingShards` is non-deterministic.
// While this is a lexicographical sort, our shard ID format is padded with leading zeros (e.g., "shard-0001"),
// ensuring that the string sort produces the same result as a natural numerical sort.
slices.SortFunc(allShards, func(a, b *registryShard) int {
return cmp.Compare(a.id, b.id)
})
fr.allShards = allShards
}

// propagateStatsDelta is the top-level, lock-free aggregator for all statistics.
func (fr *FlowRegistry) propagateStatsDelta(priority int, lenDelta, byteSizeDelta int64) {
val, _ := fr.perPriorityBandStats.Load(priority)
Expand Down
Loading
Loading