Skip to content

chore(go): deprecate dev action wrapper #3183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ func runAction(ctx context.Context, def *ToolDefinition, action core.Action, inp
if err != nil {
return nil, fmt.Errorf("error marshalling tool input for %v: %v", def.Name, err)
}
output, err := action.RunJSON(ctx, mi, nil)
result, err := action.RunJSON(ctx, mi, nil, nil)
if err != nil {
return nil, fmt.Errorf("error calling tool %v: %w", def.Name, err)
}

var uo any
err = json.Unmarshal(output, &uo)
err = json.Unmarshal(result.Result, &uo)
if err != nil {
return nil, fmt.Errorf("error parsing tool output for %v: %v", def.Name, err)
}
Expand Down
106 changes: 93 additions & 13 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/firebase/genkit/go/internal/metrics"
"github.com/firebase/genkit/go/internal/registry"
"github.com/invopop/jsonschema"
"go.opentelemetry.io/otel/trace"
)

// Func is an alias for non-streaming functions with input of type In and output of type Out.
Expand All @@ -40,18 +41,65 @@ type StreamingFunc[In, Out, Stream any] = func(context.Context, In, StreamCallba
// StreamCallback is a function that is called during streaming to return the next chunk of the stream.
type StreamCallback[Stream any] = func(context.Context, Stream) error

// ActionResult contains the result of an action along with telemetry information.
type ActionResult[Out any] struct {
Result Out `json:"result"`
Telemetry Telemetry `json:"telemetry"`
}

// Telemetry contains tracing information for an action execution.
type Telemetry struct {
TraceID string `json:"traceId"`
SpanID string `json:"spanId"`
}

// Action is the interface that all Genkit primitives (e.g. flows, models, tools) have in common.
type Action interface {
// Name returns the name of the action.
Name() string
// RunJSON runs the action with the given JSON input and streaming callback and returns the output as JSON.
RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)
// RunJSON runs the action with the given JSON input, streaming callback, and telemetry labels, returning both result and telemetry information.
RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, telemetryLabels map[string]string) (ActionResult[json.RawMessage], error)
// Desc returns a descriptor of the action.
Desc() ActionDesc
// SetTracingState sets the tracing state on the action.
SetTracingState(tstate *tracing.State)
}

// Context helpers for telemetry capture and labels
type telemetryCollector struct {
traceID string
spanID string
}

type telemetryCollectorKey struct{}
type telemetryLabelsKey struct{}

func withTelemetryCapture(ctx context.Context) (context.Context, *telemetryCollector) {
collector := &telemetryCollector{}
return context.WithValue(ctx, telemetryCollectorKey{}, collector), collector
}

func getTelemetryFromContext(ctx context.Context) *telemetryCollector {
if collector, ok := ctx.Value(telemetryCollectorKey{}).(*telemetryCollector); ok {
return collector
}
return nil
}

func withTelemetryLabels(ctx context.Context, labels map[string]string) context.Context {
if labels == nil {
return ctx
}
return context.WithValue(ctx, telemetryLabelsKey{}, labels)
}

func getTelemetryLabelsFromContext(ctx context.Context) map[string]string {
if labels, ok := ctx.Value(telemetryLabelsKey{}).(map[string]string); ok {
return labels
}
return nil
}

// An ActionType is the kind of an action.
type ActionType string

Expand Down Expand Up @@ -97,8 +145,7 @@ type noStream = func(context.Context, struct{}) error
// DefineAction creates a new non-streaming Action and registers it.
func DefineAction[In, Out any](
r *registry.Registry,
provider,
name string,
provider, name string,
atype ActionType,
metadata map[string]any,
fn Func[In, Out],
Expand Down Expand Up @@ -250,6 +297,22 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea

return tracing.RunInNewSpan(ctx, a.tstate, a.desc.Name, "action", false, input,
func(ctx context.Context, input In) (Out, error) {
// Apply telemetry labels if present
if labels := getTelemetryLabelsFromContext(ctx); labels != nil {
for k, v := range labels {
tracing.SetCustomMetadataAttr(ctx, k, v)
}
}

// Capture telemetry if collector present
if collector := getTelemetryFromContext(ctx); collector != nil {
span := trace.SpanFromContext(ctx)
if span.SpanContext().IsValid() {
collector.traceID = span.SpanContext().TraceID().String()
collector.spanID = span.SpanContext().SpanID().String()
}
}

start := time.Now()
var err error
if err = base.ValidateValue(input, a.desc.InputSchema); err != nil {
Expand All @@ -275,16 +338,16 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea
})
}

// RunJSON runs the action with a JSON input, and returns a JSON result.
func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) {
// RunJSON runs the action with a JSON input, and returns a JSON result with telemetry information.
func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, telemetryLabels map[string]string) (ActionResult[json.RawMessage], error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := base.ValidateJSON(input, a.desc.InputSchema); err != nil {
return nil, NewError(INVALID_ARGUMENT, err.Error())
return ActionResult[json.RawMessage]{}, NewError(INVALID_ARGUMENT, err.Error())
}
var in In
if input != nil {
if err := json.Unmarshal(input, &in); err != nil {
return nil, err
return ActionResult[json.RawMessage]{}, err
}
}
var callback func(context.Context, Stream) error
Expand All @@ -297,15 +360,32 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw
return cb(ctx, json.RawMessage(bytes))
}
}
out, err := a.Run(ctx, in, callback)

