Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions common/rpc/interceptor/business_id_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"strings"

commonpb "go.temporal.io/api/common/v1"
deploymentpb "go.temporal.io/api/deployment/v1"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
updatepb "go.temporal.io/api/update/v1"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/server/common/api"
"go.temporal.io/server/common/namespace"
Expand Down Expand Up @@ -57,6 +60,34 @@ type (
taskTokenGetter interface {
GetTaskToken() []byte
}

taskQueueNameGetter interface {
GetTaskQueue() string
}

taskQueueNameFromMessageGetter interface {
GetTaskQueue() *taskqueuepb.TaskQueue
}

deploymentNameGetter interface {
GetDeploymentName() string
}

deploymentVersionGetter interface {
GetDeploymentVersion() *deploymentpb.WorkerDeploymentVersion
}

pollerGroupIDGetter interface {
GetPollerGroupId() string
}

namespaceGetter interface {
GetNamespace() string
}

updateRefGetter interface {
GetUpdateRef() *updatepb.UpdateRef
}
)

// Extract extracts business ID from the request using the specified pattern.
Expand Down Expand Up @@ -99,6 +130,45 @@ func (e BusinessIDExtractor) Extract(req any, pattern BusinessIDPattern) string
case PatternMultiOperation:
return e.extractMultiOperation(req)

case PatternTaskQueueName:
if getter, ok := req.(taskQueueNameGetter); ok {
return getter.GetTaskQueue()
}

case PatternTaskQueueNameFromMessage:
if getter, ok := req.(taskQueueNameFromMessageGetter); ok {
if tq := getter.GetTaskQueue(); tq != nil {
return tq.GetName()
}
}

case PatternDeploymentName:
if getter, ok := req.(deploymentNameGetter); ok {
return getter.GetDeploymentName()
}

case PatternDeploymentVersion:
if getter, ok := req.(deploymentVersionGetter); ok {
if dv := getter.GetDeploymentVersion(); dv != nil {
return dv.GetDeploymentName()
}
}

case PatternPollerGroupID:
if getter, ok := req.(pollerGroupIDGetter); ok {
return getter.GetPollerGroupId()
}

case PatternNamespace:
if getter, ok := req.(namespaceGetter); ok {
return getter.GetNamespace()
}

case PatternUpdateRef:
if getter, ok := req.(updateRefGetter); ok {
return getter.GetUpdateRef().GetWorkflowExecution().GetWorkflowId()
}

case PatternNone:
// No extraction needed

Expand Down
48 changes: 48 additions & 0 deletions common/rpc/interceptor/business_id_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ const (
PatternTaskToken
// PatternMultiOperation indicates extraction from ExecuteMultiOperationRequest
PatternMultiOperation
// PatternTaskQueueName indicates extraction via GetTaskQueue() string method
PatternTaskQueueName
// PatternTaskQueueNameFromMessage indicates extraction via GetTaskQueue().GetName() (TaskQueue message)
PatternTaskQueueNameFromMessage
// PatternDeploymentName indicates extraction via GetDeploymentName() method
PatternDeploymentName
// PatternDeploymentVersion indicates extraction via GetDeploymentVersion().GetDeploymentName()
PatternDeploymentVersion
// PatternPollerGroupID indicates extraction via GetPollerGroupId() directly
PatternPollerGroupID
// PatternNamespace indicates extraction via GetNamespace() - used when any cell for the namespace is acceptable
PatternNamespace
// PatternUpdateRef indicates extraction via GetUpdateRef().GetWorkflowExecution().GetWorkflowId()
PatternUpdateRef
)

// methodToPattern maps API method names to their expected business ID extraction pattern.
Expand Down Expand Up @@ -88,6 +102,40 @@ var methodToPattern = map[string]BusinessIDPattern{

// Pattern: ExecuteMultiOperation special handling
"ExecuteMultiOperation": PatternMultiOperation,

// task queue name
"UpdateTaskQueueConfig": PatternTaskQueueName,

// task queue name (from TaskQueue message)
"ListTaskQueuePartitions": PatternTaskQueueNameFromMessage,

// deployment name
"DescribeWorkerDeployment": PatternDeploymentName,
"DeleteWorkerDeployment": PatternDeploymentName,
"SetWorkerDeploymentCurrentVersion": PatternDeploymentName,
"SetWorkerDeploymentManager": PatternDeploymentName,
"SetWorkerDeploymentRampingVersion": PatternDeploymentName,

// deployment name (from WorkerDeploymentVersion message)
"DescribeWorkerDeploymentVersion": PatternDeploymentVersion,
"DeleteWorkerDeploymentVersion": PatternDeploymentVersion,
"UpdateWorkerDeploymentVersionMetadata": PatternDeploymentVersion,

// namespace (deterministic routing to any cell for the namespace)
// TODO(mcn): Switch to worker_grouping_key when available for load balancing
"FetchWorkerConfig": PatternNamespace,
"UpdateWorkerConfig": PatternNamespace,
"DescribeWorker": PatternNamespace,
"RecordWorkerHeartbeat": PatternNamespace,

// workflow ID (from UpdateRef)
"PollWorkflowExecutionUpdate": PatternUpdateRef,

// TODO(mcn): Uncomment when poller_group_id field is added to requests
// "PollWorkflowTaskQueue": PatternPollerGroupID,
// "PollActivityTaskQueue": PatternPollerGroupID,
// "PollNexusTaskQueue": PatternPollerGroupID,

}

