diff --git a/examples/postgres-query/go.mod b/examples/postgres-query/go.mod index e5c7e8ea..f1142a07 100644 --- a/examples/postgres-query/go.mod +++ b/examples/postgres-query/go.mod @@ -6,7 +6,9 @@ toolchain go1.24.2 replace github.com/hypertrace/goagent => ../.. -replace github.com/hypertrace/goagent/instrumentation/hypertrace/github.com/jackc/hyperpgx => ../../instrumentation/opentelemetry/github.com/jackc/hyperpgx +replace github.com/hypertrace/goagent/instrumentation/hypertrace/github.com/jackc/hyperpgx => ../../instrumentation/hypertrace/github.com/jackc/hyperpgx + +replace github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx => ../../instrumentation/opentelemetry/github.com/jackc/hyperpgx require github.com/hypertrace/goagent/instrumentation/hypertrace/github.com/jackc/hyperpgx v0.0.0-00010101000000-000000000000 @@ -21,6 +23,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect github.com/hypertrace/agent-config/gen/go v0.0.0-20240523214336-1259231da906 // indirect github.com/hypertrace/goagent v0.0.0-00010101000000-000000000000 // indirect + github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx v0.0.0-00010101000000-000000000000 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.8.1 // indirect github.com/jackc/pgio v1.0.0 // indirect diff --git a/instrumentation/hypertrace/database/hypersql/README.md b/instrumentation/hypertrace/database/hypersql/README.md index c456771a..66d3d61b 100644 --- a/instrumentation/hypertrace/database/hypersql/README.md +++ b/instrumentation/hypertrace/database/hypersql/README.md @@ -39,3 +39,42 @@ sql.Register("ht-mysql", driver) // Connect to a MySQL database using the hypersql driver wrapper db, err = sql.Open("ht-mysql", "user:password@/dbname") ``` + +For adding a filter implementation to the instrumentation, there's an option to use hypersql.WithFilter to add filters to the instrumentation. +```go + +import ( + "github.com/go-sql-driver/mysql" + "github.com/hypertrace/goagent/instrumentation/hypertrace/database/hypersql" + "github.com/hypertrace/goagent/sdk/filter" +) + +// Explicitly wrap the MySQL driver with hypersql +driver := hypersql.Wrap(&mysql.MySQLDriver{}, hypersql.WithFilter(filter.NoopFilter{})) + +// Register our hypersql wrapper as a database driver +sql.Register("ht-mysql", driver) + +// Connect to a MySQL database using the hypersql driver wrapper +db, err = sql.Open("ht-mysql", "user:password@/dbname") +``` + +OR + +```go +import ( + "database/sql" + "github.com/hypertrace/goagent/instrumentation/hypertrace/database/hypersql" + "github.com/hypertrace/goagent/sdk/filter" +) + +// Register our hypersql wrapper for the provided MySQL driver. +driverName, err = hypersql.Register("mysql", hypersql.WithFilter(filter.NoopFilter{})) +if err != nil { + log.Fatalf("unable to register goagent driver: %v\n", err) +} + +// Connect to a MySQL database using the hypersql driver wrapper. +db, err = sql.Open(driverName, "user:password@/dbname") + +``` diff --git a/instrumentation/hypertrace/database/hypersql/options.go b/instrumentation/hypertrace/database/hypersql/options.go new file mode 100644 index 00000000..2d986307 --- /dev/null +++ b/instrumentation/hypertrace/database/hypersql/options.go @@ -0,0 +1,24 @@ +package hypersql // import "github.com/hypertrace/goagent/instrumentation/hypertrace/database/hypersql" + +import ( + "github.com/hypertrace/goagent/sdk/filter" + sdkSQL "github.com/hypertrace/goagent/sdk/instrumentation/database/sql" +) + +type options struct { + Filter filter.Filter +} + +func (o *options) toSDKOptions() *sdkSQL.Options { + opts := (sdkSQL.Options)(*o) + return &opts +} + +type Option func(o *options) + +// WithFilter adds a filter to the GRPC option. +func WithFilter(f filter.Filter) Option { + return func(o *options) { + o.Filter = f + } +} diff --git a/instrumentation/hypertrace/database/hypersql/options_test.go b/instrumentation/hypertrace/database/hypersql/options_test.go new file mode 100644 index 00000000..b2e60183 --- /dev/null +++ b/instrumentation/hypertrace/database/hypersql/options_test.go @@ -0,0 +1,15 @@ +package hypersql // import "github.com/hypertrace/goagent/instrumentation/hypertrace/database/hypersql" + +import ( + "testing" + + "github.com/hypertrace/goagent/sdk/filter" + "github.com/stretchr/testify/assert" +) + +func TestOptionsToSDK(t *testing.T) { + o := &options{ + Filter: filter.NoopFilter{}, + } + assert.Equal(t, filter.NoopFilter{}, o.toSDKOptions().Filter) +} diff --git a/instrumentation/hypertrace/database/hypersql/sql.go b/instrumentation/hypertrace/database/hypersql/sql.go index 8bdbd75b..6a0102bb 100644 --- a/instrumentation/hypertrace/database/hypersql/sql.go +++ b/instrumentation/hypertrace/database/hypersql/sql.go @@ -1,13 +1,28 @@ package hypersql // import "github.com/hypertrace/goagent/instrumentation/hypertrace/database/hypersql" import ( + "database/sql/driver" otelsql "github.com/hypertrace/goagent/instrumentation/opentelemetry/database/hypersql" ) // Wrap takes a SQL driver and wraps it with Hypertrace instrumentation. -var Wrap = otelsql.Wrap +func Wrap(d driver.Driver, opts ...Option) driver.Driver { + o := &options{} + for _, opt := range opts { + opt(o) + } + + return otelsql.Wrap(d, o.toSDKOptions()) +} // Register initializes and registers the hypersql wrapped database driver // identified by its driverName. On success it // returns the generated driverName to use when calling sql.Open. -var Register = otelsql.Register +func Register(driverName string, opts ...Option) (string, error) { + o := &options{} + for _, opt := range opts { + opt(o) + } + return otelsql.Register(driverName, o.toSDKOptions()) + +} diff --git a/instrumentation/hypertrace/github.com/jackc/hyperpgx/go.mod b/instrumentation/hypertrace/github.com/jackc/hyperpgx/go.mod index a7131182..323b3d2d 100644 --- a/instrumentation/hypertrace/github.com/jackc/hyperpgx/go.mod +++ b/instrumentation/hypertrace/github.com/jackc/hyperpgx/go.mod @@ -8,10 +8,15 @@ replace github.com/hypertrace/goagent => ../../../../.. replace github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx => ../../../../../instrumentation/opentelemetry/github.com/jackc/hyperpgx -require github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx v0.0.0-00010101000000-000000000000 +require ( + github.com/hypertrace/goagent v0.0.0-00010101000000-000000000000 + github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.10.0 +) require ( github.com/cenkalti/backoff/v5 v5.0.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-logr/logr v1.4.3 // indirect @@ -20,7 +25,6 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect github.com/hypertrace/agent-config/gen/go v0.0.0-20240523214336-1259231da906 // indirect - github.com/hypertrace/goagent v0.0.0-00010101000000-000000000000 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.8.1 // indirect github.com/jackc/pgio v1.0.0 // indirect @@ -30,6 +34,7 @@ require ( github.com/jackc/pgtype v1.7.0 // indirect github.com/jackc/pgx/v4 v4.11.0 // indirect github.com/openzipkin/zipkin-go v0.4.3 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect @@ -59,4 +64,5 @@ require ( google.golang.org/grpc v1.72.2 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/instrumentation/hypertrace/github.com/jackc/hyperpgx/options.go b/instrumentation/hypertrace/github.com/jackc/hyperpgx/options.go new file mode 100644 index 00000000..2b1ae380 --- /dev/null +++ b/instrumentation/hypertrace/github.com/jackc/hyperpgx/options.go @@ -0,0 +1,24 @@ +package hyperpgx // import "github.com/hypertrace/goagent/instrumentation/hypertrace/github.com/jackc/hyperpgx" + +import ( + otelpgx "github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx" + "github.com/hypertrace/goagent/sdk/filter" +) + +type options struct { + Filter filter.Filter +} + +func (o *options) toSDKOptions() *otelpgx.Options { + opts := (otelpgx.Options)(*o) + return &opts +} + +type Option func(o *options) + +// WithFilter adds a filter to the GRPC option. +func WithFilter(f filter.Filter) Option { + return func(o *options) { + o.Filter = f + } +} diff --git a/instrumentation/hypertrace/github.com/jackc/hyperpgx/options_test.go b/instrumentation/hypertrace/github.com/jackc/hyperpgx/options_test.go new file mode 100644 index 00000000..6829d791 --- /dev/null +++ b/instrumentation/hypertrace/github.com/jackc/hyperpgx/options_test.go @@ -0,0 +1,15 @@ +package hyperpgx // import "github.com/hypertrace/goagent/instrumentation/hypertrace/github.com/jackc/hyperpgx" + +import ( + "testing" + + "github.com/hypertrace/goagent/sdk/filter" + "github.com/stretchr/testify/assert" +) + +func TestOptionsToSDK(t *testing.T) { + o := &options{ + Filter: filter.NoopFilter{}, + } + assert.Equal(t, filter.NoopFilter{}, o.toSDKOptions().Filter) +} diff --git a/instrumentation/hypertrace/github.com/jackc/hyperpgx/pgx.go b/instrumentation/hypertrace/github.com/jackc/hyperpgx/pgx.go index 7a04ecf8..809bfbca 100644 --- a/instrumentation/hypertrace/github.com/jackc/hyperpgx/pgx.go +++ b/instrumentation/hypertrace/github.com/jackc/hyperpgx/pgx.go @@ -1,5 +1,16 @@ package hyperpgx // import "github.com/hypertrace/goagent/instrumentation/hypertrace/github.com/jackc/hyperpgx" -import otelpgx "github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx" +import ( + "context" -var Connect = otelpgx.Connect + otelpgx "github.com/hypertrace/goagent/instrumentation/opentelemetry/github.com/jackc/hyperpgx" +) + +func Connect(ctx context.Context, connString string, opts ...Option) (otelpgx.PGXConn, error) { + o := &options{} + for _, opt := range opts { + opt(o) + } + + return otelpgx.Connect(ctx, connString, o.toSDKOptions()) +} diff --git a/instrumentation/hypertrace/google.golang.org/hypergrpc/client.go b/instrumentation/hypertrace/google.golang.org/hypergrpc/client.go index 0a5fd0bc..1e9c2ea1 100644 --- a/instrumentation/hypertrace/google.golang.org/hypergrpc/client.go +++ b/instrumentation/hypertrace/google.golang.org/hypergrpc/client.go @@ -11,10 +11,16 @@ import ( // for use in a grpc.Dial call. // Interceptor format will be replaced with the stats.Handler since instrumentation has moved to the stats.Handler. // See: https://github.com/open-telemetry/opentelemetry-go-contrib/blob/v1.36.0/instrumentation/google.golang.org/grpc/otelgrpc/example_test.go -func UnaryClientInterceptor() grpc.UnaryClientInterceptor { +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + o := &options{} + for _, opt := range opts { + opt(o) + } + return sdkgrpc.WrapUnaryClientInterceptor( grpcunaryinterceptors.UnaryClientInterceptor(), opentelemetry.SpanFromContext, + o.toSDKOptions(), map[string]string{}, ) } diff --git a/instrumentation/hypertrace/net/hyperhttp/transport.go b/instrumentation/hypertrace/net/hyperhttp/transport.go index baf2373d..2e38addd 100644 --- a/instrumentation/hypertrace/net/hyperhttp/transport.go +++ b/instrumentation/hypertrace/net/hyperhttp/transport.go @@ -10,8 +10,13 @@ import ( // NewTransport wraps the provided http.RoundTripper with one that // starts a span and injects the span context into the outbound request headers. -func NewTransport(base http.RoundTripper) http.RoundTripper { +func NewTransport(base http.RoundTripper, opts ...Option) http.RoundTripper { + o := &options{} + for _, opt := range opts { + opt(o) + } + return otelhttp.NewTransport( - sdkhttp.WrapTransport(base, opentelemetry.SpanFromContext, map[string]string{}), + sdkhttp.WrapTransport(base, opentelemetry.SpanFromContext, o.toSDKOptions(), map[string]string{}), ) } diff --git a/instrumentation/opentelemetry/database/hypersql/sql.go b/instrumentation/opentelemetry/database/hypersql/sql.go index 28ec9e5a..4231d073 100644 --- a/instrumentation/opentelemetry/database/hypersql/sql.go +++ b/instrumentation/opentelemetry/database/hypersql/sql.go @@ -2,19 +2,18 @@ package hypersql // import "github.com/hypertrace/goagent/instrumentation/opente import ( "database/sql/driver" - "github.com/hypertrace/goagent/instrumentation/opentelemetry" sdkSQL "github.com/hypertrace/goagent/sdk/instrumentation/database/sql" ) // Wrap takes a SQL driver and wraps it with Hypertrace instrumentation. -func Wrap(d driver.Driver) driver.Driver { - return sdkSQL.Wrap(d, opentelemetry.StartSpan) +func Wrap(d driver.Driver, options *sdkSQL.Options) driver.Driver { + return sdkSQL.Wrap(d, opentelemetry.StartSpan, options) } // Register initializes and registers the hypersql wrapped database driver // identified by its driverName. On success it // returns the generated driverName to use when calling hypersql.Open. -func Register(driverName string) (string, error) { - return sdkSQL.Register(driverName, opentelemetry.StartSpan) +func Register(driverName string, options *sdkSQL.Options) (string, error) { + return sdkSQL.Register(driverName, opentelemetry.StartSpan, options) } diff --git a/instrumentation/opentelemetry/database/hypersql/sql_test.go b/instrumentation/opentelemetry/database/hypersql/sql_test.go index 0a7a02ce..73a2ad2f 100644 --- a/instrumentation/opentelemetry/database/hypersql/sql_test.go +++ b/instrumentation/opentelemetry/database/hypersql/sql_test.go @@ -8,16 +8,27 @@ import ( "testing" "github.com/hypertrace/goagent/instrumentation/opentelemetry/internal/tracetesting" + "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter/result" + sdkSQL "github.com/hypertrace/goagent/sdk/instrumentation/database/sql" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" sdktrace "go.opentelemetry.io/otel/sdk/trace" apitrace "go.opentelemetry.io/otel/trace" ) +type mockFilter struct { + evaluator func(span sdk.Span) result.FilterResult +} + +func (f *mockFilter) Evaluate(span sdk.Span) result.FilterResult { + return f.evaluator(span) +} + func createDB(t *testing.T) (*sql.DB, func() []sdktrace.ReadOnlySpan) { _, flusher := tracetesting.InitTracer() - driverName, err := Register("sqlite3") + driverName, err := Register("sqlite3", nil) if err != nil { t.Fatalf("unable to register driver") } @@ -187,3 +198,57 @@ func TestTxWithRollbackSuccess(t *testing.T) { db.Close() } + +func TestFilter(t *testing.T) { + _, flusher := tracetesting.InitTracer() + + driverName, err := Register("sqlite3", &sdkSQL.Options{ + Filter: &mockFilter{ + evaluator: func(span sdk.Span) result.FilterResult { + assert.Equal(t, span.GetAttributes().GetValue("span.kind"), "client") + + span.SetAttribute("span.type", "nospan") + return result.FilterResult{} + }, + }, + }) + if err != nil { + t.Fatalf("unable to register driver") + } + + db, err := sql.Open(driverName, "file:test.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query("SELECT 1 WHERE 1 = ?", 1) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + defer rows.Close() + + for rows.Next() { + var n int + if err = rows.Scan(&n); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + } + if err = rows.Err(); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + spans := flusher() + assert.Equal(t, 1, len(spans)) + + span := spans[0] + assert.Equal(t, "db:query", span.Name()) + assert.Equal(t, apitrace.SpanKindClient, span.SpanKind()) + + attrs := tracetesting.LookupAttributes(span.Attributes()) + assert.Equal(t, "SELECT 1 WHERE 1 = ?", attrs.Get("db.statement").AsString()) + assert.Equal(t, "sqlite", attrs.Get("db.system").AsString()) + assert.False(t, attrs.Has("error")) + assert.Equal(t, "nospan", attrs.Get("span.type").AsString()) + + db.Close() +} diff --git a/instrumentation/opentelemetry/github.com/jackc/hyperpgx/pgx.go b/instrumentation/opentelemetry/github.com/jackc/hyperpgx/pgx.go index 9443d9d2..0a0a7983 100644 --- a/instrumentation/opentelemetry/github.com/jackc/hyperpgx/pgx.go +++ b/instrumentation/opentelemetry/github.com/jackc/hyperpgx/pgx.go @@ -2,10 +2,11 @@ package hyperpgx // import "github.com/hypertrace/goagent/instrumentation/opente import ( "context" - "database/sql/driver" + "database/sql/driver" "github.com/hypertrace/goagent/instrumentation/opentelemetry" "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter" "github.com/jackc/pgconn" "github.com/jackc/pgtype/pgxtype" "github.com/jackc/pgx/v4" @@ -34,11 +35,17 @@ type PGXConn interface { Close(ctx context.Context) error } +type Options struct { + Filter filter.Filter +} + var _ PGXConn = (*wrappedConn)(nil) type wrappedConn struct { delegate *pgx.Conn connAttrs map[string]string + filter filter.Filter + startSpan func(ctx context.Context, name string) (context.Context, sdk.Span, func()) } var _ pgx.Row = (*wrappedRow)(nil) @@ -58,7 +65,7 @@ func (r *wrappedRow) Scan(dest ...interface{}) error { } func (w *wrappedConn) Query(ctx context.Context, query string, optionsAndArgs ...interface{}) (pgx.Rows, error) { - ctx, span, closer := opentelemetry.StartSpan(ctx, "db:query", &sdk.SpanOptions{Kind: sdk.SpanKindClient}) + ctx, span, closer := w.startSpan(ctx, "db:query") defer closer() for k, v := range w.connAttrs { @@ -75,7 +82,7 @@ func (w *wrappedConn) Query(ctx context.Context, query string, optionsAndArgs .. } func (w *wrappedConn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { - ctx, span, closer := opentelemetry.StartSpan(ctx, "db:query", &sdk.SpanOptions{Kind: sdk.SpanKindClient}) + ctx, span, closer := w.startSpan(ctx, "db:query") defer closer() for k, v := range w.connAttrs { @@ -87,7 +94,7 @@ func (w *wrappedConn) QueryRow(ctx context.Context, sql string, optionsAndArgs . } func (w *wrappedConn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { - ctx, span, closer := opentelemetry.StartSpan(ctx, "exec", &sdk.SpanOptions{Kind: sdk.SpanKindClient}) + ctx, span, closer := w.startSpan(ctx, "exec") defer closer() for k, v := range w.connAttrs { @@ -108,7 +115,7 @@ func (w *wrappedConn) Ping(ctx context.Context) error { } func (w *wrappedConn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { - ctx, span, closer := opentelemetry.StartSpan(ctx, "exec", &sdk.SpanOptions{Kind: sdk.SpanKindClient}) + ctx, span, closer := w.startSpan(ctx, "exec") defer closer() for k, v := range w.connAttrs { @@ -134,7 +141,7 @@ func (w *wrappedConn) Close(ctx context.Context) error { var _ PGXConn = (*wrappedConn)(nil) -func Connect(ctx context.Context, connString string) (PGXConn, error) { +func Connect(ctx context.Context, connString string, options *Options) (PGXConn, error) { conn, err := pgx.Connect(ctx, connString) if err != nil { return conn, err @@ -145,5 +152,23 @@ func Connect(ctx context.Context, connString string) (PGXConn, error) { connAttrs["db.system"] = "postgres" } - return &wrappedConn{conn, connAttrs}, nil + var filter filter.Filter = filter.NoopFilter{} + if options != nil && options.Filter != nil { + filter = options.Filter + } + samplingSpanStarter := func(ctx context.Context, name string) (context.Context, sdk.Span, func()) { + ctx, span, closer := opentelemetry.StartSpan(ctx, name, &sdk.SpanOptions{Kind: sdk.SpanKindClient}) + span.SetAttribute("span.kind", "client") + + return ctx, span, func() { + _ = filter.Evaluate(span) + closer() + } + } + + return &wrappedConn{ + delegate: conn, + connAttrs: connAttrs, + startSpan: samplingSpanStarter, + }, nil } diff --git a/instrumentation/opentelemetry/google.golang.org/hypergrpc/client.go b/instrumentation/opentelemetry/google.golang.org/hypergrpc/client.go index d910d022..e7baad9f 100644 --- a/instrumentation/opentelemetry/google.golang.org/hypergrpc/client.go +++ b/instrumentation/opentelemetry/google.golang.org/hypergrpc/client.go @@ -8,6 +8,6 @@ import ( // WrapUnaryClientInterceptor returns a new unary client interceptor that will // complement existing OpenTelemetry instrumentation -func WrapUnaryClientInterceptor(delegate grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { - return sdkgrpc.WrapUnaryClientInterceptor(delegate, opentelemetry.SpanFromContext, map[string]string{}) +func WrapUnaryClientInterceptor(delegate grpc.UnaryClientInterceptor, options *sdkgrpc.Options) grpc.UnaryClientInterceptor { + return sdkgrpc.WrapUnaryClientInterceptor(delegate, opentelemetry.SpanFromContext, options, map[string]string{}) } diff --git a/instrumentation/opentelemetry/google.golang.org/hypergrpc/client_test.go b/instrumentation/opentelemetry/google.golang.org/hypergrpc/client_test.go index cac30c91..01b69253 100644 --- a/instrumentation/opentelemetry/google.golang.org/hypergrpc/client_test.go +++ b/instrumentation/opentelemetry/google.golang.org/hypergrpc/client_test.go @@ -7,6 +7,7 @@ import ( "github.com/hypertrace/goagent/instrumentation/opentelemetry/google.golang.org/hypergrpc/internal/helloworld" "github.com/hypertrace/goagent/instrumentation/opentelemetry/grpcunaryinterceptors" "github.com/hypertrace/goagent/instrumentation/opentelemetry/internal/tracetesting" + sdkgrpc "github.com/hypertrace/goagent/sdk/instrumentation/google.golang.org/grpc" "github.com/stretchr/testify/assert" otelcodes "go.opentelemetry.io/otel/codes" "google.golang.org/grpc" @@ -39,6 +40,7 @@ func TestClientHelloWorldSuccess(t *testing.T) { grpc.WithUnaryInterceptor( WrapUnaryClientInterceptor( grpcunaryinterceptors.UnaryClientInterceptor(), + &sdkgrpc.Options{}, ), ), ) @@ -113,6 +115,7 @@ func TestClientRegisterPersonFails(t *testing.T) { grpc.WithUnaryInterceptor( WrapUnaryClientInterceptor( grpcunaryinterceptors.UnaryClientInterceptor(), + &sdkgrpc.Options{}, ), ), ) @@ -159,6 +162,7 @@ func BenchmarkClientRequestResponseBodyMarshaling(b *testing.B) { grpc.WithUnaryInterceptor( WrapUnaryClientInterceptor( grpcunaryinterceptors.UnaryClientInterceptor(), + &sdkgrpc.Options{}, ), ), ) diff --git a/instrumentation/opentelemetry/net/hyperhttp/transport.go b/instrumentation/opentelemetry/net/hyperhttp/transport.go index d659b42e..ee834c87 100644 --- a/instrumentation/opentelemetry/net/hyperhttp/transport.go +++ b/instrumentation/opentelemetry/net/hyperhttp/transport.go @@ -10,6 +10,6 @@ import ( // WrapTransport wraps an uninstrumented RoundTripper (e.g. http.DefaultTransport) // and returns an instrumented RoundTripper that has to be used as base for the // OTel's RoundTripper. -func WrapTransport(delegate http.RoundTripper) http.RoundTripper { - return sdkhttp.WrapTransport(delegate, opentelemetry.SpanFromContext, map[string]string{}) +func WrapTransport(delegate http.RoundTripper, options *sdkhttp.Options) http.RoundTripper { + return sdkhttp.WrapTransport(delegate, opentelemetry.SpanFromContext, options, map[string]string{}) } diff --git a/instrumentation/opentelemetry/net/hyperhttp/transport_test.go b/instrumentation/opentelemetry/net/hyperhttp/transport_test.go index 4424d7bb..16be6264 100644 --- a/instrumentation/opentelemetry/net/hyperhttp/transport_test.go +++ b/instrumentation/opentelemetry/net/hyperhttp/transport_test.go @@ -10,17 +10,16 @@ import ( "net/http/httptest" "testing" - "google.golang.org/protobuf/types/known/wrapperspb" - config "github.com/hypertrace/agent-config/gen/go/v1" "github.com/hypertrace/goagent/instrumentation/opentelemetry/internal/tracetesting" sdkconfig "github.com/hypertrace/goagent/sdk/config" - "go.opentelemetry.io/otel/propagation" - + sdkhttp "github.com/hypertrace/goagent/sdk/instrumentation/net/http" "github.com/stretchr/testify/assert" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/contrib/propagators/b3" + "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestClientRequestIsSuccessfullyTraced(t *testing.T) { @@ -53,7 +52,7 @@ func TestClientRequestIsSuccessfullyTraced(t *testing.T) { client := &http.Client{ Transport: otelhttp.NewTransport( - WrapTransport(http.DefaultTransport), + WrapTransport(http.DefaultTransport, &sdkhttp.Options{}), ), } @@ -104,7 +103,7 @@ func TestClientFailureRequestIsSuccessfullyTraced(t *testing.T) { expectedErr := errors.New("roundtrip error") client := &http.Client{ Transport: otelhttp.NewTransport( - WrapTransport(failingTransport{expectedErr}), + WrapTransport(failingTransport{expectedErr}, &sdkhttp.Options{}), ), } @@ -164,7 +163,7 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { client := &http.Client{ Transport: otelhttp.NewTransport( - WrapTransport(http.DefaultTransport), + WrapTransport(http.DefaultTransport, &sdkhttp.Options{}), ), } @@ -216,7 +215,7 @@ func TestTransportRequestInjectsHeadersSuccessfully(t *testing.T) { client := &http.Client{ Transport: otelhttp.NewTransport( - WrapTransport(http.DefaultTransport), + WrapTransport(http.DefaultTransport, &sdkhttp.Options{}), ), } diff --git a/sdk/instrumentation/database/sql/sql.go b/sdk/instrumentation/database/sql/sql.go index 854432f3..67a79da4 100644 --- a/sdk/instrumentation/database/sql/sql.go +++ b/sdk/instrumentation/database/sql/sql.go @@ -2,19 +2,23 @@ package sql // import "github.com/hypertrace/goagent/sdk/instrumentation/databas import ( "context" - stdSQL "database/sql" - "database/sql/driver" "fmt" + "reflect" "sync" + stdSQL "database/sql" + "database/sql/driver" "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter" "github.com/ngrok/sqlmw" - - "reflect" ) var regMu sync.Mutex +type Options struct { + Filter filter.Filter +} + type interceptor struct { sqlmw.NullInterceptor startSpan sdk.StartSpan @@ -188,16 +192,32 @@ func (w *dsnReadWrapper) parseDSNAttributes(dsn string) map[string]string { } // Wrap takes a SQL driver and wraps it with Hypertrace instrumentation. -func Wrap(d driver.Driver, startSpan sdk.StartSpan) driver.Driver { +func Wrap(d driver.Driver, startSpan sdk.StartSpan, options *Options) driver.Driver { driverName := getDriverName(d) - in := &interceptor{startSpan: startSpan} + + var filter filter.Filter = filter.NoopFilter{} + if options != nil && options.Filter != nil { + filter = options.Filter + } + filteringSpanStarter := func(ctx context.Context, name string, opts *sdk.SpanOptions) (context.Context, sdk.Span, func()) { + ctx, span, end := startSpan(ctx, name, opts) + span.SetAttribute("span.kind", "client") + + return ctx, span, func() { + _ = filter.Evaluate(span) + end() + } + } + in := &interceptor{ + startSpan: filteringSpanStarter, + } return &dsnReadWrapper{Driver: sqlmw.Driver(d, in), driverName: driverName, inDefaultAttributes: &in.defaultAttributes} } // Register initializes and registers the hypersql wrapped database driver // identified by its driverName. On success it // returns the generated driverName to use when calling hypersql.Open. -func Register(driverName string, startSpan sdk.StartSpan) (string, error) { +func Register(driverName string, startSpan sdk.StartSpan, options *Options) (string, error) { // retrieve the driver implementation we need to wrap with instrumentation db, err := stdSQL.Open(driverName, "") if err != nil { @@ -212,6 +232,6 @@ func Register(driverName string, startSpan sdk.StartSpan) (string, error) { defer regMu.Unlock() hyperDriverName := fmt.Sprintf("hyper-%s-%d", driverName, len(stdSQL.Drivers())) - stdSQL.Register(hyperDriverName, Wrap(dri, startSpan)) + stdSQL.Register(hyperDriverName, Wrap(dri, startSpan, options)) return hyperDriverName, nil } diff --git a/sdk/instrumentation/database/sql/sql_test.go b/sdk/instrumentation/database/sql/sql_test.go index 14090f80..d8204727 100644 --- a/sdk/instrumentation/database/sql/sql_test.go +++ b/sdk/instrumentation/database/sql/sql_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter/result" "github.com/hypertrace/goagent/sdk/internal/mock" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" @@ -29,7 +30,7 @@ func (sb *spansBuffer) StartSpan(ctx context.Context, name string, opts *sdk.Spa func createDB(t *testing.T) (*sql.DB, func() []*mock.Span) { b := &spansBuffer{} - driverName, err := Register("sqlite3", b.StartSpan) + driverName, err := Register("sqlite3", b.StartSpan, nil) if err != nil { t.Fatalf("unable to register driver") } @@ -72,6 +73,7 @@ func TestQuerySuccess(t *testing.T) { assert.Equal(t, "SELECT 1 WHERE 1 = ?", span.ReadAttribute("db.statement").(string)) assert.Equal(t, "sqlite", span.ReadAttribute("db.system").(string)) assert.Nil(t, span.ReadAttribute("error")) + assert.Equal(t, span.ReadAttribute("span.kind"), "client") assert.Zero(t, span.RemainingAttributes()) db.Close() @@ -119,6 +121,7 @@ func TestExecSuccess(t *testing.T) { assert.Equal(t, sdk.SpanKindClient, span.Options.Kind) assert.Equal(t, "db:exec", span.Name) assert.Nil(t, span.ReadAttribute("error")) + assert.Equal(t, span.ReadAttribute("span.kind"), "client") } func TestTxWithCommitSuccess(t *testing.T) { @@ -165,6 +168,7 @@ func TestTxWithCommitSuccess(t *testing.T) { assert.Equal(t, sdk.SpanKindClient, spans[i].Options.Kind) assert.Equal(t, sdk.StatusCodeOk, spans[i].Status.Code) assert.Nil(t, spans[i].ReadAttribute("error")) + assert.Equal(t, spans[i].ReadAttribute("span.kind"), "client") } db.Close() @@ -211,7 +215,47 @@ func TestTxWithRollbackSuccess(t *testing.T) { assert.Equal(t, sdk.SpanKindClient, spans[i].Options.Kind) assert.Equal(t, sdk.StatusCodeOk, spans[i].Status.Code) assert.Nil(t, spans[i].ReadAttribute("error")) + assert.Equal(t, spans[i].ReadAttribute("span.kind"), "client") } db.Close() } + +func TestFilter(t *testing.T) { + b := &spansBuffer{} + + driverName, err := Register("sqlite3", b.StartSpan, &Options{ + Filter: mock.Filter{ + Evaluator: func(span sdk.Span) result.FilterResult { + assert.Equal(t, span.GetAttributes().GetValue("span.kind"), "client") + span.SetAttribute("span.type", "nospan") + return result.FilterResult{} + }, + }, + }) + if err != nil { + t.Fatalf("unable to register driver") + } + + db, err := sql.Open(driverName, "file:test.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + + flusher := func() []*mock.Span { return b.spans } + + _, err = db.Query("SELECT * FROM unexistent") + require.Error(t, err) + + spans := flusher() + assert.Equal(t, 1, len(spans)) + + span := spans[0] + assert.Equal(t, "db:query", span.Name) + assert.Equal(t, "no such table: unexistent", span.Err.Error()) + assert.Equal(t, sdk.SpanKindClient, span.Options.Kind) + assert.Equal(t, sdk.StatusCodeError, span.Status.Code) + assert.Equal(t, "nospan", span.GetAttributes().GetValue("span.type")) + + db.Close() +} diff --git a/sdk/instrumentation/google.golang.org/grpc/client.go b/sdk/instrumentation/google.golang.org/grpc/client.go index 6fb080a6..63dd9d51 100644 --- a/sdk/instrumentation/google.golang.org/grpc/client.go +++ b/sdk/instrumentation/google.golang.org/grpc/client.go @@ -5,16 +5,27 @@ import ( "strings" "github.com/hypertrace/goagent/sdk" + codes "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter" internalconfig "github.com/hypertrace/goagent/sdk/internal/config" "github.com/hypertrace/goagent/sdk/internal/container" "google.golang.org/grpc" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // WrapUnaryClientInterceptor returns an interceptor that records the request and response message's body // and serialize it as JSON. -func WrapUnaryClientInterceptor(delegateInterceptor grpc.UnaryClientInterceptor, spanFromContext sdk.SpanFromContext, +func WrapUnaryClientInterceptor( + delegateInterceptor grpc.UnaryClientInterceptor, + spanFromContext sdk.SpanFromContext, + options *Options, spanAttributes map[string]string) grpc.UnaryClientInterceptor { + var filter filter.Filter = filter.NoopFilter{} + if options != nil && options.Filter != nil { + filter = options.Filter + } + defaultAttributes := map[string]string{ "rpc.system": "grpc", } @@ -36,6 +47,7 @@ func WrapUnaryClientInterceptor(delegateInterceptor grpc.UnaryClientInterceptor, // span). wrappedInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { span := spanFromContext(ctx) + span.SetAttribute("span.kind", "client") if span.IsNoop() || span == nil { // isNoop means either the span is not sampled or there was no span // in the request context which means this invoker is not used @@ -60,6 +72,23 @@ func WrapUnaryClientInterceptor(delegateInterceptor grpc.UnaryClientInterceptor, setAttributesFromRequestOutgoingMetadata(ctx, span) } + fr := filter.Evaluate(span) + if fr.Block { + statusText := StatusText(int(fr.ResponseStatusCode)) + statusCode := StatusCode(int(fr.ResponseStatusCode)) + span.SetStatus(codes.StatusCodeError, statusText) + span.SetAttribute("rpc.grpc.status_code", statusCode) + return status.Error(statusCode, statusText) + } else if fr.Decorations != nil { + if md, ok := metadata.FromOutgoingContext(ctx); ok { + for _, header := range fr.Decorations.RequestHeaderInjections { + md.Append(header.Key, header.Value) + span.SetAttribute("rpc.request.metadata."+header.Key, header.Value) + } + ctx = metadata.NewIncomingContext(ctx, md) + } + } + err = invoker(ctx, method, req, reply, cc, opts...) if err != nil { return err diff --git a/sdk/instrumentation/google.golang.org/grpc/client_test.go b/sdk/instrumentation/google.golang.org/grpc/client_test.go index f20bf78f..e1d9cc2e 100644 --- a/sdk/instrumentation/google.golang.org/grpc/client_test.go +++ b/sdk/instrumentation/google.golang.org/grpc/client_test.go @@ -6,12 +6,16 @@ import ( "strings" "testing" + "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter/result" "github.com/hypertrace/goagent/sdk/instrumentation/google.golang.org/grpc/internal/helloworld" "github.com/hypertrace/goagent/sdk/internal/mock" "github.com/stretchr/testify/assert" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) func makeMockUnaryClientInterceptor(mockSpans *[]*mock.Span) grpc.UnaryClientInterceptor { @@ -46,6 +50,7 @@ func TestUnaryClientHelloWorldSuccess(t *testing.T) { WrapUnaryClientInterceptor( makeMockUnaryClientInterceptor(&spans), mock.SpanFromContext, + &Options{}, map[string]string{"foo": "bar"}, ), ), @@ -96,6 +101,7 @@ func TestUnaryClientHelloWorldSuccess(t *testing.T) { } else { t.Fatalf("unexpected error: %v", err) } + assert.Equal(t, "client", span.ReadAttribute("span.kind").(string)) _ = span.ReadAttribute("container_id") // needed in containarized envs assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) @@ -191,6 +197,7 @@ func TestBodyTruncation(t *testing.T) { WrapUnaryClientInterceptor( makeMockUnaryClientInterceptor(&spans), mock.SpanFromContext, + &Options{}, map[string]string{"foo": "bar"}, ), ), @@ -240,7 +247,93 @@ func TestBodyTruncation(t *testing.T) { actualBody = span.ReadAttribute("rpc.response.body").(string) // direct comparison of the body since it will be truncated assert.Equal(t, expectedBody, actualBody) + assert.Equal(t, "client", span.ReadAttribute("span.kind").(string)) _ = span.ReadAttribute("container_id") // needed in containarized envs assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } + +func TestClientFilter(t *testing.T) { + + s := grpc.NewServer() + defer s.Stop() + + helloworld.RegisterGreeterServer(s, &server{ + replyHeader: metadata.Pairs("test_header_key", "test_header_value"), + replyTrailer: metadata.Pairs("test_trailer_key", "test_trailer_value"), + }) + + dialer := createDialer(s) + + tests := []struct { + name string + block bool + }{ + { + name: "blocking disabled", + block: false, + }, + { + name: "blocking enabled", + block: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + spans := []*mock.Span{} + ctx := context.Background() + conn, err := grpc.DialContext( + ctx, + "bufnet", + grpc.WithContextDialer(dialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUnaryInterceptor( + WrapUnaryClientInterceptor( + makeMockUnaryClientInterceptor(&spans), + mock.SpanFromContext, + &Options{ + Filter: mock.Filter{ + Evaluator: func(span sdk.Span) result.FilterResult { + span.SetAttribute("filter.evaluated", true) + return result.FilterResult{ + Block: tt.block, + ResponseStatusCode: 403, + ResponseMessage: "Access Denied", + } + }, + }, + }, + map[string]string{"foo": "bar"}, + ), + ), + ) + if err != nil { + t.Fatalf("failed to dial bufnet: %v", err) + } + defer conn.Close() + + client := helloworld.NewGreeterClient(conn) + + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("test_key_1", "test_value_1")) + _, err = client.SayHello( + ctx, + &helloworld.HelloRequest{ + Name: "Cuchi", + }, + ) + + if !tt.block { + assert.NoError(t, err) + } else { + assert.Error(t, err) + assert.Equal(t, codes.PermissionDenied, status.Code(err)) + } + + assert.Equal(t, 1, len(spans)) + span := spans[0] + assert.True(t, span.ReadAttribute("filter.evaluated").(bool)) + assert.Equal(t, "client", span.ReadAttribute("span.kind")) + }) + } +} diff --git a/sdk/instrumentation/net/http/transport.go b/sdk/instrumentation/net/http/transport.go index 72d03812..3c81a7b8 100644 --- a/sdk/instrumentation/net/http/transport.go +++ b/sdk/instrumentation/net/http/transport.go @@ -4,9 +4,12 @@ import ( "bytes" "io" "net/http" + "strings" config "github.com/hypertrace/agent-config/gen/go/v1" "github.com/hypertrace/goagent/sdk" + codes "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter" internalconfig "github.com/hypertrace/goagent/sdk/internal/config" "github.com/hypertrace/goagent/sdk/internal/container" ) @@ -18,6 +21,7 @@ type roundTripper struct { defaultAttributes map[string]string spanFromContextRetriever sdk.SpanFromContext dataCaptureConfig *config.DataCapture + filter filter.Filter } func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -30,6 +34,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return rt.delegate.RoundTrip(req) } reqHeadersAccessor := NewHeaderMapAccessor(req.Header) + span.SetAttribute("span.kind", "client") for key, value := range rt.defaultAttributes { span.SetAttribute(key, value) @@ -58,6 +63,29 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { req.Body = io.NopCloser(bytes.NewBuffer(body)) } + filterResult := rt.filter.Evaluate(span) + if filterResult.Block { + span.SetStatus(codes.StatusCodeError, "Access Denied") + span.SetAttribute("http.status_code", filterResult.ResponseStatusCode) + return &http.Response{ + Status: "Access Denied", + StatusCode: int(filterResult.ResponseStatusCode), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: req, + Header: map[string][]string{ + "Content-Type": {"text/plain"}, + }, + Body: io.NopCloser(strings.NewReader(filterResult.ResponseMessage)), + }, nil + } else if filterResult.Decorations != nil { + for _, header := range filterResult.Decorations.RequestHeaderInjections { + req.Header.Add(header.Key, header.Value) + span.SetAttribute("http.request.header."+header.Key, header.Value) + } + } + res, err := rt.delegate.RoundTrip(req) if err != nil { return res, err @@ -90,7 +118,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // WrapTransport returns a new http.RoundTripper that should be wrapped // by an instrumented http.RoundTripper -func WrapTransport(delegate http.RoundTripper, spanFromContextRetriever sdk.SpanFromContext, spanAttributes map[string]string) http.RoundTripper { +func WrapTransport(delegate http.RoundTripper, spanFromContextRetriever sdk.SpanFromContext, options *Options, spanAttributes map[string]string) http.RoundTripper { defaultAttributes := make(map[string]string) for k, v := range spanAttributes { defaultAttributes[k] = v @@ -99,5 +127,15 @@ func WrapTransport(delegate http.RoundTripper, spanFromContextRetriever sdk.Span defaultAttributes["container_id"] = containerID } - return &roundTripper{delegate, defaultAttributes, spanFromContextRetriever, internalconfig.GetConfig().GetDataCapture()} + var filter filter.Filter = filter.NoopFilter{} + if options != nil && options.Filter != nil { + filter = options.Filter + } + return &roundTripper{ + delegate: delegate, + defaultAttributes: defaultAttributes, + spanFromContextRetriever: spanFromContextRetriever, + dataCaptureConfig: internalconfig.GetConfig().GetDataCapture(), + filter: filter, + } } diff --git a/sdk/instrumentation/net/http/transport_test.go b/sdk/instrumentation/net/http/transport_test.go index 093aa4b5..54362819 100644 --- a/sdk/instrumentation/net/http/transport_test.go +++ b/sdk/instrumentation/net/http/transport_test.go @@ -12,6 +12,8 @@ import ( "testing" config "github.com/hypertrace/agent-config/gen/go/v1" + "github.com/hypertrace/goagent/sdk" + "github.com/hypertrace/goagent/sdk/filter/result" internalconfig "github.com/hypertrace/goagent/sdk/internal/config" "github.com/hypertrace/goagent/sdk/internal/mock" "github.com/stretchr/testify/assert" @@ -36,7 +38,7 @@ func TestClientRequestIsSuccessfullyTraced(t *testing.T) { })) defer srv.Close() - rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, map[string]string{"foo": "bar"}).(*roundTripper) + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, &Options{}, map[string]string{"foo": "bar"}).(*roundTripper) rt.dataCaptureConfig = &config.DataCapture{ HttpHeaders: &config.Message{ Request: config.Bool(false), @@ -76,6 +78,7 @@ func TestClientRequestIsSuccessfullyTraced(t *testing.T) { _ = span.ReadAttribute("container_id") // needed in containarized envs // custom attribute assert.Equal(t, "bar", span.ReadAttribute("foo").(string)) + assert.Equal(t, "client", span.ReadAttribute("span.kind")) // We make sure we read all attributes and covered them with tests assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } @@ -100,7 +103,7 @@ func TestClientRequestHeadersAreCapturedAccordingly(t *testing.T) { })) defer srv.Close() - rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, map[string]string{"foo": "bar"}).(*roundTripper) + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, &Options{}, map[string]string{"foo": "bar"}).(*roundTripper) rt.dataCaptureConfig = &config.DataCapture{ HttpHeaders: &config.Message{ Request: config.Bool(tCase.captureHTTPHeadersRequestConfig), @@ -168,7 +171,7 @@ func TestClientFailureRequestIsSuccessfullyTraced(t *testing.T) { expectedErr := errors.New("roundtrip error") client := &http.Client{ Transport: &mockTransport{ - baseRoundTripper: WrapTransport(failingTransport{expectedErr}, mock.SpanFromContext, map[string]string{}), + baseRoundTripper: WrapTransport(failingTransport{expectedErr}, mock.SpanFromContext, &Options{}, map[string]string{}), }, } @@ -297,7 +300,7 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { })) defer srv.Close() - rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, map[string]string{}).(*roundTripper) + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, &Options{}, map[string]string{}).(*roundTripper) rt.dataCaptureConfig = &config.DataCapture{ HttpBody: &config.Message{ Request: config.Bool(tCase.captureHTTPBodyConfig), @@ -377,3 +380,110 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { }) } } + +func TestFilter(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(202) + rw.Write([]byte(`{"id":123}`)) + })) + defer srv.Close() + + dcCfg := &config.DataCapture{ + HttpHeaders: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + HttpBody: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + BodyMaxSizeBytes: config.Int32(1000), + } + + tests := []struct { + name string + block bool + }{ + { + name: "blocking enabled", + block: true, + }, + { + name: "blocking disabled", + block: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filter := mock.Filter{ + Evaluator: func(span sdk.Span) result.FilterResult { + span.SetAttribute("filter.evaluated", true) + return result.FilterResult{ + Block: tt.block, + ResponseStatusCode: 403, + ResponseMessage: "Access Denied", + } + }, + } + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext, &Options{ + Filter: filter, + }, map[string]string{"foo": "bar"}).(*roundTripper) + rt.dataCaptureConfig = dcCfg + + tr := &mockTransport{ + baseRoundTripper: rt, + } + client := &http.Client{ + Transport: tr, + } + + req, _ := http.NewRequest("POST", srv.URL, bytes.NewBufferString(`{"name":"Jacinto"}`)) + res, err := client.Do(req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if !tt.block { + assert.Equal(t, 202, res.StatusCode) + resBody, err := io.ReadAll(res.Body) + assert.Nil(t, err) + assert.Equal(t, `{"id":123}`, string(resBody)) + + spans := tr.spans + assert.Equal(t, 1, len(spans), "unexpected number of spans") + + span := spans[0] + + _ = span.ReadAttribute("container_id") // needed in containarized envs + // custom attribute + assert.Equal(t, "bar", span.ReadAttribute("foo").(string)) + assert.True(t, span.ReadAttribute("filter.evaluated").(bool)) + assert.Equal(t, "client", span.ReadAttribute("span.kind").(string)) + // We make sure we read all attributes and covered them with tests + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) + } else { + assert.Equal(t, 403, res.StatusCode) + resBody, err := io.ReadAll(res.Body) + assert.Nil(t, err) + assert.Equal(t, `Access Denied`, string(resBody)) + + spans := tr.spans + assert.Equal(t, 1, len(spans), "unexpected number of spans") + + span := spans[0] + + _ = span.ReadAttribute("container_id") // needed in containarized envs + // custom attribute + assert.Equal(t, "bar", span.ReadAttribute("foo").(string)) + assert.Equal(t, int32(403), span.ReadAttribute("http.status_code").(int32)) + assert.True(t, span.ReadAttribute("filter.evaluated").(bool)) + assert.Equal(t, "client", span.ReadAttribute("span.kind").(string)) + // We make sure we read all attributes and covered them with tests + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) + + } + }) + } + +}