// Set up telemetry capture and labels
ctx, collector := withTelemetryCapture(ctx)
ctx = withTelemetryLabels(ctx, telemetryLabels)

output, err := a.Run(ctx, in, callback)
if err != nil {
return nil, err
return ActionResult[json.RawMessage]{}, err
}
bytes, err := json.Marshal(out)

// Get captured telemetry
telemetry := Telemetry{
TraceID: collector.traceID,
SpanID: collector.spanID,
}

// Marshal output and return
bytes, err := json.Marshal(output)
if err != nil {
return nil, err
return ActionResult[json.RawMessage]{}, err
}
return json.RawMessage(bytes), nil

return ActionResult[json.RawMessage]{
Result: json.RawMessage(bytes),
Telemetry: telemetry,
}, nil
}

// Desc returns a descriptor of the action.
Expand Down
3 changes: 2 additions & 1 deletion go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ func TestActionRunJSON(t *testing.T) {
a := defineAction(r, "test", "inc", ActionTypeCustom, nil, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.RunJSON(context.Background(), input, nil)
output, err := a.RunJSON(context.Background(), input, nil, nil)
got := output.Result
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ func (f *Flow[In, Out, Stream]) Name() string {
}

// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON.
func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) {
return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb)
func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], telemetryLabels map[string]string) (ActionResult[json.RawMessage], error) {
return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb, telemetryLabels)
}

// Desc returns the descriptor of the flow.
Expand Down
38 changes: 19 additions & 19 deletions go/genkit/reflection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal"
"go.opentelemetry.io/otel/trace"
)

type streamingCallback[Stream any] = func(context.Context, Stream) error
Expand Down Expand Up @@ -436,30 +435,31 @@ func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage
ctx = core.WithActionContext(ctx, runtimeContext)
}

var traceID string
output, err := tracing.RunInNewSpan(ctx, g.reg.TracingState(), "dev-run-action-wrapper", "", true, input, func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) {
tracing.SetCustomMetadataAttr(ctx, "genkit-dev-internal", "true")
// Set telemetry labels from payload to span
if telemetryLabels != nil {
var telemetryAttributes map[string]string
err := json.Unmarshal(telemetryLabels, &telemetryAttributes)
if err != nil {
return nil, core.NewError(core.INTERNAL, "Error unmarshalling telemetryLabels: %v", err)
}
for k, v := range telemetryAttributes {
tracing.SetCustomMetadataAttr(ctx, k, v)
}
// Prepare telemetry labels
telemetryMap := make(map[string]string)
telemetryMap["genkit-dev-internal"] = "true"

// Add custom telemetry labels from the request
if telemetryLabels != nil {
var telemetryAttributes map[string]string
err := json.Unmarshal(telemetryLabels, &telemetryAttributes)
if err != nil {
return nil, core.NewError(core.INTERNAL, "Error unmarshalling telemetryLabels: %v", err)
}
traceID = trace.SpanContextFromContext(ctx).TraceID().String()
return action.(core.Action).RunJSON(ctx, input, cb)
})
for k, v := range telemetryAttributes {
telemetryMap[k] = v
}
}

// Call action directly with telemetry labels - no wrapper span needed!
result, err := action.(core.Action).RunJSON(ctx, input, cb, telemetryMap)
if err != nil {
return nil, err
}

return &runActionResponse{
Result: output,
Telemetry: telemetry{TraceID: traceID},
Result: result.Result,
Telemetry: telemetry{TraceID: result.Telemetry.TraceID},
}, nil
}

Expand Down
6 changes: 3 additions & 3 deletions go/genkit/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func handler(a core.Action, params *handlerParams) func(http.ResponseWriter, *ht
}
}

out, err := a.RunJSON(ctx, body.Data, callback)
out, err := a.RunJSON(ctx, body.Data, callback, nil) // No telemetry labels for production HTTP handlers
if err != nil {
if stream {
_, err = fmt.Fprintf(w, "data: {\"error\": {\"status\": \"INTERNAL\", \"message\": \"stream flow error\", \"details\": \"%v\"}}\n\n", err)
Expand All @@ -178,11 +178,11 @@ func handler(a core.Action, params *handlerParams) func(http.ResponseWriter, *ht
return err
}
if stream {
_, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out)
_, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out.Result)
return err
}

_, err = fmt.Fprintf(w, "{\"result\": %s}\n", out)
_, err = fmt.Fprintf(w, "{\"result\": %s}\n", out.Result)
return err
}
}
Expand Down
Loading