// NewBusinessIDInterceptor creates a new BusinessIDInterceptor with the given extractor functions.
Expand Down
118 changes: 118 additions & 0 deletions common/rpc/interceptor/business_id_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (

"github.com/stretchr/testify/require"
commonpb "go.temporal.io/api/common/v1"
deploymentpb "go.temporal.io/api/deployment/v1"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
updatepb "go.temporal.io/api/update/v1"
"go.temporal.io/api/workflowservice/v1"
tokenspb "go.temporal.io/server/api/token/v1"
Expand Down Expand Up @@ -215,6 +217,96 @@ func TestBusinessIDInterceptor_AllMethods(t *testing.T) {
},
expectedBusinessID: "wf-id",
},

// task queue name
{
methodName: "UpdateTaskQueueConfig",
request: &workflowservice.UpdateTaskQueueConfigRequest{TaskQueue: "test-task-queue"},
expectedBusinessID: "test-task-queue",
},

// task queue name (from TaskQueue message)
{
methodName: "ListTaskQueuePartitions",
request: &workflowservice.ListTaskQueuePartitionsRequest{TaskQueue: &taskqueuepb.TaskQueue{Name: "test-task-queue"}},
expectedBusinessID: "test-task-queue",
},

// deployment name
{
methodName: "DescribeWorkerDeployment",
request: &workflowservice.DescribeWorkerDeploymentRequest{DeploymentName: "test-deployment"},
expectedBusinessID: "test-deployment",
},
{
methodName: "DeleteWorkerDeployment",
request: &workflowservice.DeleteWorkerDeploymentRequest{DeploymentName: "test-deployment"},
expectedBusinessID: "test-deployment",
},
{
methodName: "SetWorkerDeploymentCurrentVersion",
request: &workflowservice.SetWorkerDeploymentCurrentVersionRequest{DeploymentName: "test-deployment"},
expectedBusinessID: "test-deployment",
},
{
methodName: "SetWorkerDeploymentManager",
request: &workflowservice.SetWorkerDeploymentManagerRequest{DeploymentName: "test-deployment"},
expectedBusinessID: "test-deployment",
},
{
methodName: "SetWorkerDeploymentRampingVersion",
request: &workflowservice.SetWorkerDeploymentRampingVersionRequest{DeploymentName: "test-deployment"},
expectedBusinessID: "test-deployment",
},

// deployment name (from WorkerDeploymentVersion message)
{
methodName: "DescribeWorkerDeploymentVersion",
request: &workflowservice.DescribeWorkerDeploymentVersionRequest{DeploymentVersion: &deploymentpb.WorkerDeploymentVersion{DeploymentName: "test-deployment"}},
expectedBusinessID: "test-deployment",
},
{
methodName: "DeleteWorkerDeploymentVersion",
request: &workflowservice.DeleteWorkerDeploymentVersionRequest{DeploymentVersion: &deploymentpb.WorkerDeploymentVersion{DeploymentName: "test-deployment"}},
expectedBusinessID: "test-deployment",
},
{
methodName: "UpdateWorkerDeploymentVersionMetadata",
request: &workflowservice.UpdateWorkerDeploymentVersionMetadataRequest{DeploymentVersion: &deploymentpb.WorkerDeploymentVersion{DeploymentName: "test-deployment"}},
expectedBusinessID: "test-deployment",
},

// namespace
{
methodName: "FetchWorkerConfig",
request: &workflowservice.FetchWorkerConfigRequest{Namespace: "test-namespace"},
expectedBusinessID: "test-namespace",
},
{
methodName: "UpdateWorkerConfig",
request: &workflowservice.UpdateWorkerConfigRequest{Namespace: "test-namespace"},
expectedBusinessID: "test-namespace",
},
{
methodName: "DescribeWorker",
request: &workflowservice.DescribeWorkerRequest{Namespace: "test-namespace"},
expectedBusinessID: "test-namespace",
},
{
methodName: "RecordWorkerHeartbeat",
request: &workflowservice.RecordWorkerHeartbeatRequest{Namespace: "test-namespace"},
expectedBusinessID: "test-namespace",
},
// workflow ID (from UpdateRef)
{
methodName: "PollWorkflowExecutionUpdate",
request: &workflowservice.PollWorkflowExecutionUpdateRequest{
UpdateRef: &updatepb.UpdateRef{
WorkflowExecution: &commonpb.WorkflowExecution{WorkflowId: "test-workflow-id"},
},
},
expectedBusinessID: "test-workflow-id",
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -567,6 +659,32 @@ func TestMethodToPatternMapping(t *testing.T) {

// PatternMultiOperation
"ExecuteMultiOperation": PatternMultiOperation,

// PatternTaskQueueName
"UpdateTaskQueueConfig": PatternTaskQueueName,

// PatternTaskQueueNameFromMessage
"ListTaskQueuePartitions": PatternTaskQueueNameFromMessage,

// PatternDeploymentName
"DescribeWorkerDeployment": PatternDeploymentName,
"DeleteWorkerDeployment": PatternDeploymentName,
"SetWorkerDeploymentCurrentVersion": PatternDeploymentName,
"SetWorkerDeploymentManager": PatternDeploymentName,
"SetWorkerDeploymentRampingVersion": PatternDeploymentName,

// PatternDeploymentVersion
"DescribeWorkerDeploymentVersion": PatternDeploymentVersion,
"DeleteWorkerDeploymentVersion": PatternDeploymentVersion,
"UpdateWorkerDeploymentVersionMetadata": PatternDeploymentVersion,

// PatternNamespace
"FetchWorkerConfig": PatternNamespace,
"UpdateWorkerConfig": PatternNamespace,
"DescribeWorker": PatternNamespace,
"RecordWorkerHeartbeat": PatternNamespace,

"PollWorkflowExecutionUpdate": PatternUpdateRef,
}

require.Equal(t, expectedMappings, methodToPattern)
Expand Down
52 changes: 52 additions & 0 deletions common/taskqueue/stats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package taskqueue

import (
taskqueuepb "go.temporal.io/api/taskqueue/v1"
"google.golang.org/protobuf/types/known/durationpb"
)

// MergeStats merges from into into. Mutates into.
func MergeStats(into, from *taskqueuepb.TaskQueueStats) {
if from == nil {
return
}
into.ApproximateBacklogCount += from.ApproximateBacklogCount
into.ApproximateBacklogAge = oldestBacklogAge(into.ApproximateBacklogAge, from.ApproximateBacklogAge)
into.TasksAddRate += from.TasksAddRate
into.TasksDispatchRate += from.TasksDispatchRate
}

// AggregateStats merges all stats from the map into a single TaskQueueStats.
func AggregateStats(stats map[int32]*taskqueuepb.TaskQueueStats) *taskqueuepb.TaskQueueStats {
result := &taskqueuepb.TaskQueueStats{ApproximateBacklogAge: durationpb.New(0)}
for _, s := range stats {
MergeStats(result, s)
}
return result
}

// DedupPollers removes duplicate pollers by identity.
func DedupPollers(pollerInfos []*taskqueuepb.PollerInfo) []*taskqueuepb.PollerInfo {
allKeys := make(map[string]bool)
var list []*taskqueuepb.PollerInfo
for _, item := range pollerInfos {
if _, value := allKeys[item.GetIdentity()]; !value {
allKeys[item.GetIdentity()] = true
list = append(list, item)
}
}
return list
}

func oldestBacklogAge(left, right *durationpb.Duration) *durationpb.Duration {
if left == nil {
left = durationpb.New(0)
}
if right == nil {
right = durationpb.New(0)
}
if left.AsDuration() > right.AsDuration() {
return left
}
return right
}
52 changes: 52 additions & 0 deletions common/taskqueue/stats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package taskqueue

import (
"testing"
"time"

"github.com/stretchr/testify/require"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
"google.golang.org/protobuf/types/known/durationpb"
)

func TestMergeStats(t *testing.T) {
into := &taskqueuepb.TaskQueueStats{
ApproximateBacklogCount: 10,
ApproximateBacklogAge: durationpb.New(100 * time.Second),
TasksAddRate: 5,
TasksDispatchRate: 3,
}
from := &taskqueuepb.TaskQueueStats{
ApproximateBacklogCount: 20,
ApproximateBacklogAge: durationpb.New(50 * time.Second),
TasksAddRate: 2,
TasksDispatchRate: 1,
}

MergeStats(into, from)

require.Equal(t, int64(30), into.ApproximateBacklogCount)
require.Equal(t, 100*time.Second, into.ApproximateBacklogAge.AsDuration())
require.InDelta(t, 7, into.TasksAddRate, 1e-9)
require.InDelta(t, 4, into.TasksDispatchRate, 1e-9)
}

func TestDedupPollers(t *testing.T) {
pollers := []*taskqueuepb.PollerInfo{
{Identity: "worker-1"},
{Identity: "worker-2"},
{Identity: "worker-1"},
{Identity: "worker-3"},
}

result := DedupPollers(pollers)

require.Len(t, result, 3)
idents := make(map[string]bool)
for _, p := range result {
idents[p.GetIdentity()] = true
}
require.True(t, idents["worker-1"])
require.True(t, idents["worker-2"])
require.True(t, idents["worker-3"])
}
Loading
Loading