From 8536bd65de19404ae59d762d0923cc7e5116d875 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 16:55:17 +0800 Subject: [PATCH 01/11] feat: add auth proto option --- proto/buf.gen.yaml | 6 ++- proto/buf.lock | 4 +- proto/kit/auth/v1/auth.pb.go | 84 ++++++++++++++++++++++++++++++++ proto/kit/auth/v1/auth.proto | 12 +++++ proto/kit/redact/v1/redact.pb.go | 4 +- 5 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 proto/kit/auth/v1/auth.pb.go create mode 100644 proto/kit/auth/v1/auth.proto diff --git a/proto/buf.gen.yaml b/proto/buf.gen.yaml index fa5facd..44e7c8b 100644 --- a/proto/buf.gen.yaml +++ b/proto/buf.gen.yaml @@ -8,9 +8,11 @@ plugins: - run - ../errors/protoc-gen-kit-errors/main.go out: . + exclude_types: + - kit.errors.v1 - local: - go - run - - ../protoc-gen-go-redact/main.go + - ../logging/protoc-gen-go-redact/main.go out: . - opt: paths=source_relative \ No newline at end of file + opt: paths=source_relative diff --git a/proto/buf.lock b/proto/buf.lock index ace68b7..8447589 100644 --- a/proto/buf.lock +++ b/proto/buf.lock @@ -2,5 +2,5 @@ version: v2 deps: - name: buf.build/googleapis/googleapis - commit: b30c5775bfb3485d9da2e87b26590ac9 - digest: b5:13f091a467b31c7f734307e6760d864e3319b9c47656f2ada6efa45c643864d9c9e7d5cd372c92cc8e0972deb63f41bc8fc88a5ca21ab2e9ea04d2144752857d + commit: c17df5b2beca46928cc87d5656bd5343 + digest: b5:648a01e0170d4512dea7d564016165decd1ed6e34bef79fe54753e51ad7e27545709ad9157d7551270147d551155c595a2fb0bf5bb33b1c83040ddbce915c604 diff --git a/proto/kit/auth/v1/auth.pb.go b/proto/kit/auth/v1/auth.pb.go new file mode 100644 index 0000000..5c7797f --- /dev/null +++ b/proto/kit/auth/v1/auth.pb.go @@ -0,0 +1,84 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: kit/auth/v1/auth.proto + +package authv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +var file_kit_auth_v1_auth_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 105001, + Name: "kit.auth.v1.public", + Tag: "varint,105001,opt,name=public", + Filename: "kit/auth/v1/auth.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // public marks an RPC operation as not requiring authenticated user context. + // + // optional bool public = 105001; + E_Public = &file_kit_auth_v1_auth_proto_extTypes[0] +) + +var File_kit_auth_v1_auth_proto protoreflect.FileDescriptor + +const file_kit_auth_v1_auth_proto_rawDesc = "" + + "\n" + + "\x16kit/auth/v1/auth.proto\x12\vkit.auth.v1\x1a google/protobuf/descriptor.proto:8\n" + + "\x06public\x12\x1e.google.protobuf.MethodOptions\x18\xa9\xb4\x06 \x01(\bR\x06publicB8Z6github.com/crypto-zero/go-kit/proto/kit/auth/v1;authv1b\x06proto3" + +var file_kit_auth_v1_auth_proto_goTypes = []any{ + (*descriptorpb.MethodOptions)(nil), // 0: google.protobuf.MethodOptions +} +var file_kit_auth_v1_auth_proto_depIdxs = []int32{ + 0, // 0: kit.auth.v1.public:extendee -> google.protobuf.MethodOptions + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_kit_auth_v1_auth_proto_init() } +func file_kit_auth_v1_auth_proto_init() { + if File_kit_auth_v1_auth_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_kit_auth_v1_auth_proto_rawDesc), len(file_kit_auth_v1_auth_proto_rawDesc)), + NumEnums: 0, + NumMessages: 0, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_kit_auth_v1_auth_proto_goTypes, + DependencyIndexes: file_kit_auth_v1_auth_proto_depIdxs, + ExtensionInfos: file_kit_auth_v1_auth_proto_extTypes, + }.Build() + File_kit_auth_v1_auth_proto = out.File + file_kit_auth_v1_auth_proto_goTypes = nil + file_kit_auth_v1_auth_proto_depIdxs = nil +} diff --git a/proto/kit/auth/v1/auth.proto b/proto/kit/auth/v1/auth.proto new file mode 100644 index 0000000..e296050 --- /dev/null +++ b/proto/kit/auth/v1/auth.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package kit.auth.v1; + +import "google/protobuf/descriptor.proto"; + +option go_package = "github.com/crypto-zero/go-kit/proto/kit/auth/v1;authv1"; + +extend google.protobuf.MethodOptions { + // public marks an RPC operation as not requiring authenticated user context. + bool public = 105001; +} diff --git a/proto/kit/redact/v1/redact.pb.go b/proto/kit/redact/v1/redact.pb.go index 02fcaef..3ec5c95 100644 --- a/proto/kit/redact/v1/redact.pb.go +++ b/proto/kit/redact/v1/redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.9 -// protoc v5.29.3 +// protoc-gen-go v1.36.11 +// protoc (unknown) // source: kit/redact/v1/redact.proto package redact From 66687fb2a9b207525fa84af7b1e8139ee31bf7b4 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 16:57:52 +0800 Subject: [PATCH 02/11] feat: add auth policy and sse helpers --- auth/kratos/policy.go | 113 +++++++ auth/kratos/policy_test.go | 65 ++++ errors/errors.go | 5 +- errors/protoc-gen-kit-errors/main.go | 4 +- sse/detach.go | 47 +++ sse/example_test.go | 59 ++++ sse/kratos/codec.go | 82 +++++ sse/kratos/go.mod | 24 ++ sse/kratos/go.sum | 48 +++ sse/kratos/handler.go | 148 +++++++++ sse/kratos/handler_test.go | 435 +++++++++++++++++++++++++++ sse/kratos/http_cors.go | 26 ++ sse/kratos/http_cors_test.go | 45 +++ sse/kratos/http_handler.go | 225 ++++++++++++++ sse/kratos/http_handler_test.go | 190 ++++++++++++ sse/kratos/server.go | 277 +++++++++++++++++ sse/kratos/server_options.go | 131 ++++++++ sse/kratos/server_test.go | 321 ++++++++++++++++++++ sse/kratos/transport.go | 113 +++++++ sse/snapshot_live.go | 72 +++++ sse/snapshot_live_test.go | 64 ++++ sse/sse.go | 278 +++++++++++++++++ sse/sse_test.go | 274 +++++++++++++++++ 23 files changed, 3040 insertions(+), 6 deletions(-) create mode 100644 auth/kratos/policy.go create mode 100644 auth/kratos/policy_test.go create mode 100644 sse/detach.go create mode 100644 sse/example_test.go create mode 100644 sse/kratos/codec.go create mode 100644 sse/kratos/go.mod create mode 100644 sse/kratos/go.sum create mode 100644 sse/kratos/handler.go create mode 100644 sse/kratos/handler_test.go create mode 100644 sse/kratos/http_cors.go create mode 100644 sse/kratos/http_cors_test.go create mode 100644 sse/kratos/http_handler.go create mode 100644 sse/kratos/http_handler_test.go create mode 100644 sse/kratos/server.go create mode 100644 sse/kratos/server_options.go create mode 100644 sse/kratos/server_test.go create mode 100644 sse/kratos/transport.go create mode 100644 sse/snapshot_live.go create mode 100644 sse/snapshot_live_test.go create mode 100644 sse/sse.go create mode 100644 sse/sse_test.go diff --git a/auth/kratos/policy.go b/auth/kratos/policy.go new file mode 100644 index 0000000..7758396 --- /dev/null +++ b/auth/kratos/policy.go @@ -0,0 +1,113 @@ +// Package kratos provides auth helpers for Kratos operation selectors. +package kratos + +import ( + authv1 "github.com/crypto-zero/go-kit/proto/kit/auth/v1" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +// OperationPolicy reports whether a Kratos operation should run through +// authentication middleware. +type OperationPolicy struct { + public map[string]struct{} +} + +// OperationPolicyOption configures an OperationPolicy. +type OperationPolicyOption func(*OperationPolicy) + +// NewOperationPolicy constructs an auth policy from proto descriptors and +// optional manually registered operations. +func NewOperationPolicy(opts ...OperationPolicyOption) *OperationPolicy { + p := &OperationPolicy{public: make(map[string]struct{})} + for _, opt := range opts { + opt(p) + } + return p +} + +// WithPublicOperations marks explicit Kratos operations as public. +func WithPublicOperations(ops ...string) OperationPolicyOption { + return func(p *OperationPolicy) { + for _, op := range ops { + p.public[op] = struct{}{} + } + } +} + +// WithPublicFromProtoFiles scans file descriptors for methods tagged with +// `(kit.auth.v1.public) = true`. +func WithPublicFromProtoFiles(files ...protoreflect.FileDescriptor) OperationPolicyOption { + return func(p *OperationPolicy) { + for _, fd := range files { + registerPublicFromFile(p, fd) + } + } +} + +// RequiresAuth reports whether operation should run through authentication. +func (p *OperationPolicy) RequiresAuth(operation string) bool { + _, ok := p.public[operation] + return !ok +} + +// OperationName returns the Kratos operation string for a proto method. +func OperationName(m protoreflect.MethodDescriptor) string { + return "/" + string(m.Parent().FullName()) + "/" + string(m.Name()) +} + +func registerPublicFromFile(p *OperationPolicy, fd protoreflect.FileDescriptor) { + services := fd.Services() + for i := 0; i < services.Len(); i++ { + methods := services.Get(i).Methods() + for j := 0; j < methods.Len(); j++ { + m := methods.Get(j) + if methodIsPublic(m) { + p.public[OperationName(m)] = struct{}{} + } + } + } +} + +func methodIsPublic(m protoreflect.MethodDescriptor) bool { + opts, ok := m.Options().(*descriptorpb.MethodOptions) + if !ok || opts == nil { + return false + } + v := proto.GetExtension(opts, authv1.E_Public) + switch public := v.(type) { + case bool: + return public || methodOptionsUnknownBool(opts, authv1.E_Public.TypeDescriptor().Number()) + case *bool: + return (public != nil && *public) || methodOptionsUnknownBool(opts, authv1.E_Public.TypeDescriptor().Number()) + default: + return methodOptionsUnknownBool(opts, authv1.E_Public.TypeDescriptor().Number()) + } +} + +func methodOptionsUnknownBool(opts *descriptorpb.MethodOptions, number protoreflect.FieldNumber) bool { + raw := opts.ProtoReflect().GetUnknown() + for len(raw) > 0 { + num, typ, n := protowire.ConsumeTag(raw) + if n < 0 { + return false + } + raw = raw[n:] + if num != protowire.Number(number) { + n = protowire.ConsumeFieldValue(num, typ, raw) + if n < 0 { + return false + } + raw = raw[n:] + continue + } + if typ != protowire.VarintType { + return false + } + v, n := protowire.ConsumeVarint(raw) + return n >= 0 && v != 0 + } + return false +} diff --git a/auth/kratos/policy_test.go b/auth/kratos/policy_test.go new file mode 100644 index 0000000..4d77cf2 --- /dev/null +++ b/auth/kratos/policy_test.go @@ -0,0 +1,65 @@ +package kratos + +import ( + "testing" + + authv1 "github.com/crypto-zero/go-kit/proto/kit/auth/v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/types/descriptorpb" +) + +func TestOperationPolicyRegistersPublicProtoMethods(t *testing.T) { + publicOpts := &descriptorpb.MethodOptions{} + proto.SetExtension(publicOpts, authv1.E_Public, true) + fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ + Syntax: proto.String("proto3"), + Name: proto.String("test/auth/v1/service.proto"), + Package: proto.String("test.auth.v1"), + Service: []*descriptorpb.ServiceDescriptorProto{{ + Name: proto.String("AuthService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Login"), + InputType: proto.String(".test.auth.v1.LoginRequest"), + OutputType: proto.String(".test.auth.v1.LoginResponse"), + Options: publicOpts, + }, + { + Name: proto.String("Profile"), + InputType: proto.String(".test.auth.v1.ProfileRequest"), + OutputType: proto.String(".test.auth.v1.ProfileResponse"), + }, + }, + }}, + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("LoginRequest")}, + {Name: proto.String("LoginResponse")}, + {Name: proto.String("ProfileRequest")}, + {Name: proto.String("ProfileResponse")}, + }, + }, nil) + if err != nil { + t.Fatalf("NewFile: %v", err) + } + + policy := NewOperationPolicy(WithPublicFromProtoFiles(fd)) + + if policy.RequiresAuth("/test.auth.v1.AuthService/Login") { + t.Fatal("Login should be public from proto option") + } + if !policy.RequiresAuth("/test.auth.v1.AuthService/Profile") { + t.Fatal("Profile should require auth when untagged") + } +} + +func TestOperationPolicyManualPublicOperations(t *testing.T) { + policy := NewOperationPolicy(WithPublicOperations("/healthz", "/readyz")) + + if policy.RequiresAuth("/healthz") { + t.Fatal("/healthz should be public") + } + if !policy.RequiresAuth("/v1/private") { + t.Fatal("/v1/private should require auth") + } +} diff --git a/errors/errors.go b/errors/errors.go index 84e5701..50c7e50 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -3,6 +3,7 @@ package errors import ( "errors" "fmt" + "maps" "google.golang.org/genproto/googleapis/rpc/errdetails" spb "google.golang.org/genproto/googleapis/rpc/status" @@ -160,9 +161,7 @@ func FromError(err error) *Error { if info, ok := first.(*errdetails.ErrorInfo); ok { ret.Info.Reason, ret.Info.Domain = info.Reason, info.Domain ret.Info.Metadata = make(map[string]string, len(info.Metadata)) - for k, v := range info.Metadata { - ret.Info.Metadata[k] = v - } + maps.Copy(ret.Info.Metadata, info.Metadata) ret.Details = ret.Details[1:] } } diff --git a/errors/protoc-gen-kit-errors/main.go b/errors/protoc-gen-kit-errors/main.go index 3415b87..5d14584 100644 --- a/errors/protoc-gen-kit-errors/main.go +++ b/errors/protoc-gen-kit-errors/main.go @@ -107,9 +107,7 @@ func (g *GenerateErrorDeclare) generateFile( sinkVarName := varName lowSinkName, lowParentName := strings.ToLower(sinkVarName), strings.ToLower(parentDescName) - if strings.HasPrefix(lowSinkName, lowParentName) { - lowSinkName = strings.TrimPrefix(lowSinkName, lowParentName) - } + lowSinkName = strings.TrimPrefix(lowSinkName, lowParentName) sinkVarName = "Err" + strcase.ToCamel(lowSinkName) gf.P(strings.TrimSpace(ev.Comments.Leading.String())) diff --git a/sse/detach.go b/sse/detach.go new file mode 100644 index 0000000..3afa600 --- /dev/null +++ b/sse/detach.go @@ -0,0 +1,47 @@ +package sse + +import ( + "context" + "errors" + "net/http" + "time" +) + +// DetachWriteTimeout overrides the underlying connection's write deadline +// and returns a request whose context is detached from any server-injected +// deadline. +// +// It addresses a common collision between SSE handlers and HTTP servers that +// install a global write timeout for unary requests (Kratos' http.Timeout +// option is the motivating case). Two things happen: +// +// 1. The connection's write deadline is reset to writeTimeout from now via +// http.ResponseController, replacing the server-wide value for this +// connection. +// 2. The request context is replaced with one that cancels only on genuine +// client disconnect; if the original context's Err is DeadlineExceeded +// (i.e. the server-wide deadline fired), the replacement is not +// cancelled. +// +// The replacement context is cancelled when the original context completes +// for any non-deadline reason, so downstream goroutines observing +// ctx.Done() still see client disconnects. +func DetachWriteTimeout(w http.ResponseWriter, r *http.Request, writeTimeout time.Duration) *http.Request { + rc := http.NewResponseController(w) + _ = rc.SetWriteDeadline(time.Now().Add(writeTimeout)) + + parent := r.Context() + ctx, cancel := context.WithCancel(context.WithoutCancel(parent)) + go func() { + <-parent.Done() + // Forward only genuine client disconnects. When the parent's Err + // is DeadlineExceeded the server-wide timer fired — that is the + // signal we are explicitly detaching from, so the replacement + // context must stay alive. The cancel func is then released when + // the handler returns and references drop. + if !errors.Is(parent.Err(), context.DeadlineExceeded) { + cancel() + } + }() + return r.WithContext(ctx) +} diff --git a/sse/example_test.go b/sse/example_test.go new file mode 100644 index 0000000..1d2885b --- /dev/null +++ b/sse/example_test.go @@ -0,0 +1,59 @@ +package sse_test + +import ( + "context" + "net/http" + "time" + + "github.com/crypto-zero/go-kit/sse" +) + +// ExampleStream_Pump shows the typical streaming handler for a Kratos HTTP +// server. The handler is mounted as a raw net/http handler via srv.Handle, +// bypassing protobuf transcoding so tokens reach the client as soon as the +// upstream goroutine emits them. +// +// Kratos installs a server-wide write deadline through http.Timeout; the +// DetachWriteTimeout call replaces it with a per-connection deadline long +// enough for the SSE stream and detaches the request context from the +// server-injected DeadlineExceeded signal. +func ExampleStream_Pump() { + const streamTimeout = 300 * time.Second + + handler := func(w http.ResponseWriter, r *http.Request) { + r = sse.DetachWriteTimeout(w, r, streamTimeout) + + // produceTokens stands in for a biz-layer call that returns the + // token channel and an error channel. + chunks, errs := produceTokens(r.Context()) + + s := sse.NewStream(w) + _ = s.Pump(r.Context(), chunks, errs) + } + _ = handler +} + +// ExampleStream_WriteJSON shows the unary-over-SSE pattern: the handler +// computes a single response, writes it as one data frame, and terminates +// with [DONE]. Useful when the response is structured JSON but the client +// already expects an SSE-formatted endpoint. +func ExampleStream_WriteJSON() { + const callTimeout = 60 * time.Second + + handler := func(w http.ResponseWriter, r *http.Request) { + r = sse.DetachWriteTimeout(w, r, callTimeout) + + result, err := compute(r.Context()) + s := sse.NewStream(w) + if err != nil { + _ = s.Error(err.Error()) + return + } + _ = s.WriteJSON(result) + _ = s.Done() + } + _ = handler +} + +func produceTokens(context.Context) (<-chan string, <-chan error) { return nil, nil } +func compute(context.Context) (any, error) { return nil, nil } diff --git a/sse/kratos/codec.go b/sse/kratos/codec.go new file mode 100644 index 0000000..9f03138 --- /dev/null +++ b/sse/kratos/codec.go @@ -0,0 +1,82 @@ +package kratos + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-kratos/kratos/v2/encoding" + kerrors "github.com/go-kratos/kratos/v2/errors" +) + +// DefaultRequestDecoder is the default body decoder: it picks a codec by +// the request's Content-Type and unmarshals the body into v. An empty +// body is a success. Unknown Content-Types fall back to JSON. +func DefaultRequestDecoder(r *http.Request, v any) error { + if r.Body == nil || r.Body == http.NoBody { + return nil + } + body, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("sse: read body: %w", err) + } + if len(body) == 0 { + return nil + } + c := codecForContentType(r.Header.Get("Content-Type")) + if c == nil { + c = encoding.GetCodec("json") + } + if c == nil { + return errors.New("sse: no codec available") + } + if err := c.Unmarshal(body, v); err != nil { + return fmt.Errorf("sse: decode %s: %w", c.Name(), err) + } + return nil +} + +// DefaultErrorEncoder writes err as an HTTP error response, honoring +// the status code embedded in a Kratos *errors.Error if present. It is +// meant for the pre-stream phase — once SSE bytes have been flushed, +// errors should be reported via an SSE "error" event instead. +// +// The response body is the error's Message (or err.Error() when the +// underlying error is not a Kratos error). Content-Type is plain text; +// callers wanting JSON should install a custom EncodeErrorFunc via the +// ErrorEncoder option. +func DefaultErrorEncoder(w http.ResponseWriter, _ *http.Request, err error) { + se := kerrors.FromError(err) + code := int(se.Code) + if code <= 0 { + code = http.StatusInternalServerError + } + msg := se.Message + if msg == "" { + msg = err.Error() + } + http.Error(w, msg, code) +} + +// codecForContentType picks a Kratos codec by Content-Type, returning +// nil when the type is unrecognized. Parameters (";charset=utf-8") are +// stripped and a vendor prefix on the subtype is removed. +func codecForContentType(ct string) encoding.Codec { + if ct == "" { + return nil + } + if i := strings.IndexByte(ct, ';'); i >= 0 { + ct = ct[:i] + } + ct = strings.TrimSpace(ct) + subtype := ct + if _, after, ok := strings.Cut(ct, "/"); ok { + subtype = after + } + if i := strings.LastIndexByte(subtype, '.'); i >= 0 { + subtype = subtype[i+1:] + } + return encoding.GetCodec(subtype) +} diff --git a/sse/kratos/go.mod b/sse/kratos/go.mod new file mode 100644 index 0000000..a2ef001 --- /dev/null +++ b/sse/kratos/go.mod @@ -0,0 +1,24 @@ +module github.com/crypto-zero/go-kit/sse/kratos + +go 1.25.5 + +require ( + github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac + github.com/go-kratos/kratos/v2 v2.9.2 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/go-kratos/aegis v0.2.0 // indirect + github.com/go-playground/form/v4 v4.2.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/mux v1.8.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + golang.org/x/sys v0.39.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 // indirect + google.golang.org/grpc v1.77.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/crypto-zero/go-kit => ../.. diff --git a/sse/kratos/go.sum b/sse/kratos/go.sum new file mode 100644 index 0000000..bd79f80 --- /dev/null +++ b/sse/kratos/go.sum @@ -0,0 +1,48 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kratos/aegis v0.2.0 h1:dObzCDWn3XVjUkgxyBp6ZeWtx/do0DPZ7LY3yNSJLUQ= +github.com/go-kratos/aegis v0.2.0/go.mod h1:v0R2m73WgEEYB3XYu6aE2WcMwsZkJ/Rzuf5eVccm7bI= +github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo= +github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic= +github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 h1:2I6GHUeJ/4shcDpoUlLs/2WPnhg7yJwvXtqcMJt9liA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= +google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sse/kratos/handler.go b/sse/kratos/handler.go new file mode 100644 index 0000000..2e0af66 --- /dev/null +++ b/sse/kratos/handler.go @@ -0,0 +1,148 @@ +package kratos + +import ( + "context" + "net/http" + + "github.com/go-kratos/kratos/v2/middleware" + + "github.com/crypto-zero/go-kit/sse" +) + +// StreamHandler builds an http.HandlerFunc that decodes a typed request, +// runs the Kratos middleware chain over (ctx, *Req), then invokes do +// with a live *sse.Stream. +// +// Lifecycle (in order): +// +// 1. Decode the request body into *Req via srv.Decode. Decode errors +// are reported via srv.EncodeError — a standard HTTP 4xx/5xx with +// no SSE bytes written. +// 2. Run the server-wide middleware chain (Middleware option) followed +// by the per-handler extras, with req exposed to middleware as the +// `req` argument. Errors here are likewise reported via +// srv.EncodeError before any streaming starts. +// 3. Create an *sse.Stream and call do. Errors returned by do are +// emitted as an SSE "error" event; do is responsible for any final +// Done frame on success. +// +// This is the right helper for auth/JWT verification, schema validation +// (protovalidate), per-request quota checks and similar pre-handler +// concerns. Tracing, recovery and metrics that must observe the full +// stream lifetime should be installed as Filters instead. +func StreamHandler[Req any]( + srv *Server, + do func(ctx context.Context, req *Req, s *sse.Stream) error, + mws ...middleware.Middleware, +) http.HandlerFunc { + chain := srv.chainFor(mws) + return func(w http.ResponseWriter, r *http.Request) { + req, ok := preStream[Req](srv, chain, w, r) + if !ok { + return + } + s, end := srv.beginStream(r.Context(), w) + defer end() + if err := do(r.Context(), req, s); err != nil { + _ = s.Error(err.Error()) + } + } +} + +// JSONHandler is the unary sibling of StreamHandler: do produces a +// single response value that is marshaled (via srv.Codec) and emitted +// as one SSE data frame followed by a Done terminator. +// +// Errors returned by do are written as an SSE "error" event — matching +// the convention that clients of an SSE endpoint always parse SSE, +// never raw HTTP errors. Decode and middleware errors still go through +// srv.EncodeError (no SSE bytes written yet). +func JSONHandler[Req any, Resp any]( + srv *Server, + do func(ctx context.Context, req *Req) (*Resp, error), + mws ...middleware.Middleware, +) http.HandlerFunc { + chain := srv.chainFor(mws) + return func(w http.ResponseWriter, r *http.Request) { + req, ok := preStream[Req](srv, chain, w, r) + if !ok { + return + } + result, err := do(r.Context(), req) + s, end := srv.beginStream(r.Context(), w) + defer end() + if err != nil { + _ = s.Error(err.Error()) + return + } + data, mErr := srv.Codec().Marshal(result) + if mErr != nil { + _ = s.Error(mErr.Error()) + return + } + _ = s.Write(string(data)) + _ = s.Done() + } +} + +// preStream runs the request through Decode and the middleware chain. +// On success it returns (req, true). On failure it has already written +// an HTTP error response and returns (_, false). +// +// Middleware sees the decoded req via the `req` argument of +// middleware.Handler. The inner handler is intentionally a no-op: +// streaming runs outside the chain so middleware errors translate to +// real HTTP statuses while no SSE bytes have yet hit the wire. +func preStream[Req any]( + srv *Server, chain middleware.Middleware, + w http.ResponseWriter, r *http.Request, +) (*Req, bool) { + req := new(Req) + if err := srv.Decode(r, req); err != nil { + srv.EncodeError(w, r, err) + return nil, false + } + h := chain(func(context.Context, any) (any, error) { return nil, nil }) + if _, err := h(r.Context(), req); err != nil { + srv.EncodeError(w, r, err) + return nil, false + } + return req, true +} + +// beginStream constructs an *sse.Stream, starts the configured heartbeat +// (if any), and bumps the active-stream counter. The returned end +// function tears these down in the inverse order — heartbeat first +// (must stop before the response writer is recycled), then the counter. +// Callers should defer end immediately after this call. +func (s *Server) beginStream(ctx context.Context, w http.ResponseWriter) (*sse.Stream, func()) { + st := sse.NewStream(w) + stopBeat := s.startHeartbeat(ctx, st) + s.active.Add(1) + return st, func() { + stopBeat() + s.active.Add(-1) + } +} + +// startHeartbeat fires a periodic comment frame on st when the server +// has Heartbeat enabled. Returns a stop function (a no-op when +// heartbeat is disabled). +func (s *Server) startHeartbeat(ctx context.Context, st *sse.Stream) func() { + if s.heartbeat <= 0 { + return func() {} + } + return st.Heartbeat(ctx, s.heartbeat) +} + +// chainFor composes the middleware chain for one handler: server-wide +// middlewares (outermost) followed by per-handler extras. The returned +// chain does not share backing storage with srv.middlewares, so later +// additions to the server's list cannot retroactively affect handlers +// that have already been built. +func (s *Server) chainFor(extras []middleware.Middleware) middleware.Middleware { + all := make([]middleware.Middleware, 0, len(s.middlewares)+len(extras)) + all = append(all, s.middlewares...) + all = append(all, extras...) + return middleware.Chain(all...) +} diff --git a/sse/kratos/handler_test.go b/sse/kratos/handler_test.go new file mode 100644 index 0000000..bfc4ffc --- /dev/null +++ b/sse/kratos/handler_test.go @@ -0,0 +1,435 @@ +package kratos_test + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + kerrors "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + + "github.com/crypto-zero/go-kit/sse" + ksse "github.com/crypto-zero/go-kit/sse/kratos" +) + +type chatRequest struct { + Prompt string `json:"prompt"` +} + +type profileResponse struct { + Name string `json:"name"` +} + +// authMW is a tiny inline auth middleware that rejects requests missing +// the "X-Token" header by returning a Kratos Unauthorized error. +func authMW(token string) middleware.Middleware { + return func(next middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + if got := tokenFromCtx(ctx); got != token { + return nil, kerrors.Unauthorized("AUTH", "bad token") + } + return next(ctx, req) + } + } +} + +type tokenKey struct{} + +func withToken(ctx context.Context, tok string) context.Context { + return context.WithValue(ctx, tokenKey{}, tok) +} + +func tokenFromCtx(ctx context.Context) string { + v, _ := ctx.Value(tokenKey{}).(string) + return v +} + +// tokenFilter copies the X-Token header into ctx so authMW can read it. +// Demonstrates the Filter + Middleware split: Filter touches HTTP-level +// concerns (headers), middleware sees the typed req. +func tokenFilter(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(withToken(r.Context(), r.Header.Get("X-Token"))) + next.ServeHTTP(w, r) + }) +} + +func TestStreamHandler_AuthSucceeds(t *testing.T) { + srv, addr := newServerOnLoopback(t, + ksse.Filter(tokenFilter), + ksse.Middleware(authMW("good")), + ) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(_ context.Context, req *chatRequest, s *sse.Stream) error { + _ = s.Write("got:" + req.Prompt) + return s.Done() + }, + )) + + stop := startServer(t, srv) + defer stop() + + req, _ := http.NewRequest("POST", "http://"+addr+"/v1/chat", + strings.NewReader(`{"prompt":"hi"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Token", "good") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + t.Fatalf("status = %d", resp.StatusCode) + } + body := readAll(t, resp.Body) + if !strings.Contains(body, "data: got:hi\n\n") { + t.Errorf("body missing chunk: %q", body) + } + if !strings.Contains(body, "data: [DONE]\n\n") { + t.Errorf("body missing done: %q", body) + } +} + +func TestStreamHandler_AuthRejects(t *testing.T) { + srv, addr := newServerOnLoopback(t, + ksse.Filter(tokenFilter), + ksse.Middleware(authMW("good")), + ) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(_ context.Context, _ *chatRequest, s *sse.Stream) error { + t.Error("handler should not run when auth fails") + return s.Done() + }, + )) + + stop := startServer(t, srv) + defer stop() + + req, _ := http.NewRequest("POST", "http://"+addr+"/v1/chat", + strings.NewReader(`{"prompt":"hi"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Token", "bad") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); strings.HasPrefix(got, "text/event-stream") { + t.Errorf("unexpected SSE response on auth failure: %q", got) + } +} + +func TestStreamHandler_DoErrorBecomesSSEEvent(t *testing.T) { + srv, addr := newServerOnLoopback(t) + sentinel := errors.New("biz blew up") + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(_ context.Context, _ *chatRequest, _ *sse.Stream) error { + return sentinel + }, + )) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200 (SSE error rides on 200)", resp.StatusCode) + } + body := readAll(t, resp.Body) + if !strings.Contains(body, "event: error") { + t.Errorf("body missing error event: %q", body) + } + if !strings.Contains(body, "biz blew up") { + t.Errorf("body missing error message: %q", body) + } +} + +func TestStreamHandler_DecodeError(t *testing.T) { + srv, addr := newServerOnLoopback(t) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(context.Context, *chatRequest, *sse.Stream) error { + t.Error("handler should not run on decode error") + return nil + }, + )) + + stop := startServer(t, srv) + defer stop() + + // Malformed JSON. + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{not json`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode == 200 { + t.Errorf("status = 200, want 4xx/5xx for decode failure") + } +} + +func TestJSONHandler_Roundtrip(t *testing.T) { + srv, addr := newServerOnLoopback(t) + srv.HandleFunc("POST /v1/profile", ksse.JSONHandler(srv, + func(_ context.Context, _ *chatRequest) (*profileResponse, error) { + return &profileResponse{Name: "karma"}, nil + }, + )) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/v1/profile", "application/json", + strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + body := readAll(t, resp.Body) + if !strings.Contains(body, `data: {"name":"karma"}`) { + t.Errorf("body missing payload: %q", body) + } + if !strings.Contains(body, "data: [DONE]\n\n") { + t.Errorf("body missing done: %q", body) + } +} + +func TestJSONHandler_DoErrorBecomesSSEEvent(t *testing.T) { + srv, addr := newServerOnLoopback(t) + srv.HandleFunc("POST /v1/profile", ksse.JSONHandler(srv, + func(_ context.Context, _ *chatRequest) (*profileResponse, error) { + return nil, kerrors.NotFound("PROFILE", "no profile") + }, + )) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/v1/profile", "application/json", + strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + body := readAll(t, resp.Body) + if !strings.Contains(body, "event: error") { + t.Errorf("body missing error event: %q", body) + } +} + +func TestStreamHandler_PerHandlerMiddlewareAppends(t *testing.T) { + calls := make(chan string, 4) + mark := func(name string) middleware.Middleware { + return func(next middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + calls <- name + return next(ctx, req) + } + } + } + srv, addr := newServerOnLoopback(t, ksse.Middleware(mark("server"))) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(_ context.Context, _ *chatRequest, s *sse.Stream) error { + return s.Done() + }, + mark("handler"), + )) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + _ = resp.Body.Close() + close(calls) + var order []string + for s := range calls { + order = append(order, s) + } + want := []string{"server", "handler"} + if strings.Join(order, ",") != strings.Join(want, ",") { + t.Errorf("middleware order = %v, want %v", order, want) + } +} + +func TestStreamHandler_Heartbeat(t *testing.T) { + srv, addr := newServerOnLoopback(t, ksse.Heartbeat(10*time.Millisecond)) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(ctx context.Context, _ *chatRequest, s *sse.Stream) error { + _ = s.Write("first") + // Hold the stream open long enough for several heartbeats. + select { + case <-ctx.Done(): + case <-time.After(80 * time.Millisecond): + } + return s.Done() + }, + )) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + body := readAll(t, resp.Body) + // Expect at least one comment frame between "first" and "[DONE]". + if !strings.Contains(body, ":\n\n") { + t.Errorf("body missing heartbeat comment frames: %q", body) + } +} + +func TestServer_ActiveStreams(t *testing.T) { + srv, addr := newServerOnLoopback(t) + release := make(chan struct{}) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(_ context.Context, _ *chatRequest, s *sse.Stream) error { + _ = s.Write("hold") + <-release + return s.Done() + }, + )) + + stop := startServer(t, srv) + + if got := srv.ActiveStreams(); got != 0 { + t.Errorf("ActiveStreams before request = %d, want 0", got) + } + + done := make(chan struct{}) + go func() { + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{}`)) + if err == nil { + // Drain to avoid the response goroutine wedging on close. + _, _ = http.NoBody.Read(make([]byte, 1)) + _ = resp.Body.Close() + } + close(done) + }() + + // Spin until the handler is in flight. + deadline := time.Now().Add(2 * time.Second) + for srv.ActiveStreams() == 0 && time.Now().Before(deadline) { + time.Sleep(2 * time.Millisecond) + } + if got := srv.ActiveStreams(); got != 1 { + t.Errorf("ActiveStreams during request = %d, want 1", got) + } + close(release) + <-done + stop() + + if got := srv.ActiveStreams(); got != 0 { + t.Errorf("ActiveStreams after request = %d, want 0", got) + } +} + +func TestServer_GracefulShutdownUnblocksHandler(t *testing.T) { + srv, addr := newServerOnLoopback(t) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(ctx context.Context, _ *chatRequest, s *sse.Stream) error { + _ = s.Write("hi") + // Pump-style: respect ctx so shutdown can drain promptly. + <-ctx.Done() + return ctx.Err() + }, + )) + + ctx, cancel := context.WithCancel(context.Background()) + startErr := make(chan error, 1) + go func() { startErr <- srv.Start(ctx) }() + + // Client holds the connection open: it drains the body to keep the + // server-side r.Context() alive until shutdown explicitly cancels it. + // (If we closed resp.Body eagerly the server would see a client + // disconnect instead and we wouldn't exercise shutdown propagation.) + clientDone := make(chan struct{}) + go func() { + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{}`)) + if err != nil { + close(clientDone) + return + } + _, _ = io.Copy(io.Discard, resp.Body) // returns when server closes + _ = resp.Body.Close() + close(clientDone) + }() + + // Wait for the stream to register as active. + deadline := time.Now().Add(2 * time.Second) + for srv.ActiveStreams() == 0 && time.Now().Before(deadline) { + time.Sleep(2 * time.Millisecond) + } + if got := srv.ActiveStreams(); got != 1 { + t.Fatalf("ActiveStreams before shutdown = %d, want 1", got) + } + + shutdownStart := time.Now() + shutdownCtx, c := context.WithTimeout(context.Background(), 2*time.Second) + defer c() + if err := srv.Stop(shutdownCtx); err != nil { + t.Errorf("Stop: %v", err) + } + elapsed := time.Since(shutdownStart) + if elapsed > 500*time.Millisecond { + t.Errorf("Stop took %v, handler did not unblock on shutdown ctx", elapsed) + } + <-clientDone + if err := <-startErr; err != nil { + t.Errorf("Start: %v", err) + } + cancel() +} + +func TestErrorEncoder_HonorsKratosCode(t *testing.T) { + srv, addr := newServerOnLoopback(t, + ksse.Middleware(func(middleware.Handler) middleware.Handler { + return func(context.Context, any) (any, error) { + return nil, kerrors.BadRequest("VALIDATION", "field x missing") + } + }), + ) + srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, + func(context.Context, *chatRequest, *sse.Stream) error { return nil }, + )) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", + strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + body := readAll(t, resp.Body) + if !strings.Contains(body, "field x missing") { + t.Errorf("body missing kratos message: %q", body) + } +} diff --git a/sse/kratos/http_cors.go b/sse/kratos/http_cors.go new file mode 100644 index 0000000..07db144 --- /dev/null +++ b/sse/kratos/http_cors.go @@ -0,0 +1,26 @@ +package kratos + +import ( + "net/http" + + khttp "github.com/go-kratos/kratos/v2/transport/http" +) + +// HTTPPermissiveCORS returns a route filter suitable for browser EventSource +// endpoints that do not use credentials. +func HTTPPermissiveCORS() khttp.FilterFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("Access-Control-Allow-Origin", "*") + h.Set("Access-Control-Allow-Methods", "GET, OPTIONS") + h.Set("Access-Control-Allow-Headers", "Cache-Control, Last-Event-ID") + h.Set("Access-Control-Expose-Headers", "Content-Type") + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/sse/kratos/http_cors_test.go b/sse/kratos/http_cors_test.go new file mode 100644 index 0000000..970bcc0 --- /dev/null +++ b/sse/kratos/http_cors_test.go @@ -0,0 +1,45 @@ +package kratos + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestHTTPPermissiveCORSHandlesPreflight(t *testing.T) { + nextCalled := false + handler := HTTPPermissiveCORS()(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + nextCalled = true + })) + + req := httptest.NewRequest(http.MethodOptions, "/events", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + if nextCalled { + t.Fatal("next handler should not be called for preflight") + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Fatalf("allow origin = %q, want *", got) + } +} + +func TestHTTPPermissiveCORSPassesThroughGET(t *testing.T) { + handler := HTTPPermissiveCORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + req := httptest.NewRequest(http.MethodGet, "/events", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusAccepted) + } + if got := rec.Header().Get("Access-Control-Allow-Headers"); got != "Cache-Control, Last-Event-ID" { + t.Fatalf("allow headers = %q", got) + } +} diff --git a/sse/kratos/http_handler.go b/sse/kratos/http_handler.go new file mode 100644 index 0000000..7212944 --- /dev/null +++ b/sse/kratos/http_handler.go @@ -0,0 +1,225 @@ +package kratos + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "sync" + "time" + + authkratos "github.com/crypto-zero/go-kit/auth/kratos" + khttp "github.com/go-kratos/kratos/v2/transport/http" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/crypto-zero/go-kit/sse" +) + +// HTTPStreamOption configures a Kratos HTTP-attached SSE stream handler. +type HTTPStreamOption func(*httpStreamConfig) + +type httpStreamConfig struct { + heartbeat time.Duration + filters []khttp.FilterFunc +} + +// HTTPHeartbeat enables automatic SSE comment frames at interval for one +// Kratos HTTP stream handler. Set interval <= 0 to disable it. +func HTTPHeartbeat(interval time.Duration) HTTPStreamOption { + return func(c *httpStreamConfig) { c.heartbeat = interval } +} + +// HTTPFilter appends Kratos HTTP route filters around one stream endpoint. +func HTTPFilter(filters ...khttp.FilterFunc) HTTPStreamOption { + return func(c *httpStreamConfig) { c.filters = append(c.filters, filters...) } +} + +// RegisterHTTPStream mounts an SSE endpoint on an existing Kratos HTTP +// server. Unlike Server, it does not own a listener: lifecycle, filters, +// route walking, operation selection and service middleware all come from +// the supplied HTTP server. +// +// For GET and HEAD requests, proto.Message inputs are decoded from the query +// string by JSON name (with snake_case fallback). Non-proto GET/HEAD inputs +// use Kratos' configured query decoder. Other methods use Kratos' configured +// body decoder. +func RegisterHTTPStream[Req any]( + srv *khttp.Server, + method string, + path string, + operation string, + do func(ctx context.Context, req *Req, st *sse.Stream) error, + opts ...HTTPStreamOption, +) { + cfg := httpStreamConfig{} + for _, opt := range opts { + opt(&cfg) + } + srv.Route("/").Handle(method, path, func(ctx khttp.Context) error { + req := new(Req) + if err := bindHTTPStreamRequest(ctx, req); err != nil { + return err + } + if operation != "" { + khttp.SetOperation(ctx, operation) + } + streamCtx, stopStreamCtx := detachHTTPTimeout(ctx) + defer stopStreamCtx() + h := ctx.Middleware(func(mctx context.Context, raw any) (any, error) { + st := sse.NewStream(ctx.Response()) + stopBeat := startHTTPHeartbeat(mctx, st, cfg.heartbeat) + defer stopBeat() + if err := do(mctx, raw.(*Req), st); err != nil { + _ = st.Error(err.Error()) + } + return nil, nil + }) + _, err := h(streamCtx, req) + return err + }, cfg.filters...) +} + +// RegisterHTTPStreamMethod mounts an SSE endpoint for a proto method +// descriptor. The Kratos operation is derived from the method name +// (`/package.Service/Method`) so auth selectors, logging and tracing use the +// same operation identity as generated Kratos HTTP handlers. +func RegisterHTTPStreamMethod[Req any]( + srv *khttp.Server, + method protoreflect.MethodDescriptor, + httpMethod string, + path string, + do func(ctx context.Context, req *Req, st *sse.Stream) error, + opts ...HTTPStreamOption, +) { + RegisterHTTPStream(srv, httpMethod, path, authkratos.OperationName(method), do, opts...) +} + +func bindHTTPStreamRequest(ctx khttp.Context, target any) error { + switch ctx.Request().Method { + case http.MethodGet, http.MethodHead: + if msg, ok := target.(proto.Message); ok { + return decodeProtoQuery(ctx.Request(), msg) + } + return ctx.BindQuery(target) + default: + return ctx.Bind(target) + } +} + +func startHTTPHeartbeat(ctx context.Context, st *sse.Stream, interval time.Duration) func() { + if interval <= 0 { + return func() {} + } + return st.Heartbeat(ctx, interval) +} + +func detachHTTPTimeout(parent context.Context) (context.Context, func()) { + ctx, cancel := context.WithCancel(context.WithoutCancel(parent)) + done := make(chan struct{}) + go func() { + select { + case <-parent.Done(): + if !errors.Is(parent.Err(), context.DeadlineExceeded) { + cancel() + } + case <-done: + } + }() + var once sync.Once + return ctx, func() { + once.Do(func() { + close(done) + cancel() + }) + } +} + +func decodeProtoQuery(r *http.Request, msg proto.Message) error { + q := r.URL.Query() + if len(q) == 0 { + return nil + } + refl := msg.ProtoReflect() + fields := refl.Descriptor().Fields() + for i := 0; i < fields.Len(); i++ { + fd := fields.Get(i) + raw := q.Get(fd.JSONName()) + if raw == "" { + raw = q.Get(string(fd.Name())) + } + if raw == "" { + continue + } + if err := setProtoFieldFromString(refl, fd, raw); err != nil { + return fmt.Errorf("query %s: %w", fd.JSONName(), err) + } + } + return nil +} + +func setProtoFieldFromString(msg protoreflect.Message, fd protoreflect.FieldDescriptor, raw string) error { + if fd.IsList() || fd.IsMap() { + return fmt.Errorf("repeated/map fields not supported in query strings") + } + switch fd.Kind() { + case protoreflect.DoubleKind: + v, err := strconv.ParseFloat(raw, 64) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfFloat64(v)) + case protoreflect.FloatKind: + v, err := strconv.ParseFloat(raw, 32) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfFloat32(float32(v))) + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + v, err := strconv.ParseInt(raw, 10, 32) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfInt32(int32(v))) + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + v, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfUint32(uint32(v))) + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + v, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfInt64(v)) + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + v, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfUint64(v)) + case protoreflect.BoolKind: + v, err := strconv.ParseBool(raw) + if err != nil { + return err + } + msg.Set(fd, protoreflect.ValueOfBool(v)) + case protoreflect.StringKind: + msg.Set(fd, protoreflect.ValueOfString(raw)) + case protoreflect.EnumKind: + if i, err := strconv.ParseInt(raw, 10, 32); err == nil { + msg.Set(fd, protoreflect.ValueOfEnum(protoreflect.EnumNumber(i))) + return nil + } + if ev := fd.Enum().Values().ByName(protoreflect.Name(raw)); ev != nil { + msg.Set(fd, protoreflect.ValueOfEnum(ev.Number())) + return nil + } + return fmt.Errorf("unknown enum value %q for %s", raw, fd.Enum().FullName()) + default: + return fmt.Errorf("unsupported field kind %s for query decoding", fd.Kind()) + } + return nil +} diff --git a/sse/kratos/http_handler_test.go b/sse/kratos/http_handler_test.go new file mode 100644 index 0000000..8bea158 --- /dev/null +++ b/sse/kratos/http_handler_test.go @@ -0,0 +1,190 @@ +package kratos_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-kratos/kratos/v2/middleware" + ktransport "github.com/go-kratos/kratos/v2/transport" + khttp "github.com/go-kratos/kratos/v2/transport/http" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/crypto-zero/go-kit/sse" + ksse "github.com/crypto-zero/go-kit/sse/kratos" +) + +func TestHTTPStreamHandler_BindsProtoQueryAndStreamsOnKratosHTTP(t *testing.T) { + srv := khttp.NewServer(khttp.Timeout(0)) + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/duration", "/test.Duration/Watch", + func(_ context.Context, req *durationpb.Duration, st *sse.Stream) error { + if req.GetSeconds() != 12 || req.GetNanos() != 34 { + t.Fatalf("request = %ds/%dns, want 12s/34ns", req.GetSeconds(), req.GetNanos()) + } + _ = st.Write("bound") + return st.Done() + }, + ) + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/duration?seconds=12&nanos=34") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if got, want := resp.Header.Get("Content-Type"), "text/event-stream"; got != want { + t.Errorf("Content-Type = %q, want %q", got, want) + } + body := readAll(t, resp.Body) + if !strings.Contains(body, "data: bound\n\n") { + t.Errorf("body missing payload: %q", body) + } + if !strings.Contains(body, "data: [DONE]\n\n") { + t.Errorf("body missing done marker: %q", body) + } +} + +func TestHTTPStreamHandler_SetsOperationBeforeMiddleware(t *testing.T) { + const operation = "/test.Live/Watch" + seen := make(chan string, 1) + srv := khttp.NewServer( + khttp.Timeout(0), + khttp.Middleware(func(next middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + tr, ok := ktransport.FromServerContext(ctx) + if !ok { + t.Fatal("missing transport") + } + seen <- tr.Operation() + return next(ctx, req) + } + }), + ) + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/live", operation, + func(_ context.Context, _ *durationpb.Duration, st *sse.Stream) error { + return st.Done() + }, + ) + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/live") + if err != nil { + t.Fatalf("GET: %v", err) + } + _ = resp.Body.Close() + + select { + case got := <-seen: + if got != operation { + t.Errorf("operation = %q, want %q", got, operation) + } + case <-time.After(time.Second): + t.Fatal("middleware did not run") + } +} + +func TestHTTPStreamHandler_DetachesKratosHTTPTimeout(t *testing.T) { + srv := khttp.NewServer(khttp.Timeout(20 * time.Millisecond)) + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/slow", "/test.Slow/Watch", + func(ctx context.Context, _ *durationpb.Duration, st *sse.Stream) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(80 * time.Millisecond): + } + _ = st.Write("still-open") + return st.Done() + }, + ) + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/slow") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer func() { _ = resp.Body.Close() }() + body := readAll(t, resp.Body) + if !strings.Contains(body, "data: still-open\n\n") { + t.Errorf("body missing delayed payload after HTTP timeout: %q", body) + } + if !strings.Contains(body, "data: [DONE]\n\n") { + t.Errorf("body missing done marker: %q", body) + } +} + +func TestHTTPStreamHandler_MethodDescriptorSetsOperation(t *testing.T) { + const operation = "/test.live.v1.LiveService/Watch" + seen := make(chan string, 1) + srv := khttp.NewServer( + khttp.Timeout(0), + khttp.Middleware(func(next middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + tr, ok := ktransport.FromServerContext(ctx) + if !ok { + t.Fatal("missing transport") + } + seen <- tr.Operation() + return next(ctx, req) + } + }), + ) + method := testMethodDescriptor(t) + ksse.RegisterHTTPStreamMethod(srv, method, http.MethodGet, "/v1/method", + func(_ context.Context, _ *durationpb.Duration, st *sse.Stream) error { + return st.Done() + }, + ) + + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/method") + if err != nil { + t.Fatalf("GET: %v", err) + } + _ = resp.Body.Close() + + select { + case got := <-seen: + if got != operation { + t.Errorf("operation = %q, want %q", got, operation) + } + case <-time.After(time.Second): + t.Fatal("middleware did not run") + } +} + +func testMethodDescriptor(t *testing.T) protoreflect.MethodDescriptor { + t.Helper() + fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ + Syntax: proto.String("proto3"), + Name: proto.String("test/live/v1/live.proto"), + Package: proto.String("test.live.v1"), + Service: []*descriptorpb.ServiceDescriptorProto{{ + Name: proto.String("LiveService"), + Method: []*descriptorpb.MethodDescriptorProto{{ + Name: proto.String("Watch"), + InputType: proto.String(".test.live.v1.WatchRequest"), + OutputType: proto.String(".test.live.v1.WatchResponse"), + }}, + }}, + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("WatchRequest")}, + {Name: proto.String("WatchResponse")}, + }, + }, nil) + if err != nil { + t.Fatalf("NewFile: %v", err) + } + return fd.Services().ByName("LiveService").Methods().ByName("Watch") +} diff --git a/sse/kratos/server.go b/sse/kratos/server.go new file mode 100644 index 0000000..8b16bc2 --- /dev/null +++ b/sse/kratos/server.go @@ -0,0 +1,277 @@ +// Package kratos provides a Server-Sent Events transport server that plugs +// into a Kratos application as a first-class transport.Server. +// +// The server owns its own net.Listener and http.ServeMux. Requests it +// serves carry a transport.Transporter with Kind="sse" in their context, +// so Kratos middleware can inspect and act on SSE traffic the same way it +// does for HTTP and gRPC. +// +// Streaming itself is handled by the parent github.com/crypto-zero/go-kit/sse +// package: handlers construct an *sse.Stream from the ResponseWriter and +// write events through its API. This package contributes the Kratos plumbing +// (lifecycle, endpoint registration, codec selection, filter chain) around +// that core. +package kratos + +import ( + "context" + "crypto/tls" + "errors" + "log/slog" + "net" + "net/http" + "net/url" + "sync/atomic" + "time" + + "github.com/go-kratos/kratos/v2/encoding" + // Register the JSON codec by default; users can pull additional + // codecs (proto, yaml, xml) by importing them at their main package. + _ "github.com/go-kratos/kratos/v2/encoding/json" + "github.com/go-kratos/kratos/v2/middleware" + ktransport "github.com/go-kratos/kratos/v2/transport" +) + +// KindSSE identifies this transport in the Kratos transport registry. +const KindSSE ktransport.Kind = "sse" + +// DefaultReadHeaderTimeout is applied when no ReadHeaderTimeout option is +// given. It protects the server from Slowloris-style attacks (clients that +// trickle request headers to hold connections open) without affecting the +// streaming response — write deadlines are managed separately. +const DefaultReadHeaderTimeout = 10 * time.Second + +var ( + _ ktransport.Server = (*Server)(nil) + _ ktransport.Endpointer = (*Server)(nil) + _ http.Handler = (*Server)(nil) +) + +// FilterFunc wraps an http.Handler. Filters compose right-to-left around +// the request: the first filter is the outermost wrapper. +type FilterFunc func(http.Handler) http.Handler + +// FilterChain composes filters into a single wrapper. +func FilterChain(filters ...FilterFunc) FilterFunc { + return func(next http.Handler) http.Handler { + for i := len(filters) - 1; i >= 0; i-- { + next = filters[i](next) + } + return next + } +} + +// DecodeRequestFunc decodes an inbound request body into v. +type DecodeRequestFunc func(*http.Request, any) error + +// EncodeErrorFunc reports an error to the client. For SSE handlers the +// default writes an SSE "error" event when headers have not yet been sent, +// otherwise falls through to http.Error. +type EncodeErrorFunc func(http.ResponseWriter, *http.Request, error) + +// Server is a Kratos transport.Server that serves Server-Sent Events. +type Server struct { + *http.Server + + lis net.Listener + tlsConf *tls.Config + endpoint *url.URL + + network string + address string + + mux *http.ServeMux + codec encoding.Codec + logger *slog.Logger + filters []FilterFunc + middlewares []middleware.Middleware + decBody DecodeRequestFunc + errEnc EncodeErrorFunc + patterns []string + readHeaderTimeout time.Duration + heartbeat time.Duration + + // shutdownCtx is the parent context handed to every request via + // http.Server.BaseContext. Stop cancels it before calling Shutdown so + // long-running SSE handlers observe ctx.Done() and can drain cleanly + // instead of blocking the shutdown. + shutdownCtx context.Context + shutdownCancel context.CancelFunc + + // active counts live SSE streams managed by StreamHandler / + // JSONHandler. Exposed via ActiveStreams. + active atomic.Int64 +} + +// NewServer constructs a Server. With no options it listens on a random +// TCP port, uses the JSON codec, and serves plaintext HTTP. +func NewServer(opts ...ServerOption) *Server { + s := &Server{ + network: "tcp", + address: ":0", + mux: http.NewServeMux(), + codec: encoding.GetCodec("json"), + logger: slog.Default(), + decBody: DefaultRequestDecoder, + errEnc: DefaultErrorEncoder, + readHeaderTimeout: DefaultReadHeaderTimeout, + } + for _, o := range opts { + o(s) + } + s.Server = &http.Server{ + Handler: FilterChain(s.filters...)(http.HandlerFunc(s.dispatch)), + TLSConfig: s.tlsConf, + ReadHeaderTimeout: s.readHeaderTimeout, + } + return s +} + +// Name returns the transport kind, "sse". +func (s *Server) Name() string { return string(KindSSE) } + +// Endpoint returns the address the server is (or will be) listening on, +// scheme "sse://", suitable for service-registry advertisement. +func (s *Server) Endpoint() (*url.URL, error) { + if err := s.listenAndEndpoint(); err != nil { + return nil, err + } + return s.endpoint, nil +} + +// Codec returns the codec configured for request/response payloads. +func (s *Server) Codec() encoding.Codec { return s.codec } + +// Start opens the listener (if not already) and serves until Stop is +// called or the listener fails. It implements transport.Server. +func (s *Server) Start(ctx context.Context) error { + if err := s.listenAndEndpoint(); err != nil { + return err + } + // Build a cancellable child context that Stop will tear down before + // http.Server.Shutdown runs. Handlers receive this ctx via + // r.Context(), so a Pump select-on-Done unblocks promptly during + // shutdown rather than holding the connection until its own write + // deadline expires. + s.shutdownCtx, s.shutdownCancel = context.WithCancel(ctx) + s.BaseContext = func(net.Listener) context.Context { return s.shutdownCtx } + s.logger.InfoContext(ctx, "sse server listening", "addr", s.lis.Addr().String()) + + var err error + if s.tlsConf != nil { + err = s.ServeTLS(s.lis, "", "") + } else { + err = s.Serve(s.lis) + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} + +// Stop gracefully shuts the server down. It first cancels the +// shutdown-aware context that every handler receives — long-running SSE +// streams that observe ctx.Done() will exit promptly — then calls +// http.Server.Shutdown to wait for in-flight requests to complete. If +// ctx expires before drainage finishes, Stop force-closes connections, +// matching the Kratos HTTP server's behavior. +func (s *Server) Stop(ctx context.Context) error { + s.logger.InfoContext(ctx, "sse server stopping", + "active_streams", s.ActiveStreams()) + if s.shutdownCancel != nil { + s.shutdownCancel() + } + if err := s.Shutdown(ctx); err != nil { + if ctx.Err() != nil { + s.logger.WarnContext(ctx, "sse server force-closing after shutdown timeout") + return s.Close() + } + return err + } + return nil +} + +// ActiveStreams reports the number of SSE streams currently in flight +// through StreamHandler or JSONHandler. Handlers mounted via plain +// Handle / HandleFunc are not counted. +func (s *Server) ActiveStreams() int64 { return s.active.Load() } + +// Handle mounts an http.Handler at pattern. Pattern syntax follows +// net/http.ServeMux (Go 1.22+ "METHOD /path/{var}" form). +func (s *Server) Handle(pattern string, h http.Handler) { + s.mux.Handle(pattern, h) + s.patterns = append(s.patterns, pattern) +} + +// HandleFunc mounts an http.HandlerFunc at pattern. +func (s *Server) HandleFunc(pattern string, h http.HandlerFunc) { + s.Handle(pattern, h) +} + +// WalkPattern visits every pattern registered with Handle/HandleFunc. +// The order matches registration order. +func (s *Server) WalkPattern(fn func(pattern string)) { + for _, p := range s.patterns { + fn(p) + } +} + +// Decode reads r.Body and unmarshals it via the configured request +// decoder (set with RequestDecoder; defaults to DefaultRequestDecoder). +func (s *Server) Decode(r *http.Request, v any) error { + return s.decBody(r, v) +} + +// EncodeError reports err to the client via the configured error +// encoder. +func (s *Server) EncodeError(w http.ResponseWriter, r *http.Request, err error) { + s.errEnc(w, r, err) +} + +// ServeHTTP runs the filter chain around the routing dispatch. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.Handler.ServeHTTP(w, r) +} + +// dispatch is the innermost handler invoked after filters run. It installs +// the SSE transport.Transporter into the request context and dispatches +// the request via the mux. +func (s *Server) dispatch(w http.ResponseWriter, r *http.Request) { + // Resolve the matched pattern (e.g. "/v1/sse/chat:stream") so + // middleware that inspects Operation sees the route template rather + // than the request-specific path. + _, pattern := s.mux.Handler(r) + + tr := &Transport{ + endpoint: s.endpointString(), + operation: pattern, + pathTemplate: pattern, + request: r, + response: w, + reqHeader: headerCarrier(r.Header), + replyHeader: headerCarrier(w.Header()), + } + r = r.WithContext(ktransport.NewServerContext(r.Context(), tr)) + s.mux.ServeHTTP(w, r) +} + +func (s *Server) endpointString() string { + if s.endpoint == nil { + return "" + } + return s.endpoint.String() +} + +func (s *Server) listenAndEndpoint() error { + if s.lis == nil { + lis, err := net.Listen(s.network, s.address) + if err != nil { + return err + } + s.lis = lis + } + if s.endpoint == nil { + s.endpoint = &url.URL{Scheme: string(KindSSE), Host: s.lis.Addr().String()} + } + return nil +} diff --git a/sse/kratos/server_options.go b/sse/kratos/server_options.go new file mode 100644 index 0000000..c5b04f3 --- /dev/null +++ b/sse/kratos/server_options.go @@ -0,0 +1,131 @@ +package kratos + +import ( + "crypto/tls" + "log/slog" + "net" + "net/url" + "time" + + "github.com/go-kratos/kratos/v2/encoding" + "github.com/go-kratos/kratos/v2/middleware" +) + +// ServerOption configures a Server at construction time. +type ServerOption func(*Server) + +// Network sets the listener network (default "tcp"). +func Network(network string) ServerOption { + return func(s *Server) { s.network = network } +} + +// Address sets the listener address (default ":0", random port). +func Address(addr string) ServerOption { + return func(s *Server) { s.address = addr } +} + +// Listener supplies a pre-built listener, bypassing Network/Address. +func Listener(lis net.Listener) ServerOption { + return func(s *Server) { s.lis = lis } +} + +// Endpoint overrides the URL advertised via Endpoint(). Use this when the +// listener address (":8080") isn't the externally routable address. +func Endpoint(u *url.URL) ServerOption { + return func(s *Server) { s.endpoint = u } +} + +// TLSConfig configures TLS for the server. When set, Start serves via +// ServeTLS using the certificate(s) from cfg. +func TLSConfig(c *tls.Config) ServerOption { + return func(s *Server) { s.tlsConf = c } +} + +// Codec sets the default codec used by Decode when the request's +// Content-Type does not specify a recognized one. Pass a Kratos codec +// name such as "json" or "proto"; the codec must be registered +// (importing "github.com/go-kratos/kratos/v2/encoding/proto" suffices +// for proto). +// +// Passing a name with no registered codec is a programming error and +// panics — the misconfiguration would otherwise surface as an obscure +// runtime decode failure. +func Codec(name string) ServerOption { + c := encoding.GetCodec(name) + if c == nil { + panic("sse/kratos: codec not registered: " + name) + } + return func(s *Server) { s.codec = c } +} + +// Logger replaces the default slog.Logger (slog.Default()). A nil logger +// is ignored. +func Logger(l *slog.Logger) ServerOption { + return func(s *Server) { + if l != nil { + s.logger = l + } + } +} + +// RequestDecoder replaces the request body decoder used by Decode. +func RequestDecoder(dec DecodeRequestFunc) ServerOption { + return func(s *Server) { + if dec != nil { + s.decBody = dec + } + } +} + +// ErrorEncoder replaces the error encoder used by EncodeError. +func ErrorEncoder(enc EncodeErrorFunc) ServerOption { + return func(s *Server) { + if enc != nil { + s.errEnc = enc + } + } +} + +// Filter prepends HTTP middleware to the request pipeline. Filters run +// before the route is dispatched, so they may short-circuit auth, set +// cross-cutting headers, or wrap response writing. +func Filter(filters ...FilterFunc) ServerOption { + return func(s *Server) { s.filters = append(s.filters, filters...) } +} + +// Middleware appends Kratos middleware to the chain executed by +// StreamHandler and JSONHandler — between request decoding and stream +// start. Use this for auth, JWT/token verification, schema validation +// and other pre-handler concerns. +// +// These middlewares do NOT wrap the streaming portion of the response; +// any middleware that needs to observe the whole request (tracing, +// metrics, recovery) should be installed as a Filter instead. See the +// package doc for the full rationale. +func Middleware(mws ...middleware.Middleware) ServerOption { + return func(s *Server) { s.middlewares = append(s.middlewares, mws...) } +} + +// ReadHeaderTimeout overrides the time bound on reading request headers +// (default DefaultReadHeaderTimeout). Set to 0 to disable, accepting +// Slowloris risk. +// +// Note: WriteTimeout / IdleTimeout are intentionally not exposed here; +// SSE streams are long-lived and a server-wide write deadline would kill +// them. Per-handler deadlines should be set via http.ResponseController +// inside the handler instead. +func ReadHeaderTimeout(d time.Duration) ServerOption { + return func(s *Server) { s.readHeaderTimeout = d } +} + +// Heartbeat enables automatic SSE comment frames (": \n\n") at the given +// interval on every stream built via StreamHandler or JSONHandler. The +// keepalive is invisible to clients (comments are spec-defined to be +// ignored) but defeats idle-connection timers in upstream proxies +// (nginx, ALB, CloudFlare). +// +// Set to 0 (the default) to disable. Recommended value: 15s — under +// the typical 30-60s proxy idle timeout. +func Heartbeat(interval time.Duration) ServerOption { + return func(s *Server) { s.heartbeat = interval } +} diff --git a/sse/kratos/server_test.go b/sse/kratos/server_test.go new file mode 100644 index 0000000..5009d89 --- /dev/null +++ b/sse/kratos/server_test.go @@ -0,0 +1,321 @@ +package kratos_test + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "net/url" + "strings" + "testing" + "time" + + ktransport "github.com/go-kratos/kratos/v2/transport" + + "github.com/crypto-zero/go-kit/sse" + ksse "github.com/crypto-zero/go-kit/sse/kratos" +) + +func newServerOnLoopback(t *testing.T, opts ...ksse.ServerOption) (*ksse.Server, string) { + t.Helper() + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + opts = append([]ksse.ServerOption{ksse.Listener(lis)}, opts...) + srv := ksse.NewServer(opts...) + return srv, lis.Addr().String() +} + +func startServer(t *testing.T, srv *ksse.Server) func() { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + startErr := make(chan error, 1) + go func() { startErr <- srv.Start(ctx) }() + return func() { + shutdownCtx, c := context.WithTimeout(context.Background(), time.Second) + defer c() + if err := srv.Stop(shutdownCtx); err != nil { + t.Errorf("Stop: %v", err) + } + if err := <-startErr; err != nil { + t.Errorf("Start: %v", err) + } + cancel() + } +} + +func TestServer_Name(t *testing.T) { + srv := ksse.NewServer() + if got, want := srv.Name(), string(ksse.KindSSE); got != want { + t.Errorf("Name() = %q, want %q", got, want) + } +} + +func TestServer_Endpoint(t *testing.T) { + srv, addr := newServerOnLoopback(t) + u, err := srv.Endpoint() + if err != nil { + t.Fatalf("Endpoint: %v", err) + } + if u.Scheme != string(ksse.KindSSE) { + t.Errorf("scheme = %q, want %q", u.Scheme, ksse.KindSSE) + } + if u.Host != addr { + t.Errorf("host = %q, want %q", u.Host, addr) + } +} + +func TestServer_EndpointOverride(t *testing.T) { + override := &url.URL{Scheme: "sse", Host: "api.example.com:443"} + srv := ksse.NewServer(ksse.Endpoint(override)) + got, err := srv.Endpoint() + if err != nil { + t.Fatalf("Endpoint: %v", err) + } + if got.String() != override.String() { + t.Errorf("Endpoint = %q, want %q", got, override) + } +} + +func TestServer_StartAndStreamsEvents(t *testing.T) { + srv, addr := newServerOnLoopback(t) + srv.HandleFunc("/v1/stream", func(w http.ResponseWriter, _ *http.Request) { + s := sse.NewStream(w) + _ = s.Write("hello") + _ = s.Write("world") + _ = s.Done() + }) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Get("http://" + addr + "/v1/stream") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if got, want := resp.Header.Get("Content-Type"), "text/event-stream"; got != want { + t.Errorf("Content-Type = %q, want %q", got, want) + } + + body := readAll(t, resp.Body) + for _, want := range []string{"data: hello\n\n", "data: world\n\n", "data: [DONE]\n\n"} { + if !strings.Contains(body, want) { + t.Errorf("body missing %q\nfull body: %q", want, body) + } + } +} + +func TestServer_TransporterPathTemplate(t *testing.T) { + srv, addr := newServerOnLoopback(t) + got := make(chan struct { + kind ktransport.Kind + op string + ep string + pathTemplate string + hasResponse bool + }, 1) + srv.HandleFunc("/v1/items/{id}", func(w http.ResponseWriter, r *http.Request) { + tr, ok := ktransport.FromServerContext(r.Context()) + if !ok { + t.Errorf("no transport in context") + return + } + var pt string + if ptr, ok := tr.(ksse.Transporter); ok { + pt = ptr.PathTemplate() + } + var hasResp bool + if _, ok := tr.(ksse.ResponseTransporter); ok { + hasResp = true + } + got <- struct { + kind ktransport.Kind + op string + ep string + pathTemplate string + hasResponse bool + }{tr.Kind(), tr.Operation(), tr.Endpoint(), pt, hasResp} + _ = sse.NewStream(w).Done() + }) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Get("http://" + addr + "/v1/items/42") + if err != nil { + t.Fatalf("GET: %v", err) + } + _ = resp.Body.Close() + select { + case v := <-got: + if v.kind != ksse.KindSSE { + t.Errorf("Kind = %q, want %q", v.kind, ksse.KindSSE) + } + if v.op != "/v1/items/{id}" { + t.Errorf("Operation = %q, want /v1/items/{id}", v.op) + } + if v.pathTemplate != "/v1/items/{id}" { + t.Errorf("PathTemplate = %q, want /v1/items/{id}", v.pathTemplate) + } + if !strings.HasPrefix(v.ep, "sse://") { + t.Errorf("Endpoint = %q, want sse:// prefix", v.ep) + } + if !v.hasResponse { + t.Errorf("transport does not satisfy ResponseTransporter") + } + case <-time.After(2 * time.Second): + t.Fatal("handler did not run") + } +} + +func TestServer_DecodeJSON(t *testing.T) { + srv, addr := newServerOnLoopback(t) + type req struct { + Name string `json:"name"` + } + got := make(chan string, 1) + srv.HandleFunc("POST /echo", func(w http.ResponseWriter, r *http.Request) { + var v req + if err := srv.Decode(r, &v); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + got <- v.Name + _ = sse.NewStream(w).Done() + }) + + stop := startServer(t, srv) + defer stop() + + body := strings.NewReader(`{"name":"karma"}`) + resp, err := http.Post("http://"+addr+"/echo", "application/json", body) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + t.Fatalf("status = %d", resp.StatusCode) + } + select { + case v := <-got: + if v != "karma" { + t.Errorf("decoded name = %q, want karma", v) + } + case <-time.After(2 * time.Second): + t.Fatal("handler did not run") + } +} + +func TestServer_RequestDecoderOverride(t *testing.T) { + sentinel := errors.New("custom decoder") + srv, addr := newServerOnLoopback(t, + ksse.RequestDecoder(func(*http.Request, any) error { return sentinel }), + ) + srv.HandleFunc("POST /x", func(w http.ResponseWriter, r *http.Request) { + var v any + if err := srv.Decode(r, &v); err == nil { + http.Error(w, "expected sentinel", http.StatusInternalServerError) + return + } + _ = sse.NewStream(w).Done() + }) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Post("http://"+addr+"/x", "application/json", strings.NewReader(`{}`)) + if err != nil { + t.Fatalf("POST: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } +} + +func TestServer_FilterChain(t *testing.T) { + srv, addr := newServerOnLoopback(t, + ksse.Filter( + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Outer", "1") + next.ServeHTTP(w, r) + }) + }, + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Inner", "1") + next.ServeHTTP(w, r) + }) + }, + ), + ) + srv.HandleFunc("/f", func(w http.ResponseWriter, _ *http.Request) { + _ = sse.NewStream(w).Done() + }) + + stop := startServer(t, srv) + defer stop() + + resp, err := http.Get("http://" + addr + "/f") + if err != nil { + t.Fatalf("GET: %v", err) + } + _ = resp.Body.Close() + if got := resp.Header.Get("X-Outer"); got != "1" { + t.Errorf("X-Outer = %q, want 1", got) + } + if got := resp.Header.Get("X-Inner"); got != "1" { + t.Errorf("X-Inner = %q, want 1", got) + } +} + +func TestServer_WalkPattern(t *testing.T) { + srv := ksse.NewServer() + srv.HandleFunc("/a", func(http.ResponseWriter, *http.Request) {}) + srv.HandleFunc("POST /b", func(http.ResponseWriter, *http.Request) {}) + + var seen []string + srv.WalkPattern(func(p string) { seen = append(seen, p) }) + want := []string{"/a", "POST /b"} + if strings.Join(seen, ",") != strings.Join(want, ",") { + t.Errorf("WalkPattern visited %v, want %v", seen, want) + } +} + +func TestServer_TLSConfigSelected(t *testing.T) { + // We only verify the option installs the TLS config; serving TLS + // requires a real cert that's out of scope for this test. + srv := ksse.NewServer(ksse.TLSConfig(&tls.Config{})) + if srv.TLSConfig == nil { + t.Errorf("TLSConfig not propagated to http.Server") + } +} + +func TestCodecOption_Panics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Codec(unregistered) did not panic") + } + }() + ksse.Codec("nope") +} + +func readAll(t *testing.T, r interface { + Read([]byte) (int, error) +}) string { + t.Helper() + var sb strings.Builder + br := bufio.NewReader(r) + for { + line, err := br.ReadString('\n') + sb.WriteString(line) + if err != nil { + return sb.String() + } + } +} diff --git a/sse/kratos/transport.go b/sse/kratos/transport.go new file mode 100644 index 0000000..37ab4cc --- /dev/null +++ b/sse/kratos/transport.go @@ -0,0 +1,113 @@ +package kratos + +import ( + "context" + "net/http" + + ktransport "github.com/go-kratos/kratos/v2/transport" +) + +var ( + _ Transporter = (*Transport)(nil) + _ ResponseTransporter = (*Transport)(nil) +) + +// Transporter extends Kratos' transport.Transporter with the raw HTTP +// request, matching the shape that github.com/go-kratos/kratos/v2/transport/http +// exposes for its own transport. +type Transporter interface { + ktransport.Transporter + Request() *http.Request + PathTemplate() string +} + +// ResponseTransporter additionally exposes the response writer, for +// handlers that need to wrap or inspect it (e.g. SSE). +type ResponseTransporter interface { + Transporter + Response() http.ResponseWriter +} + +// Transport is the SSE transport context attached to each request. +type Transport struct { + endpoint string + operation string + pathTemplate string + request *http.Request + response http.ResponseWriter + reqHeader headerCarrier + replyHeader headerCarrier +} + +// Kind reports the transport kind, KindSSE. +func (t *Transport) Kind() ktransport.Kind { return KindSSE } + +// Endpoint reports the server's advertised endpoint. +func (t *Transport) Endpoint() string { return t.endpoint } + +// Operation reports the matched route template (e.g. "/v1/chat:stream"). +// Middleware that needs a more specific operation can override it via +// SetOperation. +func (t *Transport) Operation() string { return t.operation } + +// PathTemplate reports the matched ServeMux pattern. +func (t *Transport) PathTemplate() string { return t.pathTemplate } + +// Request returns the underlying *http.Request. +func (t *Transport) Request() *http.Request { return t.request } + +// Response returns the underlying http.ResponseWriter. +func (t *Transport) Response() http.ResponseWriter { return t.response } + +// RequestHeader returns the inbound HTTP headers. +func (t *Transport) RequestHeader() ktransport.Header { return t.reqHeader } + +// ReplyHeader returns the writable response headers. +func (t *Transport) ReplyHeader() ktransport.Header { return t.replyHeader } + +// SetOperation overrides the operation name on the SSE transport +// attached to ctx. It is a no-op when ctx does not carry an SSE +// transport. +func SetOperation(ctx context.Context, op string) { + if tr, ok := ktransport.FromServerContext(ctx); ok { + if t, ok := tr.(*Transport); ok { + t.operation = op + } + } +} + +// RequestFromServerContext returns the request stored in ctx by an SSE +// transport, or false if none is present. +func RequestFromServerContext(ctx context.Context) (*http.Request, bool) { + if tr, ok := ktransport.FromServerContext(ctx); ok { + if t, ok := tr.(Transporter); ok { + return t.Request(), true + } + } + return nil, false +} + +// ResponseWriterFromServerContext returns the response writer stored in +// ctx, or false if none is present. +func ResponseWriterFromServerContext(ctx context.Context) (http.ResponseWriter, bool) { + if tr, ok := ktransport.FromServerContext(ctx); ok { + if t, ok := tr.(ResponseTransporter); ok { + return t.Response(), true + } + } + return nil, false +} + +type headerCarrier http.Header + +func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } +func (hc headerCarrier) Set(key, value string) { http.Header(hc).Set(key, value) } +func (hc headerCarrier) Add(key, value string) { http.Header(hc).Add(key, value) } +func (hc headerCarrier) Values(key string) []string { return http.Header(hc).Values(key) } +func (hc headerCarrier) Keys() []string { + keys := make([]string, 0, len(hc)) + for k := range http.Header(hc) { + keys = append(keys, k) + } + return keys +} diff --git a/sse/snapshot_live.go b/sse/snapshot_live.go new file mode 100644 index 0000000..f990969 --- /dev/null +++ b/sse/snapshot_live.go @@ -0,0 +1,72 @@ +package sse + +import "context" + +// SnapshotLiveOptions configures StreamSnapshotThenLive. +type SnapshotLiveOptions[T any] struct { + SnapshotEvent string + SnapshotEndEvent string + SnapshotEndData string + LiveEvent string + ID func(T) string + Data func(T) (any, error) +} + +// StreamSnapshotThenLive writes an initial snapshot batch, an optional +// snapshot-end marker, then live events from ch until ctx is canceled or ch +// closes. +func StreamSnapshotThenLive[T any]( + ctx context.Context, + st *Stream, + snapshot []T, + ch <-chan T, + opts SnapshotLiveOptions[T], +) error { + for _, item := range snapshot { + if err := writeSnapshotLiveEvent(st, opts.SnapshotEvent, item, opts); err != nil { + return err + } + } + + if opts.SnapshotEndEvent != "" { + data := opts.SnapshotEndData + if data == "" { + data = "{}" + } + if err := st.WriteEvent(Event{Event: opts.SnapshotEndEvent, Data: data}); err != nil { + return err + } + } + + for { + select { + case <-ctx.Done(): + return nil + case item, ok := <-ch: + if !ok { + return nil + } + if err := writeSnapshotLiveEvent(st, opts.LiveEvent, item, opts); err != nil { + return err + } + } + } +} + +func writeSnapshotLiveEvent[T any](st *Stream, event string, item T, opts SnapshotLiveOptions[T]) error { + id := "" + if opts.ID != nil { + id = opts.ID(item) + } + if opts.Data == nil { + return st.WriteJSONEvent(Event{Event: event, ID: id}, item) + } + data, err := opts.Data(item) + if err != nil { + return err + } + if s, ok := data.(string); ok { + return st.WriteEvent(Event{Event: event, ID: id, Data: s}) + } + return st.WriteJSONEvent(Event{Event: event, ID: id}, data) +} diff --git a/sse/snapshot_live_test.go b/sse/snapshot_live_test.go new file mode 100644 index 0000000..19e1dd9 --- /dev/null +++ b/sse/snapshot_live_test.go @@ -0,0 +1,64 @@ +package sse + +import ( + "context" + "net/http/httptest" + "strings" + "testing" +) + +func TestStreamSnapshotThenLiveWritesSnapshotEndAndLive(t *testing.T) { + rec := httptest.NewRecorder() + st := NewStream(rec) + live := make(chan int, 2) + live <- 3 + close(live) + + err := StreamSnapshotThenLive(context.Background(), st, []int{1, 2}, live, SnapshotLiveOptions[int]{ + SnapshotEvent: "snapshot", + SnapshotEndEvent: "snapshot-end", + LiveEvent: "point", + ID: func(v int) string { + return string(rune('a' + v - 1)) + }, + Data: func(v int) (any, error) { + return map[string]int{"value": v}, nil + }, + }) + if err != nil { + t.Fatalf("StreamSnapshotThenLive: %v", err) + } + + body := rec.Body.String() + for _, want := range []string{ + "event: snapshot\nid: a\ndata: {\"value\":1}\n\n", + "event: snapshot\nid: b\ndata: {\"value\":2}\n\n", + "event: snapshot-end\ndata: {}\n\n", + "event: point\nid: c\ndata: {\"value\":3}\n\n", + } { + if !strings.Contains(body, want) { + t.Fatalf("body missing %q\nbody:\n%s", want, body) + } + } +} + +func TestStreamSnapshotThenLiveDefaultsToJSON(t *testing.T) { + rec := httptest.NewRecorder() + st := NewStream(rec) + live := make(chan struct { + Name string `json:"name"` + }) + close(live) + + err := StreamSnapshotThenLive(context.Background(), st, []struct { + Name string `json:"name"` + }{{Name: "one"}}, live, SnapshotLiveOptions[struct { + Name string `json:"name"` + }]{}) + if err != nil { + t.Fatalf("StreamSnapshotThenLive: %v", err) + } + if got := rec.Body.String(); !strings.Contains(got, "data: {\"name\":\"one\"}\n\n") { + t.Fatalf("expected JSON payload, got:\n%s", got) + } +} diff --git a/sse/sse.go b/sse/sse.go new file mode 100644 index 0000000..69940a1 --- /dev/null +++ b/sse/sse.go @@ -0,0 +1,278 @@ +// Package sse implements a minimal Server-Sent Events writer for net/http. +// +// The package is transport-only: it owns the wire format (headers, framing, +// flushing, the [DONE] terminator) and a small helper for draining streaming +// channels. Authentication, request decoding, validation and routing are the +// caller's responsibility. +// +// For a Kratos transport.Server adapter built on top of this package, see +// the sub-package github.com/crypto-zero/go-kit/sse/kratos. +package sse + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// DoneMarker is the conventional terminator sent as the final data frame of +// an SSE stream. It matches the OpenAI-style protocol that most browser and +// CLI clients already understand. +const DoneMarker = "[DONE]" + +// LastEventIDHeader is the HTTP header browsers send on reconnect to resume +// from the last received event. +const LastEventIDHeader = "Last-Event-ID" + +// Event is a single SSE event. All fields are optional; an event with only +// Data set produces the common "data: \n\n" frame. +type Event struct { + // Event is the event name. When empty, no "event:" field is written and + // browsers dispatch the frame as the default "message" event. + Event string + // ID populates the "id:" field, allowing clients to resume via the + // Last-Event-ID header on reconnect. + ID string + // Data is the event payload. It may contain newlines: each line is + // emitted as a separate "data:" field per the SSE spec. + Data string + // Retry, when non-zero, sets the client's reconnection delay in + // milliseconds via the "retry:" field. + Retry time.Duration +} + +// Stream writes Server-Sent Events to an HTTP response. Methods are safe for +// concurrent use, allowing a Heartbeat goroutine to coexist with the main +// writer. +type Stream struct { + mu sync.Mutex + w http.ResponseWriter + rc *http.ResponseController + started bool +} + +// NewStream wraps w for SSE output. Headers are not written until the first +// frame is sent (or Start is called explicitly), so callers may still return +// a non-SSE HTTP error after constructing the Stream. +// +// Flushes go through http.ResponseController, which walks any Unwrap() +// chain installed by middleware (metrics wrappers, response recorders) to +// reach the underlying Flusher. +func NewStream(w http.ResponseWriter) *Stream { + return &Stream{w: w, rc: http.NewResponseController(w)} +} + +// Start writes the SSE response headers and the 200 status line. It is safe +// to call multiple times; subsequent calls are no-ops. Callers normally do +// not need to invoke Start directly — any of the Write methods will trigger +// it on first use. +func (s *Stream) Start() { + s.mu.Lock() + defer s.mu.Unlock() + s.startLocked() +} + +func (s *Stream) startLocked() { + if s.started { + return + } + h := s.w.Header() + h.Set("Content-Type", "text/event-stream") + h.Set("Cache-Control", "no-cache") + h.Set("Connection", "keep-alive") + // Disable proxy buffering (nginx-specific but harmless elsewhere) so + // frames reach the client as soon as they are flushed. + h.Set("X-Accel-Buffering", "no") + s.w.WriteHeader(http.StatusOK) + s.started = true +} + +// Write sends a single default-event data frame and flushes immediately. +func (s *Stream) Write(data string) error { + return s.WriteEvent(Event{Data: data}) +} + +// WriteEvent sends a fully specified event and flushes immediately. +func (s *Stream) WriteEvent(e Event) error { + s.mu.Lock() + defer s.mu.Unlock() + s.startLocked() + + var b strings.Builder + if e.Event != "" { + b.WriteString("event: ") + b.WriteString(e.Event) + b.WriteByte('\n') + } + if e.ID != "" { + b.WriteString("id: ") + b.WriteString(e.ID) + b.WriteByte('\n') + } + if e.Retry > 0 { + fmt.Fprintf(&b, "retry: %d\n", e.Retry.Milliseconds()) + } + // Per spec, every newline in the payload starts a new "data:" field; + // the frame is terminated by a blank line. + for line := range strings.SplitSeq(e.Data, "\n") { + b.WriteString("data: ") + b.WriteString(line) + b.WriteByte('\n') + } + b.WriteByte('\n') + if _, err := io.WriteString(s.w, b.String()); err != nil { + return err + } + s.flushLocked() + return nil +} + +// WriteJSON marshals v and writes it as a single data frame. The caller is +// responsible for any terminating Done frame. +func (s *Stream) WriteJSON(v any) error { + return s.WriteJSONEvent(Event{}, v) +} + +// WriteJSONEvent marshals v and writes it as a named/id/retry event. +func (s *Stream) WriteJSONEvent(e Event, v any) error { + data, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("sse: marshal json: %w", err) + } + e.Data = string(data) + return s.WriteEvent(e) +} + +// Error writes an "error"-named event whose data payload is the JSON object +// {"error": msg}, matching the convention used by most JS EventSource +// consumers. +func (s *Stream) Error(msg string) error { + return s.WriteJSONEvent(Event{Event: "error"}, map[string]string{"error": msg}) +} + +// Done sends the DoneMarker as a final data frame, signaling end-of-stream +// to clients that follow the OpenAI-style protocol. +func (s *Stream) Done() error { + return s.Write(DoneMarker) +} + +// Comment writes an SSE comment frame (": text\n\n"). Comments are ignored +// by clients and useful as keepalive packets through proxies that close +// idle connections (nginx, ALB, CloudFlare). +// +// An empty text writes a bare ":\n\n" — the minimal valid keepalive. +func (s *Stream) Comment(text string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.startLocked() + var b strings.Builder + for line := range strings.SplitSeq(text, "\n") { + b.WriteByte(':') + if line != "" { + b.WriteByte(' ') + b.WriteString(line) + } + b.WriteByte('\n') + } + b.WriteByte('\n') + if _, err := io.WriteString(s.w, b.String()); err != nil { + return err + } + s.flushLocked() + return nil +} + +// Heartbeat starts a goroutine that emits a Comment frame every interval +// until ctx is cancelled or the returned stop function is called. +// +// stop is synchronous: it cancels the ticker and blocks until the +// goroutine has finished its current iteration. Callers MUST invoke +// stop before the underlying http.ResponseWriter is recycled (typically +// by deferring it in the handler) — writing a comment to a reclaimed +// response panics. stop is safe to call multiple times. +// +// Use this for long-lived streams that sit behind proxies with idle +// connection timeouts. +func (s *Stream) Heartbeat(ctx context.Context, interval time.Duration) (stop func()) { + ctx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + go func() { + defer close(done) + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if err := s.Comment(""); err != nil { + return + } + } + } + }() + var once sync.Once + return func() { + once.Do(func() { + cancel() + <-done + }) + } +} + +// Pump drains chunks into the stream as default data events. It returns +// when: +// +// - chunks is closed: Done is sent and nil is returned; +// - errs delivers a non-nil error: that error is forwarded via Error and +// returned (no Done is sent); +// - errs is closed: returns nil silently (no Done); +// - ctx is cancelled: returns ctx.Err() silently (no Done). +// +// Empty chunks are skipped. Callers typically use Pump to relay a +// (<-chan string, <-chan error) pair produced by a streaming biz call. +func (s *Stream) Pump(ctx context.Context, chunks <-chan string, errs <-chan error) error { + for { + select { + case chunk, ok := <-chunks: + if !ok { + return s.Done() + } + if chunk == "" { + continue + } + if err := s.Write(chunk); err != nil { + return err + } + case err, ok := <-errs: + if !ok { + return nil + } + if err != nil { + _ = s.Error(err.Error()) + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (s *Stream) flushLocked() { + // Best-effort: ResponseController.Flush returns http.ErrNotSupported + // when no Flusher is reachable through the Unwrap chain. SSE without + // flushing degrades to "client receives nothing until close" — bad, + // but not something this layer can recover from. Silently ignore. + _ = s.rc.Flush() +} + +// LastEventID returns the value of the Last-Event-ID HTTP header sent by +// EventSource clients on reconnect. Returns "" when absent. +func LastEventID(r *http.Request) string { + return r.Header.Get(LastEventIDHeader) +} diff --git a/sse/sse_test.go b/sse/sse_test.go new file mode 100644 index 0000000..d34e906 --- /dev/null +++ b/sse/sse_test.go @@ -0,0 +1,274 @@ +package sse + +import ( + "context" + "errors" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestStream_HeadersAreLazy(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if got := rec.Header().Get("Content-Type"); got != "" { + t.Errorf("Content-Type before first write = %q, want empty", got) + } + if err := s.Write("hello"); err != nil { + t.Fatalf("Write: %v", err) + } + if got, want := rec.Header().Get("Content-Type"), "text/event-stream"; got != want { + t.Errorf("Content-Type after Write = %q, want %q", got, want) + } + if got, want := rec.Header().Get("Cache-Control"), "no-cache"; got != want { + t.Errorf("Cache-Control = %q, want %q", got, want) + } + if got, want := rec.Header().Get("X-Accel-Buffering"), "no"; got != want { + t.Errorf("X-Accel-Buffering = %q, want %q", got, want) + } +} + +func TestStream_StartIdempotent(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + s.Start() + s.Start() + if got, want := rec.Code, 200; got != want { + t.Errorf("status = %d, want %d", got, want) + } +} + +func TestStream_WriteFormat(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Write("hello"); err != nil { + t.Fatalf("Write: %v", err) + } + if got, want := rec.Body.String(), "data: hello\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_WriteEventFull(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + err := s.WriteEvent(Event{ + Event: "delta", + ID: "42", + Data: "line1\nline2", + Retry: 1500 * time.Millisecond, + }) + if err != nil { + t.Fatalf("WriteEvent: %v", err) + } + want := "event: delta\nid: 42\nretry: 1500\ndata: line1\ndata: line2\n\n" + if got := rec.Body.String(); got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_WriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.WriteJSON(map[string]int{"n": 1}); err != nil { + t.Fatalf("WriteJSON: %v", err) + } + if got, want := rec.Body.String(), "data: {\"n\":1}\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_WriteJSONEvent(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.WriteJSONEvent(Event{Event: "point", ID: "42"}, map[string]int{"n": 1}); err != nil { + t.Fatalf("WriteJSONEvent: %v", err) + } + want := "event: point\nid: 42\ndata: {\"n\":1}\n\n" + if got := rec.Body.String(); got != want { + t.Fatalf("body = %q, want %q", got, want) + } +} + +func TestStream_Error(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Error("boom"); err != nil { + t.Fatalf("Error: %v", err) + } + want := "event: error\ndata: {\"error\":\"boom\"}\n\n" + if got := rec.Body.String(); got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_Done(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Done(); err != nil { + t.Fatalf("Done: %v", err) + } + if got, want := rec.Body.String(), "data: [DONE]\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestPump_CleanFinish(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string, 3) + errs := make(chan error) + chunks <- "a" + chunks <- "" + chunks <- "b" + close(chunks) + if err := s.Pump(context.Background(), chunks, errs); err != nil { + t.Fatalf("Pump: %v", err) + } + got := rec.Body.String() + want := "data: a\n\ndata: b\n\ndata: [DONE]\n\n" + if got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestPump_ErrorFromErrCh(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string) + errs := make(chan error, 1) + sentinel := errors.New("upstream failure") + errs <- sentinel + err := s.Pump(context.Background(), chunks, errs) + if !errors.Is(err, sentinel) { + t.Errorf("Pump err = %v, want %v", err, sentinel) + } + if !strings.Contains(rec.Body.String(), `"error":"upstream failure"`) { + t.Errorf("body missing error payload: %q", rec.Body.String()) + } + if strings.Contains(rec.Body.String(), DoneMarker) { + t.Errorf("body should not contain DONE on error, got %q", rec.Body.String()) + } +} + +func TestPump_ErrChClosedSilently(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string) + errs := make(chan error) + close(errs) + if err := s.Pump(context.Background(), chunks, errs); err != nil { + t.Errorf("Pump on closed errCh = %v, want nil", err) + } + if got := rec.Body.String(); got != "" { + t.Errorf("body = %q, want empty (no headers written either)", got) + } +} + +func TestPump_ContextCancel(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + ctx, cancel := context.WithCancel(context.Background()) + chunks := make(chan string) + errs := make(chan error) + cancel() + if err := s.Pump(ctx, chunks, errs); !errors.Is(err, context.Canceled) { + t.Errorf("Pump err = %v, want context.Canceled", err) + } +} + +func TestStream_Comment(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Comment("hello"); err != nil { + t.Fatalf("Comment: %v", err) + } + if got, want := rec.Body.String(), ": hello\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_CommentEmpty(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Comment(""); err != nil { + t.Fatalf("Comment: %v", err) + } + if got, want := rec.Body.String(), ":\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_Heartbeat(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + stop := s.Heartbeat(context.Background(), 5*time.Millisecond) + defer stop() + + // Concurrent writes: ensure mutex prevents interleaved frames. + done := make(chan struct{}) + go func() { + for i := 0; i < 20; i++ { + _ = s.Write("token") + time.Sleep(time.Millisecond) + } + close(done) + }() + <-done + stop() + + body := rec.Body.String() + if !strings.Contains(body, ":\n\n") { + t.Errorf("expected at least one comment frame, got %q", body) + } + if !strings.Contains(body, "data: token\n\n") { + t.Errorf("expected data frames, got %q", body) + } +} + +func TestLastEventID(t *testing.T) { + r := httptest.NewRequest("GET", "/x", nil) + r.Header.Set(LastEventIDHeader, "42") + if got := LastEventID(r); got != "42" { + t.Errorf("LastEventID(header) = %q, want %q", got, "42") + } + + r = httptest.NewRequest("GET", "/x", nil) + if got := LastEventID(r); got != "" { + t.Errorf("LastEventID(absent) = %q, want empty", got) + } +} + +func TestDetachWriteTimeout_DeadlineDoesNotPropagate(t *testing.T) { + rec := httptest.NewRecorder() + parent, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) + defer cancel() + req := httptest.NewRequest("POST", "/x", nil).WithContext(parent) + + r := DetachWriteTimeout(rec, req, time.Second) + <-parent.Done() + // Give the detach goroutine a moment to observe parent.Err(). + time.Sleep(20 * time.Millisecond) + + select { + case <-r.Context().Done(): + t.Errorf("detached context cancelled on parent DeadlineExceeded; want still alive") + default: + } +} + +func TestDetachWriteTimeout_PropagatesCancel(t *testing.T) { + rec := httptest.NewRecorder() + parent, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("POST", "/x", nil).WithContext(parent) + + r := DetachWriteTimeout(rec, req, time.Second) + cancel() + + select { + case <-r.Context().Done(): + case <-time.After(100 * time.Millisecond): + t.Errorf("detached context not cancelled when parent was Canceled") + } +} From a840ea40e7cb57b52cdc64631a67d1624c29a545 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 17:07:12 +0800 Subject: [PATCH 03/11] chore: bump go to 1.26.3 --- go.mod | 2 +- kubernetes/election/go.mod | 2 +- logging/kratos/go.mod | 2 +- s3/go.mod | 2 +- sse/kratos/go.mod | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index ad74865..3170641 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit -go 1.25.5 +go 1.26.3 require ( entgo.io/ent v0.14.5 diff --git a/kubernetes/election/go.mod b/kubernetes/election/go.mod index e45236e..54788e3 100644 --- a/kubernetes/election/go.mod +++ b/kubernetes/election/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit/kubernetes/election -go 1.25.5 +go 1.26.3 require ( k8s.io/apimachinery v0.35.0 diff --git a/logging/kratos/go.mod b/logging/kratos/go.mod index 4387484..39223f1 100644 --- a/logging/kratos/go.mod +++ b/logging/kratos/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit/logging/kratos -go 1.25.5 +go 1.26.3 require ( github.com/go-kratos/kratos/v2 v2.9.2 diff --git a/s3/go.mod b/s3/go.mod index 7a332a8..9cdf048 100644 --- a/s3/go.mod +++ b/s3/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit/s3 -go 1.25.5 +go 1.26.3 require ( github.com/minio/minio-go/v7 v7.0.97 diff --git a/sse/kratos/go.mod b/sse/kratos/go.mod index a2ef001..e446d28 100644 --- a/sse/kratos/go.mod +++ b/sse/kratos/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit/sse/kratos -go 1.25.5 +go 1.26.3 require ( github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac From c15a84268572c7ecd32bb00ef8d5d19e3091a50c Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 19:08:26 +0800 Subject: [PATCH 04/11] fix: harden sse heartbeat and deadline handling --- sse/detach.go | 38 +++++++++++++++++------- sse/kratos/handler.go | 12 +------- sse/kratos/http_handler.go | 33 ++------------------- sse/sse.go | 3 ++ sse/sse_test.go | 59 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 53 deletions(-) diff --git a/sse/detach.go b/sse/detach.go index 3afa600..1d3e45d 100644 --- a/sse/detach.go +++ b/sse/detach.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "sync" "time" ) @@ -28,20 +29,35 @@ import ( // ctx.Done() still see client disconnects. func DetachWriteTimeout(w http.ResponseWriter, r *http.Request, writeTimeout time.Duration) *http.Request { rc := http.NewResponseController(w) - _ = rc.SetWriteDeadline(time.Now().Add(writeTimeout)) + deadline := time.Time{} + if writeTimeout > 0 { + deadline = time.Now().Add(writeTimeout) + } + _ = rc.SetWriteDeadline(deadline) - parent := r.Context() + ctx, _ := DetachDeadlineContext(r.Context()) + return r.WithContext(ctx) +} + +// DetachDeadlineContext returns a context that keeps parent values but ignores +// parent DeadlineExceeded cancellation. Other parent cancellations, such as a +// client disconnect, are still forwarded. +// +// The returned stop function unregisters the parent callback and cancels the +// detached context. Callers that own a bounded streaming lifecycle should defer +// stop when the stream ends. +func DetachDeadlineContext(parent context.Context) (context.Context, func()) { ctx, cancel := context.WithCancel(context.WithoutCancel(parent)) - go func() { - <-parent.Done() - // Forward only genuine client disconnects. When the parent's Err - // is DeadlineExceeded the server-wide timer fired — that is the - // signal we are explicitly detaching from, so the replacement - // context must stay alive. The cancel func is then released when - // the handler returns and references drop. + stopParent := context.AfterFunc(parent, func() { if !errors.Is(parent.Err(), context.DeadlineExceeded) { cancel() } - }() - return r.WithContext(ctx) + }) + var once sync.Once + return ctx, func() { + once.Do(func() { + stopParent() + cancel() + }) + } } diff --git a/sse/kratos/handler.go b/sse/kratos/handler.go index 2e0af66..662e359 100644 --- a/sse/kratos/handler.go +++ b/sse/kratos/handler.go @@ -117,7 +117,7 @@ func preStream[Req any]( // Callers should defer end immediately after this call. func (s *Server) beginStream(ctx context.Context, w http.ResponseWriter) (*sse.Stream, func()) { st := sse.NewStream(w) - stopBeat := s.startHeartbeat(ctx, st) + stopBeat := st.Heartbeat(ctx, s.heartbeat) s.active.Add(1) return st, func() { stopBeat() @@ -125,16 +125,6 @@ func (s *Server) beginStream(ctx context.Context, w http.ResponseWriter) (*sse.S } } -// startHeartbeat fires a periodic comment frame on st when the server -// has Heartbeat enabled. Returns a stop function (a no-op when -// heartbeat is disabled). -func (s *Server) startHeartbeat(ctx context.Context, st *sse.Stream) func() { - if s.heartbeat <= 0 { - return func() {} - } - return st.Heartbeat(ctx, s.heartbeat) -} - // chainFor composes the middleware chain for one handler: server-wide // middlewares (outermost) followed by per-handler extras. The returned // chain does not share backing storage with srv.middlewares, so later diff --git a/sse/kratos/http_handler.go b/sse/kratos/http_handler.go index 7212944..a724f69 100644 --- a/sse/kratos/http_handler.go +++ b/sse/kratos/http_handler.go @@ -2,11 +2,9 @@ package kratos import ( "context" - "errors" "fmt" "net/http" "strconv" - "sync" "time" authkratos "github.com/crypto-zero/go-kit/auth/kratos" @@ -65,11 +63,11 @@ func RegisterHTTPStream[Req any]( if operation != "" { khttp.SetOperation(ctx, operation) } - streamCtx, stopStreamCtx := detachHTTPTimeout(ctx) + streamCtx, stopStreamCtx := sse.DetachDeadlineContext(ctx) defer stopStreamCtx() h := ctx.Middleware(func(mctx context.Context, raw any) (any, error) { st := sse.NewStream(ctx.Response()) - stopBeat := startHTTPHeartbeat(mctx, st, cfg.heartbeat) + stopBeat := st.Heartbeat(mctx, cfg.heartbeat) defer stopBeat() if err := do(mctx, raw.(*Req), st); err != nil { _ = st.Error(err.Error()) @@ -108,33 +106,6 @@ func bindHTTPStreamRequest(ctx khttp.Context, target any) error { } } -func startHTTPHeartbeat(ctx context.Context, st *sse.Stream, interval time.Duration) func() { - if interval <= 0 { - return func() {} - } - return st.Heartbeat(ctx, interval) -} - -func detachHTTPTimeout(parent context.Context) (context.Context, func()) { - ctx, cancel := context.WithCancel(context.WithoutCancel(parent)) - done := make(chan struct{}) - go func() { - select { - case <-parent.Done(): - if !errors.Is(parent.Err(), context.DeadlineExceeded) { - cancel() - } - case <-done: - } - }() - var once sync.Once - return ctx, func() { - once.Do(func() { - close(done) - cancel() - }) - } -} func decodeProtoQuery(r *http.Request, msg proto.Message) error { q := r.URL.Query() diff --git a/sse/sse.go b/sse/sse.go index 69940a1..5f7ba25 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -199,6 +199,9 @@ func (s *Stream) Comment(text string) error { // Use this for long-lived streams that sit behind proxies with idle // connection timeouts. func (s *Stream) Heartbeat(ctx context.Context, interval time.Duration) (stop func()) { + if interval <= 0 { + return func() {} + } ctx, cancel := context.WithCancel(ctx) done := make(chan struct{}) go func() { diff --git a/sse/sse_test.go b/sse/sse_test.go index d34e906..bb67316 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -1,8 +1,11 @@ package sse import ( + "bufio" "context" "errors" + "net" + "net/http" "net/http/httptest" "strings" "testing" @@ -227,6 +230,19 @@ func TestStream_Heartbeat(t *testing.T) { } } +func TestStream_HeartbeatNonPositiveIntervalIsNoop(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + + stop := s.Heartbeat(context.Background(), 0) + stop() + stop() + + if got := rec.Body.String(); got != "" { + t.Errorf("body = %q, want empty", got) + } +} + func TestLastEventID(t *testing.T) { r := httptest.NewRequest("GET", "/x", nil) r.Header.Set(LastEventIDHeader, "42") @@ -272,3 +288,46 @@ func TestDetachWriteTimeout_PropagatesCancel(t *testing.T) { t.Errorf("detached context not cancelled when parent was Canceled") } } + +func TestDetachWriteTimeout_NonPositiveClearsWriteDeadline(t *testing.T) { + var got time.Time + w := deadlineResponseWriter{ + ResponseWriter: httptest.NewRecorder(), + setWriteDeadline: func(t time.Time) error { + got = t + return nil + }, + } + req := httptest.NewRequest("POST", "/x", nil) + + _ = DetachWriteTimeout(w, req, 0) + + if !got.IsZero() { + t.Errorf("write deadline = %v, want zero time", got) + } +} + +type deadlineResponseWriter struct { + http.ResponseWriter + setWriteDeadline func(time.Time) error +} + +func (w deadlineResponseWriter) SetWriteDeadline(t time.Time) error { + return w.setWriteDeadline(t) +} + +func (w deadlineResponseWriter) SetReadDeadline(time.Time) error { + return nil +} + +func (w deadlineResponseWriter) EnableFullDuplex() error { + return nil +} + +func (w deadlineResponseWriter) Flush() error { + return nil +} + +func (w deadlineResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} From 871a48d5fa76a25bf8aeab98e9d377cf79f6c740 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 19:09:35 +0800 Subject: [PATCH 05/11] feat: add sse proto generator --- proto/buf.gen.yaml | 2 + proto/kit/sse/v1/sse.pb.go | 191 +++++++++++++++++++++++++++ proto/kit/sse/v1/sse.proto | 18 +++ sse/kratos/http_bound_test.go | 55 ++++++++ sse/kratos/http_handler.go | 121 +++-------------- sse/kratos/protoc-gen-go-sse/main.go | 169 ++++++++++++++++++++++++ 6 files changed, 455 insertions(+), 101 deletions(-) create mode 100644 proto/kit/sse/v1/sse.pb.go create mode 100644 proto/kit/sse/v1/sse.proto create mode 100644 sse/kratos/http_bound_test.go create mode 100644 sse/kratos/protoc-gen-go-sse/main.go diff --git a/proto/buf.gen.yaml b/proto/buf.gen.yaml index 44e7c8b..1ecdf6a 100644 --- a/proto/buf.gen.yaml +++ b/proto/buf.gen.yaml @@ -16,3 +16,5 @@ plugins: - ../logging/protoc-gen-go-redact/main.go out: . opt: paths=source_relative + exclude_types: + - kit.sse.v1 diff --git a/proto/kit/sse/v1/sse.pb.go b/proto/kit/sse/v1/sse.pb.go new file mode 100644 index 0000000..edb5d1b --- /dev/null +++ b/proto/kit/sse/v1/sse.pb.go @@ -0,0 +1,191 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: kit/sse/v1/sse.proto + +package ssev1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type StreamRule struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Pattern: + // + // *StreamRule_Get + // *StreamRule_Post + Pattern isStreamRule_Pattern `protobuf_oneof:"pattern"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamRule) Reset() { + *x = StreamRule{} + mi := &file_kit_sse_v1_sse_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamRule) ProtoMessage() {} + +func (x *StreamRule) ProtoReflect() protoreflect.Message { + mi := &file_kit_sse_v1_sse_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamRule.ProtoReflect.Descriptor instead. +func (*StreamRule) Descriptor() ([]byte, []int) { + return file_kit_sse_v1_sse_proto_rawDescGZIP(), []int{0} +} + +func (x *StreamRule) GetPattern() isStreamRule_Pattern { + if x != nil { + return x.Pattern + } + return nil +} + +func (x *StreamRule) GetGet() string { + if x != nil { + if x, ok := x.Pattern.(*StreamRule_Get); ok { + return x.Get + } + } + return "" +} + +func (x *StreamRule) GetPost() string { + if x != nil { + if x, ok := x.Pattern.(*StreamRule_Post); ok { + return x.Post + } + } + return "" +} + +type isStreamRule_Pattern interface { + isStreamRule_Pattern() +} + +type StreamRule_Get struct { + Get string `protobuf:"bytes,1,opt,name=get,proto3,oneof"` +} + +type StreamRule_Post struct { + Post string `protobuf:"bytes,2,opt,name=post,proto3,oneof"` +} + +func (*StreamRule_Get) isStreamRule_Pattern() {} + +func (*StreamRule_Post) isStreamRule_Pattern() {} + +var file_kit_sse_v1_sse_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*StreamRule)(nil), + Field: 81001, + Name: "kit.sse.v1.server_sent_event", + Tag: "bytes,81001,opt,name=server_sent_event", + Filename: "kit/sse/v1/sse.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // optional kit.sse.v1.StreamRule server_sent_event = 81001; + E_ServerSentEvent = &file_kit_sse_v1_sse_proto_extTypes[0] +) + +var File_kit_sse_v1_sse_proto protoreflect.FileDescriptor + +const file_kit_sse_v1_sse_proto_rawDesc = "" + + "\n" + + "\x14kit/sse/v1/sse.proto\x12\n" + + "kit.sse.v1\x1a google/protobuf/descriptor.proto\"A\n" + + "\n" + + "StreamRule\x12\x12\n" + + "\x03get\x18\x01 \x01(\tH\x00R\x03get\x12\x14\n" + + "\x04post\x18\x02 \x01(\tH\x00R\x04postB\t\n" + + "\apattern:d\n" + + "\x11server_sent_event\x12\x1e.google.protobuf.MethodOptions\x18\xe9\xf8\x04 \x01(\v2\x16.kit.sse.v1.StreamRuleR\x0fserverSentEventB6Z4github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1b\x06proto3" + +var ( + file_kit_sse_v1_sse_proto_rawDescOnce sync.Once + file_kit_sse_v1_sse_proto_rawDescData []byte +) + +func file_kit_sse_v1_sse_proto_rawDescGZIP() []byte { + file_kit_sse_v1_sse_proto_rawDescOnce.Do(func() { + file_kit_sse_v1_sse_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_kit_sse_v1_sse_proto_rawDesc), len(file_kit_sse_v1_sse_proto_rawDesc))) + }) + return file_kit_sse_v1_sse_proto_rawDescData +} + +var file_kit_sse_v1_sse_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_kit_sse_v1_sse_proto_goTypes = []any{ + (*StreamRule)(nil), // 0: kit.sse.v1.StreamRule + (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions +} +var file_kit_sse_v1_sse_proto_depIdxs = []int32{ + 1, // 0: kit.sse.v1.server_sent_event:extendee -> google.protobuf.MethodOptions + 0, // 1: kit.sse.v1.server_sent_event:type_name -> kit.sse.v1.StreamRule + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 1, // [1:2] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_kit_sse_v1_sse_proto_init() } +func file_kit_sse_v1_sse_proto_init() { + if File_kit_sse_v1_sse_proto != nil { + return + } + file_kit_sse_v1_sse_proto_msgTypes[0].OneofWrappers = []any{ + (*StreamRule_Get)(nil), + (*StreamRule_Post)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_kit_sse_v1_sse_proto_rawDesc), len(file_kit_sse_v1_sse_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_kit_sse_v1_sse_proto_goTypes, + DependencyIndexes: file_kit_sse_v1_sse_proto_depIdxs, + MessageInfos: file_kit_sse_v1_sse_proto_msgTypes, + ExtensionInfos: file_kit_sse_v1_sse_proto_extTypes, + }.Build() + File_kit_sse_v1_sse_proto = out.File + file_kit_sse_v1_sse_proto_goTypes = nil + file_kit_sse_v1_sse_proto_depIdxs = nil +} diff --git a/proto/kit/sse/v1/sse.proto b/proto/kit/sse/v1/sse.proto new file mode 100644 index 0000000..f881e2c --- /dev/null +++ b/proto/kit/sse/v1/sse.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package kit.sse.v1; + +import "google/protobuf/descriptor.proto"; + +option go_package = "github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1"; + +extend google.protobuf.MethodOptions { + StreamRule server_sent_event = 81001; +} + +message StreamRule { + oneof pattern { + string get = 1; + string post = 2; + } +} diff --git a/sse/kratos/http_bound_test.go b/sse/kratos/http_bound_test.go new file mode 100644 index 0000000..fea6e54 --- /dev/null +++ b/sse/kratos/http_bound_test.go @@ -0,0 +1,55 @@ +package kratos_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + ksse "github.com/crypto-zero/go-kit/sse" + ssekratos "github.com/crypto-zero/go-kit/sse/kratos" + khttp "github.com/go-kratos/kratos/v2/transport/http" +) + +type boundRequest struct { + Name string +} + +func TestRegisterHTTPStreamBoundUsesSuppliedBinder(t *testing.T) { + srv := khttp.NewServer(khttp.Timeout(0)) + ssekratos.RegisterHTTPStreamBound( + srv, + http.MethodGet, + "/v1/bound", + "/test.Bound/Watch", + func(ctx khttp.Context, req *boundRequest) error { + req.Name = ctx.Query().Get("name") + return nil + }, + func(_ context.Context, req *boundRequest, st *ksse.Stream) error { + if err := st.WriteJSON(map[string]string{"name": req.Name}); err != nil { + return err + } + return st.Done() + }, + ) + + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/bound?name=generated") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if !strings.Contains(string(body), `data: {"name":"generated"}`) { + t.Fatalf("bound response missing generated name:\n%s", string(body)) + } +} diff --git a/sse/kratos/http_handler.go b/sse/kratos/http_handler.go index a724f69..4543136 100644 --- a/sse/kratos/http_handler.go +++ b/sse/kratos/http_handler.go @@ -2,14 +2,11 @@ package kratos import ( "context" - "fmt" "net/http" - "strconv" "time" authkratos "github.com/crypto-zero/go-kit/auth/kratos" khttp "github.com/go-kratos/kratos/v2/transport/http" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "github.com/crypto-zero/go-kit/sse" @@ -39,10 +36,9 @@ func HTTPFilter(filters ...khttp.FilterFunc) HTTPStreamOption { // route walking, operation selection and service middleware all come from // the supplied HTTP server. // -// For GET and HEAD requests, proto.Message inputs are decoded from the query -// string by JSON name (with snake_case fallback). Non-proto GET/HEAD inputs -// use Kratos' configured query decoder. Other methods use Kratos' configured -// body decoder. +// GET/HEAD requests decode through ctx.BindQuery; other methods use +// ctx.Bind. Proto messages get the full Kratos form codec handling +// (well-known types, repeated/map fields, nested paths) for free. func RegisterHTTPStream[Req any]( srv *khttp.Server, method string, @@ -50,6 +46,21 @@ func RegisterHTTPStream[Req any]( operation string, do func(ctx context.Context, req *Req, st *sse.Stream) error, opts ...HTTPStreamOption, +) { + RegisterHTTPStreamBound(srv, method, path, operation, bindHTTPStreamRequest[Req], do, opts...) +} + +// RegisterHTTPStreamBound mounts an SSE endpoint with caller-supplied request +// binding. Generated handlers use this to keep protobuf HTTP binding code in +// generated files while reusing the shared streaming lifecycle. +func RegisterHTTPStreamBound[Req any]( + srv *khttp.Server, + method string, + path string, + operation string, + bind func(khttp.Context, *Req) error, + do func(ctx context.Context, req *Req, st *sse.Stream) error, + opts ...HTTPStreamOption, ) { cfg := httpStreamConfig{} for _, opt := range opts { @@ -57,7 +68,7 @@ func RegisterHTTPStream[Req any]( } srv.Route("/").Handle(method, path, func(ctx khttp.Context) error { req := new(Req) - if err := bindHTTPStreamRequest(ctx, req); err != nil { + if err := bind(ctx, req); err != nil { return err } if operation != "" { @@ -94,103 +105,11 @@ func RegisterHTTPStreamMethod[Req any]( RegisterHTTPStream(srv, httpMethod, path, authkratos.OperationName(method), do, opts...) } -func bindHTTPStreamRequest(ctx khttp.Context, target any) error { +func bindHTTPStreamRequest[Req any](ctx khttp.Context, target *Req) error { switch ctx.Request().Method { case http.MethodGet, http.MethodHead: - if msg, ok := target.(proto.Message); ok { - return decodeProtoQuery(ctx.Request(), msg) - } return ctx.BindQuery(target) default: return ctx.Bind(target) } } - - -func decodeProtoQuery(r *http.Request, msg proto.Message) error { - q := r.URL.Query() - if len(q) == 0 { - return nil - } - refl := msg.ProtoReflect() - fields := refl.Descriptor().Fields() - for i := 0; i < fields.Len(); i++ { - fd := fields.Get(i) - raw := q.Get(fd.JSONName()) - if raw == "" { - raw = q.Get(string(fd.Name())) - } - if raw == "" { - continue - } - if err := setProtoFieldFromString(refl, fd, raw); err != nil { - return fmt.Errorf("query %s: %w", fd.JSONName(), err) - } - } - return nil -} - -func setProtoFieldFromString(msg protoreflect.Message, fd protoreflect.FieldDescriptor, raw string) error { - if fd.IsList() || fd.IsMap() { - return fmt.Errorf("repeated/map fields not supported in query strings") - } - switch fd.Kind() { - case protoreflect.DoubleKind: - v, err := strconv.ParseFloat(raw, 64) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfFloat64(v)) - case protoreflect.FloatKind: - v, err := strconv.ParseFloat(raw, 32) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfFloat32(float32(v))) - case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: - v, err := strconv.ParseInt(raw, 10, 32) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfInt32(int32(v))) - case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: - v, err := strconv.ParseUint(raw, 10, 32) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfUint32(uint32(v))) - case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: - v, err := strconv.ParseInt(raw, 10, 64) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfInt64(v)) - case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: - v, err := strconv.ParseUint(raw, 10, 64) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfUint64(v)) - case protoreflect.BoolKind: - v, err := strconv.ParseBool(raw) - if err != nil { - return err - } - msg.Set(fd, protoreflect.ValueOfBool(v)) - case protoreflect.StringKind: - msg.Set(fd, protoreflect.ValueOfString(raw)) - case protoreflect.EnumKind: - if i, err := strconv.ParseInt(raw, 10, 32); err == nil { - msg.Set(fd, protoreflect.ValueOfEnum(protoreflect.EnumNumber(i))) - return nil - } - if ev := fd.Enum().Values().ByName(protoreflect.Name(raw)); ev != nil { - msg.Set(fd, protoreflect.ValueOfEnum(ev.Number())) - return nil - } - return fmt.Errorf("unknown enum value %q for %s", raw, fd.Enum().FullName()) - default: - return fmt.Errorf("unsupported field kind %s for query decoding", fd.Kind()) - } - return nil -} diff --git a/sse/kratos/protoc-gen-go-sse/main.go b/sse/kratos/protoc-gen-go-sse/main.go new file mode 100644 index 0000000..c589efd --- /dev/null +++ b/sse/kratos/protoc-gen-go-sse/main.go @@ -0,0 +1,169 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + "strconv" + "strings" + + ssev1 "github.com/crypto-zero/go-kit/proto/kit/sse/v1" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/pluginpb" +) + +const version = "v0.1.0" + +var ( + contextPackage = protogen.GoImportPath("context") + khttpPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") + ssePackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/sse") + kratosPackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/sse/kratos") +) + +func main() { + var flags flag.FlagSet + protogen.Options{ParamFunc: flags.Set}.Run(func(plugin *protogen.Plugin) error { + plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) + for _, file := range plugin.Files { + if file.Generate { + generateFile(plugin, file) + } + } + return nil + }) +} + +type sseMethod struct { + service *protogen.Service + method *protogen.Method + verb string + path string +} + +func generateFile(plugin *protogen.Plugin, file *protogen.File) { + methods := collectSSEMethods(file) + if len(methods) == 0 { + return + } + + filename := file.GeneratedFilenamePrefix + "_sse.pb.go" + g := plugin.NewGeneratedFile(filename, file.GoImportPath) + g.P("// Code generated by protoc-gen-go-sse. DO NOT EDIT.") + g.P("// versions:") + g.P("// - protoc-gen-go-sse ", version) + g.P("// - protoc ", protocVersion(plugin)) + g.P("// source: ", file.Desc.Path()) + g.P() + g.P("package ", file.GoPackageName) + g.P() + + grouped := map[*protogen.Service][]sseMethod{} + for _, m := range methods { + grouped[m.service] = append(grouped[m.service], m) + } + for _, service := range file.Services { + items := grouped[service] + if len(items) == 0 { + continue + } + genService(g, service, items) + } +} + +func collectSSEMethods(file *protogen.File) []sseMethod { + var out []sseMethod + for _, service := range file.Services { + for _, method := range service.Methods { + rule, ok := proto.GetExtension(method.Desc.Options(), ssev1.E_ServerSentEvent).(*ssev1.StreamRule) + if !ok || rule == nil { + continue + } + verb, path := streamRuleHTTP(rule) + if verb == "" || path == "" { + continue + } + out = append(out, sseMethod{service: service, method: method, verb: verb, path: path}) + } + } + return out +} + +func streamRuleHTTP(rule *ssev1.StreamRule) (string, string) { + switch pattern := rule.GetPattern().(type) { + case *ssev1.StreamRule_Get: + return http.MethodGet, pattern.Get + case *ssev1.StreamRule_Post: + return http.MethodPost, pattern.Post + default: + return "", "" + } +} + +func genService(g *protogen.GeneratedFile, service *protogen.Service, methods []sseMethod) { + serviceName := service.GoName + contextIdent := g.QualifiedGoIdent(contextPackage.Ident("Context")) + streamIdent := g.QualifiedGoIdent(ssePackage.Ident("Stream")) + serverIdent := g.QualifiedGoIdent(khttpPackage.Ident("Server")) + optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) + + g.P("type ", serviceName, "SSEServer interface {") + for _, item := range methods { + g.P(item.method.GoName, "(", contextIdent, ", *", item.method.Input.GoIdent, ", *", streamIdent, ") error") + } + g.P("}") + g.P() + + g.P("func Register", serviceName, "SSEServer(s *", serverIdent, ", srv ", serviceName, "SSEServer, opts ...", optionIdent, ") {") + for _, item := range methods { + g.P("_", serviceName, "_", item.method.GoName, "_SSE_Register(s, srv, opts...)") + } + g.P("}") + g.P() + + for _, item := range methods { + genMethod(g, item) + } +} + +func genMethod(g *protogen.GeneratedFile, item sseMethod) { + serviceName := item.service.GoName + methodName := item.method.GoName + input := item.method.Input.GoIdent + operationConst := "Operation" + serviceName + methodName + "SSE" + + serverIdent := g.QualifiedGoIdent(khttpPackage.Ident("Server")) + contextIdent := g.QualifiedGoIdent(khttpPackage.Ident("Context")) + optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) + registerIdent := g.QualifiedGoIdent(kratosPackage.Ident("RegisterHTTPStreamBound")) + + g.P("const ", operationConst, ` = "/`, item.method.Desc.Parent().FullName(), `/`, item.method.Desc.Name(), `"`) + g.P() + g.P("func _", serviceName, "_", methodName, "_SSE_Register(s *", serverIdent, ", srv ", serviceName, "SSEServer, opts ...", optionIdent, ") {") + g.P(registerIdent, "(s, ", strconv.Quote(item.verb), ", ", strconv.Quote(item.path), ", ", operationConst, ",") + g.P("func(ctx ", contextIdent, ", in *", input, ") error {") + if item.verb == http.MethodGet { + g.P("return ctx.BindQuery(in)") + } else { + g.P("return ctx.Bind(in)") + } + g.P("},") + g.P("srv.", methodName, ",") + g.P("opts...,") + g.P(")") + g.P("}") + g.P() +} + +func protocVersion(plugin *protogen.Plugin) string { + v := plugin.Request.GetCompilerVersion() + if v == nil { + return "(unknown)" + } + var suffix string + if v.GetSuffix() != "" { + suffix = "-" + v.GetSuffix() + } + return strings.TrimPrefix(fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix), "v0.0.0") +} From 66d8d0bbd6470b65e787721a554df30c3088cd46 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 19:24:15 +0800 Subject: [PATCH 06/11] refactor: simplify sse kratos integration --- proto/buf.gen.yaml | 2 - proto/kit/sse/v1/sse.pb.go | 132 +------ proto/kit/sse/v1/sse.proto | 9 +- sse/kratos/codec.go | 82 ---- sse/kratos/doc.go | 2 + sse/kratos/go.mod | 1 + sse/kratos/go.sum | 2 + sse/kratos/handler.go | 138 ------- sse/kratos/handler_test.go | 435 ---------------------- sse/kratos/http_bound_test.go | 55 --- sse/kratos/http_cors.go | 26 -- sse/kratos/http_cors_test.go | 45 --- sse/kratos/http_handler.go | 24 +- sse/kratos/http_handler_test.go | 79 +--- sse/kratos/protoc-gen-go-sse/main.go | 76 ++-- sse/kratos/protoc-gen-go-sse/main_test.go | 90 +++++ sse/kratos/server.go | 277 -------------- sse/kratos/server_options.go | 131 ------- sse/kratos/server_test.go | 321 ---------------- sse/kratos/transport.go | 113 ------ sse/snapshot_live.go | 72 ---- sse/snapshot_live_test.go | 64 ---- sse/sse.go | 4 +- 23 files changed, 175 insertions(+), 2005 deletions(-) delete mode 100644 sse/kratos/codec.go create mode 100644 sse/kratos/doc.go delete mode 100644 sse/kratos/handler.go delete mode 100644 sse/kratos/handler_test.go delete mode 100644 sse/kratos/http_bound_test.go delete mode 100644 sse/kratos/http_cors.go delete mode 100644 sse/kratos/http_cors_test.go create mode 100644 sse/kratos/protoc-gen-go-sse/main_test.go delete mode 100644 sse/kratos/server.go delete mode 100644 sse/kratos/server_options.go delete mode 100644 sse/kratos/server_test.go delete mode 100644 sse/kratos/transport.go delete mode 100644 sse/snapshot_live.go delete mode 100644 sse/snapshot_live_test.go diff --git a/proto/buf.gen.yaml b/proto/buf.gen.yaml index 1ecdf6a..44e7c8b 100644 --- a/proto/buf.gen.yaml +++ b/proto/buf.gen.yaml @@ -16,5 +16,3 @@ plugins: - ../logging/protoc-gen-go-redact/main.go out: . opt: paths=source_relative - exclude_types: - - kit.sse.v1 diff --git a/proto/kit/sse/v1/sse.pb.go b/proto/kit/sse/v1/sse.pb.go index edb5d1b..e68c870 100644 --- a/proto/kit/sse/v1/sse.pb.go +++ b/proto/kit/sse/v1/sse.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc (unknown) +// protoc v7.34.1 // source: kit/sse/v1/sse.proto package ssev1 @@ -11,7 +11,6 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" descriptorpb "google.golang.org/protobuf/types/descriptorpb" reflect "reflect" - sync "sync" unsafe "unsafe" ) @@ -22,102 +21,20 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type StreamRule struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Types that are valid to be assigned to Pattern: - // - // *StreamRule_Get - // *StreamRule_Post - Pattern isStreamRule_Pattern `protobuf_oneof:"pattern"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *StreamRule) Reset() { - *x = StreamRule{} - mi := &file_kit_sse_v1_sse_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *StreamRule) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*StreamRule) ProtoMessage() {} - -func (x *StreamRule) ProtoReflect() protoreflect.Message { - mi := &file_kit_sse_v1_sse_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use StreamRule.ProtoReflect.Descriptor instead. -func (*StreamRule) Descriptor() ([]byte, []int) { - return file_kit_sse_v1_sse_proto_rawDescGZIP(), []int{0} -} - -func (x *StreamRule) GetPattern() isStreamRule_Pattern { - if x != nil { - return x.Pattern - } - return nil -} - -func (x *StreamRule) GetGet() string { - if x != nil { - if x, ok := x.Pattern.(*StreamRule_Get); ok { - return x.Get - } - } - return "" -} - -func (x *StreamRule) GetPost() string { - if x != nil { - if x, ok := x.Pattern.(*StreamRule_Post); ok { - return x.Post - } - } - return "" -} - -type isStreamRule_Pattern interface { - isStreamRule_Pattern() -} - -type StreamRule_Get struct { - Get string `protobuf:"bytes,1,opt,name=get,proto3,oneof"` -} - -type StreamRule_Post struct { - Post string `protobuf:"bytes,2,opt,name=post,proto3,oneof"` -} - -func (*StreamRule_Get) isStreamRule_Pattern() {} - -func (*StreamRule_Post) isStreamRule_Pattern() {} - var file_kit_sse_v1_sse_proto_extTypes = []protoimpl.ExtensionInfo{ { ExtendedType: (*descriptorpb.MethodOptions)(nil), - ExtensionType: (*StreamRule)(nil), + ExtensionType: (*bool)(nil), Field: 81001, Name: "kit.sse.v1.server_sent_event", - Tag: "bytes,81001,opt,name=server_sent_event", + Tag: "varint,81001,opt,name=server_sent_event", Filename: "kit/sse/v1/sse.proto", }, } // Extension fields to descriptorpb.MethodOptions. var ( - // optional kit.sse.v1.StreamRule server_sent_event = 81001; + // optional bool server_sent_event = 81001; E_ServerSentEvent = &file_kit_sse_v1_sse_proto_extTypes[0] ) @@ -126,37 +43,17 @@ var File_kit_sse_v1_sse_proto protoreflect.FileDescriptor const file_kit_sse_v1_sse_proto_rawDesc = "" + "\n" + "\x14kit/sse/v1/sse.proto\x12\n" + - "kit.sse.v1\x1a google/protobuf/descriptor.proto\"A\n" + - "\n" + - "StreamRule\x12\x12\n" + - "\x03get\x18\x01 \x01(\tH\x00R\x03get\x12\x14\n" + - "\x04post\x18\x02 \x01(\tH\x00R\x04postB\t\n" + - "\apattern:d\n" + - "\x11server_sent_event\x12\x1e.google.protobuf.MethodOptions\x18\xe9\xf8\x04 \x01(\v2\x16.kit.sse.v1.StreamRuleR\x0fserverSentEventB6Z4github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1b\x06proto3" + "kit.sse.v1\x1a google/protobuf/descriptor.proto:L\n" + + "\x11server_sent_event\x12\x1e.google.protobuf.MethodOptions\x18\xe9\xf8\x04 \x01(\bR\x0fserverSentEventB6Z4github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1b\x06proto3" -var ( - file_kit_sse_v1_sse_proto_rawDescOnce sync.Once - file_kit_sse_v1_sse_proto_rawDescData []byte -) - -func file_kit_sse_v1_sse_proto_rawDescGZIP() []byte { - file_kit_sse_v1_sse_proto_rawDescOnce.Do(func() { - file_kit_sse_v1_sse_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_kit_sse_v1_sse_proto_rawDesc), len(file_kit_sse_v1_sse_proto_rawDesc))) - }) - return file_kit_sse_v1_sse_proto_rawDescData -} - -var file_kit_sse_v1_sse_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_kit_sse_v1_sse_proto_goTypes = []any{ - (*StreamRule)(nil), // 0: kit.sse.v1.StreamRule - (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions + (*descriptorpb.MethodOptions)(nil), // 0: google.protobuf.MethodOptions } var file_kit_sse_v1_sse_proto_depIdxs = []int32{ - 1, // 0: kit.sse.v1.server_sent_event:extendee -> google.protobuf.MethodOptions - 0, // 1: kit.sse.v1.server_sent_event:type_name -> kit.sse.v1.StreamRule - 2, // [2:2] is the sub-list for method output_type - 2, // [2:2] is the sub-list for method input_type - 1, // [1:2] is the sub-list for extension type_name + 0, // 0: kit.sse.v1.server_sent_event:extendee -> google.protobuf.MethodOptions + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name 0, // [0:1] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } @@ -166,23 +63,18 @@ func file_kit_sse_v1_sse_proto_init() { if File_kit_sse_v1_sse_proto != nil { return } - file_kit_sse_v1_sse_proto_msgTypes[0].OneofWrappers = []any{ - (*StreamRule_Get)(nil), - (*StreamRule_Post)(nil), - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_kit_sse_v1_sse_proto_rawDesc), len(file_kit_sse_v1_sse_proto_rawDesc)), NumEnums: 0, - NumMessages: 1, + NumMessages: 0, NumExtensions: 1, NumServices: 0, }, GoTypes: file_kit_sse_v1_sse_proto_goTypes, DependencyIndexes: file_kit_sse_v1_sse_proto_depIdxs, - MessageInfos: file_kit_sse_v1_sse_proto_msgTypes, ExtensionInfos: file_kit_sse_v1_sse_proto_extTypes, }.Build() File_kit_sse_v1_sse_proto = out.File diff --git a/proto/kit/sse/v1/sse.proto b/proto/kit/sse/v1/sse.proto index f881e2c..17a5ab6 100644 --- a/proto/kit/sse/v1/sse.proto +++ b/proto/kit/sse/v1/sse.proto @@ -7,12 +7,5 @@ import "google/protobuf/descriptor.proto"; option go_package = "github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1"; extend google.protobuf.MethodOptions { - StreamRule server_sent_event = 81001; -} - -message StreamRule { - oneof pattern { - string get = 1; - string post = 2; - } + bool server_sent_event = 81001; } diff --git a/sse/kratos/codec.go b/sse/kratos/codec.go deleted file mode 100644 index 9f03138..0000000 --- a/sse/kratos/codec.go +++ /dev/null @@ -1,82 +0,0 @@ -package kratos - -import ( - "errors" - "fmt" - "io" - "net/http" - "strings" - - "github.com/go-kratos/kratos/v2/encoding" - kerrors "github.com/go-kratos/kratos/v2/errors" -) - -// DefaultRequestDecoder is the default body decoder: it picks a codec by -// the request's Content-Type and unmarshals the body into v. An empty -// body is a success. Unknown Content-Types fall back to JSON. -func DefaultRequestDecoder(r *http.Request, v any) error { - if r.Body == nil || r.Body == http.NoBody { - return nil - } - body, err := io.ReadAll(r.Body) - if err != nil { - return fmt.Errorf("sse: read body: %w", err) - } - if len(body) == 0 { - return nil - } - c := codecForContentType(r.Header.Get("Content-Type")) - if c == nil { - c = encoding.GetCodec("json") - } - if c == nil { - return errors.New("sse: no codec available") - } - if err := c.Unmarshal(body, v); err != nil { - return fmt.Errorf("sse: decode %s: %w", c.Name(), err) - } - return nil -} - -// DefaultErrorEncoder writes err as an HTTP error response, honoring -// the status code embedded in a Kratos *errors.Error if present. It is -// meant for the pre-stream phase — once SSE bytes have been flushed, -// errors should be reported via an SSE "error" event instead. -// -// The response body is the error's Message (or err.Error() when the -// underlying error is not a Kratos error). Content-Type is plain text; -// callers wanting JSON should install a custom EncodeErrorFunc via the -// ErrorEncoder option. -func DefaultErrorEncoder(w http.ResponseWriter, _ *http.Request, err error) { - se := kerrors.FromError(err) - code := int(se.Code) - if code <= 0 { - code = http.StatusInternalServerError - } - msg := se.Message - if msg == "" { - msg = err.Error() - } - http.Error(w, msg, code) -} - -// codecForContentType picks a Kratos codec by Content-Type, returning -// nil when the type is unrecognized. Parameters (";charset=utf-8") are -// stripped and a vendor prefix on the subtype is removed. -func codecForContentType(ct string) encoding.Codec { - if ct == "" { - return nil - } - if i := strings.IndexByte(ct, ';'); i >= 0 { - ct = ct[:i] - } - ct = strings.TrimSpace(ct) - subtype := ct - if _, after, ok := strings.Cut(ct, "/"); ok { - subtype = after - } - if i := strings.LastIndexByte(subtype, '.'); i >= 0 { - subtype = subtype[i+1:] - } - return encoding.GetCodec(subtype) -} diff --git a/sse/kratos/doc.go b/sse/kratos/doc.go new file mode 100644 index 0000000..9b38ee9 --- /dev/null +++ b/sse/kratos/doc.go @@ -0,0 +1,2 @@ +// Package kratos mounts Server-Sent Events endpoints on a Kratos HTTP server. +package kratos diff --git a/sse/kratos/go.mod b/sse/kratos/go.mod index e446d28..6e48668 100644 --- a/sse/kratos/go.mod +++ b/sse/kratos/go.mod @@ -5,6 +5,7 @@ go 1.26.3 require ( github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac github.com/go-kratos/kratos/v2 v2.9.2 + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 google.golang.org/protobuf v1.36.11 ) diff --git a/sse/kratos/go.sum b/sse/kratos/go.sum index bd79f80..fd55f3e 100644 --- a/sse/kratos/go.sum +++ b/sse/kratos/go.sum @@ -35,6 +35,8 @@ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 h1:2I6GHUeJ/4shcDpoUlLs/2WPnhg7yJwvXtqcMJt9liA= google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= diff --git a/sse/kratos/handler.go b/sse/kratos/handler.go deleted file mode 100644 index 662e359..0000000 --- a/sse/kratos/handler.go +++ /dev/null @@ -1,138 +0,0 @@ -package kratos - -import ( - "context" - "net/http" - - "github.com/go-kratos/kratos/v2/middleware" - - "github.com/crypto-zero/go-kit/sse" -) - -// StreamHandler builds an http.HandlerFunc that decodes a typed request, -// runs the Kratos middleware chain over (ctx, *Req), then invokes do -// with a live *sse.Stream. -// -// Lifecycle (in order): -// -// 1. Decode the request body into *Req via srv.Decode. Decode errors -// are reported via srv.EncodeError — a standard HTTP 4xx/5xx with -// no SSE bytes written. -// 2. Run the server-wide middleware chain (Middleware option) followed -// by the per-handler extras, with req exposed to middleware as the -// `req` argument. Errors here are likewise reported via -// srv.EncodeError before any streaming starts. -// 3. Create an *sse.Stream and call do. Errors returned by do are -// emitted as an SSE "error" event; do is responsible for any final -// Done frame on success. -// -// This is the right helper for auth/JWT verification, schema validation -// (protovalidate), per-request quota checks and similar pre-handler -// concerns. Tracing, recovery and metrics that must observe the full -// stream lifetime should be installed as Filters instead. -func StreamHandler[Req any]( - srv *Server, - do func(ctx context.Context, req *Req, s *sse.Stream) error, - mws ...middleware.Middleware, -) http.HandlerFunc { - chain := srv.chainFor(mws) - return func(w http.ResponseWriter, r *http.Request) { - req, ok := preStream[Req](srv, chain, w, r) - if !ok { - return - } - s, end := srv.beginStream(r.Context(), w) - defer end() - if err := do(r.Context(), req, s); err != nil { - _ = s.Error(err.Error()) - } - } -} - -// JSONHandler is the unary sibling of StreamHandler: do produces a -// single response value that is marshaled (via srv.Codec) and emitted -// as one SSE data frame followed by a Done terminator. -// -// Errors returned by do are written as an SSE "error" event — matching -// the convention that clients of an SSE endpoint always parse SSE, -// never raw HTTP errors. Decode and middleware errors still go through -// srv.EncodeError (no SSE bytes written yet). -func JSONHandler[Req any, Resp any]( - srv *Server, - do func(ctx context.Context, req *Req) (*Resp, error), - mws ...middleware.Middleware, -) http.HandlerFunc { - chain := srv.chainFor(mws) - return func(w http.ResponseWriter, r *http.Request) { - req, ok := preStream[Req](srv, chain, w, r) - if !ok { - return - } - result, err := do(r.Context(), req) - s, end := srv.beginStream(r.Context(), w) - defer end() - if err != nil { - _ = s.Error(err.Error()) - return - } - data, mErr := srv.Codec().Marshal(result) - if mErr != nil { - _ = s.Error(mErr.Error()) - return - } - _ = s.Write(string(data)) - _ = s.Done() - } -} - -// preStream runs the request through Decode and the middleware chain. -// On success it returns (req, true). On failure it has already written -// an HTTP error response and returns (_, false). -// -// Middleware sees the decoded req via the `req` argument of -// middleware.Handler. The inner handler is intentionally a no-op: -// streaming runs outside the chain so middleware errors translate to -// real HTTP statuses while no SSE bytes have yet hit the wire. -func preStream[Req any]( - srv *Server, chain middleware.Middleware, - w http.ResponseWriter, r *http.Request, -) (*Req, bool) { - req := new(Req) - if err := srv.Decode(r, req); err != nil { - srv.EncodeError(w, r, err) - return nil, false - } - h := chain(func(context.Context, any) (any, error) { return nil, nil }) - if _, err := h(r.Context(), req); err != nil { - srv.EncodeError(w, r, err) - return nil, false - } - return req, true -} - -// beginStream constructs an *sse.Stream, starts the configured heartbeat -// (if any), and bumps the active-stream counter. The returned end -// function tears these down in the inverse order — heartbeat first -// (must stop before the response writer is recycled), then the counter. -// Callers should defer end immediately after this call. -func (s *Server) beginStream(ctx context.Context, w http.ResponseWriter) (*sse.Stream, func()) { - st := sse.NewStream(w) - stopBeat := st.Heartbeat(ctx, s.heartbeat) - s.active.Add(1) - return st, func() { - stopBeat() - s.active.Add(-1) - } -} - -// chainFor composes the middleware chain for one handler: server-wide -// middlewares (outermost) followed by per-handler extras. The returned -// chain does not share backing storage with srv.middlewares, so later -// additions to the server's list cannot retroactively affect handlers -// that have already been built. -func (s *Server) chainFor(extras []middleware.Middleware) middleware.Middleware { - all := make([]middleware.Middleware, 0, len(s.middlewares)+len(extras)) - all = append(all, s.middlewares...) - all = append(all, extras...) - return middleware.Chain(all...) -} diff --git a/sse/kratos/handler_test.go b/sse/kratos/handler_test.go deleted file mode 100644 index bfc4ffc..0000000 --- a/sse/kratos/handler_test.go +++ /dev/null @@ -1,435 +0,0 @@ -package kratos_test - -import ( - "context" - "errors" - "io" - "net/http" - "strings" - "testing" - "time" - - kerrors "github.com/go-kratos/kratos/v2/errors" - "github.com/go-kratos/kratos/v2/middleware" - - "github.com/crypto-zero/go-kit/sse" - ksse "github.com/crypto-zero/go-kit/sse/kratos" -) - -type chatRequest struct { - Prompt string `json:"prompt"` -} - -type profileResponse struct { - Name string `json:"name"` -} - -// authMW is a tiny inline auth middleware that rejects requests missing -// the "X-Token" header by returning a Kratos Unauthorized error. -func authMW(token string) middleware.Middleware { - return func(next middleware.Handler) middleware.Handler { - return func(ctx context.Context, req any) (any, error) { - if got := tokenFromCtx(ctx); got != token { - return nil, kerrors.Unauthorized("AUTH", "bad token") - } - return next(ctx, req) - } - } -} - -type tokenKey struct{} - -func withToken(ctx context.Context, tok string) context.Context { - return context.WithValue(ctx, tokenKey{}, tok) -} - -func tokenFromCtx(ctx context.Context) string { - v, _ := ctx.Value(tokenKey{}).(string) - return v -} - -// tokenFilter copies the X-Token header into ctx so authMW can read it. -// Demonstrates the Filter + Middleware split: Filter touches HTTP-level -// concerns (headers), middleware sees the typed req. -func tokenFilter(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(withToken(r.Context(), r.Header.Get("X-Token"))) - next.ServeHTTP(w, r) - }) -} - -func TestStreamHandler_AuthSucceeds(t *testing.T) { - srv, addr := newServerOnLoopback(t, - ksse.Filter(tokenFilter), - ksse.Middleware(authMW("good")), - ) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(_ context.Context, req *chatRequest, s *sse.Stream) error { - _ = s.Write("got:" + req.Prompt) - return s.Done() - }, - )) - - stop := startServer(t, srv) - defer stop() - - req, _ := http.NewRequest("POST", "http://"+addr+"/v1/chat", - strings.NewReader(`{"prompt":"hi"}`)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Token", "good") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - t.Fatalf("status = %d", resp.StatusCode) - } - body := readAll(t, resp.Body) - if !strings.Contains(body, "data: got:hi\n\n") { - t.Errorf("body missing chunk: %q", body) - } - if !strings.Contains(body, "data: [DONE]\n\n") { - t.Errorf("body missing done: %q", body) - } -} - -func TestStreamHandler_AuthRejects(t *testing.T) { - srv, addr := newServerOnLoopback(t, - ksse.Filter(tokenFilter), - ksse.Middleware(authMW("good")), - ) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(_ context.Context, _ *chatRequest, s *sse.Stream) error { - t.Error("handler should not run when auth fails") - return s.Done() - }, - )) - - stop := startServer(t, srv) - defer stop() - - req, _ := http.NewRequest("POST", "http://"+addr+"/v1/chat", - strings.NewReader(`{"prompt":"hi"}`)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-Token", "bad") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusUnauthorized { - t.Errorf("status = %d, want 401", resp.StatusCode) - } - if got := resp.Header.Get("Content-Type"); strings.HasPrefix(got, "text/event-stream") { - t.Errorf("unexpected SSE response on auth failure: %q", got) - } -} - -func TestStreamHandler_DoErrorBecomesSSEEvent(t *testing.T) { - srv, addr := newServerOnLoopback(t) - sentinel := errors.New("biz blew up") - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(_ context.Context, _ *chatRequest, _ *sse.Stream) error { - return sentinel - }, - )) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - t.Errorf("status = %d, want 200 (SSE error rides on 200)", resp.StatusCode) - } - body := readAll(t, resp.Body) - if !strings.Contains(body, "event: error") { - t.Errorf("body missing error event: %q", body) - } - if !strings.Contains(body, "biz blew up") { - t.Errorf("body missing error message: %q", body) - } -} - -func TestStreamHandler_DecodeError(t *testing.T) { - srv, addr := newServerOnLoopback(t) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(context.Context, *chatRequest, *sse.Stream) error { - t.Error("handler should not run on decode error") - return nil - }, - )) - - stop := startServer(t, srv) - defer stop() - - // Malformed JSON. - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{not json`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode == 200 { - t.Errorf("status = 200, want 4xx/5xx for decode failure") - } -} - -func TestJSONHandler_Roundtrip(t *testing.T) { - srv, addr := newServerOnLoopback(t) - srv.HandleFunc("POST /v1/profile", ksse.JSONHandler(srv, - func(_ context.Context, _ *chatRequest) (*profileResponse, error) { - return &profileResponse{Name: "karma"}, nil - }, - )) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/v1/profile", "application/json", - strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - body := readAll(t, resp.Body) - if !strings.Contains(body, `data: {"name":"karma"}`) { - t.Errorf("body missing payload: %q", body) - } - if !strings.Contains(body, "data: [DONE]\n\n") { - t.Errorf("body missing done: %q", body) - } -} - -func TestJSONHandler_DoErrorBecomesSSEEvent(t *testing.T) { - srv, addr := newServerOnLoopback(t) - srv.HandleFunc("POST /v1/profile", ksse.JSONHandler(srv, - func(_ context.Context, _ *chatRequest) (*profileResponse, error) { - return nil, kerrors.NotFound("PROFILE", "no profile") - }, - )) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/v1/profile", "application/json", - strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - t.Errorf("status = %d, want 200", resp.StatusCode) - } - body := readAll(t, resp.Body) - if !strings.Contains(body, "event: error") { - t.Errorf("body missing error event: %q", body) - } -} - -func TestStreamHandler_PerHandlerMiddlewareAppends(t *testing.T) { - calls := make(chan string, 4) - mark := func(name string) middleware.Middleware { - return func(next middleware.Handler) middleware.Handler { - return func(ctx context.Context, req any) (any, error) { - calls <- name - return next(ctx, req) - } - } - } - srv, addr := newServerOnLoopback(t, ksse.Middleware(mark("server"))) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(_ context.Context, _ *chatRequest, s *sse.Stream) error { - return s.Done() - }, - mark("handler"), - )) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - _ = resp.Body.Close() - close(calls) - var order []string - for s := range calls { - order = append(order, s) - } - want := []string{"server", "handler"} - if strings.Join(order, ",") != strings.Join(want, ",") { - t.Errorf("middleware order = %v, want %v", order, want) - } -} - -func TestStreamHandler_Heartbeat(t *testing.T) { - srv, addr := newServerOnLoopback(t, ksse.Heartbeat(10*time.Millisecond)) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(ctx context.Context, _ *chatRequest, s *sse.Stream) error { - _ = s.Write("first") - // Hold the stream open long enough for several heartbeats. - select { - case <-ctx.Done(): - case <-time.After(80 * time.Millisecond): - } - return s.Done() - }, - )) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - body := readAll(t, resp.Body) - // Expect at least one comment frame between "first" and "[DONE]". - if !strings.Contains(body, ":\n\n") { - t.Errorf("body missing heartbeat comment frames: %q", body) - } -} - -func TestServer_ActiveStreams(t *testing.T) { - srv, addr := newServerOnLoopback(t) - release := make(chan struct{}) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(_ context.Context, _ *chatRequest, s *sse.Stream) error { - _ = s.Write("hold") - <-release - return s.Done() - }, - )) - - stop := startServer(t, srv) - - if got := srv.ActiveStreams(); got != 0 { - t.Errorf("ActiveStreams before request = %d, want 0", got) - } - - done := make(chan struct{}) - go func() { - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{}`)) - if err == nil { - // Drain to avoid the response goroutine wedging on close. - _, _ = http.NoBody.Read(make([]byte, 1)) - _ = resp.Body.Close() - } - close(done) - }() - - // Spin until the handler is in flight. - deadline := time.Now().Add(2 * time.Second) - for srv.ActiveStreams() == 0 && time.Now().Before(deadline) { - time.Sleep(2 * time.Millisecond) - } - if got := srv.ActiveStreams(); got != 1 { - t.Errorf("ActiveStreams during request = %d, want 1", got) - } - close(release) - <-done - stop() - - if got := srv.ActiveStreams(); got != 0 { - t.Errorf("ActiveStreams after request = %d, want 0", got) - } -} - -func TestServer_GracefulShutdownUnblocksHandler(t *testing.T) { - srv, addr := newServerOnLoopback(t) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(ctx context.Context, _ *chatRequest, s *sse.Stream) error { - _ = s.Write("hi") - // Pump-style: respect ctx so shutdown can drain promptly. - <-ctx.Done() - return ctx.Err() - }, - )) - - ctx, cancel := context.WithCancel(context.Background()) - startErr := make(chan error, 1) - go func() { startErr <- srv.Start(ctx) }() - - // Client holds the connection open: it drains the body to keep the - // server-side r.Context() alive until shutdown explicitly cancels it. - // (If we closed resp.Body eagerly the server would see a client - // disconnect instead and we wouldn't exercise shutdown propagation.) - clientDone := make(chan struct{}) - go func() { - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{}`)) - if err != nil { - close(clientDone) - return - } - _, _ = io.Copy(io.Discard, resp.Body) // returns when server closes - _ = resp.Body.Close() - close(clientDone) - }() - - // Wait for the stream to register as active. - deadline := time.Now().Add(2 * time.Second) - for srv.ActiveStreams() == 0 && time.Now().Before(deadline) { - time.Sleep(2 * time.Millisecond) - } - if got := srv.ActiveStreams(); got != 1 { - t.Fatalf("ActiveStreams before shutdown = %d, want 1", got) - } - - shutdownStart := time.Now() - shutdownCtx, c := context.WithTimeout(context.Background(), 2*time.Second) - defer c() - if err := srv.Stop(shutdownCtx); err != nil { - t.Errorf("Stop: %v", err) - } - elapsed := time.Since(shutdownStart) - if elapsed > 500*time.Millisecond { - t.Errorf("Stop took %v, handler did not unblock on shutdown ctx", elapsed) - } - <-clientDone - if err := <-startErr; err != nil { - t.Errorf("Start: %v", err) - } - cancel() -} - -func TestErrorEncoder_HonorsKratosCode(t *testing.T) { - srv, addr := newServerOnLoopback(t, - ksse.Middleware(func(middleware.Handler) middleware.Handler { - return func(context.Context, any) (any, error) { - return nil, kerrors.BadRequest("VALIDATION", "field x missing") - } - }), - ) - srv.HandleFunc("POST /v1/chat", ksse.StreamHandler(srv, - func(context.Context, *chatRequest, *sse.Stream) error { return nil }, - )) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/v1/chat", "application/json", - strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } - body := readAll(t, resp.Body) - if !strings.Contains(body, "field x missing") { - t.Errorf("body missing kratos message: %q", body) - } -} diff --git a/sse/kratos/http_bound_test.go b/sse/kratos/http_bound_test.go deleted file mode 100644 index fea6e54..0000000 --- a/sse/kratos/http_bound_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package kratos_test - -import ( - "context" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - ksse "github.com/crypto-zero/go-kit/sse" - ssekratos "github.com/crypto-zero/go-kit/sse/kratos" - khttp "github.com/go-kratos/kratos/v2/transport/http" -) - -type boundRequest struct { - Name string -} - -func TestRegisterHTTPStreamBoundUsesSuppliedBinder(t *testing.T) { - srv := khttp.NewServer(khttp.Timeout(0)) - ssekratos.RegisterHTTPStreamBound( - srv, - http.MethodGet, - "/v1/bound", - "/test.Bound/Watch", - func(ctx khttp.Context, req *boundRequest) error { - req.Name = ctx.Query().Get("name") - return nil - }, - func(_ context.Context, req *boundRequest, st *ksse.Stream) error { - if err := st.WriteJSON(map[string]string{"name": req.Name}); err != nil { - return err - } - return st.Done() - }, - ) - - ts := httptest.NewServer(srv) - defer ts.Close() - - resp, err := http.Get(ts.URL + "/v1/bound?name=generated") - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body: %v", err) - } - if !strings.Contains(string(body), `data: {"name":"generated"}`) { - t.Fatalf("bound response missing generated name:\n%s", string(body)) - } -} diff --git a/sse/kratos/http_cors.go b/sse/kratos/http_cors.go deleted file mode 100644 index 07db144..0000000 --- a/sse/kratos/http_cors.go +++ /dev/null @@ -1,26 +0,0 @@ -package kratos - -import ( - "net/http" - - khttp "github.com/go-kratos/kratos/v2/transport/http" -) - -// HTTPPermissiveCORS returns a route filter suitable for browser EventSource -// endpoints that do not use credentials. -func HTTPPermissiveCORS() khttp.FilterFunc { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() - h.Set("Access-Control-Allow-Origin", "*") - h.Set("Access-Control-Allow-Methods", "GET, OPTIONS") - h.Set("Access-Control-Allow-Headers", "Cache-Control, Last-Event-ID") - h.Set("Access-Control-Expose-Headers", "Content-Type") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - next.ServeHTTP(w, r) - }) - } -} diff --git a/sse/kratos/http_cors_test.go b/sse/kratos/http_cors_test.go deleted file mode 100644 index 970bcc0..0000000 --- a/sse/kratos/http_cors_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package kratos - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -func TestHTTPPermissiveCORSHandlesPreflight(t *testing.T) { - nextCalled := false - handler := HTTPPermissiveCORS()(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { - nextCalled = true - })) - - req := httptest.NewRequest(http.MethodOptions, "/events", nil) - rec := httptest.NewRecorder() - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusNoContent { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) - } - if nextCalled { - t.Fatal("next handler should not be called for preflight") - } - if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "*" { - t.Fatalf("allow origin = %q, want *", got) - } -} - -func TestHTTPPermissiveCORSPassesThroughGET(t *testing.T) { - handler := HTTPPermissiveCORS()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusAccepted) - })) - - req := httptest.NewRequest(http.MethodGet, "/events", nil) - rec := httptest.NewRecorder() - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusAccepted { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusAccepted) - } - if got := rec.Header().Get("Access-Control-Allow-Headers"); got != "Cache-Control, Last-Event-ID" { - t.Fatalf("allow headers = %q", got) - } -} diff --git a/sse/kratos/http_handler.go b/sse/kratos/http_handler.go index 4543136..0760e13 100644 --- a/sse/kratos/http_handler.go +++ b/sse/kratos/http_handler.go @@ -5,9 +5,7 @@ import ( "net/http" "time" - authkratos "github.com/crypto-zero/go-kit/auth/kratos" khttp "github.com/go-kratos/kratos/v2/transport/http" - "google.golang.org/protobuf/reflect/protoreflect" "github.com/crypto-zero/go-kit/sse" ) @@ -47,13 +45,10 @@ func RegisterHTTPStream[Req any]( do func(ctx context.Context, req *Req, st *sse.Stream) error, opts ...HTTPStreamOption, ) { - RegisterHTTPStreamBound(srv, method, path, operation, bindHTTPStreamRequest[Req], do, opts...) + registerHTTPStream(srv, method, path, operation, bindHTTPStreamRequest[Req], do, opts...) } -// RegisterHTTPStreamBound mounts an SSE endpoint with caller-supplied request -// binding. Generated handlers use this to keep protobuf HTTP binding code in -// generated files while reusing the shared streaming lifecycle. -func RegisterHTTPStreamBound[Req any]( +func registerHTTPStream[Req any]( srv *khttp.Server, method string, path string, @@ -90,21 +85,6 @@ func RegisterHTTPStreamBound[Req any]( }, cfg.filters...) } -// RegisterHTTPStreamMethod mounts an SSE endpoint for a proto method -// descriptor. The Kratos operation is derived from the method name -// (`/package.Service/Method`) so auth selectors, logging and tracing use the -// same operation identity as generated Kratos HTTP handlers. -func RegisterHTTPStreamMethod[Req any]( - srv *khttp.Server, - method protoreflect.MethodDescriptor, - httpMethod string, - path string, - do func(ctx context.Context, req *Req, st *sse.Stream) error, - opts ...HTTPStreamOption, -) { - RegisterHTTPStream(srv, httpMethod, path, authkratos.OperationName(method), do, opts...) -} - func bindHTTPStreamRequest[Req any](ctx khttp.Context, target *Req) error { switch ctx.Request().Method { case http.MethodGet, http.MethodHead: diff --git a/sse/kratos/http_handler_test.go b/sse/kratos/http_handler_test.go index 8bea158..bddeaaa 100644 --- a/sse/kratos/http_handler_test.go +++ b/sse/kratos/http_handler_test.go @@ -1,6 +1,7 @@ package kratos_test import ( + "bufio" "context" "net/http" "net/http/httptest" @@ -11,10 +12,6 @@ import ( "github.com/go-kratos/kratos/v2/middleware" ktransport "github.com/go-kratos/kratos/v2/transport" khttp "github.com/go-kratos/kratos/v2/transport/http" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/known/durationpb" "github.com/crypto-zero/go-kit/sse" @@ -122,69 +119,17 @@ func TestHTTPStreamHandler_DetachesKratosHTTPTimeout(t *testing.T) { } } -func TestHTTPStreamHandler_MethodDescriptorSetsOperation(t *testing.T) { - const operation = "/test.live.v1.LiveService/Watch" - seen := make(chan string, 1) - srv := khttp.NewServer( - khttp.Timeout(0), - khttp.Middleware(func(next middleware.Handler) middleware.Handler { - return func(ctx context.Context, req any) (any, error) { - tr, ok := ktransport.FromServerContext(ctx) - if !ok { - t.Fatal("missing transport") - } - seen <- tr.Operation() - return next(ctx, req) - } - }), - ) - method := testMethodDescriptor(t) - ksse.RegisterHTTPStreamMethod(srv, method, http.MethodGet, "/v1/method", - func(_ context.Context, _ *durationpb.Duration, st *sse.Stream) error { - return st.Done() - }, - ) - - ts := httptest.NewServer(srv) - defer ts.Close() - - resp, err := http.Get(ts.URL + "/v1/method") - if err != nil { - t.Fatalf("GET: %v", err) - } - _ = resp.Body.Close() - - select { - case got := <-seen: - if got != operation { - t.Errorf("operation = %q, want %q", got, operation) - } - case <-time.After(time.Second): - t.Fatal("middleware did not run") - } -} - -func testMethodDescriptor(t *testing.T) protoreflect.MethodDescriptor { +func readAll(t *testing.T, r interface { + Read([]byte) (int, error) +}) string { t.Helper() - fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ - Syntax: proto.String("proto3"), - Name: proto.String("test/live/v1/live.proto"), - Package: proto.String("test.live.v1"), - Service: []*descriptorpb.ServiceDescriptorProto{{ - Name: proto.String("LiveService"), - Method: []*descriptorpb.MethodDescriptorProto{{ - Name: proto.String("Watch"), - InputType: proto.String(".test.live.v1.WatchRequest"), - OutputType: proto.String(".test.live.v1.WatchResponse"), - }}, - }}, - MessageType: []*descriptorpb.DescriptorProto{ - {Name: proto.String("WatchRequest")}, - {Name: proto.String("WatchResponse")}, - }, - }, nil) - if err != nil { - t.Fatalf("NewFile: %v", err) + var sb strings.Builder + br := bufio.NewReader(r) + for { + line, err := br.ReadString('\n') + sb.WriteString(line) + if err != nil { + return sb.String() + } } - return fd.Services().ByName("LiveService").Methods().ByName("Watch") } diff --git a/sse/kratos/protoc-gen-go-sse/main.go b/sse/kratos/protoc-gen-go-sse/main.go index c589efd..d5bce67 100644 --- a/sse/kratos/protoc-gen-go-sse/main.go +++ b/sse/kratos/protoc-gen-go-sse/main.go @@ -3,11 +3,11 @@ package main import ( "flag" "fmt" - "net/http" "strconv" "strings" ssev1 "github.com/crypto-zero/go-kit/proto/kit/sse/v1" + "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/pluginpb" @@ -38,8 +38,12 @@ func main() { type sseMethod struct { service *protogen.Service method *protogen.Method - verb string - path string + routes []sseRoute +} + +type sseRoute struct { + verb string + path string } func generateFile(plugin *protogen.Plugin, file *protogen.File) { @@ -76,26 +80,58 @@ func collectSSEMethods(file *protogen.File) []sseMethod { var out []sseMethod for _, service := range file.Services { for _, method := range service.Methods { - rule, ok := proto.GetExtension(method.Desc.Options(), ssev1.E_ServerSentEvent).(*ssev1.StreamRule) + opts := method.Desc.Options() + if !proto.HasExtension(opts, ssev1.E_ServerSentEvent) { + continue + } + enabled, ok := proto.GetExtension(opts, ssev1.E_ServerSentEvent).(bool) + if !ok || !enabled { + continue + } + if !proto.HasExtension(opts, annotations.E_Http) { + continue + } + rule, ok := proto.GetExtension(opts, annotations.E_Http).(*annotations.HttpRule) if !ok || rule == nil { continue } - verb, path := streamRuleHTTP(rule) - if verb == "" || path == "" { + var routes []sseRoute + for _, binding := range httpRuleBindings(rule) { + verb, path := httpRuleHTTP(binding) + if verb == "" || path == "" { + continue + } + routes = append(routes, sseRoute{verb: verb, path: path}) + } + if len(routes) == 0 { continue } - out = append(out, sseMethod{service: service, method: method, verb: verb, path: path}) + out = append(out, sseMethod{service: service, method: method, routes: routes}) } } return out } -func streamRuleHTTP(rule *ssev1.StreamRule) (string, string) { +func httpRuleBindings(rule *annotations.HttpRule) []*annotations.HttpRule { + out := []*annotations.HttpRule{rule} + out = append(out, rule.GetAdditionalBindings()...) + return out +} + +func httpRuleHTTP(rule *annotations.HttpRule) (string, string) { switch pattern := rule.GetPattern().(type) { - case *ssev1.StreamRule_Get: - return http.MethodGet, pattern.Get - case *ssev1.StreamRule_Post: - return http.MethodPost, pattern.Post + case *annotations.HttpRule_Get: + return "GET", pattern.Get + case *annotations.HttpRule_Put: + return "PUT", pattern.Put + case *annotations.HttpRule_Post: + return "POST", pattern.Post + case *annotations.HttpRule_Delete: + return "DELETE", pattern.Delete + case *annotations.HttpRule_Patch: + return "PATCH", pattern.Patch + case *annotations.HttpRule_Custom: + return pattern.Custom.Kind, pattern.Custom.Path default: return "", "" } @@ -130,28 +166,18 @@ func genService(g *protogen.GeneratedFile, service *protogen.Service, methods [] func genMethod(g *protogen.GeneratedFile, item sseMethod) { serviceName := item.service.GoName methodName := item.method.GoName - input := item.method.Input.GoIdent operationConst := "Operation" + serviceName + methodName + "SSE" serverIdent := g.QualifiedGoIdent(khttpPackage.Ident("Server")) - contextIdent := g.QualifiedGoIdent(khttpPackage.Ident("Context")) optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) - registerIdent := g.QualifiedGoIdent(kratosPackage.Ident("RegisterHTTPStreamBound")) + registerIdent := g.QualifiedGoIdent(kratosPackage.Ident("RegisterHTTPStream")) g.P("const ", operationConst, ` = "/`, item.method.Desc.Parent().FullName(), `/`, item.method.Desc.Name(), `"`) g.P() g.P("func _", serviceName, "_", methodName, "_SSE_Register(s *", serverIdent, ", srv ", serviceName, "SSEServer, opts ...", optionIdent, ") {") - g.P(registerIdent, "(s, ", strconv.Quote(item.verb), ", ", strconv.Quote(item.path), ", ", operationConst, ",") - g.P("func(ctx ", contextIdent, ", in *", input, ") error {") - if item.verb == http.MethodGet { - g.P("return ctx.BindQuery(in)") - } else { - g.P("return ctx.Bind(in)") + for _, route := range item.routes { + g.P(registerIdent, "(s, ", strconv.Quote(route.verb), ", ", strconv.Quote(route.path), ", ", operationConst, ", srv.", methodName, ", opts...)") } - g.P("},") - g.P("srv.", methodName, ",") - g.P("opts...,") - g.P(")") g.P("}") g.P() } diff --git a/sse/kratos/protoc-gen-go-sse/main_test.go b/sse/kratos/protoc-gen-go-sse/main_test.go new file mode 100644 index 0000000..2a333ca --- /dev/null +++ b/sse/kratos/protoc-gen-go-sse/main_test.go @@ -0,0 +1,90 @@ +package main + +import ( + "strings" + "testing" + + ssev1 "github.com/crypto-zero/go-kit/proto/kit/sse/v1" + "google.golang.org/genproto/googleapis/api/annotations" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/pluginpb" +) + +func TestGenerateFileUsesDefaultHTTPStreamBinding(t *testing.T) { + req := testCodeGeneratorRequest(t) + plugin, err := protogen.Options{}.New(req) + if err != nil { + t.Fatalf("protogen.New: %v", err) + } + + for _, file := range plugin.Files { + if file.Generate { + generateFile(plugin, file) + } + } + + resp := plugin.Response() + if len(resp.File) != 1 { + t.Fatalf("generated files = %d, want 1", len(resp.File)) + } + got := resp.File[0].GetContent() + if strings.Contains(got, "RegisterHTTPStreamBound") { + t.Fatalf("generated code uses bound registration:\n%s", got) + } + if count := strings.Count(got, "const OperationLiveServiceWatchSSE"); count != 1 { + t.Fatalf("operation const count = %d, want 1:\n%s", count, got) + } + if count := strings.Count(got, "func _LiveService_Watch_SSE_Register"); count != 1 { + t.Fatalf("register function count = %d, want 1:\n%s", count, got) + } + for _, want := range []string{ + `RegisterHTTPStream(s, "GET", "/v1/watch", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + `RegisterHTTPStream(s, "POST", "/v1/watch:tail", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + } { + if !strings.Contains(got, want) { + t.Fatalf("generated code missing %q:\n%s", want, got) + } + } +} + +func testCodeGeneratorRequest(t *testing.T) *pluginpb.CodeGeneratorRequest { + t.Helper() + + opts := &descriptorpb.MethodOptions{} + proto.SetExtension(opts, ssev1.E_ServerSentEvent, true) + proto.SetExtension(opts, annotations.E_Http, &annotations.HttpRule{ + Pattern: &annotations.HttpRule_Get{Get: "/v1/watch"}, + AdditionalBindings: []*annotations.HttpRule{{ + Pattern: &annotations.HttpRule_Post{Post: "/v1/watch:tail"}, + Body: "*", + }}, + }) + + return &pluginpb.CodeGeneratorRequest{ + FileToGenerate: []string{"test/v1/live.proto"}, + ProtoFile: []*descriptorpb.FileDescriptorProto{{ + Syntax: proto.String("proto3"), + Name: proto.String("test/v1/live.proto"), + Package: proto.String("test.v1"), + Options: &descriptorpb.FileOptions{ + GoPackage: proto.String("example.com/test/v1;testv1"), + }, + MessageType: []*descriptorpb.DescriptorProto{{ + Name: proto.String("WatchRequest"), + }, { + Name: proto.String("WatchResponse"), + }}, + Service: []*descriptorpb.ServiceDescriptorProto{{ + Name: proto.String("LiveService"), + Method: []*descriptorpb.MethodDescriptorProto{{ + Name: proto.String("Watch"), + InputType: proto.String(".test.v1.WatchRequest"), + OutputType: proto.String(".test.v1.WatchResponse"), + Options: opts, + }}, + }}, + }}, + } +} diff --git a/sse/kratos/server.go b/sse/kratos/server.go deleted file mode 100644 index 8b16bc2..0000000 --- a/sse/kratos/server.go +++ /dev/null @@ -1,277 +0,0 @@ -// Package kratos provides a Server-Sent Events transport server that plugs -// into a Kratos application as a first-class transport.Server. -// -// The server owns its own net.Listener and http.ServeMux. Requests it -// serves carry a transport.Transporter with Kind="sse" in their context, -// so Kratos middleware can inspect and act on SSE traffic the same way it -// does for HTTP and gRPC. -// -// Streaming itself is handled by the parent github.com/crypto-zero/go-kit/sse -// package: handlers construct an *sse.Stream from the ResponseWriter and -// write events through its API. This package contributes the Kratos plumbing -// (lifecycle, endpoint registration, codec selection, filter chain) around -// that core. -package kratos - -import ( - "context" - "crypto/tls" - "errors" - "log/slog" - "net" - "net/http" - "net/url" - "sync/atomic" - "time" - - "github.com/go-kratos/kratos/v2/encoding" - // Register the JSON codec by default; users can pull additional - // codecs (proto, yaml, xml) by importing them at their main package. - _ "github.com/go-kratos/kratos/v2/encoding/json" - "github.com/go-kratos/kratos/v2/middleware" - ktransport "github.com/go-kratos/kratos/v2/transport" -) - -// KindSSE identifies this transport in the Kratos transport registry. -const KindSSE ktransport.Kind = "sse" - -// DefaultReadHeaderTimeout is applied when no ReadHeaderTimeout option is -// given. It protects the server from Slowloris-style attacks (clients that -// trickle request headers to hold connections open) without affecting the -// streaming response — write deadlines are managed separately. -const DefaultReadHeaderTimeout = 10 * time.Second - -var ( - _ ktransport.Server = (*Server)(nil) - _ ktransport.Endpointer = (*Server)(nil) - _ http.Handler = (*Server)(nil) -) - -// FilterFunc wraps an http.Handler. Filters compose right-to-left around -// the request: the first filter is the outermost wrapper. -type FilterFunc func(http.Handler) http.Handler - -// FilterChain composes filters into a single wrapper. -func FilterChain(filters ...FilterFunc) FilterFunc { - return func(next http.Handler) http.Handler { - for i := len(filters) - 1; i >= 0; i-- { - next = filters[i](next) - } - return next - } -} - -// DecodeRequestFunc decodes an inbound request body into v. -type DecodeRequestFunc func(*http.Request, any) error - -// EncodeErrorFunc reports an error to the client. For SSE handlers the -// default writes an SSE "error" event when headers have not yet been sent, -// otherwise falls through to http.Error. -type EncodeErrorFunc func(http.ResponseWriter, *http.Request, error) - -// Server is a Kratos transport.Server that serves Server-Sent Events. -type Server struct { - *http.Server - - lis net.Listener - tlsConf *tls.Config - endpoint *url.URL - - network string - address string - - mux *http.ServeMux - codec encoding.Codec - logger *slog.Logger - filters []FilterFunc - middlewares []middleware.Middleware - decBody DecodeRequestFunc - errEnc EncodeErrorFunc - patterns []string - readHeaderTimeout time.Duration - heartbeat time.Duration - - // shutdownCtx is the parent context handed to every request via - // http.Server.BaseContext. Stop cancels it before calling Shutdown so - // long-running SSE handlers observe ctx.Done() and can drain cleanly - // instead of blocking the shutdown. - shutdownCtx context.Context - shutdownCancel context.CancelFunc - - // active counts live SSE streams managed by StreamHandler / - // JSONHandler. Exposed via ActiveStreams. - active atomic.Int64 -} - -// NewServer constructs a Server. With no options it listens on a random -// TCP port, uses the JSON codec, and serves plaintext HTTP. -func NewServer(opts ...ServerOption) *Server { - s := &Server{ - network: "tcp", - address: ":0", - mux: http.NewServeMux(), - codec: encoding.GetCodec("json"), - logger: slog.Default(), - decBody: DefaultRequestDecoder, - errEnc: DefaultErrorEncoder, - readHeaderTimeout: DefaultReadHeaderTimeout, - } - for _, o := range opts { - o(s) - } - s.Server = &http.Server{ - Handler: FilterChain(s.filters...)(http.HandlerFunc(s.dispatch)), - TLSConfig: s.tlsConf, - ReadHeaderTimeout: s.readHeaderTimeout, - } - return s -} - -// Name returns the transport kind, "sse". -func (s *Server) Name() string { return string(KindSSE) } - -// Endpoint returns the address the server is (or will be) listening on, -// scheme "sse://", suitable for service-registry advertisement. -func (s *Server) Endpoint() (*url.URL, error) { - if err := s.listenAndEndpoint(); err != nil { - return nil, err - } - return s.endpoint, nil -} - -// Codec returns the codec configured for request/response payloads. -func (s *Server) Codec() encoding.Codec { return s.codec } - -// Start opens the listener (if not already) and serves until Stop is -// called or the listener fails. It implements transport.Server. -func (s *Server) Start(ctx context.Context) error { - if err := s.listenAndEndpoint(); err != nil { - return err - } - // Build a cancellable child context that Stop will tear down before - // http.Server.Shutdown runs. Handlers receive this ctx via - // r.Context(), so a Pump select-on-Done unblocks promptly during - // shutdown rather than holding the connection until its own write - // deadline expires. - s.shutdownCtx, s.shutdownCancel = context.WithCancel(ctx) - s.BaseContext = func(net.Listener) context.Context { return s.shutdownCtx } - s.logger.InfoContext(ctx, "sse server listening", "addr", s.lis.Addr().String()) - - var err error - if s.tlsConf != nil { - err = s.ServeTLS(s.lis, "", "") - } else { - err = s.Serve(s.lis) - } - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return err - } - return nil -} - -// Stop gracefully shuts the server down. It first cancels the -// shutdown-aware context that every handler receives — long-running SSE -// streams that observe ctx.Done() will exit promptly — then calls -// http.Server.Shutdown to wait for in-flight requests to complete. If -// ctx expires before drainage finishes, Stop force-closes connections, -// matching the Kratos HTTP server's behavior. -func (s *Server) Stop(ctx context.Context) error { - s.logger.InfoContext(ctx, "sse server stopping", - "active_streams", s.ActiveStreams()) - if s.shutdownCancel != nil { - s.shutdownCancel() - } - if err := s.Shutdown(ctx); err != nil { - if ctx.Err() != nil { - s.logger.WarnContext(ctx, "sse server force-closing after shutdown timeout") - return s.Close() - } - return err - } - return nil -} - -// ActiveStreams reports the number of SSE streams currently in flight -// through StreamHandler or JSONHandler. Handlers mounted via plain -// Handle / HandleFunc are not counted. -func (s *Server) ActiveStreams() int64 { return s.active.Load() } - -// Handle mounts an http.Handler at pattern. Pattern syntax follows -// net/http.ServeMux (Go 1.22+ "METHOD /path/{var}" form). -func (s *Server) Handle(pattern string, h http.Handler) { - s.mux.Handle(pattern, h) - s.patterns = append(s.patterns, pattern) -} - -// HandleFunc mounts an http.HandlerFunc at pattern. -func (s *Server) HandleFunc(pattern string, h http.HandlerFunc) { - s.Handle(pattern, h) -} - -// WalkPattern visits every pattern registered with Handle/HandleFunc. -// The order matches registration order. -func (s *Server) WalkPattern(fn func(pattern string)) { - for _, p := range s.patterns { - fn(p) - } -} - -// Decode reads r.Body and unmarshals it via the configured request -// decoder (set with RequestDecoder; defaults to DefaultRequestDecoder). -func (s *Server) Decode(r *http.Request, v any) error { - return s.decBody(r, v) -} - -// EncodeError reports err to the client via the configured error -// encoder. -func (s *Server) EncodeError(w http.ResponseWriter, r *http.Request, err error) { - s.errEnc(w, r, err) -} - -// ServeHTTP runs the filter chain around the routing dispatch. -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.Handler.ServeHTTP(w, r) -} - -// dispatch is the innermost handler invoked after filters run. It installs -// the SSE transport.Transporter into the request context and dispatches -// the request via the mux. -func (s *Server) dispatch(w http.ResponseWriter, r *http.Request) { - // Resolve the matched pattern (e.g. "/v1/sse/chat:stream") so - // middleware that inspects Operation sees the route template rather - // than the request-specific path. - _, pattern := s.mux.Handler(r) - - tr := &Transport{ - endpoint: s.endpointString(), - operation: pattern, - pathTemplate: pattern, - request: r, - response: w, - reqHeader: headerCarrier(r.Header), - replyHeader: headerCarrier(w.Header()), - } - r = r.WithContext(ktransport.NewServerContext(r.Context(), tr)) - s.mux.ServeHTTP(w, r) -} - -func (s *Server) endpointString() string { - if s.endpoint == nil { - return "" - } - return s.endpoint.String() -} - -func (s *Server) listenAndEndpoint() error { - if s.lis == nil { - lis, err := net.Listen(s.network, s.address) - if err != nil { - return err - } - s.lis = lis - } - if s.endpoint == nil { - s.endpoint = &url.URL{Scheme: string(KindSSE), Host: s.lis.Addr().String()} - } - return nil -} diff --git a/sse/kratos/server_options.go b/sse/kratos/server_options.go deleted file mode 100644 index c5b04f3..0000000 --- a/sse/kratos/server_options.go +++ /dev/null @@ -1,131 +0,0 @@ -package kratos - -import ( - "crypto/tls" - "log/slog" - "net" - "net/url" - "time" - - "github.com/go-kratos/kratos/v2/encoding" - "github.com/go-kratos/kratos/v2/middleware" -) - -// ServerOption configures a Server at construction time. -type ServerOption func(*Server) - -// Network sets the listener network (default "tcp"). -func Network(network string) ServerOption { - return func(s *Server) { s.network = network } -} - -// Address sets the listener address (default ":0", random port). -func Address(addr string) ServerOption { - return func(s *Server) { s.address = addr } -} - -// Listener supplies a pre-built listener, bypassing Network/Address. -func Listener(lis net.Listener) ServerOption { - return func(s *Server) { s.lis = lis } -} - -// Endpoint overrides the URL advertised via Endpoint(). Use this when the -// listener address (":8080") isn't the externally routable address. -func Endpoint(u *url.URL) ServerOption { - return func(s *Server) { s.endpoint = u } -} - -// TLSConfig configures TLS for the server. When set, Start serves via -// ServeTLS using the certificate(s) from cfg. -func TLSConfig(c *tls.Config) ServerOption { - return func(s *Server) { s.tlsConf = c } -} - -// Codec sets the default codec used by Decode when the request's -// Content-Type does not specify a recognized one. Pass a Kratos codec -// name such as "json" or "proto"; the codec must be registered -// (importing "github.com/go-kratos/kratos/v2/encoding/proto" suffices -// for proto). -// -// Passing a name with no registered codec is a programming error and -// panics — the misconfiguration would otherwise surface as an obscure -// runtime decode failure. -func Codec(name string) ServerOption { - c := encoding.GetCodec(name) - if c == nil { - panic("sse/kratos: codec not registered: " + name) - } - return func(s *Server) { s.codec = c } -} - -// Logger replaces the default slog.Logger (slog.Default()). A nil logger -// is ignored. -func Logger(l *slog.Logger) ServerOption { - return func(s *Server) { - if l != nil { - s.logger = l - } - } -} - -// RequestDecoder replaces the request body decoder used by Decode. -func RequestDecoder(dec DecodeRequestFunc) ServerOption { - return func(s *Server) { - if dec != nil { - s.decBody = dec - } - } -} - -// ErrorEncoder replaces the error encoder used by EncodeError. -func ErrorEncoder(enc EncodeErrorFunc) ServerOption { - return func(s *Server) { - if enc != nil { - s.errEnc = enc - } - } -} - -// Filter prepends HTTP middleware to the request pipeline. Filters run -// before the route is dispatched, so they may short-circuit auth, set -// cross-cutting headers, or wrap response writing. -func Filter(filters ...FilterFunc) ServerOption { - return func(s *Server) { s.filters = append(s.filters, filters...) } -} - -// Middleware appends Kratos middleware to the chain executed by -// StreamHandler and JSONHandler — between request decoding and stream -// start. Use this for auth, JWT/token verification, schema validation -// and other pre-handler concerns. -// -// These middlewares do NOT wrap the streaming portion of the response; -// any middleware that needs to observe the whole request (tracing, -// metrics, recovery) should be installed as a Filter instead. See the -// package doc for the full rationale. -func Middleware(mws ...middleware.Middleware) ServerOption { - return func(s *Server) { s.middlewares = append(s.middlewares, mws...) } -} - -// ReadHeaderTimeout overrides the time bound on reading request headers -// (default DefaultReadHeaderTimeout). Set to 0 to disable, accepting -// Slowloris risk. -// -// Note: WriteTimeout / IdleTimeout are intentionally not exposed here; -// SSE streams are long-lived and a server-wide write deadline would kill -// them. Per-handler deadlines should be set via http.ResponseController -// inside the handler instead. -func ReadHeaderTimeout(d time.Duration) ServerOption { - return func(s *Server) { s.readHeaderTimeout = d } -} - -// Heartbeat enables automatic SSE comment frames (": \n\n") at the given -// interval on every stream built via StreamHandler or JSONHandler. The -// keepalive is invisible to clients (comments are spec-defined to be -// ignored) but defeats idle-connection timers in upstream proxies -// (nginx, ALB, CloudFlare). -// -// Set to 0 (the default) to disable. Recommended value: 15s — under -// the typical 30-60s proxy idle timeout. -func Heartbeat(interval time.Duration) ServerOption { - return func(s *Server) { s.heartbeat = interval } -} diff --git a/sse/kratos/server_test.go b/sse/kratos/server_test.go deleted file mode 100644 index 5009d89..0000000 --- a/sse/kratos/server_test.go +++ /dev/null @@ -1,321 +0,0 @@ -package kratos_test - -import ( - "bufio" - "context" - "crypto/tls" - "errors" - "net" - "net/http" - "net/url" - "strings" - "testing" - "time" - - ktransport "github.com/go-kratos/kratos/v2/transport" - - "github.com/crypto-zero/go-kit/sse" - ksse "github.com/crypto-zero/go-kit/sse/kratos" -) - -func newServerOnLoopback(t *testing.T, opts ...ksse.ServerOption) (*ksse.Server, string) { - t.Helper() - lis, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - opts = append([]ksse.ServerOption{ksse.Listener(lis)}, opts...) - srv := ksse.NewServer(opts...) - return srv, lis.Addr().String() -} - -func startServer(t *testing.T, srv *ksse.Server) func() { - t.Helper() - ctx, cancel := context.WithCancel(context.Background()) - startErr := make(chan error, 1) - go func() { startErr <- srv.Start(ctx) }() - return func() { - shutdownCtx, c := context.WithTimeout(context.Background(), time.Second) - defer c() - if err := srv.Stop(shutdownCtx); err != nil { - t.Errorf("Stop: %v", err) - } - if err := <-startErr; err != nil { - t.Errorf("Start: %v", err) - } - cancel() - } -} - -func TestServer_Name(t *testing.T) { - srv := ksse.NewServer() - if got, want := srv.Name(), string(ksse.KindSSE); got != want { - t.Errorf("Name() = %q, want %q", got, want) - } -} - -func TestServer_Endpoint(t *testing.T) { - srv, addr := newServerOnLoopback(t) - u, err := srv.Endpoint() - if err != nil { - t.Fatalf("Endpoint: %v", err) - } - if u.Scheme != string(ksse.KindSSE) { - t.Errorf("scheme = %q, want %q", u.Scheme, ksse.KindSSE) - } - if u.Host != addr { - t.Errorf("host = %q, want %q", u.Host, addr) - } -} - -func TestServer_EndpointOverride(t *testing.T) { - override := &url.URL{Scheme: "sse", Host: "api.example.com:443"} - srv := ksse.NewServer(ksse.Endpoint(override)) - got, err := srv.Endpoint() - if err != nil { - t.Fatalf("Endpoint: %v", err) - } - if got.String() != override.String() { - t.Errorf("Endpoint = %q, want %q", got, override) - } -} - -func TestServer_StartAndStreamsEvents(t *testing.T) { - srv, addr := newServerOnLoopback(t) - srv.HandleFunc("/v1/stream", func(w http.ResponseWriter, _ *http.Request) { - s := sse.NewStream(w) - _ = s.Write("hello") - _ = s.Write("world") - _ = s.Done() - }) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Get("http://" + addr + "/v1/stream") - if err != nil { - t.Fatalf("GET: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if got, want := resp.Header.Get("Content-Type"), "text/event-stream"; got != want { - t.Errorf("Content-Type = %q, want %q", got, want) - } - - body := readAll(t, resp.Body) - for _, want := range []string{"data: hello\n\n", "data: world\n\n", "data: [DONE]\n\n"} { - if !strings.Contains(body, want) { - t.Errorf("body missing %q\nfull body: %q", want, body) - } - } -} - -func TestServer_TransporterPathTemplate(t *testing.T) { - srv, addr := newServerOnLoopback(t) - got := make(chan struct { - kind ktransport.Kind - op string - ep string - pathTemplate string - hasResponse bool - }, 1) - srv.HandleFunc("/v1/items/{id}", func(w http.ResponseWriter, r *http.Request) { - tr, ok := ktransport.FromServerContext(r.Context()) - if !ok { - t.Errorf("no transport in context") - return - } - var pt string - if ptr, ok := tr.(ksse.Transporter); ok { - pt = ptr.PathTemplate() - } - var hasResp bool - if _, ok := tr.(ksse.ResponseTransporter); ok { - hasResp = true - } - got <- struct { - kind ktransport.Kind - op string - ep string - pathTemplate string - hasResponse bool - }{tr.Kind(), tr.Operation(), tr.Endpoint(), pt, hasResp} - _ = sse.NewStream(w).Done() - }) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Get("http://" + addr + "/v1/items/42") - if err != nil { - t.Fatalf("GET: %v", err) - } - _ = resp.Body.Close() - select { - case v := <-got: - if v.kind != ksse.KindSSE { - t.Errorf("Kind = %q, want %q", v.kind, ksse.KindSSE) - } - if v.op != "/v1/items/{id}" { - t.Errorf("Operation = %q, want /v1/items/{id}", v.op) - } - if v.pathTemplate != "/v1/items/{id}" { - t.Errorf("PathTemplate = %q, want /v1/items/{id}", v.pathTemplate) - } - if !strings.HasPrefix(v.ep, "sse://") { - t.Errorf("Endpoint = %q, want sse:// prefix", v.ep) - } - if !v.hasResponse { - t.Errorf("transport does not satisfy ResponseTransporter") - } - case <-time.After(2 * time.Second): - t.Fatal("handler did not run") - } -} - -func TestServer_DecodeJSON(t *testing.T) { - srv, addr := newServerOnLoopback(t) - type req struct { - Name string `json:"name"` - } - got := make(chan string, 1) - srv.HandleFunc("POST /echo", func(w http.ResponseWriter, r *http.Request) { - var v req - if err := srv.Decode(r, &v); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - got <- v.Name - _ = sse.NewStream(w).Done() - }) - - stop := startServer(t, srv) - defer stop() - - body := strings.NewReader(`{"name":"karma"}`) - resp, err := http.Post("http://"+addr+"/echo", "application/json", body) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - t.Fatalf("status = %d", resp.StatusCode) - } - select { - case v := <-got: - if v != "karma" { - t.Errorf("decoded name = %q, want karma", v) - } - case <-time.After(2 * time.Second): - t.Fatal("handler did not run") - } -} - -func TestServer_RequestDecoderOverride(t *testing.T) { - sentinel := errors.New("custom decoder") - srv, addr := newServerOnLoopback(t, - ksse.RequestDecoder(func(*http.Request, any) error { return sentinel }), - ) - srv.HandleFunc("POST /x", func(w http.ResponseWriter, r *http.Request) { - var v any - if err := srv.Decode(r, &v); err == nil { - http.Error(w, "expected sentinel", http.StatusInternalServerError) - return - } - _ = sse.NewStream(w).Done() - }) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Post("http://"+addr+"/x", "application/json", strings.NewReader(`{}`)) - if err != nil { - t.Fatalf("POST: %v", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != 200 { - t.Errorf("status = %d, want 200", resp.StatusCode) - } -} - -func TestServer_FilterChain(t *testing.T) { - srv, addr := newServerOnLoopback(t, - ksse.Filter( - func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Outer", "1") - next.ServeHTTP(w, r) - }) - }, - func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Inner", "1") - next.ServeHTTP(w, r) - }) - }, - ), - ) - srv.HandleFunc("/f", func(w http.ResponseWriter, _ *http.Request) { - _ = sse.NewStream(w).Done() - }) - - stop := startServer(t, srv) - defer stop() - - resp, err := http.Get("http://" + addr + "/f") - if err != nil { - t.Fatalf("GET: %v", err) - } - _ = resp.Body.Close() - if got := resp.Header.Get("X-Outer"); got != "1" { - t.Errorf("X-Outer = %q, want 1", got) - } - if got := resp.Header.Get("X-Inner"); got != "1" { - t.Errorf("X-Inner = %q, want 1", got) - } -} - -func TestServer_WalkPattern(t *testing.T) { - srv := ksse.NewServer() - srv.HandleFunc("/a", func(http.ResponseWriter, *http.Request) {}) - srv.HandleFunc("POST /b", func(http.ResponseWriter, *http.Request) {}) - - var seen []string - srv.WalkPattern(func(p string) { seen = append(seen, p) }) - want := []string{"/a", "POST /b"} - if strings.Join(seen, ",") != strings.Join(want, ",") { - t.Errorf("WalkPattern visited %v, want %v", seen, want) - } -} - -func TestServer_TLSConfigSelected(t *testing.T) { - // We only verify the option installs the TLS config; serving TLS - // requires a real cert that's out of scope for this test. - srv := ksse.NewServer(ksse.TLSConfig(&tls.Config{})) - if srv.TLSConfig == nil { - t.Errorf("TLSConfig not propagated to http.Server") - } -} - -func TestCodecOption_Panics(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Error("Codec(unregistered) did not panic") - } - }() - ksse.Codec("nope") -} - -func readAll(t *testing.T, r interface { - Read([]byte) (int, error) -}) string { - t.Helper() - var sb strings.Builder - br := bufio.NewReader(r) - for { - line, err := br.ReadString('\n') - sb.WriteString(line) - if err != nil { - return sb.String() - } - } -} diff --git a/sse/kratos/transport.go b/sse/kratos/transport.go deleted file mode 100644 index 37ab4cc..0000000 --- a/sse/kratos/transport.go +++ /dev/null @@ -1,113 +0,0 @@ -package kratos - -import ( - "context" - "net/http" - - ktransport "github.com/go-kratos/kratos/v2/transport" -) - -var ( - _ Transporter = (*Transport)(nil) - _ ResponseTransporter = (*Transport)(nil) -) - -// Transporter extends Kratos' transport.Transporter with the raw HTTP -// request, matching the shape that github.com/go-kratos/kratos/v2/transport/http -// exposes for its own transport. -type Transporter interface { - ktransport.Transporter - Request() *http.Request - PathTemplate() string -} - -// ResponseTransporter additionally exposes the response writer, for -// handlers that need to wrap or inspect it (e.g. SSE). -type ResponseTransporter interface { - Transporter - Response() http.ResponseWriter -} - -// Transport is the SSE transport context attached to each request. -type Transport struct { - endpoint string - operation string - pathTemplate string - request *http.Request - response http.ResponseWriter - reqHeader headerCarrier - replyHeader headerCarrier -} - -// Kind reports the transport kind, KindSSE. -func (t *Transport) Kind() ktransport.Kind { return KindSSE } - -// Endpoint reports the server's advertised endpoint. -func (t *Transport) Endpoint() string { return t.endpoint } - -// Operation reports the matched route template (e.g. "/v1/chat:stream"). -// Middleware that needs a more specific operation can override it via -// SetOperation. -func (t *Transport) Operation() string { return t.operation } - -// PathTemplate reports the matched ServeMux pattern. -func (t *Transport) PathTemplate() string { return t.pathTemplate } - -// Request returns the underlying *http.Request. -func (t *Transport) Request() *http.Request { return t.request } - -// Response returns the underlying http.ResponseWriter. -func (t *Transport) Response() http.ResponseWriter { return t.response } - -// RequestHeader returns the inbound HTTP headers. -func (t *Transport) RequestHeader() ktransport.Header { return t.reqHeader } - -// ReplyHeader returns the writable response headers. -func (t *Transport) ReplyHeader() ktransport.Header { return t.replyHeader } - -// SetOperation overrides the operation name on the SSE transport -// attached to ctx. It is a no-op when ctx does not carry an SSE -// transport. -func SetOperation(ctx context.Context, op string) { - if tr, ok := ktransport.FromServerContext(ctx); ok { - if t, ok := tr.(*Transport); ok { - t.operation = op - } - } -} - -// RequestFromServerContext returns the request stored in ctx by an SSE -// transport, or false if none is present. -func RequestFromServerContext(ctx context.Context) (*http.Request, bool) { - if tr, ok := ktransport.FromServerContext(ctx); ok { - if t, ok := tr.(Transporter); ok { - return t.Request(), true - } - } - return nil, false -} - -// ResponseWriterFromServerContext returns the response writer stored in -// ctx, or false if none is present. -func ResponseWriterFromServerContext(ctx context.Context) (http.ResponseWriter, bool) { - if tr, ok := ktransport.FromServerContext(ctx); ok { - if t, ok := tr.(ResponseTransporter); ok { - return t.Response(), true - } - } - return nil, false -} - -type headerCarrier http.Header - -func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) } -func (hc headerCarrier) Set(key, value string) { http.Header(hc).Set(key, value) } -func (hc headerCarrier) Add(key, value string) { http.Header(hc).Add(key, value) } -func (hc headerCarrier) Values(key string) []string { return http.Header(hc).Values(key) } -func (hc headerCarrier) Keys() []string { - keys := make([]string, 0, len(hc)) - for k := range http.Header(hc) { - keys = append(keys, k) - } - return keys -} diff --git a/sse/snapshot_live.go b/sse/snapshot_live.go deleted file mode 100644 index f990969..0000000 --- a/sse/snapshot_live.go +++ /dev/null @@ -1,72 +0,0 @@ -package sse - -import "context" - -// SnapshotLiveOptions configures StreamSnapshotThenLive. -type SnapshotLiveOptions[T any] struct { - SnapshotEvent string - SnapshotEndEvent string - SnapshotEndData string - LiveEvent string - ID func(T) string - Data func(T) (any, error) -} - -// StreamSnapshotThenLive writes an initial snapshot batch, an optional -// snapshot-end marker, then live events from ch until ctx is canceled or ch -// closes. -func StreamSnapshotThenLive[T any]( - ctx context.Context, - st *Stream, - snapshot []T, - ch <-chan T, - opts SnapshotLiveOptions[T], -) error { - for _, item := range snapshot { - if err := writeSnapshotLiveEvent(st, opts.SnapshotEvent, item, opts); err != nil { - return err - } - } - - if opts.SnapshotEndEvent != "" { - data := opts.SnapshotEndData - if data == "" { - data = "{}" - } - if err := st.WriteEvent(Event{Event: opts.SnapshotEndEvent, Data: data}); err != nil { - return err - } - } - - for { - select { - case <-ctx.Done(): - return nil - case item, ok := <-ch: - if !ok { - return nil - } - if err := writeSnapshotLiveEvent(st, opts.LiveEvent, item, opts); err != nil { - return err - } - } - } -} - -func writeSnapshotLiveEvent[T any](st *Stream, event string, item T, opts SnapshotLiveOptions[T]) error { - id := "" - if opts.ID != nil { - id = opts.ID(item) - } - if opts.Data == nil { - return st.WriteJSONEvent(Event{Event: event, ID: id}, item) - } - data, err := opts.Data(item) - if err != nil { - return err - } - if s, ok := data.(string); ok { - return st.WriteEvent(Event{Event: event, ID: id, Data: s}) - } - return st.WriteJSONEvent(Event{Event: event, ID: id}, data) -} diff --git a/sse/snapshot_live_test.go b/sse/snapshot_live_test.go deleted file mode 100644 index 19e1dd9..0000000 --- a/sse/snapshot_live_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package sse - -import ( - "context" - "net/http/httptest" - "strings" - "testing" -) - -func TestStreamSnapshotThenLiveWritesSnapshotEndAndLive(t *testing.T) { - rec := httptest.NewRecorder() - st := NewStream(rec) - live := make(chan int, 2) - live <- 3 - close(live) - - err := StreamSnapshotThenLive(context.Background(), st, []int{1, 2}, live, SnapshotLiveOptions[int]{ - SnapshotEvent: "snapshot", - SnapshotEndEvent: "snapshot-end", - LiveEvent: "point", - ID: func(v int) string { - return string(rune('a' + v - 1)) - }, - Data: func(v int) (any, error) { - return map[string]int{"value": v}, nil - }, - }) - if err != nil { - t.Fatalf("StreamSnapshotThenLive: %v", err) - } - - body := rec.Body.String() - for _, want := range []string{ - "event: snapshot\nid: a\ndata: {\"value\":1}\n\n", - "event: snapshot\nid: b\ndata: {\"value\":2}\n\n", - "event: snapshot-end\ndata: {}\n\n", - "event: point\nid: c\ndata: {\"value\":3}\n\n", - } { - if !strings.Contains(body, want) { - t.Fatalf("body missing %q\nbody:\n%s", want, body) - } - } -} - -func TestStreamSnapshotThenLiveDefaultsToJSON(t *testing.T) { - rec := httptest.NewRecorder() - st := NewStream(rec) - live := make(chan struct { - Name string `json:"name"` - }) - close(live) - - err := StreamSnapshotThenLive(context.Background(), st, []struct { - Name string `json:"name"` - }{{Name: "one"}}, live, SnapshotLiveOptions[struct { - Name string `json:"name"` - }]{}) - if err != nil { - t.Fatalf("StreamSnapshotThenLive: %v", err) - } - if got := rec.Body.String(); !strings.Contains(got, "data: {\"name\":\"one\"}\n\n") { - t.Fatalf("expected JSON payload, got:\n%s", got) - } -} diff --git a/sse/sse.go b/sse/sse.go index 5f7ba25..8a39905 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -5,8 +5,8 @@ // channels. Authentication, request decoding, validation and routing are the // caller's responsibility. // -// For a Kratos transport.Server adapter built on top of this package, see -// the sub-package github.com/crypto-zero/go-kit/sse/kratos. +// For helpers that mount SSE endpoints on a Kratos HTTP server, see the +// sub-package github.com/crypto-zero/go-kit/sse/kratos. package sse import ( From 2f890dbe843d87a72735b7930c80a960609be2d8 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 13 May 2026 20:17:10 +0800 Subject: [PATCH 07/11] feat: generate sse clients --- sse/kratos/http_client.go | 103 ++++++++++++++++++++++ sse/kratos/http_client_test.go | 38 ++++++++ sse/kratos/protoc-gen-go-sse/main.go | 31 +++++++ sse/kratos/protoc-gen-go-sse/main_test.go | 5 ++ sse/reader.go | 78 ++++++++++++++++ sse/reader_test.go | 23 +++++ 6 files changed, 278 insertions(+) create mode 100644 sse/kratos/http_client.go create mode 100644 sse/kratos/http_client_test.go create mode 100644 sse/reader.go create mode 100644 sse/reader_test.go diff --git a/sse/kratos/http_client.go b/sse/kratos/http_client.go new file mode 100644 index 0000000..7c10c1f --- /dev/null +++ b/sse/kratos/http_client.go @@ -0,0 +1,103 @@ +package kratos + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/crypto-zero/go-kit/sse" +) + +// HTTPClient opens SSE streams over HTTP. +type HTTPClient struct { + endpoint string + client *http.Client +} + +// HTTPClientOption configures an HTTPClient. +type HTTPClientOption func(*HTTPClient) + +// WithHTTPClient sets the underlying net/http client. +func WithHTTPClient(client *http.Client) HTTPClientOption { + return func(c *HTTPClient) { + if client != nil { + c.client = client + } + } +} + +// NewHTTPClient returns an SSE HTTP client for endpoint. +func NewHTTPClient(endpoint string, opts ...HTTPClientOption) *HTTPClient { + c := &HTTPClient{endpoint: strings.TrimRight(endpoint, "/"), client: http.DefaultClient} + for _, opt := range opts { + opt(c) + } + return c +} + +// HTTPStreamCallOption configures one SSE stream request. +type HTTPStreamCallOption func(*httpStreamCallConfig) + +type httpStreamCallConfig struct { + headers http.Header +} + +// WithRequestHeader adds one request header value. +func WithRequestHeader(key, value string) HTTPStreamCallOption { + return func(c *httpStreamCallConfig) { + c.headers.Add(key, value) + } +} + +// WithLastEventID sets the Last-Event-ID header for stream resumption. +func WithLastEventID(id string) HTTPStreamCallOption { + return WithRequestHeader(sse.LastEventIDHeader, id) +} + +// Open opens an SSE stream and returns a reader for response events. +func (c *HTTPClient) Open(ctx context.Context, method, path string, opts ...HTTPStreamCallOption) (*sse.Reader, error) { + cfg := httpStreamCallConfig{headers: make(http.Header)} + for _, opt := range opts { + opt(&cfg) + } + u, err := joinEndpointPath(c.endpoint, path) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, method, u, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "text/event-stream") + for key, values := range cfg.headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode < http.StatusOK || resp.StatusCode > 299 { + defer resp.Body.Close() + return nil, fmt.Errorf("sse: unexpected status %d", resp.StatusCode) + } + return sse.NewReader(resp.Body), nil +} + +func joinEndpointPath(endpoint, path string) (string, error) { + if endpoint == "" { + return "", fmt.Errorf("sse: endpoint is empty") + } + base, err := url.Parse(endpoint) + if err != nil { + return "", err + } + ref, err := url.Parse(path) + if err != nil { + return "", err + } + return base.ResolveReference(ref).String(), nil +} diff --git a/sse/kratos/http_client_test.go b/sse/kratos/http_client_test.go new file mode 100644 index 0000000..77d8e20 --- /dev/null +++ b/sse/kratos/http_client_test.go @@ -0,0 +1,38 @@ +package kratos_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ksse "github.com/crypto-zero/go-kit/sse" + ssekratos "github.com/crypto-zero/go-kit/sse/kratos" +) + +func TestHTTPClientOpen(t *testing.T) { + var gotLastEventID string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotLastEventID = r.Header.Get(ksse.LastEventIDHeader) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("id: 7\nevent: point\ndata: {}\n\n")) + })) + defer ts.Close() + + client := ssekratos.NewHTTPClient(ts.URL) + reader, err := client.Open(context.Background(), http.MethodGet, "/v1/watch?west=1", ssekratos.WithLastEventID("6")) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer reader.Close() + if gotLastEventID != "6" { + t.Fatalf("Last-Event-ID = %q, want 6", gotLastEventID) + } + ev, err := reader.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + if ev.ID != "7" || ev.Event != "point" || ev.Data != "{}" { + t.Fatalf("unexpected event: %#v", ev) + } +} diff --git a/sse/kratos/protoc-gen-go-sse/main.go b/sse/kratos/protoc-gen-go-sse/main.go index d5bce67..68ef1c0 100644 --- a/sse/kratos/protoc-gen-go-sse/main.go +++ b/sse/kratos/protoc-gen-go-sse/main.go @@ -141,8 +141,11 @@ func genService(g *protogen.GeneratedFile, service *protogen.Service, methods [] serviceName := service.GoName contextIdent := g.QualifiedGoIdent(contextPackage.Ident("Context")) streamIdent := g.QualifiedGoIdent(ssePackage.Ident("Stream")) + readerIdent := g.QualifiedGoIdent(ssePackage.Ident("Reader")) serverIdent := g.QualifiedGoIdent(khttpPackage.Ident("Server")) optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) + clientIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPClient")) + callOptionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamCallOption")) g.P("type ", serviceName, "SSEServer interface {") for _, item := range methods { @@ -158,6 +161,23 @@ func genService(g *protogen.GeneratedFile, service *protogen.Service, methods [] g.P("}") g.P() + g.P("type ", serviceName, "SSEClient interface {") + for _, item := range methods { + g.P(item.method.GoName, "(", contextIdent, ", *", item.method.Input.GoIdent, ", ...", callOptionIdent, ") (*", readerIdent, ", error)") + } + g.P("}") + g.P() + + g.P("type ", serviceName, "SSEClientImpl struct {") + g.P("cc *", clientIdent) + g.P("}") + g.P() + + g.P("func New", serviceName, "SSEClient(client *", clientIdent, ") ", serviceName, "SSEClient {") + g.P("return &", serviceName, "SSEClientImpl{client}") + g.P("}") + g.P() + for _, item := range methods { genMethod(g, item) } @@ -180,6 +200,17 @@ func genMethod(g *protogen.GeneratedFile, item sseMethod) { } g.P("}") g.P() + + callOptionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamCallOption")) + readerIdent := g.QualifiedGoIdent(ssePackage.Ident("Reader")) + bindingIdent := g.QualifiedGoIdent(protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http/binding").Ident("EncodeURL")) + route := item.routes[0] + g.P("func (c *", serviceName, "SSEClientImpl) ", methodName, "(ctx ", g.QualifiedGoIdent(contextPackage.Ident("Context")), ", in *", item.method.Input.GoIdent, ", opts ...", callOptionIdent, ") (*", readerIdent, ", error) {") + g.P("pattern := ", strconv.Quote(route.path)) + g.P("path := ", bindingIdent, "(pattern, in, true)") + g.P("return c.cc.Open(ctx, ", strconv.Quote(route.verb), ", path, opts...)") + g.P("}") + g.P() } func protocVersion(plugin *protogen.Plugin) string { diff --git a/sse/kratos/protoc-gen-go-sse/main_test.go b/sse/kratos/protoc-gen-go-sse/main_test.go index 2a333ca..5aafdbb 100644 --- a/sse/kratos/protoc-gen-go-sse/main_test.go +++ b/sse/kratos/protoc-gen-go-sse/main_test.go @@ -42,6 +42,11 @@ func TestGenerateFileUsesDefaultHTTPStreamBinding(t *testing.T) { for _, want := range []string{ `RegisterHTTPStream(s, "GET", "/v1/watch", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, `RegisterHTTPStream(s, "POST", "/v1/watch:tail", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + `type LiveServiceSSEClient interface`, + `func NewLiveServiceSSEClient(client *kratos.HTTPClient) LiveServiceSSEClient`, + `func (c *LiveServiceSSEClientImpl) Watch(ctx context.Context, in *WatchRequest, opts ...kratos.HTTPStreamCallOption) (*sse.Reader, error)`, + `path := binding.EncodeURL(pattern, in, true)`, + `return c.cc.Open(ctx, "GET", path, opts...)`, } { if !strings.Contains(got, want) { t.Fatalf("generated code missing %q:\n%s", want, got) diff --git a/sse/reader.go b/sse/reader.go new file mode 100644 index 0000000..100ce91 --- /dev/null +++ b/sse/reader.go @@ -0,0 +1,78 @@ +package sse + +import ( + "bufio" + "errors" + "io" + "strconv" + "strings" + "time" +) + +// Reader reads Server-Sent Events from an HTTP response body. +type Reader struct { + r *bufio.Reader + c io.Closer +} + +// NewReader returns an SSE event reader for r. +func NewReader(r io.Reader) *Reader { + er := &Reader{r: bufio.NewReader(r)} + if c, ok := r.(io.Closer); ok { + er.c = c + } + return er +} + +// Next blocks until the next complete event frame is read. +func (r *Reader) Next() (*Event, error) { + var ev Event + var data []string + seen := false + for { + line, err := r.r.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) && seen { + ev.Data = strings.Join(data, "\n") + return &ev, nil + } + return nil, err + } + line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r") + if line == "" { + if !seen { + continue + } + ev.Data = strings.Join(data, "\n") + return &ev, nil + } + if strings.HasPrefix(line, ":") { + continue + } + seen = true + field, value, ok := strings.Cut(line, ":") + if ok && strings.HasPrefix(value, " ") { + value = value[1:] + } + switch field { + case "event": + ev.Event = value + case "id": + ev.ID = value + case "data": + data = append(data, value) + case "retry": + if ms, err := strconv.ParseInt(value, 10, 64); err == nil { + ev.Retry = time.Duration(ms) * time.Millisecond + } + } + } +} + +// Close closes the underlying reader when it implements io.Closer. +func (r *Reader) Close() error { + if r.c == nil { + return nil + } + return r.c.Close() +} diff --git a/sse/reader_test.go b/sse/reader_test.go new file mode 100644 index 0000000..458dc87 --- /dev/null +++ b/sse/reader_test.go @@ -0,0 +1,23 @@ +package sse + +import ( + "errors" + "io" + "strings" + "testing" + "time" +) + +func TestReaderNext(t *testing.T) { + r := NewReader(strings.NewReader(": keepalive\nid: 42\nevent: point\nretry: 1500\ndata: {\"ok\":true}\ndata: tail\n\n")) + ev, err := r.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + if ev.ID != "42" || ev.Event != "point" || ev.Retry != 1500*time.Millisecond || ev.Data != "{\"ok\":true}\ntail" { + t.Fatalf("unexpected event: %#v", ev) + } + if _, err := r.Next(); !errors.Is(err, io.EOF) { + t.Fatalf("Next EOF = %v, want io.EOF", err) + } +} From 4c8711f53bd51fb53e47830500367b4a7a079fd5 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Fri, 15 May 2026 14:06:10 +0800 Subject: [PATCH 08/11] fix: correct sse http stream binding --- sse/kratos/http_client.go | 52 ++++++++++++-- sse/kratos/http_client_test.go | 63 ++++++++++++++++- sse/kratos/http_handler.go | 3 + sse/kratos/http_handler_test.go | 4 +- sse/kratos/protoc-gen-go-sse/main.go | 85 +++++++++++++++++++++-- sse/kratos/protoc-gen-go-sse/main_test.go | 31 ++++++--- sse/reader.go | 27 +++++-- sse/reader_test.go | 18 ++++- sse/sse.go | 8 ++- sse/sse_test.go | 14 ++++ 10 files changed, 269 insertions(+), 36 deletions(-) diff --git a/sse/kratos/http_client.go b/sse/kratos/http_client.go index 7c10c1f..af7b22c 100644 --- a/sse/kratos/http_client.go +++ b/sse/kratos/http_client.go @@ -1,12 +1,16 @@ package kratos import ( + "bytes" "context" "fmt" + "io" "net/http" "net/url" "strings" + "github.com/go-kratos/kratos/v2/encoding" + "github.com/crypto-zero/go-kit/sse" ) @@ -22,9 +26,10 @@ type HTTPClientOption func(*HTTPClient) // WithHTTPClient sets the underlying net/http client. func WithHTTPClient(client *http.Client) HTTPClientOption { return func(c *HTTPClient) { - if client != nil { - c.client = client + if client == nil { + panic("sse/kratos: nil HTTP client") } + c.client = client } } @@ -56,8 +61,9 @@ func WithLastEventID(id string) HTTPStreamCallOption { return WithRequestHeader(sse.LastEventIDHeader, id) } -// Open opens an SSE stream and returns a reader for response events. -func (c *HTTPClient) Open(ctx context.Context, method, path string, opts ...HTTPStreamCallOption) (*sse.Reader, error) { +// Open opens an SSE stream and returns a reader for response events. Non-nil +// body values are JSON encoded unless body already implements io.Reader. +func (c *HTTPClient) Open(ctx context.Context, method, path string, body any, opts ...HTTPStreamCallOption) (*sse.Reader, error) { cfg := httpStreamCallConfig{headers: make(http.Header)} for _, opt := range opts { opt(&cfg) @@ -66,11 +72,18 @@ func (c *HTTPClient) Open(ctx context.Context, method, path string, opts ...HTTP if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, method, u, nil) + bodyReader, contentType, err := encodeRequestBody(body) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, method, u, bodyReader) if err != nil { return nil, err } req.Header.Set("Accept", "text/event-stream") + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } for key, values := range cfg.headers { for _, value := range values { req.Header.Add(key, value) @@ -82,11 +95,30 @@ func (c *HTTPClient) Open(ctx context.Context, method, path string, opts ...HTTP } if resp.StatusCode < http.StatusOK || resp.StatusCode > 299 { defer resp.Body.Close() - return nil, fmt.Errorf("sse: unexpected status %d", resp.StatusCode) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + msg := strings.TrimSpace(string(body)) + if msg == "" { + return nil, fmt.Errorf("sse: unexpected status %d", resp.StatusCode) + } + return nil, fmt.Errorf("sse: unexpected status %d: %s", resp.StatusCode, msg) } return sse.NewReader(resp.Body), nil } +func encodeRequestBody(body any) (io.Reader, string, error) { + if body == nil { + return nil, "", nil + } + if r, ok := body.(io.Reader); ok { + return r, "", nil + } + b, err := encoding.GetCodec("json").Marshal(body) + if err != nil { + return nil, "", fmt.Errorf("sse: marshal request body: %w", err) + } + return bytes.NewReader(b), "application/json", nil +} + func joinEndpointPath(endpoint, path string) (string, error) { if endpoint == "" { return "", fmt.Errorf("sse: endpoint is empty") @@ -99,5 +131,11 @@ func joinEndpointPath(endpoint, path string) (string, error) { if err != nil { return "", err } - return base.ResolveReference(ref).String(), nil + if ref.IsAbs() { + return ref.String(), nil + } + base.Path = strings.TrimRight(base.Path, "/") + "/" + strings.TrimLeft(ref.Path, "/") + base.RawQuery = ref.RawQuery + base.Fragment = ref.Fragment + return base.String(), nil } diff --git a/sse/kratos/http_client_test.go b/sse/kratos/http_client_test.go index 77d8e20..f40d90e 100644 --- a/sse/kratos/http_client_test.go +++ b/sse/kratos/http_client_test.go @@ -2,8 +2,10 @@ package kratos_test import ( "context" + "io" "net/http" "net/http/httptest" + "strings" "testing" ksse "github.com/crypto-zero/go-kit/sse" @@ -12,19 +14,24 @@ import ( func TestHTTPClientOpen(t *testing.T) { var gotLastEventID string + var gotPath string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotLastEventID = r.Header.Get(ksse.LastEventIDHeader) + gotPath = r.URL.RequestURI() w.Header().Set("Content-Type", "text/event-stream") _, _ = w.Write([]byte("id: 7\nevent: point\ndata: {}\n\n")) })) defer ts.Close() - client := ssekratos.NewHTTPClient(ts.URL) - reader, err := client.Open(context.Background(), http.MethodGet, "/v1/watch?west=1", ssekratos.WithLastEventID("6")) + client := ssekratos.NewHTTPClient(ts.URL + "/api") + reader, err := client.Open(context.Background(), http.MethodGet, "/v1/watch?west=1", nil, ssekratos.WithLastEventID("6")) if err != nil { t.Fatalf("Open: %v", err) } defer reader.Close() + if gotPath != "/api/v1/watch?west=1" { + t.Fatalf("path = %q, want /api/v1/watch?west=1", gotPath) + } if gotLastEventID != "6" { t.Fatalf("Last-Event-ID = %q, want 6", gotLastEventID) } @@ -36,3 +43,55 @@ func TestHTTPClientOpen(t *testing.T) { t.Fatalf("unexpected event: %#v", ev) } } + +func TestHTTPClientOpenSendsJSONBody(t *testing.T) { + var gotBody string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + gotBody = string(body) + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Fatalf("Content-Type = %q, want application/json", got) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: ok\n\n")) + })) + defer ts.Close() + + client := ssekratos.NewHTTPClient(ts.URL) + reader, err := client.Open(context.Background(), http.MethodPost, "/v1/watch", map[string]string{"name": "alice"}) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer reader.Close() + if gotBody != `{"name":"alice"}` { + t.Fatalf("body = %q, want JSON payload", gotBody) + } +} + +func TestHTTPClientOpenStatusErrorIncludesBody(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"bad token"}`, http.StatusUnauthorized) + })) + defer ts.Close() + + client := ssekratos.NewHTTPClient(ts.URL) + _, err := client.Open(context.Background(), http.MethodGet, "/v1/watch", nil) + if err == nil { + t.Fatal("Open err = nil, want error") + } + if got := err.Error(); !strings.Contains(got, "401") || !strings.Contains(got, "bad token") { + t.Fatalf("Open err = %q, want status and body", got) + } +} + +func TestWithHTTPClientNilPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("NewHTTPClient did not panic") + } + }() + _ = ssekratos.NewHTTPClient("http://example.com", ssekratos.WithHTTPClient(nil)) +} diff --git a/sse/kratos/http_handler.go b/sse/kratos/http_handler.go index 0760e13..5496b8b 100644 --- a/sse/kratos/http_handler.go +++ b/sse/kratos/http_handler.go @@ -86,6 +86,9 @@ func registerHTTPStream[Req any]( } func bindHTTPStreamRequest[Req any](ctx khttp.Context, target *Req) error { + if err := ctx.BindVars(target); err != nil { + return err + } switch ctx.Request().Method { case http.MethodGet, http.MethodHead: return ctx.BindQuery(target) diff --git a/sse/kratos/http_handler_test.go b/sse/kratos/http_handler_test.go index bddeaaa..7ac8bf5 100644 --- a/sse/kratos/http_handler_test.go +++ b/sse/kratos/http_handler_test.go @@ -20,7 +20,7 @@ import ( func TestHTTPStreamHandler_BindsProtoQueryAndStreamsOnKratosHTTP(t *testing.T) { srv := khttp.NewServer(khttp.Timeout(0)) - ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/duration", "/test.Duration/Watch", + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/duration/{seconds}", "/test.Duration/Watch", func(_ context.Context, req *durationpb.Duration, st *sse.Stream) error { if req.GetSeconds() != 12 || req.GetNanos() != 34 { t.Fatalf("request = %ds/%dns, want 12s/34ns", req.GetSeconds(), req.GetNanos()) @@ -32,7 +32,7 @@ func TestHTTPStreamHandler_BindsProtoQueryAndStreamsOnKratosHTTP(t *testing.T) { ts := httptest.NewServer(srv) defer ts.Close() - resp, err := http.Get(ts.URL + "/v1/duration?seconds=12&nanos=34") + resp, err := http.Get(ts.URL + "/v1/duration/12?nanos=34") if err != nil { t.Fatalf("GET: %v", err) } diff --git a/sse/kratos/protoc-gen-go-sse/main.go b/sse/kratos/protoc-gen-go-sse/main.go index 68ef1c0..9d524d2 100644 --- a/sse/kratos/protoc-gen-go-sse/main.go +++ b/sse/kratos/protoc-gen-go-sse/main.go @@ -44,6 +44,7 @@ type sseMethod struct { type sseRoute struct { verb string path string + body string } func generateFile(plugin *protogen.Plugin, file *protogen.File) { @@ -101,7 +102,7 @@ func collectSSEMethods(file *protogen.File) []sseMethod { if verb == "" || path == "" { continue } - routes = append(routes, sseRoute{verb: verb, path: path}) + routes = append(routes, sseRoute{verb: verb, path: path, body: binding.GetBody()}) } if len(routes) == 0 { continue @@ -146,6 +147,7 @@ func genService(g *protogen.GeneratedFile, service *protogen.Service, methods [] optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) clientIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPClient")) callOptionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamCallOption")) + clientImpl := unexport(serviceName) + "SSEClient" g.P("type ", serviceName, "SSEServer interface {") for _, item := range methods { @@ -168,13 +170,13 @@ func genService(g *protogen.GeneratedFile, service *protogen.Service, methods [] g.P("}") g.P() - g.P("type ", serviceName, "SSEClientImpl struct {") + g.P("type ", clientImpl, " struct {") g.P("cc *", clientIdent) g.P("}") g.P() g.P("func New", serviceName, "SSEClient(client *", clientIdent, ") ", serviceName, "SSEClient {") - g.P("return &", serviceName, "SSEClientImpl{client}") + g.P("return &", clientImpl, "{cc: client}") g.P("}") g.P() @@ -204,15 +206,86 @@ func genMethod(g *protogen.GeneratedFile, item sseMethod) { callOptionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamCallOption")) readerIdent := g.QualifiedGoIdent(ssePackage.Ident("Reader")) bindingIdent := g.QualifiedGoIdent(protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http/binding").Ident("EncodeURL")) + clientImpl := unexport(serviceName) + "SSEClient" + // Match protoc-gen-go-http: generated clients call the primary + // google.api.http binding; additional_bindings are server aliases. route := item.routes[0] - g.P("func (c *", serviceName, "SSEClientImpl) ", methodName, "(ctx ", g.QualifiedGoIdent(contextPackage.Ident("Context")), ", in *", item.method.Input.GoIdent, ", opts ...", callOptionIdent, ") (*", readerIdent, ", error) {") + bodyArg := "nil" + needQuery := "true" + if route.body != "" { + bodyArg = bodyExpr(route.body) + needQuery = "false" + } + g.P("func (c *", clientImpl, ") ", methodName, "(ctx ", g.QualifiedGoIdent(contextPackage.Ident("Context")), ", in *", item.method.Input.GoIdent, ", opts ...", callOptionIdent, ") (*", readerIdent, ", error) {") g.P("pattern := ", strconv.Quote(route.path)) - g.P("path := ", bindingIdent, "(pattern, in, true)") - g.P("return c.cc.Open(ctx, ", strconv.Quote(route.verb), ", path, opts...)") + g.P("path := ", bindingIdent, "(pattern, in, ", needQuery, ")") + g.P("return c.cc.Open(ctx, ", strconv.Quote(route.verb), ", path, ", bodyArg, ", opts...)") g.P("}") g.P() } +func unexport(s string) string { + if s == "" { + return "" + } + return strings.ToLower(s[:1]) + s[1:] +} + +func bodyExpr(body string) string { + if body == "*" { + return "in" + } + return "in." + camelCaseVars(body) +} + +func camelCaseVars(s string) string { + subs := strings.Split(s, ".") + vars := make([]string, 0, len(subs)) + for _, sub := range subs { + vars = append(vars, camelCase(sub)) + } + return strings.Join(vars, ".") +} + +func camelCase(s string) string { + if s == "" { + return "" + } + t := make([]byte, 0, 32) + i := 0 + if s[0] == '_' { + t = append(t, 'X') + i++ + } + for ; i < len(s); i++ { + c := s[i] + if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) { + continue + } + if isASCIIDigit(c) { + t = append(t, c) + continue + } + if isASCIILower(c) { + c ^= ' ' + } + t = append(t, c) + for i+1 < len(s) && isASCIILower(s[i+1]) { + i++ + t = append(t, s[i]) + } + } + return string(t) +} + +func isASCIILower(c byte) bool { + return 'a' <= c && c <= 'z' +} + +func isASCIIDigit(c byte) bool { + return '0' <= c && c <= '9' +} + func protocVersion(plugin *protogen.Plugin) string { v := plugin.Request.GetCompilerVersion() if v == nil { diff --git a/sse/kratos/protoc-gen-go-sse/main_test.go b/sse/kratos/protoc-gen-go-sse/main_test.go index 5aafdbb..ae0923d 100644 --- a/sse/kratos/protoc-gen-go-sse/main_test.go +++ b/sse/kratos/protoc-gen-go-sse/main_test.go @@ -40,13 +40,13 @@ func TestGenerateFileUsesDefaultHTTPStreamBinding(t *testing.T) { t.Fatalf("register function count = %d, want 1:\n%s", count, got) } for _, want := range []string{ - `RegisterHTTPStream(s, "GET", "/v1/watch", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, - `RegisterHTTPStream(s, "POST", "/v1/watch:tail", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + `RegisterHTTPStream(s, "POST", "/v1/watch", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + `RegisterHTTPStream(s, "GET", "/v1/watch:tail", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, `type LiveServiceSSEClient interface`, `func NewLiveServiceSSEClient(client *kratos.HTTPClient) LiveServiceSSEClient`, - `func (c *LiveServiceSSEClientImpl) Watch(ctx context.Context, in *WatchRequest, opts ...kratos.HTTPStreamCallOption) (*sse.Reader, error)`, - `path := binding.EncodeURL(pattern, in, true)`, - `return c.cc.Open(ctx, "GET", path, opts...)`, + `func (c *liveServiceSSEClient) Watch(ctx context.Context, in *WatchRequest, opts ...kratos.HTTPStreamCallOption) (*sse.Reader, error)`, + `path := binding.EncodeURL(pattern, in, false)`, + `return c.cc.Open(ctx, "POST", path, in, opts...)`, } { if !strings.Contains(got, want) { t.Fatalf("generated code missing %q:\n%s", want, got) @@ -54,16 +54,31 @@ func TestGenerateFileUsesDefaultHTTPStreamBinding(t *testing.T) { } } +func TestBodyExpr(t *testing.T) { + for _, tc := range []struct { + body string + want string + }{ + {body: "*", want: "in"}, + {body: "payload", want: "in.Payload"}, + {body: "filter_box.zoom_level", want: "in.FilterBox.ZoomLevel"}, + } { + if got := bodyExpr(tc.body); got != tc.want { + t.Fatalf("bodyExpr(%q) = %q, want %q", tc.body, got, tc.want) + } + } +} + func testCodeGeneratorRequest(t *testing.T) *pluginpb.CodeGeneratorRequest { t.Helper() opts := &descriptorpb.MethodOptions{} proto.SetExtension(opts, ssev1.E_ServerSentEvent, true) proto.SetExtension(opts, annotations.E_Http, &annotations.HttpRule{ - Pattern: &annotations.HttpRule_Get{Get: "/v1/watch"}, + Pattern: &annotations.HttpRule_Post{Post: "/v1/watch"}, + Body: "*", AdditionalBindings: []*annotations.HttpRule{{ - Pattern: &annotations.HttpRule_Post{Post: "/v1/watch:tail"}, - Body: "*", + Pattern: &annotations.HttpRule_Get{Get: "/v1/watch:tail"}, }}, }) diff --git a/sse/reader.go b/sse/reader.go index 100ce91..55000df 100644 --- a/sse/reader.go +++ b/sse/reader.go @@ -11,8 +11,9 @@ import ( // Reader reads Server-Sent Events from an HTTP response body. type Reader struct { - r *bufio.Reader - c io.Closer + r *bufio.Reader + c io.Closer + lastID string } // NewReader returns an SSE event reader for r. @@ -29,12 +30,12 @@ func (r *Reader) Next() (*Event, error) { var ev Event var data []string seen := false + idSeen := false for { line, err := r.r.ReadString('\n') if err != nil { if errors.Is(err, io.EOF) && seen { - ev.Data = strings.Join(data, "\n") - return &ev, nil + return r.finishEvent(ev, data, idSeen), nil } return nil, err } @@ -43,15 +44,16 @@ func (r *Reader) Next() (*Event, error) { if !seen { continue } - ev.Data = strings.Join(data, "\n") - return &ev, nil + return r.finishEvent(ev, data, idSeen), nil } if strings.HasPrefix(line, ":") { continue } seen = true field, value, ok := strings.Cut(line, ":") - if ok && strings.HasPrefix(value, " ") { + if !ok { + value = "" + } else if strings.HasPrefix(value, " ") { value = value[1:] } switch field { @@ -59,6 +61,7 @@ func (r *Reader) Next() (*Event, error) { ev.Event = value case "id": ev.ID = value + idSeen = true case "data": data = append(data, value) case "retry": @@ -69,6 +72,16 @@ func (r *Reader) Next() (*Event, error) { } } +func (r *Reader) finishEvent(ev Event, data []string, idSeen bool) *Event { + if idSeen { + r.lastID = ev.ID + } else { + ev.ID = r.lastID + } + ev.Data = strings.Join(data, "\n") + return &ev +} + // Close closes the underlying reader when it implements io.Closer. func (r *Reader) Close() error { if r.c == nil { diff --git a/sse/reader_test.go b/sse/reader_test.go index 458dc87..e20981f 100644 --- a/sse/reader_test.go +++ b/sse/reader_test.go @@ -9,7 +9,9 @@ import ( ) func TestReaderNext(t *testing.T) { - r := NewReader(strings.NewReader(": keepalive\nid: 42\nevent: point\nretry: 1500\ndata: {\"ok\":true}\ndata: tail\n\n")) + r := NewReader(strings.NewReader(": keepalive\nid: 42\nevent: point\nretry: 1500\ndata: {\"ok\":true}\ndata: tail\n\n" + + "event: next\ndata\n\n" + + "id:\nevent: reset\ndata: done\n\n")) ev, err := r.Next() if err != nil { t.Fatalf("Next: %v", err) @@ -17,6 +19,20 @@ func TestReaderNext(t *testing.T) { if ev.ID != "42" || ev.Event != "point" || ev.Retry != 1500*time.Millisecond || ev.Data != "{\"ok\":true}\ntail" { t.Fatalf("unexpected event: %#v", ev) } + ev, err = r.Next() + if err != nil { + t.Fatalf("Next sticky ID: %v", err) + } + if ev.ID != "42" || ev.Event != "next" || ev.Data != "" { + t.Fatalf("unexpected sticky/no-colon event: %#v", ev) + } + ev, err = r.Next() + if err != nil { + t.Fatalf("Next reset ID: %v", err) + } + if ev.ID != "" || ev.Event != "reset" || ev.Data != "done" { + t.Fatalf("unexpected reset event: %#v", ev) + } if _, err := r.Next(); !errors.Is(err, io.EOF) { t.Fatalf("Next EOF = %v, want io.EOF", err) } diff --git a/sse/sse.go b/sse/sse.go index 8a39905..4846b08 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -234,6 +234,7 @@ func (s *Stream) Heartbeat(ctx context.Context, interval time.Duration) (stop fu // - chunks is closed: Done is sent and nil is returned; // - errs delivers a non-nil error: that error is forwarded via Error and // returned (no Done is sent); +// - errs delivers nil: returns nil silently (no Done); // - errs is closed: returns nil silently (no Done); // - ctx is cancelled: returns ctx.Err() silently (no Done). // @@ -256,10 +257,11 @@ func (s *Stream) Pump(ctx context.Context, chunks <-chan string, errs <-chan err if !ok { return nil } - if err != nil { - _ = s.Error(err.Error()) - return err + if err == nil { + return nil } + _ = s.Error(err.Error()) + return err case <-ctx.Done(): return ctx.Err() } diff --git a/sse/sse_test.go b/sse/sse_test.go index bb67316..d021f5b 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -169,6 +169,20 @@ func TestPump_ErrChClosedSilently(t *testing.T) { } } +func TestPump_NilErrReturnsSilently(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string) + errs := make(chan error, 1) + errs <- nil + if err := s.Pump(context.Background(), chunks, errs); err != nil { + t.Errorf("Pump on nil err = %v, want nil", err) + } + if got := rec.Body.String(); got != "" { + t.Errorf("body = %q, want empty", got) + } +} + func TestPump_ContextCancel(t *testing.T) { rec := httptest.NewRecorder() s := NewStream(rec) From 4b08c1ce5f0e010abdabc300ae8e823cee7c54fa Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Wed, 27 May 2026 17:15:13 +0800 Subject: [PATCH 09/11] feat: add kratos adapters and ratelimit --- .gitignore | 3 +- .../protoc-gen-go-redact/README.md | 6 +- .../internal/redact/benchmark_test.go | 2 +- .../internal/redact/integration_test.go | 2 +- .../internal/redact/redact.go | 0 .../internal/redact/redact_test.go | 0 .../redact/testdata/crossfile/container.pb.go | 4 +- .../redact/testdata/crossfile/container.proto | 2 +- .../testdata/crossfile/container_redact.pb.go | 2 +- .../testdata/crossfile/crossfile_test.go | 0 .../redact/testdata/crossfile/sensitive.pb.go | 4 +- .../redact/testdata/crossfile/sensitive.proto | 2 +- .../testdata/crossfile/sensitive_redact.pb.go | 2 +- .../internal/redact/testdata/example.pb.go | 4 +- .../internal/redact/testdata/example.proto | 2 +- .../redact/testdata/example_redact.pb.go | 2 +- .../internal/redact/testdata/generate.sh | 0 .../internal/redact/version.go | 0 {logging => cmd}/protoc-gen-go-redact/main.go | 2 +- {sse/kratos => cmd}/protoc-gen-go-sse/main.go | 2 +- .../protoc-gen-go-sse/main_test.go | 4 +- {errors => cmd}/protoc-gen-kit-errors/main.go | 0 go.mod | 2 +- {auth/kratos => kratos/auth}/policy.go | 62 +---- {auth/kratos => kratos/auth}/policy_test.go | 2 +- kratos/clientip/clientip.go | 97 +++++++ kratos/clientip/clientip_test.go | 66 +++++ {sse/kratos => kratos}/go.mod | 6 +- {sse/kratos => kratos}/go.sum | 4 +- kratos/internal/protoop/protoop.go | 79 ++++++ {logging/kratos => kratos/logging}/logging.go | 111 +------- .../kratos => kratos/logging}/logging_test.go | 20 +- kratos/ratelimit/policy.go | 147 ++++++++++ kratos/ratelimit/ratelimit.go | 142 ++++++++++ kratos/ratelimit/ratelimit_test.go | 183 +++++++++++++ kratos/sse/doc.go | 2 + {sse/kratos => kratos/sse}/http_client.go | 4 +- .../kratos => kratos/sse}/http_client_test.go | 4 +- {sse/kratos => kratos/sse}/http_handler.go | 2 +- .../sse}/http_handler_test.go | 4 +- logging/kratos/go.mod | 20 -- logging/kratos/go.sum | 48 ---- proto/buf.gen.yaml | 4 +- proto/kit/ratelimit/v1/ratelimit.pb.go | 243 +++++++++++++++++ proto/kit/ratelimit/v1/ratelimit.proto | 34 +++ ratelimit/ratelimit.go | 258 ++++++++++++++++++ ratelimit/ratelimit_test.go | 132 +++++++++ sse/kratos/doc.go | 2 - sse/sse.go | 2 +- 49 files changed, 1443 insertions(+), 282 deletions(-) rename {logging => cmd}/protoc-gen-go-redact/README.md (98%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/benchmark_test.go (98%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/integration_test.go (99%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/redact.go (100%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/redact_test.go (100%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go (98%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto (88%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go (99%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go (100%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go (97%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto (83%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go (97%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/example.pb.go (99%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/example.proto (99%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go (99%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/testdata/generate.sh (100%) rename {logging => cmd}/protoc-gen-go-redact/internal/redact/version.go (100%) rename {logging => cmd}/protoc-gen-go-redact/main.go (92%) rename {sse/kratos => cmd}/protoc-gen-go-sse/main.go (99%) rename {sse/kratos => cmd}/protoc-gen-go-sse/main_test.go (95%) rename {errors => cmd}/protoc-gen-kit-errors/main.go (100%) rename {auth/kratos => kratos/auth}/policy.go (51%) rename {auth/kratos => kratos/auth}/policy_test.go (99%) create mode 100644 kratos/clientip/clientip.go create mode 100644 kratos/clientip/clientip_test.go rename {sse/kratos => kratos}/go.mod (83%) rename {sse/kratos => kratos}/go.sum (96%) create mode 100644 kratos/internal/protoop/protoop.go rename {logging/kratos => kratos/logging}/logging.go (69%) rename {logging/kratos => kratos/logging}/logging_test.go (96%) create mode 100644 kratos/ratelimit/policy.go create mode 100644 kratos/ratelimit/ratelimit.go create mode 100644 kratos/ratelimit/ratelimit_test.go create mode 100644 kratos/sse/doc.go rename {sse/kratos => kratos/sse}/http_client.go (98%) rename {sse/kratos => kratos/sse}/http_client_test.go (97%) rename {sse/kratos => kratos/sse}/http_handler.go (99%) rename {sse/kratos => kratos/sse}/http_handler_test.go (98%) delete mode 100644 logging/kratos/go.mod delete mode 100644 logging/kratos/go.sum create mode 100644 proto/kit/ratelimit/v1/ratelimit.pb.go create mode 100644 proto/kit/ratelimit/v1/ratelimit.proto create mode 100644 ratelimit/ratelimit.go create mode 100644 ratelimit/ratelimit_test.go delete mode 100644 sse/kratos/doc.go diff --git a/.gitignore b/.gitignore index 0981e25..06ef959 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ vendor/ # Go workspace file go.work +go.work.sum # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o @@ -37,4 +38,4 @@ bin/ logging/protoc-gen-go-redact/protoc-gen-go-redact # cursor -.cursor \ No newline at end of file +.cursor diff --git a/logging/protoc-gen-go-redact/README.md b/cmd/protoc-gen-go-redact/README.md similarity index 98% rename from logging/protoc-gen-go-redact/README.md rename to cmd/protoc-gen-go-redact/README.md index 8869fec..8c32920 100644 --- a/logging/protoc-gen-go-redact/README.md +++ b/cmd/protoc-gen-go-redact/README.md @@ -16,13 +16,13 @@ A protoc plugin that generates `Redact()` methods for Protocol Buffer messages t ## Installation ```bash -go install github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact@latest +go install github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact@latest ``` Or build from source: ```bash -cd logging/protoc-gen-go-redact +cd cmd/protoc-gen-go-redact go build -o protoc-gen-go-redact . ``` @@ -273,7 +273,7 @@ type Redacter interface { Use with the logging middleware: ```go -import "github.com/crypto-zero/go-kit/logging/kratos" +import "github.com/crypto-zero/go-kit/kratos/logging" // Server middleware srv := http.NewServer( diff --git a/logging/protoc-gen-go-redact/internal/redact/benchmark_test.go b/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go similarity index 98% rename from logging/protoc-gen-go-redact/internal/redact/benchmark_test.go rename to cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go index db8a90d..360e6e2 100644 --- a/logging/protoc-gen-go-redact/internal/redact/benchmark_test.go +++ b/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata" + "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata" "google.golang.org/protobuf/types/known/timestamppb" ) diff --git a/logging/protoc-gen-go-redact/internal/redact/integration_test.go b/cmd/protoc-gen-go-redact/internal/redact/integration_test.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/integration_test.go rename to cmd/protoc-gen-go-redact/internal/redact/integration_test.go index e7fbff2..9a0dfdb 100644 --- a/logging/protoc-gen-go-redact/internal/redact/integration_test.go +++ b/cmd/protoc-gen-go-redact/internal/redact/integration_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata" + "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" diff --git a/logging/protoc-gen-go-redact/internal/redact/redact.go b/cmd/protoc-gen-go-redact/internal/redact/redact.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/redact.go rename to cmd/protoc-gen-go-redact/internal/redact/redact.go diff --git a/logging/protoc-gen-go-redact/internal/redact/redact_test.go b/cmd/protoc-gen-go-redact/internal/redact/redact_test.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/redact_test.go rename to cmd/protoc-gen-go-redact/internal/redact/redact_test.go diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go similarity index 98% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go index 48d2c7b..7ec6c0d 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.1 // source: crossfile/container.proto package crossfile @@ -238,7 +238,7 @@ const file_crossfile_container_proto_rawDesc = "" + "\bdata_map\x18\x01 \x03(\v2-.testdata.crossfile.MapContainer.DataMapEntryR\adataMap\x1a]\n" + "\fDataMapEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x127\n" + - "\x05value\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x05value:\x028\x01B_Z]github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" + "\x05value\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x05value:\x028\x01B[ZYgithub.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" var ( file_crossfile_container_proto_rawDescOnce sync.Once diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto similarity index 88% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto index b3c96ed..e4e9552 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto @@ -3,7 +3,7 @@ syntax = "proto3"; package testdata.crossfile; -option go_package = "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile"; +option go_package = "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile"; import "crossfile/sensitive.proto"; diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go index e401d68..8d8ce9e 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-redact. DO NOT EDIT. // versions: // - protoc-gen-go-redact v1.1.0 -// - protoc v6.33.2 +// - protoc v7.34.1 // source: crossfile/container.proto package crossfile diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go similarity index 97% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go index 78c8fa6..3334f37 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go @@ -3,7 +3,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.1 // source: crossfile/sensitive.proto package crossfile @@ -150,7 +150,7 @@ const file_crossfile_sensitive_proto_rawDesc = "" + "\x05phone\x18\x03 \x01(\tB\x0f¢3\v\b\x01\x12\a[PHONE]R\x05phone\"\\\n" + "\x0fNestedSensitive\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x125\n" + - "\x04data\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x04dataB_Z]github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" + "\x04data\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x04dataB[ZYgithub.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" var ( file_crossfile_sensitive_proto_rawDescOnce sync.Once diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto similarity index 83% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto index 9720665..816ee68 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package testdata.crossfile; -option go_package = "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile"; +option go_package = "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile"; import "kit/redact/v1/redact.proto"; diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go similarity index 97% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go index 27cbca1..52abbfb 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-redact. DO NOT EDIT. // versions: // - protoc-gen-go-redact v1.1.0 -// - protoc v6.33.2 +// - protoc v7.34.1 // source: crossfile/sensitive.proto package crossfile diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/example.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.pb.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/example.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/example.pb.go index 0c99200..615f113 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/example.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.1 // source: example.proto package testdata @@ -3924,7 +3924,7 @@ const file_example_proto_rawDesc = "" + "\fPRIORITY_LOW\x10\x01\x12\x13\n" + "\x0fPRIORITY_MEDIUM\x10\x02\x12\x11\n" + "\rPRIORITY_HIGH\x10\x03\x12\x15\n" + - "\x11PRIORITY_CRITICAL\x10\x04BEZCgithub.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/testdatab\x06proto3" + "\x11PRIORITY_CRITICAL\x10\x04BAZ?github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/testdatab\x06proto3" var ( file_example_proto_rawDescOnce sync.Once diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/example.proto b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.proto similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/example.proto rename to cmd/protoc-gen-go-redact/internal/redact/testdata/example.proto index 45e9035..b5e25b4 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/example.proto +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.proto @@ -9,7 +9,7 @@ import "google/protobuf/any.proto"; import "google/protobuf/struct.proto"; import "kit/redact/v1/redact.proto"; -option go_package = "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/testdata"; +option go_package = "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/testdata"; // ============================================================================ // Enum Definitions diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go index 3451665..1f60ff0 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-redact. DO NOT EDIT. // versions: // - protoc-gen-go-redact v1.1.0 -// - protoc v6.33.2 +// - protoc v7.34.1 // source: example.proto package testdata diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/generate.sh b/cmd/protoc-gen-go-redact/internal/redact/testdata/generate.sh similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/testdata/generate.sh rename to cmd/protoc-gen-go-redact/internal/redact/testdata/generate.sh diff --git a/logging/protoc-gen-go-redact/internal/redact/version.go b/cmd/protoc-gen-go-redact/internal/redact/version.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/version.go rename to cmd/protoc-gen-go-redact/internal/redact/version.go diff --git a/logging/protoc-gen-go-redact/main.go b/cmd/protoc-gen-go-redact/main.go similarity index 92% rename from logging/protoc-gen-go-redact/main.go rename to cmd/protoc-gen-go-redact/main.go index 51fccf9..0cb428f 100644 --- a/logging/protoc-gen-go-redact/main.go +++ b/cmd/protoc-gen-go-redact/main.go @@ -4,7 +4,7 @@ import ( "flag" "fmt" - "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact" + "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/types/pluginpb" ) diff --git a/sse/kratos/protoc-gen-go-sse/main.go b/cmd/protoc-gen-go-sse/main.go similarity index 99% rename from sse/kratos/protoc-gen-go-sse/main.go rename to cmd/protoc-gen-go-sse/main.go index 9d524d2..913e0e3 100644 --- a/sse/kratos/protoc-gen-go-sse/main.go +++ b/cmd/protoc-gen-go-sse/main.go @@ -19,7 +19,7 @@ var ( contextPackage = protogen.GoImportPath("context") khttpPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") ssePackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/sse") - kratosPackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/sse/kratos") + kratosPackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/kratos/sse") ) func main() { diff --git a/sse/kratos/protoc-gen-go-sse/main_test.go b/cmd/protoc-gen-go-sse/main_test.go similarity index 95% rename from sse/kratos/protoc-gen-go-sse/main_test.go rename to cmd/protoc-gen-go-sse/main_test.go index ae0923d..cbb96a0 100644 --- a/sse/kratos/protoc-gen-go-sse/main_test.go +++ b/cmd/protoc-gen-go-sse/main_test.go @@ -43,8 +43,8 @@ func TestGenerateFileUsesDefaultHTTPStreamBinding(t *testing.T) { `RegisterHTTPStream(s, "POST", "/v1/watch", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, `RegisterHTTPStream(s, "GET", "/v1/watch:tail", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, `type LiveServiceSSEClient interface`, - `func NewLiveServiceSSEClient(client *kratos.HTTPClient) LiveServiceSSEClient`, - `func (c *liveServiceSSEClient) Watch(ctx context.Context, in *WatchRequest, opts ...kratos.HTTPStreamCallOption) (*sse.Reader, error)`, + `func NewLiveServiceSSEClient(client *sse1.HTTPClient) LiveServiceSSEClient`, + `func (c *liveServiceSSEClient) Watch(ctx context.Context, in *WatchRequest, opts ...sse1.HTTPStreamCallOption) (*sse.Reader, error)`, `path := binding.EncodeURL(pattern, in, false)`, `return c.cc.Open(ctx, "POST", path, in, opts...)`, } { diff --git a/errors/protoc-gen-kit-errors/main.go b/cmd/protoc-gen-kit-errors/main.go similarity index 100% rename from errors/protoc-gen-kit-errors/main.go rename to cmd/protoc-gen-kit-errors/main.go diff --git a/go.mod b/go.mod index 3170641..6ce7545 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( go.uber.org/zap v1.27.1 go.uber.org/zap/exp v0.3.0 golang.org/x/text v0.32.0 + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 google.golang.org/grpc v1.77.0 google.golang.org/protobuf v1.36.11 @@ -71,6 +72,5 @@ require ( golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/auth/kratos/policy.go b/kratos/auth/policy.go similarity index 51% rename from auth/kratos/policy.go rename to kratos/auth/policy.go index 7758396..56087db 100644 --- a/auth/kratos/policy.go +++ b/kratos/auth/policy.go @@ -1,12 +1,10 @@ -// Package kratos provides auth helpers for Kratos operation selectors. -package kratos +// Package auth provides auth helpers for Kratos operation selectors. +package auth import ( + "github.com/crypto-zero/go-kit/kratos/internal/protoop" authv1 "github.com/crypto-zero/go-kit/proto/kit/auth/v1" - "google.golang.org/protobuf/encoding/protowire" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/descriptorpb" ) // OperationPolicy reports whether a Kratos operation should run through @@ -55,59 +53,17 @@ func (p *OperationPolicy) RequiresAuth(operation string) bool { // OperationName returns the Kratos operation string for a proto method. func OperationName(m protoreflect.MethodDescriptor) string { - return "/" + string(m.Parent().FullName()) + "/" + string(m.Name()) + return protoop.OperationName(m) } func registerPublicFromFile(p *OperationPolicy, fd protoreflect.FileDescriptor) { - services := fd.Services() - for i := 0; i < services.Len(); i++ { - methods := services.Get(i).Methods() - for j := 0; j < methods.Len(); j++ { - m := methods.Get(j) - if methodIsPublic(m) { - p.public[OperationName(m)] = struct{}{} - } + protoop.WalkMethods([]protoreflect.FileDescriptor{fd}, func(m protoreflect.MethodDescriptor) { + if methodIsPublic(m) { + p.public[protoop.OperationName(m)] = struct{}{} } - } + }) } func methodIsPublic(m protoreflect.MethodDescriptor) bool { - opts, ok := m.Options().(*descriptorpb.MethodOptions) - if !ok || opts == nil { - return false - } - v := proto.GetExtension(opts, authv1.E_Public) - switch public := v.(type) { - case bool: - return public || methodOptionsUnknownBool(opts, authv1.E_Public.TypeDescriptor().Number()) - case *bool: - return (public != nil && *public) || methodOptionsUnknownBool(opts, authv1.E_Public.TypeDescriptor().Number()) - default: - return methodOptionsUnknownBool(opts, authv1.E_Public.TypeDescriptor().Number()) - } -} - -func methodOptionsUnknownBool(opts *descriptorpb.MethodOptions, number protoreflect.FieldNumber) bool { - raw := opts.ProtoReflect().GetUnknown() - for len(raw) > 0 { - num, typ, n := protowire.ConsumeTag(raw) - if n < 0 { - return false - } - raw = raw[n:] - if num != protowire.Number(number) { - n = protowire.ConsumeFieldValue(num, typ, raw) - if n < 0 { - return false - } - raw = raw[n:] - continue - } - if typ != protowire.VarintType { - return false - } - v, n := protowire.ConsumeVarint(raw) - return n >= 0 && v != 0 - } - return false + return protoop.BoolExtension(m, authv1.E_Public) } diff --git a/auth/kratos/policy_test.go b/kratos/auth/policy_test.go similarity index 99% rename from auth/kratos/policy_test.go rename to kratos/auth/policy_test.go index 4d77cf2..80b6a02 100644 --- a/auth/kratos/policy_test.go +++ b/kratos/auth/policy_test.go @@ -1,4 +1,4 @@ -package kratos +package auth import ( "testing" diff --git a/kratos/clientip/clientip.go b/kratos/clientip/clientip.go new file mode 100644 index 0000000..2cab7e3 --- /dev/null +++ b/kratos/clientip/clientip.go @@ -0,0 +1,97 @@ +// Package clientip extracts client IP addresses from Kratos request contexts. +package clientip + +import ( + "context" + "net" + "strings" + + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +// FromContext extracts the client IP address from a Kratos server context. +// +// Priority: X-Forwarded-For, X-Real-IP, then remote peer address. Returned +// values are normalized: IPv6 zone IDs are removed and IPv4-mapped IPv6 +// addresses are converted to IPv4. +func FromContext(ctx context.Context) string { + tr, ok := transport.FromServerContext(ctx) + if !ok { + return "" + } + if httpTr, ok := tr.(*kratoshttp.Transport); ok { + return fromHTTP(httpTr) + } + return fromGRPC(ctx) +} + +func fromHTTP(httpTr *kratoshttp.Transport) string { + req := httpTr.Request() + if req == nil { + return "" + } + if ip := extract(req.Header.Get("X-Forwarded-For")); ip != "" { + return ip + } + if ip := extract(req.Header.Get("X-Real-IP")); ip != "" { + return ip + } + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + return normalize(req.RemoteAddr) + } + return normalize(host) +} + +func fromGRPC(ctx context.Context) string { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if xff := md.Get("x-forwarded-for"); len(xff) > 0 { + if ip := extract(xff[0]); ip != "" { + return ip + } + } + if xrip := md.Get("x-real-ip"); len(xrip) > 0 { + if ip := extract(xrip[0]); ip != "" { + return ip + } + } + } + if p, ok := peer.FromContext(ctx); ok && p.Addr != nil { + host, _, err := net.SplitHostPort(p.Addr.String()) + if err != nil { + return normalize(p.Addr.String()) + } + return normalize(host) + } + return "" +} + +func extract(headerVal string) string { + if headerVal == "" { + return "" + } + if idx := strings.IndexByte(headerVal, ','); idx != -1 { + headerVal = headerVal[:idx] + } + return normalize(strings.TrimSpace(headerVal)) +} + +func normalize(ip string) string { + if ip == "" { + return "" + } + if idx := strings.IndexByte(ip, '%'); idx != -1 { + ip = ip[:idx] + } + parsed := net.ParseIP(ip) + if parsed == nil { + return "" + } + if ipv4 := parsed.To4(); ipv4 != nil { + return ipv4.String() + } + return parsed.String() +} diff --git a/kratos/clientip/clientip_test.go b/kratos/clientip/clientip_test.go new file mode 100644 index 0000000..e87d557 --- /dev/null +++ b/kratos/clientip/clientip_test.go @@ -0,0 +1,66 @@ +package clientip + +import ( + "context" + "net" + "testing" + + "github.com/go-kratos/kratos/v2/transport" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +type mockTransport struct{} + +func (m *mockTransport) Kind() transport.Kind { return transport.KindGRPC } +func (m *mockTransport) Endpoint() string { return "localhost:9000" } +func (m *mockTransport) Operation() string { return "/test.Service/Method" } +func (m *mockTransport) RequestHeader() transport.Header { + return &mockHeader{} +} +func (m *mockTransport) ReplyHeader() transport.Header { + return &mockHeader{} +} + +type mockHeader struct{} + +func (m *mockHeader) Get(string) string { return "" } +func (m *mockHeader) Set(string, string) {} +func (m *mockHeader) Add(string, string) {} +func (m *mockHeader) Keys() []string { return nil } +func (m *mockHeader) Values(string) []string { return nil } + +func TestFromContextUsesForwardedHeader(t *testing.T) { + ctx := transport.NewServerContext(context.Background(), &mockTransport{}) + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs( + "x-forwarded-for", " ::ffff:192.168.1.10, 10.0.0.1", + "x-real-ip", "192.168.1.11", + )) + + if got := FromContext(ctx); got != "192.168.1.10" { + t.Fatalf("FromContext() = %q, want forwarded IP", got) + } +} + +func TestFromContextFallsBackToPeer(t *testing.T) { + ctx := transport.NewServerContext(context.Background(), &mockTransport{}) + ctx = peer.NewContext(ctx, &peer.Peer{Addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 443}}) + + if got := FromContext(ctx); got != "2001:db8::1" { + t.Fatalf("FromContext() = %q, want peer IP", got) + } +} + +func TestFromContextRequiresServerTransport(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-real-ip", "192.168.1.1")) + + if got := FromContext(ctx); got != "" { + t.Fatalf("FromContext() = %q, want empty without server transport", got) + } +} + +func TestNormalizeRejectsInvalidIP(t *testing.T) { + if got := normalize("not-an-ip"); got != "" { + t.Fatalf("normalize() = %q, want empty", got) + } +} diff --git a/sse/kratos/go.mod b/kratos/go.mod similarity index 83% rename from sse/kratos/go.mod rename to kratos/go.mod index 6e48668..a4e4483 100644 --- a/sse/kratos/go.mod +++ b/kratos/go.mod @@ -1,4 +1,4 @@ -module github.com/crypto-zero/go-kit/sse/kratos +module github.com/crypto-zero/go-kit/kratos go 1.26.3 @@ -6,6 +6,7 @@ require ( github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac github.com/go-kratos/kratos/v2 v2.9.2 google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 + google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 ) @@ -18,8 +19,5 @@ require ( github.com/rogpeppe/go-internal v1.14.1 // indirect golang.org/x/sys v0.39.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 // indirect - google.golang.org/grpc v1.77.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -replace github.com/crypto-zero/go-kit => ../.. diff --git a/sse/kratos/go.sum b/kratos/go.sum similarity index 96% rename from sse/kratos/go.sum rename to kratos/go.sum index fd55f3e..c5daace 100644 --- a/sse/kratos/go.sum +++ b/kratos/go.sum @@ -39,8 +39,8 @@ google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1: google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 h1:2I6GHUeJ/4shcDpoUlLs/2WPnhg7yJwvXtqcMJt9liA= google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= -google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM= -google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig= +google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= +google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/kratos/internal/protoop/protoop.go b/kratos/internal/protoop/protoop.go new file mode 100644 index 0000000..b8a9107 --- /dev/null +++ b/kratos/internal/protoop/protoop.go @@ -0,0 +1,79 @@ +// Package protoop contains helpers for Kratos proto operation descriptors. +package protoop + +import ( + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +// WalkMethods calls fn for every method in files. +func WalkMethods(files []protoreflect.FileDescriptor, fn func(protoreflect.MethodDescriptor)) { + for _, fd := range files { + services := fd.Services() + for i := 0; i < services.Len(); i++ { + methods := services.Get(i).Methods() + for j := 0; j < methods.Len(); j++ { + fn(methods.Get(j)) + } + } + } +} + +// OperationName returns the Kratos operation string for a proto method. +func OperationName(m protoreflect.MethodDescriptor) string { + return "/" + string(m.Parent().FullName()) + "/" + string(m.Name()) +} + +// Extension returns the extension value from a proto method option. +func Extension(m protoreflect.MethodDescriptor, ext protoreflect.ExtensionType) (any, bool) { + opts, ok := m.Options().(*descriptorpb.MethodOptions) + if !ok || opts == nil { + return nil, false + } + return proto.GetExtension(opts, ext), true +} + +// BoolExtension reports whether a boolean extension is set to true. +func BoolExtension(m protoreflect.MethodDescriptor, ext protoreflect.ExtensionType) bool { + v, ok := Extension(m, ext) + if !ok { + return false + } + opts := m.Options().(*descriptorpb.MethodOptions) + number := ext.TypeDescriptor().Number() + switch value := v.(type) { + case bool: + return value || unknownBool(opts, number) + case *bool: + return (value != nil && *value) || unknownBool(opts, number) + default: + return unknownBool(opts, number) + } +} + +func unknownBool(opts *descriptorpb.MethodOptions, number protoreflect.FieldNumber) bool { + raw := opts.ProtoReflect().GetUnknown() + for len(raw) > 0 { + num, typ, n := protowire.ConsumeTag(raw) + if n < 0 { + return false + } + raw = raw[n:] + if num != protowire.Number(number) { + n = protowire.ConsumeFieldValue(num, typ, raw) + if n < 0 { + return false + } + raw = raw[n:] + continue + } + if typ != protowire.VarintType { + return false + } + v, n := protowire.ConsumeVarint(raw) + return n >= 0 && v != 0 + } + return false +} diff --git a/logging/kratos/logging.go b/kratos/logging/logging.go similarity index 69% rename from logging/kratos/logging.go rename to kratos/logging/logging.go index ac60472..192f1b1 100644 --- a/logging/kratos/logging.go +++ b/kratos/logging/logging.go @@ -5,17 +5,16 @@ import ( "encoding/json" "fmt" "log/slog" - "net" "reflect" "strings" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/peer" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "github.com/crypto-zero/go-kit/kratos/clientip" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" @@ -204,113 +203,7 @@ func extractError(err error) (slog.Level, string) { // Supports both HTTP and gRPC transports. // Returns normalized IP address (IPv4-mapped IPv6 converted to IPv4, Zone ID removed). func GetClientIP(ctx context.Context) string { - tr, ok := transport.FromServerContext(ctx) - if !ok { - return "" - } - - // Handle HTTP transport - if httpTr, ok := tr.(*kratoshttp.Transport); ok { - return getClientIPFromHTTP(httpTr) - } - - // Handle gRPC transport - return getClientIPFromGRPC(ctx) -} - -// getClientIPFromHTTP extracts client IP from HTTP request. -func getClientIPFromHTTP(httpTr *kratoshttp.Transport) string { - req := httpTr.Request() - if req == nil { - return "" - } - - if ip := extractIP(req.Header.Get("X-Forwarded-For")); ip != "" { - return ip - } - - if ip := extractIP(req.Header.Get("X-Real-IP")); ip != "" { - return ip - } - - host, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - return normalizeIP(req.RemoteAddr) - } - return normalizeIP(host) -} - -// getClientIPFromGRPC extracts client IP from gRPC context. -func getClientIPFromGRPC(ctx context.Context) string { - // Try to get forwarded headers from gRPC metadata first - if md, ok := metadata.FromIncomingContext(ctx); ok { - // Check X-Forwarded-For - if xff := md.Get("x-forwarded-for"); len(xff) > 0 { - if ip := extractIP(xff[0]); ip != "" { - return ip - } - } - // Check X-Real-IP - if xrip := md.Get("x-real-ip"); len(xrip) > 0 { - if ip := extractIP(xrip[0]); ip != "" { - return ip - } - } - } - - // Fall back to peer address - if p, ok := peer.FromContext(ctx); ok && p.Addr != nil { - host, _, err := net.SplitHostPort(p.Addr.String()) - if err != nil { - return normalizeIP(p.Addr.String()) - } - return normalizeIP(host) - } - - return "" -} - -// extractIP extracts the first valid IP from a header value (e.g., X-Forwarded-For). -// It handles comma-separated values (taking the first one) and applies normalization. -func extractIP(headerVal string) string { - if headerVal == "" { - return "" - } - // X-Forwarded-For can contain multiple IPs, the first one is the client IP. - // Use IndexByte instead of Split to avoid memory allocation. - if idx := strings.IndexByte(headerVal, ','); idx != -1 { - headerVal = headerVal[:idx] - } - return normalizeIP(strings.TrimSpace(headerVal)) -} - -// normalizeIP validates and normalizes an IP address string. -// It performs the following normalizations: -// - Removes IPv6 zone ID (e.g., "fe80::1%eth0" -> "fe80::1") -// - Converts IPv4-mapped IPv6 to IPv4 (e.g., "::ffff:192.168.1.1" -> "192.168.1.1") -// -// Returns empty string if the input is not a valid IP address. -func normalizeIP(ip string) string { - if ip == "" { - return "" - } - - // Remove IPv6 zone ID (e.g., fe80::1%eth0 -> fe80::1) - if idx := strings.IndexByte(ip, '%'); idx != -1 { - ip = ip[:idx] - } - - parsed := net.ParseIP(ip) - if parsed == nil { - return "" - } - - // Convert IPv4-mapped IPv6 to IPv4 (e.g., ::ffff:192.168.1.1 -> 192.168.1.1) - if ipv4 := parsed.To4(); ipv4 != nil { - return ipv4.String() - } - - return parsed.String() + return clientip.FromContext(ctx) } // getClientDevice extracts the client device info (User-Agent or custom header) from the request context. diff --git a/logging/kratos/logging_test.go b/kratos/logging/logging_test.go similarity index 96% rename from logging/kratos/logging_test.go rename to kratos/logging/logging_test.go index 9adf027..6926c63 100644 --- a/logging/kratos/logging_test.go +++ b/kratos/logging/logging_test.go @@ -158,12 +158,12 @@ func TestExtractArgs_PlainStruct(t *testing.T) { func TestExtractArgs_Nil(t *testing.T) { result := extractArgs(nil, false) - str, ok := result.(string) + raw, ok := result.(json.RawMessage) if !ok { - t.Fatalf("Expected string for nil, got %T", result) + t.Fatalf("Expected json.RawMessage for nil, got %T", result) } - if str != "" { - t.Errorf("Expected '', got: %s", str) + if string(raw) != "{}" { + t.Errorf("Expected '{}', got: %s", raw) } } @@ -240,8 +240,8 @@ func TestServer_Success(t *testing.T) { // Check log output logOutput := buf.String() - if !strings.Contains(logOutput, "server request") { - t.Errorf("Expected 'server request' in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"msg":"server"`) { + t.Errorf("Expected server log message, got: %s", logOutput) } if !strings.Contains(logOutput, "/api/v1/test") { t.Errorf("Expected operation in log, got: %s", logOutput) @@ -314,8 +314,8 @@ func TestServer_WithoutTransport(t *testing.T) { } logOutput := buf.String() - if !strings.Contains(logOutput, "server request") { - t.Errorf("Expected 'server request' in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"msg":"server"`) { + t.Errorf("Expected server log message, got: %s", logOutput) } } @@ -350,8 +350,8 @@ func TestClient_Success(t *testing.T) { } logOutput := buf.String() - if !strings.Contains(logOutput, "client request") { - t.Errorf("Expected 'client request' in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"msg":"client"`) { + t.Errorf("Expected client log message, got: %s", logOutput) } if !strings.Contains(logOutput, "/api/external/call") { t.Errorf("Expected operation in log, got: %s", logOutput) diff --git a/kratos/ratelimit/policy.go b/kratos/ratelimit/policy.go new file mode 100644 index 0000000..5c4f30c --- /dev/null +++ b/kratos/ratelimit/policy.go @@ -0,0 +1,147 @@ +package ratelimit + +import ( + "time" + + "github.com/crypto-zero/go-kit/kratos/internal/protoop" + ratelimitv1 "github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1" + coreratelimit "github.com/crypto-zero/go-kit/ratelimit" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// OperationPolicy selects rate-limit behavior for Kratos operations. +type OperationPolicy struct { + operations map[string]operationLimit + store coreratelimit.Store +} + +type operationLimit struct { + limiter Limiter + keyFunc KeyFunc + config coreratelimit.Config +} + +// OperationPolicyOption configures an OperationPolicy. +type OperationPolicyOption func(*OperationPolicy) + +// NewOperationPolicy constructs a policy from proto descriptors and manual +// operation rules. +func NewOperationPolicy(opts ...OperationPolicyOption) *OperationPolicy { + p := &OperationPolicy{operations: make(map[string]operationLimit)} + for _, opt := range opts { + opt(p) + } + return p +} + +// WithStore sets the storage backend used by operation-specific limiters. +func WithStore(store coreratelimit.Store) OperationPolicyOption { + return func(p *OperationPolicy) { + if store != nil { + p.store = store + for operation, rule := range p.operations { + p.operations[operation] = p.build(rule.config, rule.keyFunc) + } + } + } +} + +// WithOperation registers a rate limit for one Kratos operation. +func WithOperation(operation string, cfg coreratelimit.Config, keyFunc KeyFunc) OperationPolicyOption { + return func(p *OperationPolicy) { + p.register(operation, cfg, keyFunc) + } +} + +// WithRateLimitFromProtoFiles scans file descriptors for methods tagged with +// `(kit.ratelimit.v1.rate_limit)`. +func WithRateLimitFromProtoFiles(files ...protoreflect.FileDescriptor) OperationPolicyOption { + return func(p *OperationPolicy) { + for _, fd := range files { + registerRateLimitsFromFile(p, fd) + } + } +} + +func (p *OperationPolicy) lookup(operation string) (Limiter, KeyFunc, bool) { + if p == nil { + return nil, nil, false + } + rule, ok := p.operations[operation] + if !ok { + return nil, nil, false + } + return rule.limiter, rule.keyFunc, true +} + +func (p *OperationPolicy) register(operation string, cfg coreratelimit.Config, keyFunc KeyFunc) { + if operation == "" || keyFunc == nil { + return + } + rule := p.build(cfg, keyFunc) + if rule.limiter == nil { + return + } + p.operations[operation] = rule +} + +func (p *OperationPolicy) build(cfg coreratelimit.Config, keyFunc KeyFunc) operationLimit { + opts := []coreratelimit.Option(nil) + if p.store != nil { + opts = append(opts, coreratelimit.WithStore(p.store)) + } + limiter, err := coreratelimit.New(cfg, opts...) + if err != nil { + return operationLimit{} + } + return operationLimit{ + limiter: limiter, + keyFunc: keyFunc, + config: cfg, + } +} + +func registerRateLimitsFromFile(p *OperationPolicy, fd protoreflect.FileDescriptor) { + protoop.WalkMethods([]protoreflect.FileDescriptor{fd}, func(m protoreflect.MethodDescriptor) { + rule, ok := methodRateLimit(m) + if !ok { + return + } + p.register(protoop.OperationName(m), configFromProto(rule), keyFuncFromProto(rule.GetKey())) + }) +} + +func methodRateLimit(m protoreflect.MethodDescriptor) (*ratelimitv1.RateLimit, bool) { + v, ok := protoop.Extension(m, ratelimitv1.E_RateLimit) + if !ok { + return nil, false + } + rule, ok := v.(*ratelimitv1.RateLimit) + return rule, ok && rule != nil +} + +func configFromProto(rule *ratelimitv1.RateLimit) coreratelimit.Config { + return coreratelimit.Config{ + Rate: int(rule.GetRate()), + Per: durationFromProto(rule), + Burst: int(rule.GetBurst()), + } +} + +func durationFromProto(rule *ratelimitv1.RateLimit) time.Duration { + if rule.GetPer() == nil { + return 0 + } + return rule.GetPer().AsDuration() +} + +func keyFuncFromProto(key ratelimitv1.Key) KeyFunc { + switch key { + case ratelimitv1.Key_KEY_CLIENT_IP: + return ClientIPKey + case ratelimitv1.Key_KEY_OPERATION_CLIENT_IP: + return CompositeKey(OperationKey, ClientIPKey) + default: + return OperationKey + } +} diff --git a/kratos/ratelimit/ratelimit.go b/kratos/ratelimit/ratelimit.go new file mode 100644 index 0000000..c8ec68d --- /dev/null +++ b/kratos/ratelimit/ratelimit.go @@ -0,0 +1,142 @@ +// Package ratelimit provides rate-limit middleware for Kratos services. +package ratelimit + +import ( + "context" + "strconv" + "strings" + + "github.com/crypto-zero/go-kit/kratos/clientip" + "github.com/crypto-zero/go-kit/ratelimit" + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" +) + +const ( + defaultKey = "global" + reason = "RATELIMIT" +) + +// ErrLimitExceed is returned when a request exceeds its rate limit. +var ErrLimitExceed = errors.New(429, reason, "service unavailable due to rate limit exceeded") + +// Limiter is the behavior required by the middleware. +type Limiter interface { + AllowContext(context.Context, string) (ratelimit.Result, error) +} + +// KeyFunc derives a rate-limit key from a request. +type KeyFunc func(context.Context, any) string + +// Option configures server middleware. +type Option func(*options) + +type options struct { + keyFunc KeyFunc + err *errors.Error + policy *OperationPolicy +} + +// WithKeyFunc sets how requests are grouped into buckets. +func WithKeyFunc(fn KeyFunc) Option { + return func(o *options) { + if fn != nil { + o.keyFunc = fn + } + } +} + +// WithError sets the error returned when a request is rejected. +func WithError(err *errors.Error) Option { + return func(o *options) { + if err != nil { + o.err = err + } + } +} + +// WithOperationPolicy sets per-operation rate-limit rules. +func WithOperationPolicy(policy *OperationPolicy) Option { + return func(o *options) { + o.policy = policy + } +} + +// Server returns a Kratos server middleware using limiter. +func Server(limiter Limiter, opts ...Option) middleware.Middleware { + if limiter == nil { + limiter = ratelimit.NewDefault() + } + o := &options{ + keyFunc: OperationKey, + err: ErrLimitExceed, + } + for _, opt := range opts { + opt(o) + } + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + activeLimiter := limiter + keyFunc := o.keyFunc + if opLimiter, opKeyFunc, ok := o.policy.lookup(OperationKey(ctx, req)); ok { + activeLimiter = opLimiter + keyFunc = opKeyFunc + } + key := keyFunc(ctx, req) + res, err := activeLimiter.AllowContext(ctx, key) + if err != nil { + return nil, err + } + if !res.Allowed { + return nil, o.err.WithMetadata(retryMetadata(res)) + } + return handler(ctx, req) + } + } +} + +// OperationKey groups requests by Kratos operation. +func OperationKey(ctx context.Context, _ any) string { + if tr, ok := transport.FromServerContext(ctx); ok && tr.Operation() != "" { + return tr.Operation() + } + return defaultKey +} + +// ClientIPKey groups requests by client IP address. +func ClientIPKey(ctx context.Context, _ any) string { + if ip := clientip.FromContext(ctx); ip != "" { + return ip + } + return defaultKey +} + +// CompositeKey joins multiple key functions into one key. +func CompositeKey(fns ...KeyFunc) KeyFunc { + return func(ctx context.Context, req any) string { + parts := make([]string, 0, len(fns)) + for _, fn := range fns { + if fn == nil { + continue + } + if part := fn(ctx, req); part != "" { + parts = append(parts, part) + } + } + if len(parts) == 0 { + return defaultKey + } + return strings.Join(parts, ":") + } +} + +func retryMetadata(res ratelimit.Result) map[string]string { + md := map[string]string{ + "remaining": strconv.Itoa(res.Remaining), + } + if res.RetryAfter > 0 { + md["retry_after"] = strconv.FormatFloat(res.RetryAfter.Seconds(), 'f', 3, 64) + } + return md +} diff --git a/kratos/ratelimit/ratelimit_test.go b/kratos/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..9b7a2c0 --- /dev/null +++ b/kratos/ratelimit/ratelimit_test.go @@ -0,0 +1,183 @@ +package ratelimit + +import ( + "context" + "errors" + "testing" + "time" + + ratelimitv1 "github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1" + "github.com/crypto-zero/go-kit/ratelimit" + kratoserrors "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/transport" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/known/durationpb" +) + +type mockTransport struct { + operation string +} + +func (m *mockTransport) Kind() transport.Kind { return transport.KindHTTP } +func (m *mockTransport) Endpoint() string { return "localhost:8000" } +func (m *mockTransport) Operation() string { return m.operation } +func (m *mockTransport) RequestHeader() transport.Header { + return &mockHeader{} +} +func (m *mockTransport) ReplyHeader() transport.Header { + return &mockHeader{} +} + +type mockHeader struct{} + +func (m *mockHeader) Get(string) string { return "" } +func (m *mockHeader) Set(string, string) {} +func (m *mockHeader) Add(string, string) {} +func (m *mockHeader) Keys() []string { return nil } +func (m *mockHeader) Values(string) []string { return nil } + +func TestServerRejectsWhenLimitExceeded(t *testing.T) { + limiter, err := ratelimit.New(ratelimit.Config{Rate: 1, Per: time.Second, Burst: 1}) + if err != nil { + t.Fatalf("New: %v", err) + } + wrapped := Server(limiter)(func(context.Context, any) (any, error) { + return "ok", nil + }) + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/Test"}) + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + _, err = wrapped(ctx, nil) + if err == nil { + t.Fatal("second request error = nil, want rate-limit error") + } + if !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed", err) + } + se := kratoserrors.FromError(err) + if se.Code != 429 || se.Reason != reason || se.Metadata["retry_after"] == "" { + t.Fatalf("kratos error = %+v, want 429 RATELIMIT with retry_after", se) + } +} + +func TestServerUsesOperationKeyByDefault(t *testing.T) { + limiter, err := ratelimit.New(ratelimit.Config{Rate: 1, Per: time.Second, Burst: 1}) + if err != nil { + t.Fatalf("New: %v", err) + } + wrapped := Server(limiter)(func(context.Context, any) (any, error) { + return "ok", nil + }) + + ctxA := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/A"}) + ctxB := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/B"}) + if _, err := wrapped(ctxA, nil); err != nil { + t.Fatalf("operation A error = %v, want nil", err) + } + if _, err := wrapped(ctxB, nil); err != nil { + t.Fatalf("operation B error = %v, want nil because it has a separate bucket", err) + } +} + +func TestServerUsesCustomKeyFunc(t *testing.T) { + limiter, err := ratelimit.New(ratelimit.Config{Rate: 1, Per: time.Second, Burst: 1}) + if err != nil { + t.Fatalf("New: %v", err) + } + wrapped := Server(limiter, WithKeyFunc(func(context.Context, any) string { + return "tenant-1" + }))(func(context.Context, any) (any, error) { + return "ok", nil + }) + + if _, err := wrapped(context.Background(), nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + if _, err := wrapped(context.Background(), nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed", err) + } +} + +func TestCompositeKey(t *testing.T) { + key := CompositeKey( + func(context.Context, any) string { return "/svc/A" }, + func(context.Context, any) string { return "127.0.0.1" }, + )(context.Background(), nil) + + if key != "/svc/A:127.0.0.1" { + t.Fatalf("CompositeKey = %q, want joined key", key) + } +} + +func TestServerUsesProtoOperationPolicy(t *testing.T) { + defaultLimiter, err := ratelimit.New(ratelimit.Config{Rate: 100, Per: time.Second, Burst: 100}) + if err != nil { + t.Fatalf("New default limiter: %v", err) + } + policy := NewOperationPolicy(WithRateLimitFromProtoFiles(rateLimitFile(t))) + wrapped := Server(defaultLimiter, WithOperationPolicy(policy))(func(context.Context, any) (any, error) { + return "ok", nil + }) + fastCtx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/test.limit.v1.LimitService/Fast"}) + slowCtx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/test.limit.v1.LimitService/Slow"}) + + if _, err := wrapped(fastCtx, nil); err != nil { + t.Fatalf("fast first request error = %v, want nil", err) + } + if _, err := wrapped(fastCtx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("fast second request error = %v, want ErrLimitExceed from proto policy", err) + } + for i := 0; i < 3; i++ { + if _, err := wrapped(slowCtx, nil); err != nil { + t.Fatalf("slow request %d error = %v, want default limiter", i+1, err) + } + } +} + +func rateLimitFile(t *testing.T) protoreflect.FileDescriptor { + t.Helper() + + rateLimitOpts := &descriptorpb.MethodOptions{} + proto.SetExtension(rateLimitOpts, ratelimitv1.E_RateLimit, &ratelimitv1.RateLimit{ + Rate: 1, + Per: durationpb.New(time.Second), + Burst: 1, + Key: ratelimitv1.Key_KEY_OPERATION, + }) + fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ + Syntax: proto.String("proto3"), + Name: proto.String("test/limit/v1/service.proto"), + Package: proto.String("test.limit.v1"), + Service: []*descriptorpb.ServiceDescriptorProto{{ + Name: proto.String("LimitService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Fast"), + InputType: proto.String(".test.limit.v1.FastRequest"), + OutputType: proto.String(".test.limit.v1.FastResponse"), + Options: rateLimitOpts, + }, + { + Name: proto.String("Slow"), + InputType: proto.String(".test.limit.v1.SlowRequest"), + OutputType: proto.String(".test.limit.v1.SlowResponse"), + }, + }, + }}, + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("FastRequest")}, + {Name: proto.String("FastResponse")}, + {Name: proto.String("SlowRequest")}, + {Name: proto.String("SlowResponse")}, + }, + }, nil) + if err != nil { + t.Fatalf("NewFile: %v", err) + } + return fd +} diff --git a/kratos/sse/doc.go b/kratos/sse/doc.go new file mode 100644 index 0000000..a2a12a9 --- /dev/null +++ b/kratos/sse/doc.go @@ -0,0 +1,2 @@ +// Package sse mounts Server-Sent Events endpoints on a Kratos HTTP server. +package sse diff --git a/sse/kratos/http_client.go b/kratos/sse/http_client.go similarity index 98% rename from sse/kratos/http_client.go rename to kratos/sse/http_client.go index af7b22c..8c425b0 100644 --- a/sse/kratos/http_client.go +++ b/kratos/sse/http_client.go @@ -1,4 +1,4 @@ -package kratos +package sse import ( "bytes" @@ -27,7 +27,7 @@ type HTTPClientOption func(*HTTPClient) func WithHTTPClient(client *http.Client) HTTPClientOption { return func(c *HTTPClient) { if client == nil { - panic("sse/kratos: nil HTTP client") + panic("kratos/sse: nil HTTP client") } c.client = client } diff --git a/sse/kratos/http_client_test.go b/kratos/sse/http_client_test.go similarity index 97% rename from sse/kratos/http_client_test.go rename to kratos/sse/http_client_test.go index f40d90e..1e525dc 100644 --- a/sse/kratos/http_client_test.go +++ b/kratos/sse/http_client_test.go @@ -1,4 +1,4 @@ -package kratos_test +package sse_test import ( "context" @@ -8,8 +8,8 @@ import ( "strings" "testing" + ssekratos "github.com/crypto-zero/go-kit/kratos/sse" ksse "github.com/crypto-zero/go-kit/sse" - ssekratos "github.com/crypto-zero/go-kit/sse/kratos" ) func TestHTTPClientOpen(t *testing.T) { diff --git a/sse/kratos/http_handler.go b/kratos/sse/http_handler.go similarity index 99% rename from sse/kratos/http_handler.go rename to kratos/sse/http_handler.go index 5496b8b..7097712 100644 --- a/sse/kratos/http_handler.go +++ b/kratos/sse/http_handler.go @@ -1,4 +1,4 @@ -package kratos +package sse import ( "context" diff --git a/sse/kratos/http_handler_test.go b/kratos/sse/http_handler_test.go similarity index 98% rename from sse/kratos/http_handler_test.go rename to kratos/sse/http_handler_test.go index 7ac8bf5..cf32b90 100644 --- a/sse/kratos/http_handler_test.go +++ b/kratos/sse/http_handler_test.go @@ -1,4 +1,4 @@ -package kratos_test +package sse_test import ( "bufio" @@ -14,8 +14,8 @@ import ( khttp "github.com/go-kratos/kratos/v2/transport/http" "google.golang.org/protobuf/types/known/durationpb" + ksse "github.com/crypto-zero/go-kit/kratos/sse" "github.com/crypto-zero/go-kit/sse" - ksse "github.com/crypto-zero/go-kit/sse/kratos" ) func TestHTTPStreamHandler_BindsProtoQueryAndStreamsOnKratosHTTP(t *testing.T) { diff --git a/logging/kratos/go.mod b/logging/kratos/go.mod deleted file mode 100644 index 39223f1..0000000 --- a/logging/kratos/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/crypto-zero/go-kit/logging/kratos - -go 1.26.3 - -require ( - github.com/go-kratos/kratos/v2 v2.9.2 - google.golang.org/grpc v1.78.0 - google.golang.org/protobuf v1.36.11 -) - -require ( - github.com/go-kratos/aegis v0.2.0 // indirect - github.com/go-playground/form/v4 v4.2.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/mux v1.8.1 // indirect - github.com/kr/text v0.2.0 // indirect - golang.org/x/sys v0.38.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/logging/kratos/go.sum b/logging/kratos/go.sum deleted file mode 100644 index 0229ab0..0000000 --- a/logging/kratos/go.sum +++ /dev/null @@ -1,48 +0,0 @@ -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-kratos/aegis v0.2.0 h1:dObzCDWn3XVjUkgxyBp6ZeWtx/do0DPZ7LY3yNSJLUQ= -github.com/go-kratos/aegis v0.2.0/go.mod h1:v0R2m73WgEEYB3XYu6aE2WcMwsZkJ/Rzuf5eVccm7bI= -github.com/go-kratos/kratos/v2 v2.9.2 h1:px8GJQBeLpquDKQWQ9zohEWiLA8n4D/pv7aH3asvUvo= -github.com/go-kratos/kratos/v2 v2.9.2/go.mod h1:Jc7jaeYd4RAPjetun2C+oFAOO7HNMHTT/Z4LxpuEDJM= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic= -github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda h1:i/Q+bfisr7gq6feoJnS/DlpdwEL4ihp41fvRiM3Ork0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= -google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= -google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= -google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= -google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/proto/buf.gen.yaml b/proto/buf.gen.yaml index 44e7c8b..e9a6ce4 100644 --- a/proto/buf.gen.yaml +++ b/proto/buf.gen.yaml @@ -6,13 +6,13 @@ plugins: - local: - go - run - - ../errors/protoc-gen-kit-errors/main.go + - ../cmd/protoc-gen-kit-errors/main.go out: . exclude_types: - kit.errors.v1 - local: - go - run - - ../logging/protoc-gen-go-redact/main.go + - ../cmd/protoc-gen-go-redact/main.go out: . opt: paths=source_relative diff --git a/proto/kit/ratelimit/v1/ratelimit.pb.go b/proto/kit/ratelimit/v1/ratelimit.pb.go new file mode 100644 index 0000000..6a1906a --- /dev/null +++ b/proto/kit/ratelimit/v1/ratelimit.pb.go @@ -0,0 +1,243 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v7.34.1 +// source: proto/kit/ratelimit/v1/ratelimit.proto + +package ratelimitv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + durationpb "google.golang.org/protobuf/types/known/durationpb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Key int32 + +const ( + Key_KEY_UNSPECIFIED Key = 0 + Key_KEY_OPERATION Key = 1 + Key_KEY_CLIENT_IP Key = 2 + Key_KEY_OPERATION_CLIENT_IP Key = 3 +) + +// Enum value maps for Key. +var ( + Key_name = map[int32]string{ + 0: "KEY_UNSPECIFIED", + 1: "KEY_OPERATION", + 2: "KEY_CLIENT_IP", + 3: "KEY_OPERATION_CLIENT_IP", + } + Key_value = map[string]int32{ + "KEY_UNSPECIFIED": 0, + "KEY_OPERATION": 1, + "KEY_CLIENT_IP": 2, + "KEY_OPERATION_CLIENT_IP": 3, + } +) + +func (x Key) Enum() *Key { + p := new(Key) + *p = x + return p +} + +func (x Key) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Key) Descriptor() protoreflect.EnumDescriptor { + return file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes[0].Descriptor() +} + +func (Key) Type() protoreflect.EnumType { + return &file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes[0] +} + +func (x Key) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Key.Descriptor instead. +func (Key) EnumDescriptor() ([]byte, []int) { + return file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescGZIP(), []int{0} +} + +type RateLimit struct { + state protoimpl.MessageState `protogen:"open.v1"` + // rate is the number of tokens replenished every per duration. + Rate int32 `protobuf:"varint,1,opt,name=rate,proto3" json:"rate,omitempty"` + // per is the refill window for rate tokens. + Per *durationpb.Duration `protobuf:"bytes,2,opt,name=per,proto3" json:"per,omitempty"` + // burst is the maximum number of tokens a key can accumulate. + Burst int32 `protobuf:"varint,3,opt,name=burst,proto3" json:"burst,omitempty"` + // key selects how requests are grouped into buckets. + Key Key `protobuf:"varint,4,opt,name=key,proto3,enum=kit.ratelimit.v1.Key" json:"key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RateLimit) Reset() { + *x = RateLimit{} + mi := &file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RateLimit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RateLimit) ProtoMessage() {} + +func (x *RateLimit) ProtoReflect() protoreflect.Message { + mi := &file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RateLimit.ProtoReflect.Descriptor instead. +func (*RateLimit) Descriptor() ([]byte, []int) { + return file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescGZIP(), []int{0} +} + +func (x *RateLimit) GetRate() int32 { + if x != nil { + return x.Rate + } + return 0 +} + +func (x *RateLimit) GetPer() *durationpb.Duration { + if x != nil { + return x.Per + } + return nil +} + +func (x *RateLimit) GetBurst() int32 { + if x != nil { + return x.Burst + } + return 0 +} + +func (x *RateLimit) GetKey() Key { + if x != nil { + return x.Key + } + return Key_KEY_UNSPECIFIED +} + +var file_proto_kit_ratelimit_v1_ratelimit_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*RateLimit)(nil), + Field: 106001, + Name: "kit.ratelimit.v1.rate_limit", + Tag: "bytes,106001,opt,name=rate_limit", + Filename: "proto/kit/ratelimit/v1/ratelimit.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // rate_limit configures per-method rate limiting for Kratos middleware. + // + // optional kit.ratelimit.v1.RateLimit rate_limit = 106001; + E_RateLimit = &file_proto_kit_ratelimit_v1_ratelimit_proto_extTypes[0] +) + +var File_proto_kit_ratelimit_v1_ratelimit_proto protoreflect.FileDescriptor + +const file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc = "" + + "\n" + + "&proto/kit/ratelimit/v1/ratelimit.proto\x12\x10kit.ratelimit.v1\x1a google/protobuf/descriptor.proto\x1a\x1egoogle/protobuf/duration.proto\"\x8b\x01\n" + + "\tRateLimit\x12\x12\n" + + "\x04rate\x18\x01 \x01(\x05R\x04rate\x12+\n" + + "\x03per\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\x03per\x12\x14\n" + + "\x05burst\x18\x03 \x01(\x05R\x05burst\x12'\n" + + "\x03key\x18\x04 \x01(\x0e2\x15.kit.ratelimit.v1.KeyR\x03key*]\n" + + "\x03Key\x12\x13\n" + + "\x0fKEY_UNSPECIFIED\x10\x00\x12\x11\n" + + "\rKEY_OPERATION\x10\x01\x12\x11\n" + + "\rKEY_CLIENT_IP\x10\x02\x12\x1b\n" + + "\x17KEY_OPERATION_CLIENT_IP\x10\x03:\\\n" + + "\n" + + "rate_limit\x12\x1e.google.protobuf.MethodOptions\x18\x91\xbc\x06 \x01(\v2\x1b.kit.ratelimit.v1.RateLimitR\trateLimitBBZ@github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1;ratelimitv1b\x06proto3" + +var ( + file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescOnce sync.Once + file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescData []byte +) + +func file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescGZIP() []byte { + file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescOnce.Do(func() { + file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc), len(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc))) + }) + return file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescData +} + +var file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_proto_kit_ratelimit_v1_ratelimit_proto_goTypes = []any{ + (Key)(0), // 0: kit.ratelimit.v1.Key + (*RateLimit)(nil), // 1: kit.ratelimit.v1.RateLimit + (*durationpb.Duration)(nil), // 2: google.protobuf.Duration + (*descriptorpb.MethodOptions)(nil), // 3: google.protobuf.MethodOptions +} +var file_proto_kit_ratelimit_v1_ratelimit_proto_depIdxs = []int32{ + 2, // 0: kit.ratelimit.v1.RateLimit.per:type_name -> google.protobuf.Duration + 0, // 1: kit.ratelimit.v1.RateLimit.key:type_name -> kit.ratelimit.v1.Key + 3, // 2: kit.ratelimit.v1.rate_limit:extendee -> google.protobuf.MethodOptions + 1, // 3: kit.ratelimit.v1.rate_limit:type_name -> kit.ratelimit.v1.RateLimit + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 3, // [3:4] is the sub-list for extension type_name + 2, // [2:3] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_proto_kit_ratelimit_v1_ratelimit_proto_init() } +func file_proto_kit_ratelimit_v1_ratelimit_proto_init() { + if File_proto_kit_ratelimit_v1_ratelimit_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc), len(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc)), + NumEnums: 1, + NumMessages: 1, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_proto_kit_ratelimit_v1_ratelimit_proto_goTypes, + DependencyIndexes: file_proto_kit_ratelimit_v1_ratelimit_proto_depIdxs, + EnumInfos: file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes, + MessageInfos: file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes, + ExtensionInfos: file_proto_kit_ratelimit_v1_ratelimit_proto_extTypes, + }.Build() + File_proto_kit_ratelimit_v1_ratelimit_proto = out.File + file_proto_kit_ratelimit_v1_ratelimit_proto_goTypes = nil + file_proto_kit_ratelimit_v1_ratelimit_proto_depIdxs = nil +} diff --git a/proto/kit/ratelimit/v1/ratelimit.proto b/proto/kit/ratelimit/v1/ratelimit.proto new file mode 100644 index 0000000..8b33eef --- /dev/null +++ b/proto/kit/ratelimit/v1/ratelimit.proto @@ -0,0 +1,34 @@ +syntax = "proto3"; + +package kit.ratelimit.v1; + +import "google/protobuf/descriptor.proto"; +import "google/protobuf/duration.proto"; + +option go_package = "github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1;ratelimitv1"; + +extend google.protobuf.MethodOptions { + // rate_limit configures per-method rate limiting for Kratos middleware. + RateLimit rate_limit = 106001; +} + +message RateLimit { + // rate is the number of tokens replenished every per duration. + int32 rate = 1; + + // per is the refill window for rate tokens. + google.protobuf.Duration per = 2; + + // burst is the maximum number of tokens a key can accumulate. + int32 burst = 3; + + // key selects how requests are grouped into buckets. + Key key = 4; +} + +enum Key { + KEY_UNSPECIFIED = 0; + KEY_OPERATION = 1; + KEY_CLIENT_IP = 2; + KEY_OPERATION_CLIENT_IP = 3; +} diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 0000000..a1472bf --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,258 @@ +// Package ratelimit provides an in-memory token-bucket rate limiter. +package ratelimit + +import ( + "context" + "errors" + "math" + "sync" + "time" +) + +const defaultKey = "global" + +var ( + // ErrInvalidConfig reports a limiter configuration that cannot be applied. + ErrInvalidConfig = errors.New("invalid ratelimit config") + + // DefaultConfig is suitable for service-level protection. + DefaultConfig = Config{ + Rate: 100, + Per: time.Minute, + Burst: 100, + MaxKeys: 10000, + } +) + +// Limit describes the token-bucket parameters used for one Allow call. +type Limit struct { + // Rate is the number of tokens replenished every Per duration. + Rate int + // Per is the refill window for Rate tokens. + Per time.Duration + // Burst is the maximum number of tokens a key can accumulate. + Burst int +} + +// Config controls token-bucket behavior. +type Config struct { + // Rate is the number of tokens replenished every Per duration. + Rate int + // Per is the refill window for Rate tokens. + Per time.Duration + // Burst is the maximum number of tokens a key can accumulate. + Burst int + // MaxKeys bounds the number of tracked buckets. Zero means unbounded. + MaxKeys int +} + +// Result describes the outcome of an Allow call. +type Result struct { + Allowed bool + Remaining int + RetryAfter time.Duration +} + +// Store persists token-bucket state. +// +// Implementations must apply the operation atomically for key. Distributed +// stores such as Redis should use ctx for cancellation and deadlines. +type Store interface { + Take(ctx context.Context, key string, now time.Time, limit Limit, n int) (Result, error) + Len(ctx context.Context) (int, error) +} + +// Option configures a Limiter. +type Option func(*Limiter) + +// WithNow sets the clock used by the limiter. +func WithNow(now func() time.Time) Option { + return func(l *Limiter) { + if now != nil { + l.now = now + } + } +} + +// WithStore sets the storage backend used by the limiter. +func WithStore(store Store) Option { + return func(l *Limiter) { + if store != nil { + l.store = store + } + } +} + +// Limiter applies token-bucket rate limits independently per key. +type Limiter struct { + limit Limit + store Store + now func() time.Time +} + +// New constructs a Limiter. +func New(cfg Config, opts ...Option) (*Limiter, error) { + if cfg.Rate <= 0 || cfg.Per <= 0 || cfg.Burst <= 0 || cfg.MaxKeys < 0 { + return nil, ErrInvalidConfig + } + l := &Limiter{ + limit: Limit{ + Rate: cfg.Rate, + Per: cfg.Per, + Burst: cfg.Burst, + }, + store: NewMemoryStore(cfg.MaxKeys), + now: time.Now, + } + for _, opt := range opts { + opt(l) + } + return l, nil +} + +// NewDefault constructs a Limiter with DefaultConfig. +func NewDefault(opts ...Option) *Limiter { + l, err := New(DefaultConfig, opts...) + if err != nil { + panic(err) + } + return l +} + +// Allow consumes one token for key if capacity is available. +func (l *Limiter) Allow(key string) Result { + return l.AllowN(key, 1) +} + +// AllowContext consumes one token for key if capacity is available. +func (l *Limiter) AllowContext(ctx context.Context, key string) (Result, error) { + return l.AllowNContext(ctx, key, 1) +} + +// AllowN consumes n tokens for key if capacity is available. +func (l *Limiter) AllowN(key string, n int) Result { + res, _ := l.AllowNContext(context.Background(), key, n) + return res +} + +// AllowNContext consumes n tokens for key if capacity is available. +func (l *Limiter) AllowNContext(ctx context.Context, key string, n int) (Result, error) { + if n <= 0 { + return Result{Allowed: true}, nil + } + if n > l.limit.Burst { + return Result{RetryAfter: retryAfter(l.limit, float64(n))}, nil + } + if key == "" { + key = defaultKey + } + return l.store.Take(ctx, key, l.now(), l.limit, n) +} + +// Len returns the number of tracked buckets when the store supports counting. +func (l *Limiter) Len() int { + n, _ := l.LenContext(context.Background()) + return n +} + +// LenContext returns the number of tracked buckets when the store supports counting. +func (l *Limiter) LenContext(ctx context.Context) (int, error) { + return l.store.Len(ctx) +} + +// MemoryStore stores token buckets in memory. +type MemoryStore struct { + mu sync.Mutex + buckets map[string]*bucket + maxKeys int +} + +type bucket struct { + tokens float64 + seen time.Time +} + +// NewMemoryStore constructs an in-memory Store. +func NewMemoryStore(maxKeys int) *MemoryStore { + return &MemoryStore{ + buckets: make(map[string]*bucket), + maxKeys: maxKeys, + } +} + +// Take consumes n tokens from key if capacity is available. +func (s *MemoryStore) Take(_ context.Context, key string, now time.Time, limit Limit, n int) (Result, error) { + s.mu.Lock() + defer s.mu.Unlock() + + b := s.bucketFor(key, now, limit.Burst) + refill(b, now, limit) + need := float64(n) + if b.tokens < need { + return Result{ + Remaining: int(math.Floor(b.tokens)), + RetryAfter: retryAfter(limit, need-b.tokens), + }, nil + } + b.tokens -= need + return Result{ + Allowed: true, + Remaining: int(math.Floor(b.tokens)), + }, nil +} + +// Len returns the number of tracked buckets. +func (s *MemoryStore) Len(_ context.Context) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.buckets), nil +} + +func (s *MemoryStore) bucketFor(key string, now time.Time, burst int) *bucket { + if b, ok := s.buckets[key]; ok { + return b + } + if s.maxKeys > 0 && len(s.buckets) >= s.maxKeys { + s.evictOldest() + } + b := &bucket{ + tokens: float64(burst), + seen: now, + } + s.buckets[key] = b + return b +} + +func refill(b *bucket, now time.Time, limit Limit) { + elapsed := now.Sub(b.seen) + if elapsed <= 0 { + b.seen = now + return + } + rate := float64(limit.Rate) / limit.Per.Seconds() + b.tokens = math.Min(float64(limit.Burst), b.tokens+elapsed.Seconds()*rate) + b.seen = now +} + +func retryAfter(limit Limit, tokens float64) time.Duration { + perToken := time.Duration(float64(limit.Per) / float64(limit.Rate)) + d := time.Duration(math.Ceil(float64(perToken) * tokens)) + if d < 0 { + return 0 + } + return d +} + +func (s *MemoryStore) evictOldest() { + var ( + oldestKey string + oldest time.Time + ) + for key, b := range s.buckets { + if oldestKey == "" || b.seen.Before(oldest) { + oldestKey = key + oldest = b.seen + } + } + delete(s.buckets, oldestKey) +} diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..90e53f2 --- /dev/null +++ b/ratelimit/ratelimit_test.go @@ -0,0 +1,132 @@ +package ratelimit + +import ( + "context" + "errors" + "testing" + "time" +) + +type recordingStore struct { + key string + n int + limit Limit +} + +func (s *recordingStore) Take(_ context.Context, key string, _ time.Time, limit Limit, n int) (Result, error) { + s.key = key + s.n = n + s.limit = limit + return Result{Allowed: true, Remaining: limit.Burst - n}, nil +} + +func (s *recordingStore) Len(context.Context) (int, error) { + return 0, nil +} + +func TestLimiterAllowsBurstThenRejects(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New(Config{Rate: 2, Per: time.Second, Burst: 2}, WithNow(func() time.Time { + return now + })) + if err != nil { + t.Fatalf("New: %v", err) + } + + if res := limiter.Allow("user-1"); !res.Allowed || res.Remaining != 1 { + t.Fatalf("first request = %+v, want allowed with one remaining", res) + } + if res := limiter.Allow("user-1"); !res.Allowed || res.Remaining != 0 { + t.Fatalf("second request = %+v, want allowed with zero remaining", res) + } + if res := limiter.Allow("user-1"); res.Allowed || res.RetryAfter != 500*time.Millisecond { + t.Fatalf("third request = %+v, want rejected with 500ms retry", res) + } +} + +func TestLimiterRefillsByElapsedTime(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New(Config{Rate: 2, Per: time.Second, Burst: 2}, WithNow(func() time.Time { + return now + })) + if err != nil { + t.Fatalf("New: %v", err) + } + + limiter.Allow("user-1") + limiter.Allow("user-1") + now = now.Add(500 * time.Millisecond) + + if res := limiter.Allow("user-1"); !res.Allowed || res.Remaining != 0 { + t.Fatalf("refilled request = %+v, want allowed with zero remaining", res) + } +} + +func TestLimiterSeparatesKeys(t *testing.T) { + limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}) + if err != nil { + t.Fatalf("New: %v", err) + } + + if res := limiter.Allow("user-1"); !res.Allowed { + t.Fatalf("user-1 first request = %+v, want allowed", res) + } + if res := limiter.Allow("user-1"); res.Allowed { + t.Fatalf("user-1 second request = %+v, want rejected", res) + } + if res := limiter.Allow("user-2"); !res.Allowed { + t.Fatalf("user-2 first request = %+v, want allowed", res) + } +} + +func TestLimiterEvictsOldestWhenMaxKeysReached(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1, MaxKeys: 2}, WithNow(func() time.Time { + return now + })) + if err != nil { + t.Fatalf("New: %v", err) + } + + limiter.Allow("a") + now = now.Add(time.Millisecond) + limiter.Allow("b") + now = now.Add(time.Millisecond) + limiter.Allow("c") + + if got := limiter.Len(); got != 2 { + t.Fatalf("Len() = %d, want 2", got) + } + if res := limiter.Allow("a"); !res.Allowed { + t.Fatalf("a should have been evicted and recreated with full burst, got %+v", res) + } +} + +func TestLimiterRejectsInvalidConfig(t *testing.T) { + _, err := New(Config{Rate: 0, Per: time.Second, Burst: 1}) + if !errors.Is(err, ErrInvalidConfig) { + t.Fatalf("New error = %v, want ErrInvalidConfig", err) + } +} + +func TestLimiterUsesStore(t *testing.T) { + store := &recordingStore{} + limiter, err := New(Config{Rate: 5, Per: time.Second, Burst: 10}, WithStore(store)) + if err != nil { + t.Fatalf("New: %v", err) + } + + res, err := limiter.AllowNContext(context.Background(), "tenant-1", 3) + if err != nil { + t.Fatalf("AllowNContext: %v", err) + } + if !res.Allowed || res.Remaining != 7 { + t.Fatalf("result = %+v, want allowed with 7 remaining", res) + } + if store.key != "tenant-1" || store.n != 3 { + t.Fatalf("store saw key=%q n=%d, want tenant-1 and 3", store.key, store.n) + } + if store.limit.Rate != 5 || store.limit.Per != time.Second || store.limit.Burst != 10 { + t.Fatalf("store limit = %+v, want configured limit", store.limit) + } +} diff --git a/sse/kratos/doc.go b/sse/kratos/doc.go deleted file mode 100644 index 9b38ee9..0000000 --- a/sse/kratos/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package kratos mounts Server-Sent Events endpoints on a Kratos HTTP server. -package kratos diff --git a/sse/sse.go b/sse/sse.go index 4846b08..4f03d72 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -6,7 +6,7 @@ // caller's responsibility. // // For helpers that mount SSE endpoints on a Kratos HTTP server, see the -// sub-package github.com/crypto-zero/go-kit/sse/kratos. +// module package github.com/crypto-zero/go-kit/kratos/sse. package sse import ( From 88b3a71c09e69513bf18b733548af4e9b07df6a8 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Thu, 28 May 2026 17:15:38 +0800 Subject: [PATCH 10/11] refactor redis ratelimit store --- .gitignore | 3 +- Taskfile.yml | 56 +++ proto/buf.lock => buf.lock | 0 proto/buf.yaml => buf.yaml | 6 +- errors/client.go | 4 +- errors/errors.go | 26 +- errors/errors_test.go | 24 ++ errors/status.go | 6 +- go.work | 8 + kratos/go.mod | 5 + kratos/go.sum | 8 + kratos/ratelimit/example/README.md | 101 ++++++ kratos/ratelimit/example/main.go | 269 ++++++++++++++ kratos/ratelimit/example/main_test.go | 62 ++++ kratos/ratelimit/inmem_store_test.go | 126 +++++++ kratos/ratelimit/limiter.go | 133 +++++++ kratos/ratelimit/limiter_test.go | 153 ++++++++ kratos/ratelimit/policy.go | 271 +++++++++----- kratos/ratelimit/ratelimit.go | 198 ++++++++--- kratos/ratelimit/ratelimit_test.go | 473 ++++++++++++++++++++----- kratos/ratelimit/redis/store.go | 339 ++++++++++++++++++ kratos/ratelimit/redis/store_test.go | 299 ++++++++++++++++ kubernetes/election/election.go | 84 +++-- kubernetes/election/election_test.go | 27 ++ kubernetes/kubernetes.go | 9 +- lifecycle/loop_runner.go | 206 +++++++++++ lifecycle/loop_runner_test.go | 153 ++++++++ maxmind/maxmind.go | 34 +- maxmind/maxmind_test.go | 11 +- otel/resouce.go | 8 +- otel/trace_provider.go | 44 ++- pgx/pgx.go | 47 +-- pgx/pgx_test.go | 44 ++- pprof/pprof.go | 42 ++- proto/kit/ratelimit/v1/ratelimit.pb.go | 243 ------------- proto/kit/ratelimit/v1/ratelimit.proto | 34 -- query/convert.go | 2 +- query/paging.go | 6 +- ratelimit/ratelimit.go | 258 -------------- ratelimit/ratelimit_test.go | 132 ------- s3/event.go | 18 +- s3/go.sum | 22 -- s3/s3.go | 157 +++++--- s3/s3_minio_test.go | 7 + snowflake/snowflake.go | 9 +- text/text.go | 25 +- 46 files changed, 3025 insertions(+), 1167 deletions(-) create mode 100644 Taskfile.yml rename proto/buf.lock => buf.lock (100%) rename proto/buf.yaml => buf.yaml (70%) create mode 100644 go.work create mode 100644 kratos/ratelimit/example/README.md create mode 100644 kratos/ratelimit/example/main.go create mode 100644 kratos/ratelimit/example/main_test.go create mode 100644 kratos/ratelimit/inmem_store_test.go create mode 100644 kratos/ratelimit/limiter.go create mode 100644 kratos/ratelimit/limiter_test.go create mode 100644 kratos/ratelimit/redis/store.go create mode 100644 kratos/ratelimit/redis/store_test.go create mode 100644 kubernetes/election/election_test.go create mode 100644 lifecycle/loop_runner.go create mode 100644 lifecycle/loop_runner_test.go delete mode 100644 proto/kit/ratelimit/v1/ratelimit.pb.go delete mode 100644 proto/kit/ratelimit/v1/ratelimit.proto delete mode 100644 ratelimit/ratelimit.go delete mode 100644 ratelimit/ratelimit_test.go diff --git a/.gitignore b/.gitignore index 06ef959..a432141 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,7 @@ # Dependency directories (remove the comment below to include it) vendor/ -# Go workspace file -go.work +# Go workspace checksum file go.work.sum # Compiled Object files, Static and Dynamic libs (Shared Objects) diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..fcb1764 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,56 @@ +version: "3" + +vars: + GO_MODULES: + sh: find . -name go.mod -not -path './.git/*' -exec dirname {} \; | sort + +tasks: + default: + desc: Run all repository verification checks. + cmds: + - task: all + + all: + desc: Run all repository verification checks. + deps: + - test + - vet + - proto + + test: + desc: Run tests for all Go modules. + cmds: + - for: { var: GO_MODULES } + task: test:module + vars: + MODULE: "{{.ITEM}}" + + test:module: + internal: true + dir: "{{.MODULE}}" + cmds: + - go test ./... + + vet: + desc: Run go vet for all Go modules. + cmds: + - for: { var: GO_MODULES } + task: vet:module + vars: + MODULE: "{{.ITEM}}" + + vet:module: + internal: true + dir: "{{.MODULE}}" + cmds: + - go vet ./... + + proto: + desc: Lint proto files with buf. + cmds: + - buf lint + + verify: + desc: Run all repository verification checks. + cmds: + - task: all diff --git a/proto/buf.lock b/buf.lock similarity index 100% rename from proto/buf.lock rename to buf.lock diff --git a/proto/buf.yaml b/buf.yaml similarity index 70% rename from proto/buf.yaml rename to buf.yaml index c453ea7..442e551 100644 --- a/proto/buf.yaml +++ b/buf.yaml @@ -1,5 +1,7 @@ version: v2 -name: buf.build/crypto-zero/go-kit +modules: + - path: proto + name: buf.build/crypto-zero/go-kit deps: - buf.build/googleapis/googleapis:main lint: @@ -10,4 +12,4 @@ lint: breaking: except: - EXTENSION_NO_DELETE - - FIELD_SAME_DEFAULT \ No newline at end of file + - FIELD_SAME_DEFAULT diff --git a/errors/client.go b/errors/client.go index f2b30c5..c70678a 100644 --- a/errors/client.go +++ b/errors/client.go @@ -22,10 +22,10 @@ func HttpServerErrorEncoder( _, _ = w.Write(body) } -// RPCHandler is a rpc handler for grpc/http client. +// RPCHandler is an RPC handler for gRPC/HTTP clients. type RPCHandler func(ctx context.Context, req any) (any, error) -// RPCClientErrorParser is a rpc client error parser. +// RPCClientErrorParser converts client errors into this package's Error type. func RPCClientErrorParser(handler RPCHandler) RPCHandler { return func(ctx context.Context, req any) (any, error) { reply, err := handler(ctx, req) diff --git a/errors/errors.go b/errors/errors.go index 50c7e50..d51bda7 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -18,9 +18,10 @@ import ( type PBError = pberrors.Error +// Error is the concrete error type used by this package. type Error PBError -// Error return text message and http status code. +// Error returns a text representation of e. func (e *Error) Error() string { return fmt.Sprintf("error: code = %d reason = %s message = %s", e.Status, e.Info.Reason, e.Message) } @@ -33,7 +34,7 @@ func (e *Error) Is(err error) bool { return false } -// GRPCStatus returns the Status represented by se. +// GRPCStatus returns the gRPC status represented by e. func (e *Error) GRPCStatus() *status.Status { s := &spb.Status{Code: int32(ToGRPCCode(int(e.Status))), Message: e.Message} if codes.Code(s.Code) == codes.OK { @@ -48,13 +49,13 @@ func (e *Error) GRPCStatus() *status.Status { return status.FromProto(s) } -// MarshalJSON marshals se to JSON. +// MarshalJSON marshals e to JSON. func (e *Error) MarshalJSON() ([]byte, error) { pbErr := (*PBError)(e) return protojson.Marshal(pbErr) } -// Clone returns a deep copy of se. +// Clone returns a deep copy of e. func (e *Error) Clone() *Error { if e == nil { return nil @@ -64,14 +65,17 @@ func (e *Error) Clone() *Error { return (*Error)(pbErr) } -// SetMetadata set metadata for error info. +// SetMetadata returns a copy of e with an ErrorInfo metadata value set. func (e *Error) SetMetadata(key, value string) *Error { copied := e.Clone() + if copied.Info.Metadata == nil { + copied.Info.Metadata = make(map[string]string) + } copied.Info.Metadata[key] = value return copied } -// SetCause set cause for error info. +// SetCause returns a copy of e with err recorded as ErrorInfo metadata. func (e *Error) SetCause(err error) *Error { if err == nil { return e @@ -79,7 +83,7 @@ func (e *Error) SetCause(err error) *Error { return e.SetMetadata("cause", err.Error()) } -// SetDomainAndCode set domain and code for info without clone. +// SetDomainAndCode sets the ErrorInfo domain and numeric code in place. func (e *Error) SetDomainAndCode(domain string, code int) *Error { if e.Info.Metadata == nil { e.Info.Metadata = make(map[string]string) @@ -96,7 +100,7 @@ const ( UnknownReason = "" ) -// New returns an error object for the code, message. +// New returns an error object for code, reason, and message. func New(code int, reason, message string) *Error { return &Error{ Status: int32(code), @@ -105,17 +109,17 @@ func New(code int, reason, message string) *Error { } } -// Newf New(code fmt.Sprintf(format, a...)) +// Newf returns an error object with a formatted message. func Newf(code int, reason, format string, a ...any) *Error { return New(code, reason, fmt.Sprintf(format, a...)) } -// Errorf returns an error object for the code, message and error info. +// Errorf returns an error object for code, reason, and a formatted message. func Errorf(code int, reason, format string, a ...any) error { return New(code, reason, fmt.Sprintf(format, a...)) } -// Code returns the http code for an error. +// Code returns the HTTP status code for an error. // It supports wrapped errors. func Code(err error) int { if err == nil { diff --git a/errors/errors_test.go b/errors/errors_test.go index f80d746..dd18086 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -1,6 +1,7 @@ package errors import ( + stderrors "errors" "reflect" "testing" @@ -54,3 +55,26 @@ func TestError_Clone(t *testing.T) { }) } } + +func TestError_SetMetadataInitializesMap(t *testing.T) { + err := New(400, "bad_request", "bad request") + + got := err.SetMetadata("key", "value") + + if got.Info.Metadata["key"] != "value" { + t.Fatalf("SetMetadata() metadata = %v; want key=value", got.Info.Metadata) + } + if err.Info.Metadata != nil { + t.Fatalf("SetMetadata() mutated original metadata = %v; want nil", err.Info.Metadata) + } +} + +func TestError_SetCauseInitializesMap(t *testing.T) { + err := New(400, "bad_request", "bad request") + + got := err.SetCause(stderrors.New("root cause")) + + if got.Info.Metadata["cause"] != "root cause" { + t.Fatalf("SetCause() metadata = %v; want cause=root cause", got.Info.Metadata) + } +} diff --git a/errors/status.go b/errors/status.go index a9d91c7..e43bd4c 100644 --- a/errors/status.go +++ b/errors/status.go @@ -7,8 +7,8 @@ import ( ) const ( - // HttpCodeClientClosed is non-standard http status code, - // which defined by nginx. + // HttpCodeClientClosed is the non-standard HTTP status code used by nginx + // when a client closes the request. // https://httpstatus.in/499/ HttpCodeClientClosed = 499 ) @@ -24,7 +24,7 @@ type Converter interface { type statusConverter struct{} -// DefaultConverter default converter. +// DefaultConverter is the default status converter. var DefaultConverter Converter = statusConverter{} // ToGRPCCode converts an HTTP error code into the corresponding gRPC response status. diff --git a/go.work b/go.work new file mode 100644 index 0000000..7e30e87 --- /dev/null +++ b/go.work @@ -0,0 +1,8 @@ +go 1.26.3 + +use ( + . + ./kratos + ./kubernetes/election + ./s3 +) diff --git a/kratos/go.mod b/kratos/go.mod index a4e4483..6c12dbd 100644 --- a/kratos/go.mod +++ b/kratos/go.mod @@ -2,21 +2,26 @@ module github.com/crypto-zero/go-kit/kratos go 1.26.3 +replace github.com/crypto-zero/go-kit => .. + require ( github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac github.com/go-kratos/kratos/v2 v2.9.2 + github.com/redis/go-redis/v9 v9.19.0 google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 ) require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/go-kratos/aegis v0.2.0 // indirect github.com/go-playground/form/v4 v4.2.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/sys v0.39.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/kratos/go.sum b/kratos/go.sum index c5daace..e380f83 100644 --- a/kratos/go.sum +++ b/kratos/go.sum @@ -1,4 +1,8 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac h1:mAUwaMQHRclziH/e4xbuNEHyGrr5DBalMqFtw8wu2uw= +github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac/go.mod h1:2V6ihklJZoXNZzlmeX+sk5RMYXcwJwq5kIrubxbsUjo= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-kratos/aegis v0.2.0 h1:dObzCDWn3XVjUkgxyBp6ZeWtx/do0DPZ7LY3yNSJLUQ= @@ -23,10 +27,14 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= diff --git a/kratos/ratelimit/example/README.md b/kratos/ratelimit/example/README.md new file mode 100644 index 0000000..8d78eaf --- /dev/null +++ b/kratos/ratelimit/example/README.md @@ -0,0 +1,101 @@ +# Kratos Rate Limit Example + +This example shows operation-level rate limiting driven only by external config. + +Kratos already provides the current operation name at request time. The +middleware uses that operation as the namespace, then applies the configured +business key parts: + +- `CreateOrder`: limit by `user_id`, and separately by `client_ip` +- `GetOrder`: limit by `user_id` + +In a real Kratos service, define this config shape in that service's own +`internal/conf/conf.proto`, then load it from YAML and convert it to +`ratelimit.OperationRules`. The config proto should live with the service +because operations and business dimensions are service-owned. + +```proto +message Bootstrap { + Server server = 1; + Data data = 2; + RateLimit rate_limit = 3; +} + +message RateLimit { + map operations = 1; + + message Operation { + repeated Rule rules = 1; + } + + message Rule { + repeated string key_parts = 1; + int32 rate = 2; + google.protobuf.Duration per = 3; + int32 burst = 4; + } +} +``` + +```yaml +rate_limit: + operations: + /ratelimit.example.order.v1.OrderService/CreateOrder: + rules: + - key_parts: [user_id] + rate: 10 + per: 60s + burst: 10 + - key_parts: [client_ip] + rate: 30 + per: 60s + burst: 30 + /ratelimit.example.order.v1.OrderService/GetOrder: + rules: + - key_parts: [user_id] + rate: 100 + per: 60s + burst: 100 +``` + +Convert the generated service config before wiring the middleware: + +```go +operationRules, err := convertRateLimitConfig(conf.RateLimit) +if err != nil { + return err +} +``` + +The server wires external rules into the middleware: + +```go +mw, err := ratelimit.Server( + ratelimit.WithRuleStore(store), + ratelimit.WithOperationRules(operationRules), + ratelimit.WithClientIPKeyFunc(ratelimit.ClientIPKey), + ratelimit.WithUserKeyFunc(userIDFromHeader), +) +if err != nil { + return err +} +``` + +Run Redis locally, then start the example: + +```bash +go run ./kratos/ratelimit/example +``` + +Try the limited endpoints: + +```bash +curl -H 'X-User-ID: user_123' \ + -H 'X-Real-IP: 127.0.0.1' \ + -H 'Content-Type: application/json' \ + -d '{"sku":"book"}' \ + http://127.0.0.1:8000/v1/orders + +curl -H 'X-User-ID: user_123' \ + http://127.0.0.1:8000/v1/orders/order_for_book +``` diff --git a/kratos/ratelimit/example/main.go b/kratos/ratelimit/example/main.go new file mode 100644 index 0000000..9a92955 --- /dev/null +++ b/kratos/ratelimit/example/main.go @@ -0,0 +1,269 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + redisstore "github.com/crypto-zero/go-kit/kratos/ratelimit/redis" + khttp "github.com/go-kratos/kratos/v2/transport/http" + goredis "github.com/redis/go-redis/v9" + "google.golang.org/protobuf/types/known/durationpb" +) + +const ( + createOrderOperation = "/ratelimit.example.order.v1.OrderService/CreateOrder" + getOrderOperation = "/ratelimit.example.order.v1.OrderService/GetOrder" +) + +type config struct { + HTTPAddr string + RedisAddr string + RedisPrefix string + + RateLimit *rateLimitConfig +} + +// These config structs mirror the shape a business service would define in +// internal/conf/conf.proto and receive from Kratos config loading. +type rateLimitConfig struct { + Operations map[string]*rateLimitOperation +} + +type rateLimitOperation struct { + Rules []*rateLimitRule +} + +type rateLimitRule struct { + KeyParts []string + Rate int32 + Per *durationpb.Duration + Burst int32 +} + +func main() { + cfg := config{ + HTTPAddr: ":8000", + RedisAddr: "127.0.0.1:6379", + RedisPrefix: "ratelimit-example:order-api", + RateLimit: &rateLimitConfig{ + Operations: map[string]*rateLimitOperation{ + createOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"user_id"}, + Rate: 10, + Per: durationpb.New(time.Minute), + Burst: 10, + }, + { + KeyParts: []string{"client_ip"}, + Rate: 30, + Per: durationpb.New(time.Minute), + Burst: 30, + }, + }, + }, + getOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"user_id"}, + Rate: 100, + Per: durationpb.New(time.Minute), + Burst: 100, + }, + }, + }, + }, + }, + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := run(ctx, cfg); err != nil { + log.Fatal(err) + } +} + +func run(ctx context.Context, cfg config) error { + redisClient := goredis.NewClient(&goredis.Options{Addr: cfg.RedisAddr}) + defer redisClient.Close() + + store, err := redisstore.NewStore(redisClient, cfg.RedisPrefix) + if err != nil { + return err + } + operationRules, err := convertRateLimitConfig(cfg.RateLimit) + if err != nil { + return err + } + mw, err := ratelimit.Server( + ratelimit.WithRuleStore(store), + ratelimit.WithOperationRules(operationRules), + ratelimit.WithClientIPKeyFunc(ratelimit.ClientIPKey), + ratelimit.WithUserKeyFunc(userIDFromHeader), + ) + if err != nil { + return err + } + srv := khttp.NewServer( + khttp.Address(cfg.HTTPAddr), + khttp.Middleware(mw), + ) + registerOrderHTTPServer(srv) + + errc := make(chan error, 1) + go func() { errc <- srv.Start(ctx) }() + + select { + case <-ctx.Done(): + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return srv.Stop(stopCtx) + case err := <-errc: + return err + } +} + +func convertRateLimitConfig(cfg *rateLimitConfig) (ratelimit.OperationRules, error) { + if cfg == nil || len(cfg.Operations) == 0 { + return nil, nil + } + rules := make(ratelimit.OperationRules, len(cfg.Operations)) + for operation, op := range cfg.Operations { + if operation == "" { + return nil, fmt.Errorf("rate_limit operation must not be empty") + } + if op == nil || len(op.Rules) == 0 { + return nil, fmt.Errorf("%s: rate_limit rules must not be empty", operation) + } + for i, rule := range op.Rules { + converted, err := convertRateLimitRule(rule) + if err != nil { + return nil, fmt.Errorf("%s rules[%d]: %w", operation, i, err) + } + rules[operation] = append(rules[operation], converted) + } + } + return rules, nil +} + +func convertRateLimitRule(rule *rateLimitRule) (ratelimit.RuleConfig, error) { + if rule == nil { + return ratelimit.RuleConfig{}, fmt.Errorf("rule must not be nil") + } + if rule.Rate <= 0 || rule.Per == nil || rule.Per.AsDuration() <= 0 || rule.Burst <= 0 { + return ratelimit.RuleConfig{}, fmt.Errorf("rate, per, and burst must be explicitly positive") + } + keyParts, err := convertKeyParts(rule.KeyParts) + if err != nil { + return ratelimit.RuleConfig{}, err + } + return ratelimit.RuleConfig{ + Config: ratelimit.Config{ + Rate: int(rule.Rate), + Per: rule.Per.AsDuration(), + Burst: int(rule.Burst), + }, + KeyParts: keyParts, + }, nil +} + +func convertKeyParts(parts []string) ([]ratelimit.KeyPart, error) { + if len(parts) == 0 { + return nil, fmt.Errorf("key_parts must not be empty") + } + keyParts := make([]ratelimit.KeyPart, 0, len(parts)) + for _, part := range parts { + kp, err := ratelimit.ParseKeyPart(part) + if err != nil { + return nil, err + } + keyParts = append(keyParts, kp) + } + return keyParts, nil +} + +func userIDFromHeader(ctx context.Context, _ any) string { + if req, ok := khttp.RequestFromServerContext(ctx); ok { + return req.Header.Get("X-User-ID") + } + return "" +} + +func registerOrderHTTPServer(srv *khttp.Server) { + route := srv.Route("/") + route.POST("/v1/orders", createOrder) + route.GET("/v1/orders/{order_id}", getOrder) +} + +func createOrder(ctx khttp.Context) error { + khttp.SetOperation(ctx, createOrderOperation) + + req := new(createOrderRequest) + if err := ctx.Bind(req); err != nil { + return err + } + handler := ctx.Middleware(func(context.Context, any) (any, error) { + return &createOrderResponse{ + OrderId: fmt.Sprintf("order_for_%s", req.GetSku()), + }, nil + }) + return ctx.Returns(handler(ctx, req)) +} + +func getOrder(ctx khttp.Context) error { + khttp.SetOperation(ctx, getOrderOperation) + + req := new(getOrderRequest) + if err := ctx.BindVars(req); err != nil { + return err + } + handler := ctx.Middleware(func(context.Context, any) (any, error) { + return &getOrderResponse{ + OrderId: req.GetOrderId(), + Status: "created", + }, nil + }) + return ctx.Returns(handler(ctx, req)) +} + +var _ http.Handler = (*khttp.Server)(nil) + +type createOrderRequest struct { + Sku string `json:"sku"` +} + +func (r *createOrderRequest) GetSku() string { + if r == nil { + return "" + } + return r.Sku +} + +type createOrderResponse struct { + OrderId string `json:"order_id"` +} + +type getOrderRequest struct { + OrderId string `json:"order_id" form:"order_id"` +} + +func (r *getOrderRequest) GetOrderId() string { + if r == nil { + return "" + } + return r.OrderId +} + +type getOrderResponse struct { + OrderId string `json:"order_id"` + Status string `json:"status"` +} diff --git a/kratos/ratelimit/example/main_test.go b/kratos/ratelimit/example/main_test.go new file mode 100644 index 0000000..8b88746 --- /dev/null +++ b/kratos/ratelimit/example/main_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "testing" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + "google.golang.org/protobuf/types/known/durationpb" +) + +func TestConvertRateLimitConfig(t *testing.T) { + rules, err := convertRateLimitConfig(&rateLimitConfig{ + Operations: map[string]*rateLimitOperation{ + createOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"user_id", "client_ip"}, + Rate: 10, + Per: durationpb.New(time.Minute), + Burst: 20, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("convertRateLimitConfig error = %v, want nil", err) + } + if got := len(rules[createOrderOperation]); got != 1 { + t.Fatalf("len(rules[%q]) = %d, want 1", createOrderOperation, got) + } + rule := rules[createOrderOperation][0] + if rule.Config.Rate != 10 || rule.Config.Per != time.Minute || rule.Config.Burst != 20 { + t.Fatalf("rule.Config = %+v, want rate 10 per 1m burst 20", rule.Config) + } + wantKeyParts := []ratelimit.KeyPart{ratelimit.KeyPartUserID, ratelimit.KeyPartClientIP} + for i, want := range wantKeyParts { + if rule.KeyParts[i] != want { + t.Fatalf("rule.KeyParts[%d] = %q, want %q", i, rule.KeyParts[i], want) + } + } +} + +func TestConvertRateLimitConfigRejectsUnknownKeyPart(t *testing.T) { + _, err := convertRateLimitConfig(&rateLimitConfig{ + Operations: map[string]*rateLimitOperation{ + createOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"device_id"}, + Rate: 10, + Per: durationpb.New(time.Minute), + Burst: 20, + }, + }, + }, + }, + }) + if err == nil { + t.Fatal("convertRateLimitConfig error = nil, want error") + } +} diff --git a/kratos/ratelimit/inmem_store_test.go b/kratos/ratelimit/inmem_store_test.go new file mode 100644 index 0000000..b822fa3 --- /dev/null +++ b/kratos/ratelimit/inmem_store_test.go @@ -0,0 +1,126 @@ +package ratelimit + +import ( + "context" + "fmt" + "math" + "sync" + "time" +) + +// inMemStore is the token-bucket reference implementation shared by tests. +// It is the only Store fixture in the package; Take and TakeMany honor the +// atomicity guarantees in the Store interface doc. +type inMemStore struct { + mu sync.Mutex + buckets map[string]*inMemBucket +} + +type inMemBucket struct { + tokens float64 + seen time.Time +} + +func newInMemStore() *inMemStore { + return &inMemStore{buckets: make(map[string]*inMemBucket)} +} + +func (s *inMemStore) Take(ctx context.Context, key string, now time.Time, limit Limit, n int) (Result, error) { + results, err := s.TakeMany(ctx, []string{key}, now, []Limit{limit}, n) + if err != nil { + return Result{}, err + } + return results[0], nil +} + +func (s *inMemStore) TakeMany(_ context.Context, keys []string, now time.Time, limits []Limit, n int) ([]Result, error) { + if len(keys) != len(limits) { + return nil, fmt.Errorf("keys/limits length mismatch") + } + if n <= 0 { + results := make([]Result, len(keys)) + for i := range results { + results[i] = Result{Allowed: true} + } + return results, nil + } + for i, key := range keys { + if key == "" { + return nil, ErrMissingKey + } + if err := limits[i].Validate(); err != nil { + return nil, err + } + } + s.mu.Lock() + defer s.mu.Unlock() + + tokens := make([]float64, len(keys)) + buckets := make([]*inMemBucket, len(keys)) + for i, key := range keys { + b := s.bucketFor(key, now, limits[i].Burst) + refillInMemBucket(b, now, limits[i]) + buckets[i] = b + tokens[i] = b.tokens + } + + need := float64(n) + allow := true + for i := range buckets { + if need > float64(limits[i].Burst) || tokens[i] < need { + allow = false + break + } + } + + results := make([]Result, len(keys)) + for i, b := range buckets { + if allow { + b.tokens -= need + results[i] = Result{Allowed: true, Remaining: int(math.Floor(b.tokens))} + continue + } + var retry time.Duration + if need > float64(limits[i].Burst) { + retry = retryAfter(limits[i], need-float64(limits[i].Burst)) + } else if tokens[i] < need { + retry = retryAfter(limits[i], need-tokens[i]) + } + results[i] = Result{ + Remaining: int(math.Floor(tokens[i])), + RetryAfter: retry, + } + } + return results, nil +} + +func (s *inMemStore) bucketFor(key string, now time.Time, burst int) *inMemBucket { + if b, ok := s.buckets[key]; ok { + return b + } + b := &inMemBucket{ + tokens: float64(burst), + seen: now, + } + s.buckets[key] = b + return b +} + +func refillInMemBucket(b *inMemBucket, now time.Time, limit Limit) { + elapsed := now.Sub(b.seen) + if elapsed <= 0 { + return + } + per := time.Duration(inMemDurationMillis(limit.Per)) * time.Millisecond + rate := float64(limit.Rate) / per.Seconds() + b.tokens = math.Min(float64(limit.Burst), b.tokens+elapsed.Seconds()*rate) + b.seen = now +} + +func inMemDurationMillis(d time.Duration) int64 { + ms := d.Milliseconds() + if ms <= 0 { + return 1 + } + return ms +} diff --git a/kratos/ratelimit/limiter.go b/kratos/ratelimit/limiter.go new file mode 100644 index 0000000..1d0da13 --- /dev/null +++ b/kratos/ratelimit/limiter.go @@ -0,0 +1,133 @@ +package ratelimit + +import ( + "context" + "errors" + "math" + "time" +) + +var ( + // ErrInvalidConfig reports a limiter configuration that cannot be applied. + ErrInvalidConfig = errors.New("invalid ratelimit config") + // ErrMissingKey reports a request without a rate-limit key. + ErrMissingKey = errors.New("missing ratelimit key") + // ErrMissingStore reports a limiter constructed without a storage backend. + ErrMissingStore = errors.New("missing ratelimit store") +) + +// Config controls token-bucket behavior. It is also the limit type Store +// implementations consume; the Limit alias below preserves the name used on +// the Store interface. +type Config struct { + // Rate is the number of tokens replenished every Per duration. + Rate int + // Per is the refill window for Rate tokens. + Per time.Duration + // Burst is the maximum number of tokens a key can accumulate. + Burst int +} + +// Validate reports whether c carries usable token-bucket parameters. +func (c Config) Validate() error { + if c.Rate <= 0 || c.Per < time.Millisecond || c.Burst <= 0 { + return ErrInvalidConfig + } + return nil +} + +// Limit is the configuration carried into the Store. It is an alias of Config +// to keep the Store signature readable without introducing a second type. +type Limit = Config + +// Result describes the outcome of an Allow call. +type Result struct { + Allowed bool + Remaining int + RetryAfter time.Duration +} + +// Store persists token-bucket state. +// +// Implementations must be safe for concurrent use. Take must apply atomically +// to one key. TakeMany must atomically check every key and either consume n +// tokens from all of them or from none of them, returning one Result per key +// in input order. +type Store interface { + Take(ctx context.Context, key string, now time.Time, limit Limit, n int) (Result, error) + TakeMany(ctx context.Context, keys []string, now time.Time, limits []Limit, n int) ([]Result, error) +} + +// LimiterOption configures a Limiter. +type LimiterOption func(*Limiter) + +// WithNow sets the clock used by the limiter. +func WithNow(now func() time.Time) LimiterOption { + return func(l *Limiter) { + if now != nil { + l.now = now + } + } +} + +// WithStore sets the storage backend used by the limiter. +func WithStore(store Store) LimiterOption { + return func(l *Limiter) { + if store != nil { + l.store = store + } + } +} + +// Limiter applies token-bucket rate limits independently per key. +type Limiter struct { + limit Limit + store Store + now func() time.Time +} + +// New constructs a Limiter. +func New(cfg Config, opts ...LimiterOption) (*Limiter, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + l := &Limiter{ + limit: cfg, + now: time.Now, + } + for _, opt := range opts { + opt(l) + } + if l.store == nil { + return nil, ErrMissingStore + } + return l, nil +} + +// AllowContext consumes one token for key if capacity is available. +func (l *Limiter) AllowContext(ctx context.Context, key string) (Result, error) { + return l.AllowNContext(ctx, key, 1) +} + +// AllowNContext consumes n tokens for key if capacity is available. +func (l *Limiter) AllowNContext(ctx context.Context, key string, n int) (Result, error) { + if n <= 0 { + return Result{Allowed: true}, nil + } + if key == "" { + return Result{}, ErrMissingKey + } + if n > l.limit.Burst { + return Result{RetryAfter: retryAfter(l.limit, float64(n))}, nil + } + return l.store.Take(ctx, key, l.now(), l.limit, n) +} + +func retryAfter(limit Limit, tokens float64) time.Duration { + perToken := time.Duration(float64(limit.Per) / float64(limit.Rate)) + d := time.Duration(math.Ceil(float64(perToken) * tokens)) + if d < 0 { + return 0 + } + return d +} diff --git a/kratos/ratelimit/limiter_test.go b/kratos/ratelimit/limiter_test.go new file mode 100644 index 0000000..608f068 --- /dev/null +++ b/kratos/ratelimit/limiter_test.go @@ -0,0 +1,153 @@ +package ratelimit + +import ( + "context" + "errors" + "testing" + "time" +) + +type recordingStore struct { + key string + keys []string + n int + limit Limit +} + +func (s *recordingStore) Take(_ context.Context, key string, _ time.Time, limit Limit, n int) (Result, error) { + s.key = key + s.n = n + s.limit = limit + return Result{Allowed: true, Remaining: limit.Burst - n}, nil +} + +func (s *recordingStore) TakeMany(_ context.Context, keys []string, _ time.Time, limits []Limit, n int) ([]Result, error) { + s.keys = append([]string(nil), keys...) + s.n = n + results := make([]Result, len(keys)) + for i, limit := range limits { + results[i] = Result{Allowed: true, Remaining: limit.Burst - n} + } + return results, nil +} + +func TestLimiterAllowsBurstThenRejects(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New( + Config{Rate: 2, Per: time.Second, Burst: 2}, + WithStore(newInMemStore()), + WithNow(func() time.Time { return now }), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed || res.Remaining != 1 { + t.Fatalf("first request = %+v, want allowed with one remaining", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed || res.Remaining != 0 { + t.Fatalf("second request = %+v, want allowed with zero remaining", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || res.Allowed || res.RetryAfter != 500*time.Millisecond { + t.Fatalf("third request = %+v, want rejected with 500ms retry", res) + } +} + +func TestLimiterRefillsByElapsedTime(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New( + Config{Rate: 2, Per: time.Second, Burst: 2}, + WithStore(newInMemStore()), + WithNow(func() time.Time { return now }), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + if _, err := limiter.AllowContext(context.Background(), "user-1"); err != nil { + t.Fatalf("first request: %v", err) + } + if _, err := limiter.AllowContext(context.Background(), "user-1"); err != nil { + t.Fatalf("second request: %v", err) + } + now = now.Add(500 * time.Millisecond) + + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed || res.Remaining != 0 { + t.Fatalf("refilled request = %+v, want allowed with zero remaining", res) + } +} + +func TestLimiterSeparatesKeys(t *testing.T) { + limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}, WithStore(newInMemStore())) + if err != nil { + t.Fatalf("New: %v", err) + } + + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed { + t.Fatalf("user-1 first request = %+v, want allowed", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || res.Allowed { + t.Fatalf("user-1 second request = %+v, want rejected", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-2"); err != nil || !res.Allowed { + t.Fatalf("user-2 first request = %+v, want allowed", res) + } +} + +func TestLimiterRejectsInvalidConfig(t *testing.T) { + _, err := New(Config{Rate: 0, Per: time.Second, Burst: 1}, WithStore(newInMemStore())) + if !errors.Is(err, ErrInvalidConfig) { + t.Fatalf("New error = %v, want ErrInvalidConfig", err) + } + + _, err = New(Config{Rate: 1, Per: time.Nanosecond, Burst: 1}, WithStore(newInMemStore())) + if !errors.Is(err, ErrInvalidConfig) { + t.Fatalf("New sub-ms config error = %v, want ErrInvalidConfig", err) + } +} + +func TestLimiterRejectsMissingStore(t *testing.T) { + _, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}) + if !errors.Is(err, ErrMissingStore) { + t.Fatalf("New error = %v, want ErrMissingStore", err) + } +} + +func TestLimiterRejectsMissingKey(t *testing.T) { + limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}, WithStore(newInMemStore())) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = limiter.AllowContext(context.Background(), "") + if !errors.Is(err, ErrMissingKey) { + t.Fatalf("AllowContext error = %v, want ErrMissingKey", err) + } + + _, err = limiter.AllowNContext(context.Background(), "", 2) + if !errors.Is(err, ErrMissingKey) { + t.Fatalf("AllowNContext over burst error = %v, want ErrMissingKey", err) + } +} + +func TestLimiterUsesStore(t *testing.T) { + store := &recordingStore{} + limiter, err := New(Config{Rate: 5, Per: time.Second, Burst: 10}, WithStore(store)) + if err != nil { + t.Fatalf("New: %v", err) + } + + res, err := limiter.AllowNContext(context.Background(), "tenant-1", 3) + if err != nil { + t.Fatalf("AllowNContext: %v", err) + } + if !res.Allowed || res.Remaining != 7 { + t.Fatalf("result = %+v, want allowed with 7 remaining", res) + } + if store.key != "tenant-1" || store.n != 3 { + t.Fatalf("store saw key=%q n=%d, want tenant-1 and 3", store.key, store.n) + } + if store.limit.Rate != 5 || store.limit.Per != time.Second || store.limit.Burst != 10 { + t.Fatalf("store limit = %+v, want configured limit", store.limit) + } +} diff --git a/kratos/ratelimit/policy.go b/kratos/ratelimit/policy.go index 5c4f30c..44146c8 100644 --- a/kratos/ratelimit/policy.go +++ b/kratos/ratelimit/policy.go @@ -1,147 +1,240 @@ package ratelimit import ( + "context" + "fmt" + "strings" "time" +) + +// KeyPart identifies one business dimension used to build a rate-limit key. +type KeyPart string - "github.com/crypto-zero/go-kit/kratos/internal/protoop" - ratelimitv1 "github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1" - coreratelimit "github.com/crypto-zero/go-kit/ratelimit" - "google.golang.org/protobuf/reflect/protoreflect" +const ( + KeyPartClientIP KeyPart = "client_ip" + KeyPartUserID KeyPart = "user_id" ) +// ParseKeyPart converts a canonical key-part name into a KeyPart. Callers that +// load rules from external config can use it to avoid re-implementing the +// allow-list of supported dimensions. +func ParseKeyPart(s string) (KeyPart, error) { + switch KeyPart(s) { + case KeyPartClientIP, KeyPartUserID: + return KeyPart(s), nil + default: + return "", fmt.Errorf("unsupported key part %q", s) + } +} + +// RuleConfig provides the runtime limit for one operation rule. +type RuleConfig struct { + Config Config + KeyParts []KeyPart +} + +// OperationRules maps Kratos operation names to their runtime limit rules. +type OperationRules map[string][]RuleConfig + // OperationPolicy selects rate-limit behavior for Kratos operations. type OperationPolicy struct { - operations map[string]operationLimit - store coreratelimit.Store + operations map[string][]operationLimit + store Store + now func() time.Time + clientIPKeyFunc KeyFunc + userKeyFunc KeyFunc } type operationLimit struct { - limiter Limiter keyFunc KeyFunc - config coreratelimit.Config + limit Limit } // OperationPolicyOption configures an OperationPolicy. type OperationPolicyOption func(*OperationPolicy) -// NewOperationPolicy constructs a policy from proto descriptors and manual -// operation rules. -func NewOperationPolicy(opts ...OperationPolicyOption) *OperationPolicy { - p := &OperationPolicy{operations: make(map[string]operationLimit)} - for _, opt := range opts { - opt(p) +// WithPolicyUserKeyFunc sets how operation policies extract the business user id. +func WithPolicyUserKeyFunc(fn KeyFunc) OperationPolicyOption { + return func(p *OperationPolicy) { + p.userKeyFunc = fn } - return p } -// WithStore sets the storage backend used by operation-specific limiters. -func WithStore(store coreratelimit.Store) OperationPolicyOption { +// WithPolicyClientIPKeyFunc sets how operation policies extract the client IP. +func WithPolicyClientIPKeyFunc(fn KeyFunc) OperationPolicyOption { return func(p *OperationPolicy) { - if store != nil { - p.store = store - for operation, rule := range p.operations { - p.operations[operation] = p.build(rule.config, rule.keyFunc) - } - } + p.clientIPKeyFunc = fn } } -// WithOperation registers a rate limit for one Kratos operation. -func WithOperation(operation string, cfg coreratelimit.Config, keyFunc KeyFunc) OperationPolicyOption { +// WithPolicyNow overrides the clock used to stamp Store calls. Mostly useful in tests. +func WithPolicyNow(now func() time.Time) OperationPolicyOption { return func(p *OperationPolicy) { - p.register(operation, cfg, keyFunc) + if now != nil { + p.now = now + } } } -// WithRateLimitFromProtoFiles scans file descriptors for methods tagged with -// `(kit.ratelimit.v1.rate_limit)`. -func WithRateLimitFromProtoFiles(files ...protoreflect.FileDescriptor) OperationPolicyOption { - return func(p *OperationPolicy) { - for _, fd := range files { - registerRateLimitsFromFile(p, fd) +// NewOperationPolicy constructs a policy from operation rules. All construction +// errors are returned eagerly; the resulting policy never silently disables +// limiting at runtime. +func NewOperationPolicy( + store Store, + rules OperationRules, + opts ...OperationPolicyOption, +) (*OperationPolicy, error) { + if store == nil { + return nil, ErrMissingStore + } + if len(rules) == 0 { + return nil, ErrMissingRules + } + p := &OperationPolicy{ + operations: make(map[string][]operationLimit, len(rules)), + store: store, + now: time.Now, + } + for _, opt := range opts { + opt(p) + } + for operation, configs := range rules { + if operation == "" { + return nil, fmt.Errorf("ratelimit operation must not be empty") + } + if len(configs) == 0 { + return nil, fmt.Errorf("%s: ratelimit rules must not be empty", operation) + } + seen := make(map[string]struct{}, len(configs)) + for _, cfg := range configs { + sig := keyPartsSignature(cfg.KeyParts) + label := ruleLabel(operation, sig) + if err := validateKeyParts(cfg.KeyParts); err != nil { + return nil, fmt.Errorf("%s: %w", label, err) + } + if err := cfg.Config.Validate(); err != nil { + return nil, fmt.Errorf("%s: %w", label, err) + } + if _, dup := seen[sig]; dup { + return nil, fmt.Errorf("%s: duplicate ratelimit rule %s", operation, sig) + } + seen[sig] = struct{}{} + + keyFunc, err := p.keyFuncFromParts(operation, cfg.KeyParts) + if err != nil { + return nil, fmt.Errorf("%s: %w", operation, err) + } + p.operations[operation] = append(p.operations[operation], operationLimit{ + keyFunc: keyFunc, + limit: cfg.Config, + }) } } + return p, nil } -func (p *OperationPolicy) lookup(operation string) (Limiter, KeyFunc, bool) { +// allow runs every rule for operation in one atomic Store call and returns the +// per-rule Results. It returns (nil, nil) when no rule matches. +func (p *OperationPolicy) allow(ctx context.Context, operation string, req any) ([]Result, error) { if p == nil { - return nil, nil, false + return nil, nil } - rule, ok := p.operations[operation] - if !ok { - return nil, nil, false + rules := p.operations[operation] + if len(rules) == 0 { + return nil, nil } - return rule.limiter, rule.keyFunc, true + keys := make([]string, len(rules)) + limits := make([]Limit, len(rules)) + for i, rule := range rules { + key := rule.keyFunc(ctx, req) + if key == "" { + return nil, ErrMissingKey + } + keys[i] = key + limits[i] = rule.limit + } + return p.store.TakeMany(ctx, keys, p.now(), limits, 1) } -func (p *OperationPolicy) register(operation string, cfg coreratelimit.Config, keyFunc KeyFunc) { - if operation == "" || keyFunc == nil { - return - } - rule := p.build(cfg, keyFunc) - if rule.limiter == nil { - return +func (p *OperationPolicy) validate() error { + // NewOperationPolicy already performs full validation. This method catches + // zero-value policies passed through WithOperationPolicy. + if p == nil || p.store == nil || len(p.operations) == 0 { + return ErrMissingRules } - p.operations[operation] = rule + return nil } -func (p *OperationPolicy) build(cfg coreratelimit.Config, keyFunc KeyFunc) operationLimit { - opts := []coreratelimit.Option(nil) - if p.store != nil { - opts = append(opts, coreratelimit.WithStore(p.store)) - } - limiter, err := coreratelimit.New(cfg, opts...) - if err != nil { - return operationLimit{} - } - return operationLimit{ - limiter: limiter, - keyFunc: keyFunc, - config: cfg, +func (p *OperationPolicy) keyFuncFromParts(operation string, parts []KeyPart) (KeyFunc, error) { + fns := make([]KeyFunc, 0, len(parts)) + for _, part := range parts { + fn, err := p.keyFuncForPart(part) + if err != nil { + return nil, fmt.Errorf("%s: %w", keyPartsSignature(parts), err) + } + fns = append(fns, namedKeyPart(part, fn)) } + return operationScopedKey(operation, CompositeKey(fns...)), nil } -func registerRateLimitsFromFile(p *OperationPolicy, fd protoreflect.FileDescriptor) { - protoop.WalkMethods([]protoreflect.FileDescriptor{fd}, func(m protoreflect.MethodDescriptor) { - rule, ok := methodRateLimit(m) - if !ok { - return +func (p *OperationPolicy) keyFuncForPart(part KeyPart) (KeyFunc, error) { + switch part { + case KeyPartClientIP: + if p.clientIPKeyFunc == nil { + return nil, fmt.Errorf("client IP key function is required") + } + return p.clientIPKeyFunc, nil + case KeyPartUserID: + if p.userKeyFunc == nil { + return nil, fmt.Errorf("user key function is required") } - p.register(protoop.OperationName(m), configFromProto(rule), keyFuncFromProto(rule.GetKey())) - }) + return p.userKeyFunc, nil + default: + return nil, fmt.Errorf("unsupported key part %s", part) + } } -func methodRateLimit(m protoreflect.MethodDescriptor) (*ratelimitv1.RateLimit, bool) { - v, ok := protoop.Extension(m, ratelimitv1.E_RateLimit) - if !ok { - return nil, false +func namedKeyPart(part KeyPart, fn KeyFunc) KeyFunc { + return func(ctx context.Context, req any) string { + value := fn(ctx, req) + if value == "" { + return "" + } + return string(part) + ":" + escapeKeyPartValue(value) } - rule, ok := v.(*ratelimitv1.RateLimit) - return rule, ok && rule != nil } -func configFromProto(rule *ratelimitv1.RateLimit) coreratelimit.Config { - return coreratelimit.Config{ - Rate: int(rule.GetRate()), - Per: durationFromProto(rule), - Burst: int(rule.GetBurst()), +var keyPartEscaper = strings.NewReplacer("%", "%25", ":", "%3A") + +func escapeKeyPartValue(value string) string { + return keyPartEscaper.Replace(value) +} + +func validateKeyParts(parts []KeyPart) error { + if len(parts) == 0 { + return fmt.Errorf("key_parts must not be empty") } + for _, part := range parts { + switch part { + case KeyPartClientIP, KeyPartUserID: + default: + return fmt.Errorf("unsupported key part %s", part) + } + } + return nil } -func durationFromProto(rule *ratelimitv1.RateLimit) time.Duration { - if rule.GetPer() == nil { - return 0 +func keyPartsSignature(parts []KeyPart) string { + names := make([]string, 0, len(parts)) + for _, part := range parts { + names = append(names, string(part)) } - return rule.GetPer().AsDuration() + return strings.Join(names, "+") } -func keyFuncFromProto(key ratelimitv1.Key) KeyFunc { - switch key { - case ratelimitv1.Key_KEY_CLIENT_IP: - return ClientIPKey - case ratelimitv1.Key_KEY_OPERATION_CLIENT_IP: - return CompositeKey(OperationKey, ClientIPKey) - default: - return OperationKey +func ruleLabel(operation, sig string) string { + if sig == "" { + return operation } + return operation + " " + sig } diff --git a/kratos/ratelimit/ratelimit.go b/kratos/ratelimit/ratelimit.go index c8ec68d..d9b1849 100644 --- a/kratos/ratelimit/ratelimit.go +++ b/kratos/ratelimit/ratelimit.go @@ -3,52 +3,70 @@ package ratelimit import ( "context" + "errors" "strconv" "strings" "github.com/crypto-zero/go-kit/kratos/clientip" - "github.com/crypto-zero/go-kit/ratelimit" - "github.com/go-kratos/kratos/v2/errors" + kratoserrors "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" ) +// Reason is the Kratos error reason emitted when a request is rejected. +const Reason = "RATELIMIT" + +// Metadata keys carried on the Kratos error returned by Server. const ( - defaultKey = "global" - reason = "RATELIMIT" + MetadataRemaining = "remaining" + MetadataRetryAfter = "retry_after" ) // ErrLimitExceed is returned when a request exceeds its rate limit. -var ErrLimitExceed = errors.New(429, reason, "service unavailable due to rate limit exceeded") +var ErrLimitExceed = kratoserrors.New(429, Reason, "service unavailable due to rate limit exceeded") -// Limiter is the behavior required by the middleware. -type Limiter interface { - AllowContext(context.Context, string) (ratelimit.Result, error) -} +// ErrStoreUnavailable is returned when the backing store cannot evaluate a +// limit decision. +var ErrStoreUnavailable = kratoserrors.New(503, "RATELIMIT_UNAVAILABLE", "service unavailable due to rate limit store unavailable") + +// ErrPolicyConflict reports that incompatible options were combined — most +// commonly passing both WithOperationPolicy and one of WithOperationRules, +// WithRuleStore, WithUserKeyFunc, or WithClientIPKeyFunc. +var ErrPolicyConflict = errors.New("ratelimit: WithOperationPolicy conflicts with rule-building options") + +// ErrMissingRules reports rule-building options that do not include any +// operation rules. +var ErrMissingRules = errors.New("ratelimit: missing operation rules") // KeyFunc derives a rate-limit key from a request. type KeyFunc func(context.Context, any) string -// Option configures server middleware. +// Option configures Server. type Option func(*options) type options struct { - keyFunc KeyFunc - err *errors.Error - policy *OperationPolicy + err *kratoserrors.Error + policy *OperationPolicy + store Store + rules OperationRules + clientIPKeyFunc KeyFunc + userKeyFunc KeyFunc + storeSet bool + rulesSet bool + clientIPSet bool + userSet bool } -// WithKeyFunc sets how requests are grouped into buckets. -func WithKeyFunc(fn KeyFunc) Option { +// WithRuleStore sets the storage backend used to build operation rules. +func WithRuleStore(store Store) Option { return func(o *options) { - if fn != nil { - o.keyFunc = fn - } + o.store = store + o.storeSet = true } } // WithError sets the error returned when a request is rejected. -func WithError(err *errors.Error) Option { +func WithError(err *kratoserrors.Error) Option { return func(o *options) { if err != nil { o.err = err @@ -56,63 +74,119 @@ func WithError(err *errors.Error) Option { } } -// WithOperationPolicy sets per-operation rate-limit rules. +// WithOperationPolicy installs a pre-built policy. Mutually exclusive with +// WithOperationRules / WithRuleStore / WithUserKeyFunc / WithClientIPKeyFunc. func WithOperationPolicy(policy *OperationPolicy) Option { + return func(o *options) { o.policy = policy } +} + +// WithOperationRules sets per-operation rate-limit rules from external config. +func WithOperationRules(rules OperationRules) Option { return func(o *options) { - o.policy = policy + o.rules = rules + o.rulesSet = true } } -// Server returns a Kratos server middleware using limiter. -func Server(limiter Limiter, opts ...Option) middleware.Middleware { - if limiter == nil { - limiter = ratelimit.NewDefault() +// WithUserKeyFunc sets how user_id key parts are extracted from requests. +func WithUserKeyFunc(fn KeyFunc) Option { + return func(o *options) { + o.userKeyFunc = fn + o.userSet = true } - o := &options{ - keyFunc: OperationKey, - err: ErrLimitExceed, +} + +// WithClientIPKeyFunc sets how client_ip key parts are extracted from requests. +func WithClientIPKeyFunc(fn KeyFunc) Option { + return func(o *options) { + o.clientIPKeyFunc = fn + o.clientIPSet = true } +} + +// Server returns a Kratos server middleware that enforces rate limits. +// +// Construction errors are returned eagerly so callers fail-fast at startup +// instead of crashing on the first request. HTTP handlers must set a Kratos +// operation with http.SetOperation; requests without an operation are treated +// as unconfigured and are not limited. +func Server(opts ...Option) (middleware.Middleware, error) { + o := &options{err: ErrLimitExceed} for _, opt := range opts { opt(o) } + if o.policy != nil { + if o.rulesSet || o.storeSet || o.userSet || o.clientIPSet { + return nil, ErrPolicyConflict + } + if err := o.policy.validate(); err != nil { + return nil, err + } + } else if o.rulesSet { + if len(o.rules) == 0 { + return nil, ErrMissingRules + } + policy, err := NewOperationPolicy(o.store, o.rules, + WithPolicyClientIPKeyFunc(o.clientIPKeyFunc), + WithPolicyUserKeyFunc(o.userKeyFunc), + ) + if err != nil { + return nil, err + } + o.policy = policy + } else { + return nil, ErrMissingRules + } + policy := o.policy + errResp := o.err return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req any) (any, error) { - activeLimiter := limiter - keyFunc := o.keyFunc - if opLimiter, opKeyFunc, ok := o.policy.lookup(OperationKey(ctx, req)); ok { - activeLimiter = opLimiter - keyFunc = opKeyFunc - } - key := keyFunc(ctx, req) - res, err := activeLimiter.AllowContext(ctx, key) + results, err := policy.allow(ctx, OperationKey(ctx, req), req) if err != nil { - return nil, err + if errors.Is(err, ErrMissingKey) { + return nil, errResp.WithMetadata(map[string]string{ + MetadataRemaining: "0", + }).WithCause(err) + } + return nil, ErrStoreUnavailable.WithCause(err) } - if !res.Allowed { - return nil, o.err.WithMetadata(retryMetadata(res)) + if rejected, ok := rejectedResult(results); ok { + return nil, errResp.WithMetadata(retryMetadata(rejected)) } return handler(ctx, req) } + }, nil +} + +func rejectedResult(results []Result) (Result, bool) { + var rejected Result + var ok bool + for _, res := range results { + if !res.Allowed && (!ok || res.RetryAfter > rejected.RetryAfter) { + rejected = res + ok = true + } } + return rejected, ok } -// OperationKey groups requests by Kratos operation. +// OperationKey returns the Kratos operation from the server context. func OperationKey(ctx context.Context, _ any) string { - if tr, ok := transport.FromServerContext(ctx); ok && tr.Operation() != "" { + if tr, ok := transport.FromServerContext(ctx); ok { return tr.Operation() } - return defaultKey + return "" } -// ClientIPKey groups requests by client IP address. +// ClientIPKey returns the client IP from the server context. func ClientIPKey(ctx context.Context, _ any) string { - if ip := clientip.FromContext(ctx); ip != "" { - return ip - } - return defaultKey + return clientip.FromContext(ctx) } -// CompositeKey joins multiple key functions into one key. +// CompositeKey joins multiple key functions into one. If any underlying +// function returns the empty string, the composite returns the empty string — +// callers can distinguish "key fully derived" from "at least one dimension +// missing" without silently collapsing into a narrower bucket. func CompositeKey(fns ...KeyFunc) KeyFunc { return func(ctx context.Context, req any) string { parts := make([]string, 0, len(fns)) @@ -120,23 +194,35 @@ func CompositeKey(fns ...KeyFunc) KeyFunc { if fn == nil { continue } - if part := fn(ctx, req); part != "" { - parts = append(parts, part) + part := fn(ctx, req) + if part == "" { + return "" } - } - if len(parts) == 0 { - return defaultKey + parts = append(parts, part) } return strings.Join(parts, ":") } } -func retryMetadata(res ratelimit.Result) map[string]string { +func operationScopedKey(operation string, fn KeyFunc) KeyFunc { + return func(ctx context.Context, req any) string { + if fn == nil { + return operation + } + key := fn(ctx, req) + if key == "" { + return "" + } + return operation + ":" + key + } +} + +func retryMetadata(res Result) map[string]string { md := map[string]string{ - "remaining": strconv.Itoa(res.Remaining), + MetadataRemaining: strconv.Itoa(res.Remaining), } if res.RetryAfter > 0 { - md["retry_after"] = strconv.FormatFloat(res.RetryAfter.Seconds(), 'f', 3, 64) + md[MetadataRetryAfter] = strconv.FormatFloat(res.RetryAfter.Seconds(), 'f', 3, 64) } return md } diff --git a/kratos/ratelimit/ratelimit_test.go b/kratos/ratelimit/ratelimit_test.go index 9b7a2c0..8838366 100644 --- a/kratos/ratelimit/ratelimit_test.go +++ b/kratos/ratelimit/ratelimit_test.go @@ -3,33 +3,24 @@ package ratelimit import ( "context" "errors" + "fmt" "testing" "time" - ratelimitv1 "github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1" - "github.com/crypto-zero/go-kit/ratelimit" kratoserrors "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/transport" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/descriptorpb" - "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/grpc/metadata" ) type mockTransport struct { operation string } -func (m *mockTransport) Kind() transport.Kind { return transport.KindHTTP } -func (m *mockTransport) Endpoint() string { return "localhost:8000" } -func (m *mockTransport) Operation() string { return m.operation } -func (m *mockTransport) RequestHeader() transport.Header { - return &mockHeader{} -} -func (m *mockTransport) ReplyHeader() transport.Header { - return &mockHeader{} -} +func (m *mockTransport) Kind() transport.Kind { return transport.KindHTTP } +func (m *mockTransport) Endpoint() string { return "localhost:8000" } +func (m *mockTransport) Operation() string { return m.operation } +func (m *mockTransport) RequestHeader() transport.Header { return &mockHeader{} } +func (m *mockTransport) ReplyHeader() transport.Header { return &mockHeader{} } type mockHeader struct{} @@ -39,67 +30,148 @@ func (m *mockHeader) Add(string, string) {} func (m *mockHeader) Keys() []string { return nil } func (m *mockHeader) Values(string) []string { return nil } -func TestServerRejectsWhenLimitExceeded(t *testing.T) { - limiter, err := ratelimit.New(ratelimit.Config{Rate: 1, Per: time.Second, Burst: 1}) +type errorStore struct { + err error +} + +func (s errorStore) Take(context.Context, string, time.Time, Limit, int) (Result, error) { + return Result{}, s.err +} + +func (s errorStore) TakeMany(context.Context, []string, time.Time, []Limit, int) ([]Result, error) { + return nil, s.err +} + +func mustServer(t *testing.T, opts ...Option) func(context.Context, any) (any, error) { + t.Helper() + mw, err := Server(opts...) if err != nil { - t.Fatalf("New: %v", err) + t.Fatalf("Server: %v", err) } - wrapped := Server(limiter)(func(context.Context, any) (any, error) { - return "ok", nil - }) - ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/Test"}) + return mw(func(context.Context, any) (any, error) { return "ok", nil }) +} + +func TestServerRejectsWhenLimitExceeded(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/svc/Test", "192.168.1.10") if _, err := wrapped(ctx, nil); err != nil { t.Fatalf("first request error = %v, want nil", err) } - _, err = wrapped(ctx, nil) - if err == nil { - t.Fatal("second request error = nil, want rate-limit error") - } + _, err := wrapped(ctx, nil) if !errors.Is(err, ErrLimitExceed) { t.Fatalf("second request error = %v, want ErrLimitExceed", err) } se := kratoserrors.FromError(err) - if se.Code != 429 || se.Reason != reason || se.Metadata["retry_after"] == "" { + if se.Code != 429 || se.Reason != Reason || se.Metadata[MetadataRetryAfter] == "" { t.Fatalf("kratos error = %+v, want 429 RATELIMIT with retry_after", se) } } -func TestServerUsesOperationKeyByDefault(t *testing.T) { - limiter, err := ratelimit.New(ratelimit.Config{Rate: 1, Per: time.Second, Burst: 1}) - if err != nil { - t.Fatalf("New: %v", err) +func TestServerRejectsMissingRules(t *testing.T) { + if _, err := Server(); !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules", err) } - wrapped := Server(limiter)(func(context.Context, any) (any, error) { - return "ok", nil - }) +} - ctxA := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/A"}) - ctxB := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/B"}) - if _, err := wrapped(ctxA, nil); err != nil { - t.Fatalf("operation A error = %v, want nil", err) +func TestServerRejectsMissingKey(t *testing.T) { + wrapped := mustServer(t, + WithRuleStore(newInMemStore()), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/Test"}) + + _, err := wrapped(ctx, nil) + if !errors.Is(err, ErrMissingKey) || !errors.Is(err, ErrLimitExceed) { + t.Fatalf("request error = %v, want ErrMissingKey cause on ErrLimitExceed", err) } - if _, err := wrapped(ctxB, nil); err != nil { - t.Fatalf("operation B error = %v, want nil because it has a separate bucket", err) + se := kratoserrors.FromError(err) + if se.Code != 429 || se.Reason != Reason { + t.Fatalf("kratos error = %+v, want 429 RATELIMIT", se) } } -func TestServerUsesCustomKeyFunc(t *testing.T) { - limiter, err := ratelimit.New(ratelimit.Config{Rate: 1, Per: time.Second, Burst: 1}) - if err != nil { - t.Fatalf("New: %v", err) +func TestServerMapsStoreError(t *testing.T) { + storeErr := errors.New("redis unavailable") + wrapped := mustServer(t, + WithRuleStore(errorStore{err: storeErr}), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/svc/Test", "192.168.1.10") + + _, err := wrapped(ctx, nil) + if !errors.Is(err, ErrStoreUnavailable) || !errors.Is(err, storeErr) { + t.Fatalf("request error = %v, want ErrStoreUnavailable with store cause", err) } - wrapped := Server(limiter, WithKeyFunc(func(context.Context, any) string { - return "tenant-1" - }))(func(context.Context, any) (any, error) { - return "ok", nil - }) + se := kratoserrors.FromError(err) + if se.Code != 503 || se.Reason != "RATELIMIT_UNAVAILABLE" { + t.Fatalf("kratos error = %+v, want 503 RATELIMIT_UNAVAILABLE", se) + } +} - if _, err := wrapped(context.Background(), nil); err != nil { - t.Fatalf("first request error = %v, want nil", err) +func TestServerMultiPartKeyMissingDimensionRejects(t *testing.T) { + // With KeyParts=[user_id, client_ip], a missing user_id must not silently + // degrade into a client_ip-only bucket — the composite key must be empty + // and the request must be rejected with ErrMissingKey. + wrapped := mustServer(t, + WithRuleStore(newInMemStore()), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID, KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + WithUserKeyFunc(func(context.Context, any) string { return "" }), + ) + ctx := clientIPContext("/svc/Test", "192.168.1.10") + + if _, err := wrapped(ctx, nil); !errors.Is(err, ErrMissingKey) { + t.Fatalf("request error = %v, want ErrMissingKey", err) } - if _, err := wrapped(context.Background(), nil); !errors.Is(err, ErrLimitExceed) { - t.Fatalf("second request error = %v, want ErrLimitExceed", err) +} + +func TestOperationPolicyEscapesKeyPartValues(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }}, + }), + WithUserKeyFunc(func(context.Context, any) string { return "u%1:admin" }), + ) + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/Test"}) + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("request error = %v, want nil", err) + } + if _, ok := store.buckets["/svc/Test:user_id:u%251%3Aadmin"]; !ok { + t.Fatalf("store buckets = %#v, want escaped key part value", store.buckets) } } @@ -114,23 +186,41 @@ func TestCompositeKey(t *testing.T) { } } -func TestServerUsesProtoOperationPolicy(t *testing.T) { - defaultLimiter, err := ratelimit.New(ratelimit.Config{Rate: 100, Per: time.Second, Burst: 100}) +func TestCompositeKeyReturnsEmptyWhenAnyPartIsEmpty(t *testing.T) { + key := CompositeKey( + func(context.Context, any) string { return "/svc/A" }, + func(context.Context, any) string { return "" }, + )(context.Background(), nil) + + if key != "" { + t.Fatalf("CompositeKey = %q, want empty when any part is empty", key) + } +} + +func TestServerUsesOperationRules(t *testing.T) { + store := newInMemStore() + policy, err := NewOperationPolicy( + store, + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }, + WithPolicyClientIPKeyFunc(ClientIPKey), + ) if err != nil { - t.Fatalf("New default limiter: %v", err) + t.Fatalf("NewOperationPolicy: %v", err) } - policy := NewOperationPolicy(WithRateLimitFromProtoFiles(rateLimitFile(t))) - wrapped := Server(defaultLimiter, WithOperationPolicy(policy))(func(context.Context, any) (any, error) { - return "ok", nil - }) - fastCtx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/test.limit.v1.LimitService/Fast"}) + wrapped := mustServer(t, WithOperationPolicy(policy)) + fastCtx := clientIPContext("/test.limit.v1.LimitService/Fast", "192.168.1.10") slowCtx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/test.limit.v1.LimitService/Slow"}) if _, err := wrapped(fastCtx, nil); err != nil { t.Fatalf("fast first request error = %v, want nil", err) } if _, err := wrapped(fastCtx, nil); !errors.Is(err, ErrLimitExceed) { - t.Fatalf("fast second request error = %v, want ErrLimitExceed from proto policy", err) + t.Fatalf("fast second request error = %v, want ErrLimitExceed from operation policy", err) } for i := 0; i < 3; i++ { if _, err := wrapped(slowCtx, nil); err != nil { @@ -139,45 +229,238 @@ func TestServerUsesProtoOperationPolicy(t *testing.T) { } } -func rateLimitFile(t *testing.T) protoreflect.FileDescriptor { - t.Helper() +func TestOperationPolicyRejectsMissingKeyParts(t *testing.T) { + _, err := NewOperationPolicy( + newInMemStore(), + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: nil, + }}, + }, + ) + if err == nil { + t.Fatal("NewOperationPolicy error = nil, want error") + } +} - rateLimitOpts := &descriptorpb.MethodOptions{} - proto.SetExtension(rateLimitOpts, ratelimitv1.E_RateLimit, &ratelimitv1.RateLimit{ - Rate: 1, - Per: durationpb.New(time.Second), - Burst: 1, - Key: ratelimitv1.Key_KEY_OPERATION, - }) - fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ - Syntax: proto.String("proto3"), - Name: proto.String("test/limit/v1/service.proto"), - Package: proto.String("test.limit.v1"), - Service: []*descriptorpb.ServiceDescriptorProto{{ - Name: proto.String("LimitService"), - Method: []*descriptorpb.MethodDescriptorProto{ +func TestServerUsesOperationRulesOption(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/test.limit.v1.LimitService/Fast", "192.168.1.10") + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + if _, err := wrapped(ctx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed from operation rule", err) + } +} + +func TestServerAppliesMultipleOperationRules(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/test.limit.v1.LimitService/Fast": { { - Name: proto.String("Fast"), - InputType: proto.String(".test.limit.v1.FastRequest"), - OutputType: proto.String(".test.limit.v1.FastResponse"), - Options: rateLimitOpts, + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, }, { - Name: proto.String("Slow"), - InputType: proto.String(".test.limit.v1.SlowRequest"), - OutputType: proto.String(".test.limit.v1.SlowResponse"), + Config: Config{Rate: 100, Per: time.Second, Burst: 100}, + KeyParts: []KeyPart{KeyPartClientIP}, }, }, - }}, - MessageType: []*descriptorpb.DescriptorProto{ - {Name: proto.String("FastRequest")}, - {Name: proto.String("FastResponse")}, - {Name: proto.String("SlowRequest")}, - {Name: proto.String("SlowResponse")}, + }), + WithUserKeyFunc(func(context.Context, any) string { return "user-1" }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/test.limit.v1.LimitService/Fast", "192.168.1.10") + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + if _, err := wrapped(ctx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed from user rule", err) + } +} + +// Multi-rule consumption must be atomic. Rule[0] is a tight per-user limit; +// rule[1] is a generous per-IP limit. Once the user rule starts rejecting, +// the IP rule must NOT have been silently decremented by the rejected +// attempts — otherwise a different user from the same IP would be wrongly +// throttled. +func TestServerMultiRuleConsumptionIsAtomic(t *testing.T) { + store := newInMemStore() + buildServer := func(userID string) func(context.Context, any) (any, error) { + return mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/svc/Multi": { + { + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }, + { + Config: Config{Rate: 100, Per: time.Second, Burst: 100}, + KeyParts: []KeyPart{KeyPartClientIP}, + }, + }, + }), + WithUserKeyFunc(func(context.Context, any) string { return userID }), + WithClientIPKeyFunc(ClientIPKey), + ) + } + + user1 := buildServer("user-1") + ctx := clientIPContext("/svc/Multi", "192.168.1.10") + + if _, err := user1(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + for i := 0; i < 200; i++ { + if _, err := user1(ctx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("rejected request %d error = %v, want ErrLimitExceed", i+2, err) + } + } + + for i := 0; i < 99; i++ { + user := buildServer(fmt.Sprintf("user-%d", i+2)) + if _, err := user(ctx, nil); err != nil { + t.Fatalf("user-2 request %d error = %v, want IP bucket still has 99 tokens", i+1, err) + } + } +} + +func TestOperationPolicyRequiresUserKeyFunc(t *testing.T) { + _, err := NewOperationPolicy( + newInMemStore(), + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }}, + }, + ) + if err == nil { + t.Fatal("NewOperationPolicy error = nil, want missing user key func error") + } +} + +func TestOperationPolicyRequiresClientIPKeyFunc(t *testing.T) { + _, err := NewOperationPolicy( + newInMemStore(), + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, }, - }, nil) + ) + if err == nil { + t.Fatal("NewOperationPolicy error = nil, want missing client IP key func error") + } +} + +func TestOperationPolicyRequiresStore(t *testing.T) { + _, err := NewOperationPolicy(nil, OperationRules{ + "/svc/A": {{Config: Config{Rate: 1, Per: time.Second, Burst: 1}, KeyParts: []KeyPart{KeyPartClientIP}}}, + }) + if !errors.Is(err, ErrMissingStore) { + t.Fatalf("NewOperationPolicy error = %v, want ErrMissingStore", err) + } +} + +func TestOperationPolicyRejectsMissingRules(t *testing.T) { + _, err := NewOperationPolicy(newInMemStore(), nil) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("NewOperationPolicy error = %v, want ErrMissingRules", err) + } + + _, err = NewOperationPolicy(newInMemStore(), OperationRules{}) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("NewOperationPolicy error = %v, want ErrMissingRules for empty rules", err) + } +} + +func TestServerRejectsConflictingOptions(t *testing.T) { + policy, err := NewOperationPolicy(newInMemStore(), OperationRules{ + "/svc/A": {{Config: Config{Rate: 1, Per: time.Second, Burst: 1}, KeyParts: []KeyPart{KeyPartClientIP}}}, + }, WithPolicyClientIPKeyFunc(ClientIPKey)) + if err != nil { + t.Fatalf("NewOperationPolicy: %v", err) + } + + if _, err := Server(WithOperationPolicy(policy), WithClientIPKeyFunc(ClientIPKey)); !errors.Is(err, ErrPolicyConflict) { + t.Fatalf("Server error = %v, want ErrPolicyConflict when both policy and key func given", err) + } +} + +func TestServerRejectsInvalidOperationPolicy(t *testing.T) { + if _, err := Server(WithOperationPolicy(&OperationPolicy{})); !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules for zero-value policy", err) + } +} + +func TestServerRejectsRuleBuildingOptionsWithoutRules(t *testing.T) { + _, err := Server(WithRuleStore(newInMemStore())) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules", err) + } + + _, err = Server(WithRuleStore(newInMemStore()), WithOperationRules(nil)) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules for nil rules", err) + } + + _, err = Server(WithRuleStore(newInMemStore()), WithOperationRules(OperationRules{})) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules for empty rules", err) + } +} + +func TestOperationPolicyScopesClientIPByOperation(t *testing.T) { + store := newInMemStore() + rules := OperationRules{ + "/svc/A": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + "/svc/B": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + } + policy, err := NewOperationPolicy(store, rules, WithPolicyClientIPKeyFunc(ClientIPKey)) if err != nil { - t.Fatalf("NewFile: %v", err) + t.Fatalf("NewOperationPolicy: %v", err) } - return fd + wrapped := mustServer(t, WithOperationPolicy(policy)) + ctxA := clientIPContext("/svc/A", "192.168.1.10") + ctxB := clientIPContext("/svc/B", "192.168.1.10") + + if _, err := wrapped(ctxA, nil); err != nil { + t.Fatalf("operation A first request error = %v, want nil", err) + } + if _, err := wrapped(ctxA, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("operation A second request error = %v, want ErrLimitExceed", err) + } + if _, err := wrapped(ctxB, nil); err != nil { + t.Fatalf("operation B first request error = %v, want nil with operation-scoped client IP key", err) + } +} + +func clientIPContext(operation, ip string) context.Context { + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: operation}) + return metadata.NewIncomingContext(ctx, metadata.Pairs("x-real-ip", ip)) } diff --git a/kratos/ratelimit/redis/store.go b/kratos/ratelimit/redis/store.go new file mode 100644 index 0000000..5e08ac6 --- /dev/null +++ b/kratos/ratelimit/redis/store.go @@ -0,0 +1,339 @@ +// Package redis provides a Redis-backed ratelimit store. +package redis + +import ( + "context" + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + goredis "github.com/redis/go-redis/v9" +) + +const ( + minBucketTTLMillis = 1000 + scriptArgsPerKey = 4 + scriptResultWidth = 3 +) + +var ( + // ErrMissingClient reports a store constructed without a Redis client. + ErrMissingClient = errors.New("redis ratelimit: missing client") + // ErrMissingPrefix reports a store constructed without an explicit Redis key prefix. + ErrMissingPrefix = errors.New("redis ratelimit: missing prefix") + // ErrInvalidScriptResult reports an unexpected Lua script return value. + ErrInvalidScriptResult = errors.New("redis ratelimit: invalid script result") + // ErrKeyLimitMismatch reports a TakeMany call with mismatched keys/limits. + ErrKeyLimitMismatch = errors.New("redis ratelimit: keys and limits length mismatch") + // ErrKeyGroupMismatch reports keys that cannot be evaluated in one Redis slot. + ErrKeyGroupMismatch = errors.New("redis ratelimit: keys must share a group") +) + +// Store persists rate-limit buckets in Redis. +// +// The script uses Redis-side TIME, so it tolerates client clock skew across +// multiple Kratos instances. It performs a two-phase check-and-commit across +// every key so multi-rule operations consume tokens atomically or not at all. +type Store struct { + client goredis.Scripter + prefix string + script *goredis.Script +} + +// NewStore constructs a Redis-backed store. +func NewStore(client goredis.UniversalClient, prefix string) (*Store, error) { + return newScriptStore(client, prefix) +} + +func newScriptStore(client goredis.Scripter, prefix string) (*Store, error) { + if client == nil { + return nil, ErrMissingClient + } + prefix = strings.TrimSpace(prefix) + if prefix == "" { + return nil, ErrMissingPrefix + } + return &Store{ + client: client, + prefix: prefix, + script: goredis.NewScript(takeScript), + }, nil +} + +// Take consumes n tokens from key if capacity is available. +func (s *Store) Take(ctx context.Context, key string, now time.Time, limit ratelimit.Limit, n int) (ratelimit.Result, error) { + results, err := s.TakeMany(ctx, []string{key}, now, []ratelimit.Limit{limit}, n) + if err != nil { + return ratelimit.Result{}, err + } + return results[0], nil +} + +// TakeMany consumes n tokens from every key atomically: either all keys +// commit, or none do. Order of returned Results matches input order. +// +// The now argument is ignored — the script uses Redis-side TIME. +func (s *Store) TakeMany(ctx context.Context, keys []string, _ time.Time, limits []ratelimit.Limit, n int) ([]ratelimit.Result, error) { + if err := s.validateTake(keys, limits, n); err != nil { + return nil, err + } + if len(keys) == 0 { + return nil, nil + } + if n <= 0 { + return allowedResults(len(keys)), nil + } + + redisKeys, args, err := s.buildScriptCall(keys, limits, n) + if err != nil { + return nil, err + } + values, err := s.script.Run(ctx, s.client, redisKeys, args...).Slice() + if err != nil { + return nil, err + } + return parseResults(values, len(keys)) +} + +func (s *Store) validateTake(keys []string, limits []ratelimit.Limit, n int) error { + if s == nil || s.client == nil { + return ErrMissingClient + } + if s.prefix == "" { + return ErrMissingPrefix + } + if len(keys) != len(limits) { + return ErrKeyLimitMismatch + } + if n <= 0 { + return nil + } + for i, key := range keys { + if key == "" { + return ratelimit.ErrMissingKey + } + if err := limits[i].Validate(); err != nil { + return err + } + } + return nil +} + +func allowedResults(n int) []ratelimit.Result { + results := make([]ratelimit.Result, n) + for i := range results { + results[i] = ratelimit.Result{Allowed: true} + } + return results +} + +func (s *Store) buildScriptCall(keys []string, limits []ratelimit.Limit, n int) ([]string, []any, error) { + slotTag, err := sharedSlotTag(keys) + if err != nil { + return nil, nil, err + } + redisKeys := make([]string, len(keys)) + args := make([]any, 0, 1+len(limits)*scriptArgsPerKey) + args = append(args, n) + for i, key := range keys { + redisKeys[i] = s.redisKey(key, slotTag) + args = appendLimitArgs(args, limits[i]) + } + return redisKeys, args, nil +} + +func sharedSlotTag(keys []string) (string, error) { + group := keyGroup(keys[0]) + for _, key := range keys[1:] { + if keyGroup(key) != group { + return "", ErrKeyGroupMismatch + } + } + return sanitizeHashTag(group), nil +} + +func (s *Store) redisKey(key, slotTag string) string { + return "{" + slotTag + "}:" + s.prefix + ":" + key +} + +func keyGroup(key string) string { + group, _, ok := strings.Cut(key, ":") + if ok && group != "" { + return group + } + return key +} + +func sanitizeHashTag(tag string) string { + return strings.NewReplacer("{", "_", "}", "_").Replace(tag) +} + +func appendLimitArgs(args []any, limit ratelimit.Limit) []any { + return append(args, + limit.Rate, + durationMillis(limit.Per), + limit.Burst, + ttl(limit), + ) +} + +func parseResults(values []any, want int) ([]ratelimit.Result, error) { + if len(values) != want*scriptResultWidth { + return nil, fmt.Errorf("%w: got %d values, want %d", ErrInvalidScriptResult, len(values), want*scriptResultWidth) + } + results := make([]ratelimit.Result, want) + for i := 0; i < want; i++ { + res, err := parseResult(values[i*scriptResultWidth : (i+1)*scriptResultWidth]) + if err != nil { + return nil, err + } + results[i] = res + } + return results, nil +} + +func parseResult(values []any) (ratelimit.Result, error) { + allowed, err := int64Value(values[0]) + if err != nil { + return ratelimit.Result{}, err + } + remaining, err := int64Value(values[1]) + if err != nil { + return ratelimit.Result{}, err + } + retryMillis, err := int64Value(values[2]) + if err != nil { + return ratelimit.Result{}, err + } + return ratelimit.Result{ + Allowed: allowed == 1, + Remaining: int(remaining), + RetryAfter: time.Duration(retryMillis) * time.Millisecond, + }, nil +} + +func int64Value(v any) (int64, error) { + switch v := v.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case string: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, fmt.Errorf("%w: %q", ErrInvalidScriptResult, v) + } + return n, nil + default: + return 0, fmt.Errorf("%w: %T", ErrInvalidScriptResult, v) + } +} + +// ttl returns the bucket's PEXPIRE in milliseconds. The bucket must outlive a +// full refill from empty so idle keys don't reset to burst and bypass the +// in-progress retry-after. A 1s floor avoids degenerate sub-second TTLs. +func ttl(limit ratelimit.Limit) int64 { + perMs := durationMillis(limit.Per) + refillFullMs := int64(math.Ceil(float64(limit.Burst) * float64(perMs) / float64(limit.Rate))) + expireMillis := 2 * perMs + if refillFullMs > expireMillis { + expireMillis = refillFullMs + } + if expireMillis < minBucketTTLMillis { + expireMillis = minBucketTTLMillis + } + return expireMillis +} + +func durationMillis(d time.Duration) int64 { + ms := d.Milliseconds() + if ms <= 0 { + return 1 + } + return ms +} + +// takeScript drives one two-phase token consumption across N keys. +// +// ARGV layout: ARGV[1] = n; then 4 args per key starting at ARGV[2]: +// rate, per_ms, burst, ttl_ms. KEYS[i] pairs with ARGV[2+(i-1)*4 .. +3]. +// +// Phase 1 refills every bucket from its stored state and snapshots the tokens. +// Phase 2 checks whether every key has capacity. Phase 3 commits the refilled +// state (and the n-token decrement only when every key passed), then emits +// three values per key: allowed (0|1), floor(remaining_tokens), retry_after_ms. +const takeScript = ` +local n = tonumber(ARGV[1]) +local count = #KEYS + +local t = redis.call('TIME') +local now = t[1] * 1000 + math.floor(t[2] / 1000) + +local snap = {} +for i = 1, count do + local base = 2 + (i - 1) * 4 + local rate = tonumber(ARGV[base]) + local per = tonumber(ARGV[base + 1]) + local burst = tonumber(ARGV[base + 2]) + local ttl = tonumber(ARGV[base + 3]) + + local bucket = redis.call("HMGET", KEYS[i], "tokens", "seen") + local tokens = tonumber(bucket[1]) + local seen = tonumber(bucket[2]) + if tokens == nil or seen == nil then + tokens = burst + seen = now + else + local elapsed = now - seen + if elapsed > 0 then + tokens = math.min(burst, tokens + (elapsed * rate / per)) + seen = now + end + end + + snap[i] = {tokens = tokens, seen = seen, rate = rate, per = per, burst = burst, ttl = ttl} +end + +local allow = true +for i = 1, count do + local s = snap[i] + if n > s.burst or s.tokens < n then + allow = false + break + end +end + +local out = {} +for i = 1, count do + local s = snap[i] + local committed = s.tokens + if allow then + committed = s.tokens - n + end + redis.call("HSET", KEYS[i], "tokens", committed, "seen", s.seen) + redis.call("PEXPIRE", KEYS[i], s.ttl) + + if allow then + table.insert(out, 1) + table.insert(out, math.floor(committed)) + table.insert(out, 0) + else + local retry + if n > s.burst then + retry = math.ceil((n - s.burst) * s.per / s.rate) + elseif s.tokens < n then + retry = math.ceil((n - s.tokens) * s.per / s.rate) + else + retry = 0 + end + table.insert(out, 0) + table.insert(out, math.floor(committed)) + table.insert(out, retry) + end +end +return out +` diff --git a/kratos/ratelimit/redis/store_test.go b/kratos/ratelimit/redis/store_test.go new file mode 100644 index 0000000..6ab1705 --- /dev/null +++ b/kratos/ratelimit/redis/store_test.go @@ -0,0 +1,299 @@ +package redis + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + goredis "github.com/redis/go-redis/v9" +) + +// fakeClient implements goredis.Scripter for tests. It forces the EVALSHA +// fast-path to miss so Eval is called, which lets the test inspect the script +// body and arg layout. +type fakeClient struct { + script string + keys []string + args []any + cmd *goredis.Cmd + evals int +} + +func (f *fakeClient) Eval(_ context.Context, script string, keys []string, args ...any) *goredis.Cmd { + f.evals++ + f.script = script + f.keys = append([]string(nil), keys...) + f.args = append([]any(nil), args...) + if f.cmd != nil { + return f.cmd + } + return goredis.NewCmdResult([]any{int64(1), int64(4), int64(0)}, nil) +} + +func (f *fakeClient) EvalRO(ctx context.Context, script string, keys []string, args ...any) *goredis.Cmd { + return f.Eval(ctx, script, keys, args...) +} + +func (f *fakeClient) EvalSha(_ context.Context, _ string, _ []string, _ ...any) *goredis.Cmd { + return goredis.NewCmdResult(nil, goredis.ErrNoScript) +} + +func (f *fakeClient) EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...any) *goredis.Cmd { + return f.EvalSha(ctx, sha1, keys, args...) +} + +func (f *fakeClient) ScriptExists(_ context.Context, _ ...string) *goredis.BoolSliceCmd { + return goredis.NewBoolSliceResult(nil, nil) +} + +func (f *fakeClient) ScriptLoad(_ context.Context, _ string) *goredis.StringCmd { + return goredis.NewStringResult("", nil) +} + +func TestStoreTakeEvaluatesScript(t *testing.T) { + client := &fakeClient{} + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + res, err := store.Take(context.Background(), "tenant-1", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 5, + Per: time.Second, + Burst: 10, + }, 3) + if err != nil { + t.Fatalf("Take: %v", err) + } + if !res.Allowed || res.Remaining != 4 || res.RetryAfter != 0 { + t.Fatalf("result = %+v, want allowed with 4 remaining", res) + } + if client.script != takeScript { + t.Fatal("Eval script mismatch") + } + if len(client.keys) != 1 || client.keys[0] != "{tenant-1}:api:tenant-1" { + t.Fatalf("keys = %#v, want hash-tagged api key", client.keys) + } + wantArgs := []any{3, 5, int64(1000), 10, int64(2000)} + if len(client.args) != len(wantArgs) { + t.Fatalf("args = %#v, want %#v", client.args, wantArgs) + } + for i := range wantArgs { + if client.args[i] != wantArgs[i] { + t.Fatalf("arg %d = %#v, want %#v", i, client.args[i], wantArgs[i]) + } + } +} + +func TestStoreTakeParsesRejectedResult(t *testing.T) { + client := &fakeClient{ + cmd: goredis.NewCmdResult([]any{int64(0), int64(0), int64(500)}, nil), + } + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + res, err := store.Take(context.Background(), "tenant-1", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 2, + Per: time.Second, + Burst: 1, + }, 1) + if err != nil { + t.Fatalf("Take: %v", err) + } + if res.Allowed || res.Remaining != 0 || res.RetryAfter != 500*time.Millisecond { + t.Fatalf("result = %+v, want rejected with 500ms retry", res) + } +} + +func TestStoreTakeManyParsesPerKeyResults(t *testing.T) { + client := &fakeClient{ + cmd: goredis.NewCmdResult([]any{ + int64(1), int64(9), int64(0), + int64(0), int64(0), int64(750), + }, nil), + } + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + results, err := store.TakeMany(context.Background(), + []string{"/svc/A:user_id:user-1", "/svc/A:client_ip:ip-1"}, + time.UnixMilli(1000), + []ratelimit.Limit{ + {Rate: 10, Per: time.Second, Burst: 10}, + {Rate: 2, Per: time.Second, Burst: 1}, + }, + 1, + ) + if err != nil { + t.Fatalf("TakeMany: %v", err) + } + if len(results) != 2 { + t.Fatalf("got %d results, want 2", len(results)) + } + if !results[0].Allowed || results[0].Remaining != 9 || results[0].RetryAfter != 0 { + t.Fatalf("results[0] = %+v, want allowed with 9 remaining", results[0]) + } + if results[1].Allowed || results[1].RetryAfter != 750*time.Millisecond { + t.Fatalf("results[1] = %+v, want rejected with 750ms retry", results[1]) + } + if len(client.keys) != 2 || client.keys[0] != "{/svc/A}:api:/svc/A:user_id:user-1" || client.keys[1] != "{/svc/A}:api:/svc/A:client_ip:ip-1" { + t.Fatalf("keys = %#v, want shared hash tag from first key in order", client.keys) + } +} + +func TestRedisKeyUsesSanitizedGroupHashTag(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api{prod}") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + tag := sanitizeHashTag(keyGroup("/svc/{Order}:Create")) + if got := store.redisKey("/svc/{Order}:Create:user_id:u1", tag); got != "{/svc/_Order_}:api{prod}:/svc/{Order}:Create:user_id:u1" { + t.Fatalf("redisKey = %q, want sanitized group hash tag", got) + } +} + +func TestStoreTakeManyRejectsKeyLimitMismatch(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.TakeMany(context.Background(), + []string{"a", "b"}, + time.UnixMilli(1000), + []ratelimit.Limit{{Rate: 1, Per: time.Second, Burst: 1}}, + 1, + ) + if !errors.Is(err, ErrKeyLimitMismatch) { + t.Fatalf("TakeMany error = %v, want ErrKeyLimitMismatch", err) + } + + _, err = store.TakeMany(context.Background(), + nil, + time.UnixMilli(1000), + []ratelimit.Limit{{Rate: 1, Per: time.Second, Burst: 1}}, + 1, + ) + if !errors.Is(err, ErrKeyLimitMismatch) { + t.Fatalf("TakeMany empty keys mismatch error = %v, want ErrKeyLimitMismatch", err) + } +} + +func TestStoreTakeManyRejectsKeyGroupMismatch(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.TakeMany(context.Background(), + []string{"/svc/A:user_id:u1", "/svc/B:client_ip:127.0.0.1"}, + time.UnixMilli(1000), + []ratelimit.Limit{ + {Rate: 1, Per: time.Second, Burst: 1}, + {Rate: 1, Per: time.Second, Burst: 1}, + }, + 1, + ) + if !errors.Is(err, ErrKeyGroupMismatch) { + t.Fatalf("TakeMany error = %v, want ErrKeyGroupMismatch", err) + } +} + +func TestStoreTakeManyAllowsNonPositiveNWithoutRedis(t *testing.T) { + client := &fakeClient{} + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + results, err := store.TakeMany(context.Background(), + []string{"tenant-1"}, + time.UnixMilli(1000), + []ratelimit.Limit{{Rate: 1, Per: time.Second, Burst: 1}}, + -1, + ) + if err != nil { + t.Fatalf("TakeMany: %v", err) + } + if len(results) != 1 || !results[0].Allowed { + t.Fatalf("results = %+v, want one allowed result", results) + } + if client.evals != 0 { + t.Fatalf("Eval calls = %d, want 0 for non-positive n", client.evals) + } +} + +func TestStoreTakeReturnsRedisError(t *testing.T) { + redisErr := errors.New("redis unavailable") + store, err := newScriptStore(&fakeClient{cmd: goredis.NewCmdResult(nil, redisErr)}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.Take(context.Background(), "tenant-1", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 1, + Per: time.Second, + Burst: 1, + }, 1) + if !errors.Is(err, redisErr) { + t.Fatalf("Take error = %v, want redis error", err) + } +} + +func TestNewStoreRejectsMissingClient(t *testing.T) { + _, err := NewStore(nil, "api") + if !errors.Is(err, ErrMissingClient) { + t.Fatalf("NewStore error = %v, want ErrMissingClient", err) + } +} + +func TestNewStoreRejectsMissingPrefix(t *testing.T) { + _, err := newScriptStore(&fakeClient{}, "") + if !errors.Is(err, ErrMissingPrefix) { + t.Fatalf("NewStore error = %v, want ErrMissingPrefix", err) + } +} + +func TestStoreTakeRejectsMissingKey(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.Take(context.Background(), "", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 1, + Per: time.Second, + Burst: 1, + }, 1) + if !errors.Is(err, ratelimit.ErrMissingKey) { + t.Fatalf("Take error = %v, want ErrMissingKey", err) + } +} + +func TestParseResultsRejectsInvalidLength(t *testing.T) { + _, err := parseResults([]any{int64(1)}, 1) + if !errors.Is(err, ErrInvalidScriptResult) { + t.Fatalf("parseResults error = %v, want ErrInvalidScriptResult", err) + } +} + +func TestDurationMillisRoundsTinyDurationUp(t *testing.T) { + if got := durationMillis(time.Nanosecond); got != 1 { + t.Fatalf("durationMillis(time.Nanosecond) = %d, want 1", got) + } +} + +func TestTTLCoversFullRefillFromEmpty(t *testing.T) { + got := ttl(ratelimit.Limit{Rate: 1, Per: time.Minute, Burst: 600}) + wantMin := int64(600) * time.Minute.Milliseconds() / 1 // refill-from-empty time + if got < wantMin { + t.Fatalf("ttl = %d ms, want >= %d ms so idle keys do not reset to burst", got, wantMin) + } +} diff --git a/kubernetes/election/election.go b/kubernetes/election/election.go index 1ec10f4..06b64ed 100644 --- a/kubernetes/election/election.go +++ b/kubernetes/election/election.go @@ -3,11 +3,11 @@ package election import ( "context" "fmt" - "io" "log/slog" "os" "strings" "sync" + "sync/atomic" "time" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -17,6 +17,11 @@ import ( "k8s.io/client-go/tools/leaderelection/resourcelock" ) +const ( + loggerNamed = "__LOGGER.NAMED__" + namespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" +) + type ( // StateMachine is the state machine interface. StateMachine interface { @@ -26,15 +31,13 @@ type ( EnsureMaster(ctx context.Context) error // EnsureSlave ensures the state machine is slave. EnsureSlave(ctx context.Context) error - // Do the state machine. + // Do runs one state machine iteration and returns the delay before the next run. Do(ctx context.Context) (after time.Duration) - // Cleanup the state machine. + // Cleanup releases resources held by the state machine. Cleanup() } - // StateMachineRunner is the state machine runner interface + // StateMachineRunner runs leader-elected state machines. StateMachineRunner interface { - // implements kratos.Server - // Start starts the state machine runner. Start(context.Context) error // Stop stops the state machine runner. @@ -45,11 +48,12 @@ type ( } ) -// StateMachiRunnerImpl is the state machine runner implementation. -type StateMachiRunnerImpl struct { - ctx context.Context - cancel func() - closed bool +// StateMachineRunnerImpl is the state machine runner implementation. +type StateMachineRunnerImpl struct { + ctx context.Context + cancel func() + closed atomic.Bool + cleanupOnce sync.Once wg sync.WaitGroup @@ -60,29 +64,35 @@ type StateMachiRunnerImpl struct { logger *slog.Logger } +// StateMachiRunnerImpl is deprecated. +// +// Deprecated: Use StateMachineRunnerImpl. +type StateMachiRunnerImpl = StateMachineRunnerImpl + // Start starts the state machine runner. -func (s *StateMachiRunnerImpl) Start(context.Context) error { return nil } +func (s *StateMachineRunnerImpl) Start(context.Context) error { return nil } // Stop stops the state machine runner. -func (s *StateMachiRunnerImpl) Stop(context.Context) error { return nil } +func (s *StateMachineRunnerImpl) Stop(context.Context) error { + s.cleanup() + return nil +} // cleanup cleans up the state machine runner. -func (s *StateMachiRunnerImpl) cleanup() { - s.closed = true - s.cancel() - s.wg.Wait() +func (s *StateMachineRunnerImpl) cleanup() { + s.cleanupOnce.Do(func() { + s.closed.Store(true) + s.cancel() + s.wg.Wait() + }) } // serveMachine serves the state machine. -func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { - // The logger name is conventionally assigned to the key "__LOGGER.NAMED__" defined in go-kit/zap. - const ( - LoggerNamed = "__LOGGER.NAMED__" - ) +func (s *StateMachineRunnerImpl) serveMachine(machine StateMachine) { + defer s.wg.Done() - // type hint here that can be omitted name := fmt.Sprintf("state-machine-runner-%s", machine.Name()) - logger := s.logger.With(LoggerNamed, name) + logger := s.logger.With(loggerNamed, name) // lease lock name rule: [a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)* isLeaderChan := make(chan bool, 10) @@ -127,7 +137,6 @@ func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { ctx, cancel := context.WithCancel(s.ctx) - defer s.wg.Done() defer func() { logger.Info("stopped") }() defer machine.Cleanup() defer cancel() @@ -161,7 +170,7 @@ func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { } } - for !s.closed { + for !s.closed.Load() { after := machine.Do(ctx) if after <= 0 { return @@ -178,18 +187,13 @@ func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { } } -func (s *StateMachiRunnerImpl) AddMachine(machine StateMachine) { +func (s *StateMachineRunnerImpl) AddMachine(machine StateMachine) { s.wg.Add(1) go s.serveMachine(machine) } // NewStateMachineRunnerImpl creates a new StateMachineRunner. func NewStateMachineRunnerImpl(logger *slog.Logger) (StateMachineRunner, func(), error) { - out := &StateMachiRunnerImpl{ - logger: logger, - } - out.ctx, out.cancel = context.WithCancel(context.Background()) - config, err := rest.InClusterConfig() if err != nil { return nil, nil, err @@ -202,17 +206,19 @@ func NewStateMachineRunnerImpl(logger *slog.Logger) (StateMachineRunner, func(), if err != nil { return nil, nil, err } - out.cli, out.namespace, out.pod = cli, GetCurrentNamespace(), pod - return out, sync.OnceFunc(out.cleanup), nil + out := &StateMachineRunnerImpl{ + logger: logger, + cli: cli, + namespace: GetCurrentNamespace(), + pod: pod, + } + out.ctx, out.cancel = context.WithCancel(context.Background()) + return out, out.cleanup, nil } // GetCurrentNamespace returns the current namespace in the kubernetes cluster. func GetCurrentNamespace() (namespace string) { - namespaceFile, err := os.Open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") - if err != nil { - return "" - } - d, err := io.ReadAll(namespaceFile) + d, err := os.ReadFile(namespacePath) if err != nil { return "" } diff --git a/kubernetes/election/election_test.go b/kubernetes/election/election_test.go new file mode 100644 index 0000000..cdf1a2d --- /dev/null +++ b/kubernetes/election/election_test.go @@ -0,0 +1,27 @@ +package election + +import ( + "context" + "testing" +) + +func TestStateMachineRunnerStopIsIdempotent(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + runner := &StateMachineRunnerImpl{ + ctx: ctx, + cancel: cancel, + } + + if err := runner.Stop(context.Background()); err != nil { + t.Fatalf("first Stop: %v", err) + } + if err := runner.Stop(context.Background()); err != nil { + t.Fatalf("second Stop: %v", err) + } + + select { + case <-ctx.Done(): + default: + t.Fatal("Stop did not cancel runner context") + } +} diff --git a/kubernetes/kubernetes.go b/kubernetes/kubernetes.go index 907fbe2..39fd061 100644 --- a/kubernetes/kubernetes.go +++ b/kubernetes/kubernetes.go @@ -1,18 +1,15 @@ package kubernetes import ( - "io" "os" "strings" ) +const namespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + // GetCurrentNamespace returns the current namespace in the kubernetes cluster. func GetCurrentNamespace() (namespace string) { - namespaceFile, err := os.Open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") - if err != nil { - return "" - } - d, err := io.ReadAll(namespaceFile) + d, err := os.ReadFile(namespacePath) if err != nil { return "" } diff --git a/lifecycle/loop_runner.go b/lifecycle/loop_runner.go new file mode 100644 index 0000000..7fbc415 --- /dev/null +++ b/lifecycle/loop_runner.go @@ -0,0 +1,206 @@ +// Package lifecycle provides helpers for service lifecycles. +package lifecycle + +import ( + "context" + "errors" + "log/slog" + "sync" + "time" +) + +// ErrAlreadyStarted reports a runner configuration change after Start. +var ErrAlreadyStarted = errors.New("lifecycle: runner already started") + +// Service is the common lifecycle interface used by Kratos servers. +type Service interface { + Start(context.Context) error + Stop(context.Context) error +} + +// LoopRunner manages a group of background loops. +type LoopRunner interface { + Service + + // OnStart registers a callback that runs during Start before any loops are + // spawned. Calling OnStart twice overrides the previous callback. + OnStart(func(context.Context) error) LoopRunner + + // Add registers a loop to spawn during Start. + Add(name string, fn func(context.Context)) LoopRunner + + // AddTick registers a periodic loop that invokes fn every interval. + AddTick(name string, interval time.Duration, fn func(context.Context) error) LoopRunner +} + +type namedLoop struct { + name string + fn func(context.Context) +} + +type loopRunner struct { + name string + logger *slog.Logger + + mu sync.Mutex + onStart func(context.Context) error + loops []namedLoop + cancel context.CancelFunc + started bool + + wg sync.WaitGroup +} + +// NewLoopRunner constructs a runner scoped to name. +func NewLoopRunner(name string, logger *slog.Logger) LoopRunner { + if logger == nil { + panic("lifecycle: logger is nil") + } + if name == "" { + panic("lifecycle: service name is empty") + } + return &loopRunner{ + name: name, + logger: logger.With("service", name), + } +} + +func (r *loopRunner) OnStart(fn func(context.Context) error) LoopRunner { + if err := r.configure(func() { + r.onStart = fn + }); err != nil { + panic(err) + } + return r +} + +func (r *loopRunner) Add(name string, fn func(context.Context)) LoopRunner { + if name == "" { + panic("lifecycle: loop name is empty") + } + if fn == nil { + panic("lifecycle: loop function is nil") + } + if err := r.configure(func() { + r.loops = append(r.loops, namedLoop{name: name, fn: fn}) + }); err != nil { + panic(err) + } + return r +} + +func (r *loopRunner) AddTick(name string, interval time.Duration, fn func(context.Context) error) LoopRunner { + if interval <= 0 { + panic("lifecycle: tick interval must be positive") + } + if fn == nil { + panic("lifecycle: tick function is nil") + } + return r.Add(name, func(ctx context.Context) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := fn(ctx); err != nil { + r.logger.ErrorContext(ctx, "loop tick failed", "loop", name, "err", err) + } + } + } + }) +} + +// Start implements Service. +func (r *loopRunner) Start(ctx context.Context) error { + runCtx, onStart, loops, err := r.start() + if err != nil { + return err + } + if onStart != nil { + if err := onStart(ctx); err != nil { + _ = r.Stop(context.Background()) + return err + } + } + for _, l := range loops { + r.spawn(runCtx, l.name, l.fn) + } + r.logger.InfoContext(ctx, "service started", "loops", len(loops)) + return nil +} + +// Stop implements Service. +func (r *loopRunner) Stop(ctx context.Context) error { + cancel := r.stop() + if cancel == nil { + return nil + } + cancel() + + done := make(chan struct{}) + go func() { + r.wg.Wait() + r.finishStop() + close(done) + }() + + select { + case <-done: + r.logger.InfoContext(ctx, "service stopped") + case <-ctx.Done(): + r.logger.WarnContext(ctx, "stop deadline exceeded, goroutines may still be running") + } + return nil +} + +func (r *loopRunner) configure(fn func()) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.started { + return ErrAlreadyStarted + } + fn() + return nil +} + +func (r *loopRunner) start() (context.Context, func(context.Context) error, []namedLoop, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.started { + return nil, nil, nil, ErrAlreadyStarted + } + runCtx, cancel := context.WithCancel(context.Background()) + r.cancel = cancel + r.started = true + return runCtx, r.onStart, append([]namedLoop(nil), r.loops...), nil +} + +func (r *loopRunner) stop() context.CancelFunc { + r.mu.Lock() + defer r.mu.Unlock() + cancel := r.cancel + r.cancel = nil + return cancel +} + +func (r *loopRunner) finishStop() { + r.mu.Lock() + defer r.mu.Unlock() + r.started = false +} + +func (r *loopRunner) spawn(ctx context.Context, name string, fn func(context.Context)) { + r.wg.Go(func() { + defer func() { + if rec := recover(); rec != nil { + r.logger.ErrorContext(ctx, "background goroutine panicked", "loop", name, "panic", rec) + } + }() + if fn == nil { + return + } + fn(ctx) + }) +} diff --git a/lifecycle/loop_runner_test.go b/lifecycle/loop_runner_test.go new file mode 100644 index 0000000..1f32bbf --- /dev/null +++ b/lifecycle/loop_runner_test.go @@ -0,0 +1,153 @@ +package lifecycle + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestLoopRunnerStartsAndStopsLoops(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + var loaded atomic.Bool + started := make(chan struct{}) + + runner := NewLoopRunner("test", logger). + OnStart(func(context.Context) error { + loaded.Store(true) + return nil + }). + Add("main", func(ctx context.Context) { + if !loaded.Load() { + t.Error("loop started before OnStart completed") + } + close(started) + <-ctx.Done() + }) + + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + waitFor(t, started) + + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := runner.Stop(stopCtx); err != nil { + t.Fatalf("Stop: %v", err) + } + if !strings.Contains(buf.String(), "service stopped") { + t.Fatalf("logs = %q, want service stopped", buf.String()) + } +} + +func TestLoopRunnerRejectsConfigurationAfterStart(t *testing.T) { + runner := NewLoopRunner("test", testLogger()).Add("main", func(ctx context.Context) { + <-ctx.Done() + }) + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + defer func() { + _ = runner.Stop(context.Background()) + }() + + err := recoverPanic(func() { + runner.Add("late", func(context.Context) {}) + }) + if !errors.Is(err, ErrAlreadyStarted) { + t.Fatalf("panic = %v, want ErrAlreadyStarted", err) + } +} + +func TestLoopRunnerRecoversLoopPanic(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + runner := NewLoopRunner("test", logger).Add("panic", func(context.Context) { + panic("boom") + }) + + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := runner.Stop(stopCtx); err != nil { + t.Fatalf("Stop: %v", err) + } + if !strings.Contains(buf.String(), "background goroutine panicked") { + t.Fatalf("logs = %q, want panic log", buf.String()) + } +} + +func TestLoopRunnerTickLoop(t *testing.T) { + ticked := make(chan struct{}, 1) + runner := NewLoopRunner("test", testLogger()).AddTick("tick", time.Millisecond, func(context.Context) error { + select { + case ticked <- struct{}{}: + default: + } + return nil + }) + + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + waitFor(t, ticked) + + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := runner.Stop(stopCtx); err != nil { + t.Fatalf("Stop: %v", err) + } +} + +func TestNewLoopRunnerRejectsMissingName(t *testing.T) { + err := recoverPanic(func() { + NewLoopRunner("", testLogger()) + }) + if err == nil || err.Error() != "panic" { + t.Fatalf("panic = %v, want panic for missing name", err) + } +} + +func TestNewLoopRunnerRejectsMissingLogger(t *testing.T) { + err := recoverPanic(func() { + NewLoopRunner("test", nil) + }) + if err == nil || err.Error() != "panic" { + t.Fatalf("panic = %v, want panic for missing logger", err) + } +} + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func waitFor(t *testing.T, ch <-chan struct{}) { + t.Helper() + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timed out waiting for channel") + } +} + +func recoverPanic(fn func()) (err error) { + defer func() { + if rec := recover(); rec != nil { + if e, ok := rec.(error); ok { + err = e + return + } + err = errors.New("panic") + } + }() + fn() + return nil +} diff --git a/maxmind/maxmind.go b/maxmind/maxmind.go index 3b54191..800eb70 100644 --- a/maxmind/maxmind.go +++ b/maxmind/maxmind.go @@ -7,7 +7,7 @@ import ( "github.com/oschwald/maxminddb-golang" ) -// GeoNames is a struct for multiple languages +// GeoNames is a struct for multiple languages. type GeoNames struct { German string `maxminddb:"de"` English string `maxminddb:"en"` @@ -19,7 +19,7 @@ type GeoNames struct { Chinese string `maxminddb:"zh-CN"` } -// GeoCity is a struct for maxminddb city result +// GeoCity is a struct for maxminddb city result. type GeoCity struct { City struct { Name GeoNames `maxminddb:"names"` @@ -53,17 +53,20 @@ type GeoCity struct { var emptyGeoCity = GeoCity{} -// Database is an interface for maxminddb +// Database reads GeoCity records from a MaxMind database. type Database interface { - // Lookup returns GeoCity for given IP + // Lookup returns GeoCity for given IP. Lookup(ip net.IP) (*GeoCity, error) } -// DatabaseImpl is an implementation of Database +// DatabaseImpl implements Database. type DatabaseImpl struct { db *maxminddb.Reader } +var _ Database = (*DatabaseImpl)(nil) + +// Lookup returns GeoCity for given IP. func (d *DatabaseImpl) Lookup(ip net.IP) (*GeoCity, error) { var record GeoCity if err := d.db.Lookup(ip, &record); err != nil { @@ -75,16 +78,16 @@ func (d *DatabaseImpl) Lookup(ip net.IP) (*GeoCity, error) { return &record, nil } -// Path is a type for maxminddb path +// Path is a type for maxminddb path. type Path string -// ContainerPath returns path to maxminddb container +// ContainerPath returns path to maxminddb container. func ContainerPath() Path { return "/app/bin/GeoLite2-City.mmdb" } -// NewDatabaseImpl returns implementation of Database -func NewDatabaseImpl(path Path) (Database, func(), error) { +// NewDatabase opens a MaxMind database. +func NewDatabase(path Path) (*DatabaseImpl, func(), error) { db, err := maxminddb.Open(string(path)) if err != nil { return nil, nil, err @@ -94,7 +97,18 @@ func NewDatabaseImpl(path Path) (Database, func(), error) { }, nil } -// IsEmptyGeoCity checks if GeoCity is empty +// NewDatabaseImpl returns implementation of Database. +// +// Deprecated: Use NewDatabase when a concrete *DatabaseImpl is acceptable. +func NewDatabaseImpl(path Path) (Database, func(), error) { + database, cleanup, err := NewDatabase(path) + if err != nil { + return nil, nil, err + } + return database, cleanup, nil +} + +// IsEmptyGeoCity checks if GeoCity is empty. func IsEmptyGeoCity(geoCity GeoCity) bool { return reflect.DeepEqual(geoCity, emptyGeoCity) } diff --git a/maxmind/maxmind_test.go b/maxmind/maxmind_test.go index 989edb4..217363d 100644 --- a/maxmind/maxmind_test.go +++ b/maxmind/maxmind_test.go @@ -2,7 +2,9 @@ package maxmind import ( "encoding/json" + "errors" "net" + "os" "testing" "github.com/oschwald/maxminddb-golang" @@ -11,14 +13,21 @@ import ( func TestMaxmindRead(t *testing.T) { reader, err := maxminddb.Open("./GeoLite2-City.mmdb") if err != nil { + if errors.Is(err, os.ErrNotExist) { + t.Skip("GeoLite2-City.mmdb is not available") + } t.Fatal(err) } + defer reader.Close() var record GeoCity internalIP := net.ParseIP("81.2.69.142") if err = reader.Lookup(internalIP, &record); err != nil { t.Fatal(err) } - b, _ := json.Marshal(record) + b, err := json.Marshal(record) + if err != nil { + t.Fatal(err) + } t.Log(string(b), IsEmptyGeoCity(record)) } diff --git a/otel/resouce.go b/otel/resouce.go index fcd046b..dd9f946 100644 --- a/otel/resouce.go +++ b/otel/resouce.go @@ -3,22 +3,22 @@ package otel import "go.opentelemetry.io/otel/attribute" const ( - // SigNozSystemDBKey span type database call + // SigNozSystemDBKey is the span attribute key for database calls. // https://signoz.io/docs/userguide/metrics/ SigNozSystemDBKey = attribute.Key("db.system") ) -// SigNozSystemDB return db system attribute +// SigNozSystemDB returns a database system attribute. func SigNozSystemDB(system string) attribute.KeyValue { return SigNozSystemDBKey.String(system) } -// SigNozSystemDBPostgres return db system attribute for postgres +// SigNozSystemDBPostgres returns a database system attribute for PostgreSQL. func SigNozSystemDBPostgres() attribute.KeyValue { return SigNozSystemDB("postgresql") } -// SigNozSystemDBNats return db system attribute for nats +// SigNozSystemDBNats returns a database system attribute for NATS. func SigNozSystemDBNats() attribute.KeyValue { return SigNozSystemDB("nats") } diff --git a/otel/trace_provider.go b/otel/trace_provider.go index a8fa8e7..24a0f61 100644 --- a/otel/trace_provider.go +++ b/otel/trace_provider.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "strings" + "time" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -17,7 +18,15 @@ import ( "github.com/crypto-zero/go-kit/kubernetes" ) -// TraceProviderConfig is an open telemetry trace provider config. +const traceShutdownTimeout = 5 * time.Second + +// TraceProvider is an open telemetry trace service. +// +// Deprecated: This broad compatibility type is an alias-shaped service token. +// Consumers should depend on the behavior they need instead of this type. +type TraceProvider any + +// TraceProviderConfig configures an OpenTelemetry trace provider. type TraceProviderConfig struct { Context context.Context Name string @@ -28,7 +37,7 @@ type TraceProviderConfig struct { SampleFraction float64 } -// FromEnv load config from env. +// FromEnv loads trace provider config from environment variables. func (c *TraceProviderConfig) FromEnv() { value := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") // The env var may contain a scheme, which we need to remove. @@ -39,12 +48,12 @@ func (c *TraceProviderConfig) FromEnv() { } } -// TraceProvider is an open telemetry trace service. -type TraceProvider any - +// TraceProviderImpl is an OpenTelemetry trace service. type TraceProviderImpl struct{} -// NewTraceProvider new an open telemetry trace provider. +// NewTraceProvider creates an OpenTelemetry trace provider. +// +// It returns TraceProvider for backward compatibility with earlier releases. func NewTraceProvider(c *TraceProviderConfig) ( TraceProvider, func(), error, ) { @@ -57,7 +66,11 @@ func NewTraceProvider(c *TraceProviderConfig) ( exportGrpcOptions = append(exportGrpcOptions, otlptracegrpc.WithInsecure()) } exportGrpcOptions = append(exportGrpcOptions, otlptracegrpc.WithEndpoint(c.Endpoint)) - exporter, err := otlptrace.New(c.Context, otlptracegrpc.NewClient(exportGrpcOptions...)) + ctx := c.Context + if ctx == nil { + ctx = context.Background() + } + exporter, err := otlptrace.New(ctx, otlptracegrpc.NewClient(exportGrpcOptions...)) if err != nil { return nil, nil, fmt.Errorf("failed to create the collector exporter: %w", err) } @@ -79,12 +92,15 @@ func NewTraceProvider(c *TraceProviderConfig) ( } } - otel.SetTracerProvider( - sdktrace.NewTracerProvider( - sdktrace.WithSampler(sdktrace.TraceIDRatioBased(c.SampleFraction)), - sdktrace.WithBatcher(exporter), - sdktrace.WithResource(resource.NewSchemaless(attrs...)), - ), + provider := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.TraceIDRatioBased(c.SampleFraction)), + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(resource.NewSchemaless(attrs...)), ) - return &TraceProviderImpl{}, func() {}, nil + otel.SetTracerProvider(provider) + return &TraceProviderImpl{}, func() { + ctx, cancel := context.WithTimeout(context.Background(), traceShutdownTimeout) + defer cancel() + _ = provider.Shutdown(ctx) + }, nil } diff --git a/pgx/pgx.go b/pgx/pgx.go index c44d8b4..a22532b 100644 --- a/pgx/pgx.go +++ b/pgx/pgx.go @@ -10,7 +10,8 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -// ent not support for generic types, so we need to declare wrapper types for each type +// Ent does not support generic field types, so concrete wrapper types are +// declared for each supported database value. type ( // CIDRWrapper is a wrapper for pgx standard sql library types. CIDRWrapper StdWrapper[netip.Prefix] @@ -32,23 +33,15 @@ type ( ) // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w CIDRWrapper) Value() (driver.Value, error) { return StdWrapper[netip.Prefix](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *CIDRWrapper) Scan(src any) error { return (*StdWrapper[netip.Prefix])(w).Scan(src) } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w DurationWrapper) Value() (driver.Value, error) { return StdWrapper[time.Duration](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *DurationWrapper) Scan(src any) error { return (*StdWrapper[time.Duration])(w).Scan(src) } @@ -57,52 +50,36 @@ func (w *DurationWrapper) Scan(src any) error { func NewIntsWrapper() IntsWrapper { return IntsWrapper{V: make([]int, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w IntsWrapper) Value() (driver.Value, error) { return SliceWrapper[int](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *IntsWrapper) Scan(src any) error { return (*SliceWrapper[int])(w).Scan(src) } // NewFloatsWrapper returns a new FloatsWrapper. func NewFloatsWrapper() FloatsWrapper { return FloatsWrapper{V: make([]float64, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w FloatsWrapper) Value() (driver.Value, error) { return SliceWrapper[float64](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *FloatsWrapper) Scan(src any) error { return (*SliceWrapper[float64])(w).Scan(src) } // NewStringsWrapper returns a new StringsWrapper. func NewStringsWrapper() StringsWrapper { return StringsWrapper{V: make([]string, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w StringsWrapper) Value() (driver.Value, error) { return SliceWrapper[string](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *StringsWrapper) Scan(src any) error { return (*SliceWrapper[string])(w).Scan(src) } // NewCIDRsWrapper returns a new CIDRsWrapper. func NewCIDRsWrapper() CIDRsWrapper { return CIDRsWrapper{V: make([]netip.Prefix, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w CIDRsWrapper) Value() (driver.Value, error) { return SliceWrapper[netip.Prefix](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *CIDRsWrapper) Scan(src any) error { return (*SliceWrapper[netip.Prefix])(w).Scan(src) } // NewDurationsWrapper returns a new DurationsWrapper. @@ -111,15 +88,11 @@ func NewDurationsWrapper() DurationsWrapper { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w DurationsWrapper) Value() (driver.Value, error) { return SliceWrapper[time.Duration](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *DurationsWrapper) Scan(src any) error { return (*SliceWrapper[time.Duration])(w).Scan(src) } @@ -130,15 +103,11 @@ func NewTimestampsWrapper() TimestampsWrapper { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w TimestampsWrapper) Value() (driver.Value, error) { return SliceWrapper[time.Time](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *TimestampsWrapper) Scan(src any) error { return (*SliceWrapper[time.Time])(w).Scan(src) } @@ -156,15 +125,11 @@ type StdWrapper[T any] struct { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w StdWrapper[T]) Value() (driver.Value, error) { return w.V, nil } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *StdWrapper[T]) Scan(src any) (err error) { return typeMapScan(src, &w.V) } @@ -175,10 +140,8 @@ type SliceWrapper[T any] struct { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w SliceWrapper[T]) Value() (driver.Value, error) { - // pgx treats nil slice as NULL, but we want to treat it as an empty array + // pgx treats nil slices as NULL, but callers expect an empty array. if w.V == nil { return make([]T, 0), nil } @@ -186,8 +149,6 @@ func (w SliceWrapper[T]) Value() (driver.Value, error) { } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *SliceWrapper[T]) Scan(src any) (err error) { return typeMapScan(src, &w.V) } @@ -220,7 +181,7 @@ func guessingScan[T any](src any) (value T, err error) { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(bufSrc)) + bufSrc = []byte(fmt.Sprint(src)) } } diff --git a/pgx/pgx_test.go b/pgx/pgx_test.go index c83ebb9..3fe923f 100644 --- a/pgx/pgx_test.go +++ b/pgx/pgx_test.go @@ -16,18 +16,21 @@ import ( ) var db *sql.DB +var skipPGXTests string func TestMain(m *testing.M) { // uses a sensible default on windows (tcp/http) and linux/osx (socket) pool, err := dockertest.NewPool("") if err != nil { - log.Fatalf("Could not construct pool: %s", err) + skipPGXTests = fmt.Sprintf("could not construct Docker pool: %s", err) + os.Exit(m.Run()) } // uses pool to try to connect to Docker err = pool.Client.Ping() if err != nil { - log.Fatalf("Could not connect to Docker: %s", err) + skipPGXTests = fmt.Sprintf("could not connect to Docker: %s", err) + os.Exit(m.Run()) } // pulls an image, creates a container based on it and runs it @@ -47,7 +50,8 @@ func TestMain(m *testing.M) { } }) if err != nil { - log.Fatalf("Could not start resource: %s", err) + skipPGXTests = fmt.Sprintf("could not start PostgreSQL container: %s", err) + os.Exit(m.Run()) } // exponential backoff-retry, because the application in the container might not be ready to accept connections yet @@ -59,12 +63,16 @@ func TestMain(m *testing.M) { } return db.Ping() }); err != nil { - log.Fatalf("Could not connect to database: %s", err) + skipPGXTests = fmt.Sprintf("could not connect to database: %s", err) + os.Exit(m.Run()) } code := m.Run() // You can't defer this because os.Exit doesn't care for defer + if err := db.Close(); err != nil { + log.Fatalf("Could not close database: %s", err) + } if err := pool.Purge(resource); err != nil { log.Fatalf("Could not purge resource: %s", err) } @@ -72,7 +80,31 @@ func TestMain(m *testing.M) { os.Exit(code) } +func requirePGXDB(t *testing.T) *sql.DB { + t.Helper() + if skipPGXTests != "" { + t.Skip(skipPGXTests) + } + return db +} + +func TestGuessingScanUsesFallbackSource(t *testing.T) { + type textJSON string + type person struct { + Name string `json:"name"` + } + + got, err := guessingScan[person](textJSON(`{"name":"John"}`)) + if err != nil { + t.Fatal(err) + } + if got.Name != "John" { + t.Fatalf("guessingScan() = %+v; want name John", got) + } +} + func TestPGXIntArray(t *testing.T) { + db := requirePGXDB(t) input := []int{1, 2, 3} var output StdWrapper[[]int] if err := db.QueryRow("select $1::int[]", input).Scan(&output); err != nil { @@ -84,6 +116,7 @@ func TestPGXIntArray(t *testing.T) { } func TestPGXJSON(t *testing.T) { + db := requirePGXDB(t) type person struct { Name string `json:"name"` Age int `json:"age"` @@ -99,6 +132,7 @@ func TestPGXJSON(t *testing.T) { } func TestPGXJSONArray(t *testing.T) { + db := requirePGXDB(t) type person struct { Name string `json:"name"` Age int `json:"age"` @@ -122,6 +156,7 @@ func TestPGXJSONArray(t *testing.T) { } func TestPGXNetPrefix(t *testing.T) { + db := requirePGXDB(t) input := netip.MustParsePrefix("255.255.255.255/32") var output StdWrapper[netip.Prefix] if err := db.QueryRow("select $1::cidr", input).Scan(&output); err != nil { @@ -133,6 +168,7 @@ func TestPGXNetPrefix(t *testing.T) { } func TestPGXNetPrefixArray(t *testing.T) { + db := requirePGXDB(t) input := []netip.Prefix{ netip.MustParsePrefix("127.0.0.1/32"), netip.MustParsePrefix("10.0.0.0/8"), diff --git a/pprof/pprof.go b/pprof/pprof.go index 2d18a5d..41d7954 100644 --- a/pprof/pprof.go +++ b/pprof/pprof.go @@ -1,36 +1,62 @@ package pprof import ( + "errors" "fmt" - "log" + "log/slog" "net" "net/http" _ "net/http/pprof" + "sync" "github.com/google/gops/agent" ) // Pprof is a pprof service. +// +// Deprecated: This broad compatibility type is an alias-shaped service token. +// Consumers should depend on the behavior they need instead of this type. type Pprof any // PprofImpl is a pprof service implementation. -type PprofImpl struct{} +type PprofImpl struct { + listener net.Listener + once sync.Once +} // NewPProfImpl returns a new PprofImpl. -// it provides gops agent and pprof service. +// It provides gops agent and pprof service. +// +// It returns Pprof for backward compatibility with earlier releases. func NewPProfImpl() (Pprof, func(), error) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, func() {}, fmt.Errorf("start pprof failed: %v", err) + return nil, func() {}, fmt.Errorf("start pprof failed: %w", err) } - log.Println("start pprof service on:", ln.Addr()) + service := &PprofImpl{listener: ln} + cleanup := service.Close + + slog.Info("start pprof service", "addr", ln.Addr().String()) go func() { - _ = http.Serve(ln, nil) + if err := http.Serve(ln, nil); err != nil && !errors.Is(err, net.ErrClosed) { + slog.Error("pprof service stopped", "err", err) + } }() if err := agent.Listen(agent.Options{ShutdownCleanup: false}); err != nil { - return nil, func() {}, fmt.Errorf("start gops agent failed: %v", err) + cleanup() + return nil, func() {}, fmt.Errorf("start gops agent failed: %w", err) } - return &PprofImpl{}, func() {}, nil + return service, cleanup, nil +} + +// Close stops the pprof service and gops agent. +func (p *PprofImpl) Close() { + p.once.Do(func() { + agent.Close() + if p.listener != nil { + _ = p.listener.Close() + } + }) } diff --git a/proto/kit/ratelimit/v1/ratelimit.pb.go b/proto/kit/ratelimit/v1/ratelimit.pb.go deleted file mode 100644 index 6a1906a..0000000 --- a/proto/kit/ratelimit/v1/ratelimit.pb.go +++ /dev/null @@ -1,243 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v7.34.1 -// source: proto/kit/ratelimit/v1/ratelimit.proto - -package ratelimitv1 - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - descriptorpb "google.golang.org/protobuf/types/descriptorpb" - durationpb "google.golang.org/protobuf/types/known/durationpb" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type Key int32 - -const ( - Key_KEY_UNSPECIFIED Key = 0 - Key_KEY_OPERATION Key = 1 - Key_KEY_CLIENT_IP Key = 2 - Key_KEY_OPERATION_CLIENT_IP Key = 3 -) - -// Enum value maps for Key. -var ( - Key_name = map[int32]string{ - 0: "KEY_UNSPECIFIED", - 1: "KEY_OPERATION", - 2: "KEY_CLIENT_IP", - 3: "KEY_OPERATION_CLIENT_IP", - } - Key_value = map[string]int32{ - "KEY_UNSPECIFIED": 0, - "KEY_OPERATION": 1, - "KEY_CLIENT_IP": 2, - "KEY_OPERATION_CLIENT_IP": 3, - } -) - -func (x Key) Enum() *Key { - p := new(Key) - *p = x - return p -} - -func (x Key) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (Key) Descriptor() protoreflect.EnumDescriptor { - return file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes[0].Descriptor() -} - -func (Key) Type() protoreflect.EnumType { - return &file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes[0] -} - -func (x Key) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use Key.Descriptor instead. -func (Key) EnumDescriptor() ([]byte, []int) { - return file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescGZIP(), []int{0} -} - -type RateLimit struct { - state protoimpl.MessageState `protogen:"open.v1"` - // rate is the number of tokens replenished every per duration. - Rate int32 `protobuf:"varint,1,opt,name=rate,proto3" json:"rate,omitempty"` - // per is the refill window for rate tokens. - Per *durationpb.Duration `protobuf:"bytes,2,opt,name=per,proto3" json:"per,omitempty"` - // burst is the maximum number of tokens a key can accumulate. - Burst int32 `protobuf:"varint,3,opt,name=burst,proto3" json:"burst,omitempty"` - // key selects how requests are grouped into buckets. - Key Key `protobuf:"varint,4,opt,name=key,proto3,enum=kit.ratelimit.v1.Key" json:"key,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *RateLimit) Reset() { - *x = RateLimit{} - mi := &file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *RateLimit) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*RateLimit) ProtoMessage() {} - -func (x *RateLimit) ProtoReflect() protoreflect.Message { - mi := &file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use RateLimit.ProtoReflect.Descriptor instead. -func (*RateLimit) Descriptor() ([]byte, []int) { - return file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescGZIP(), []int{0} -} - -func (x *RateLimit) GetRate() int32 { - if x != nil { - return x.Rate - } - return 0 -} - -func (x *RateLimit) GetPer() *durationpb.Duration { - if x != nil { - return x.Per - } - return nil -} - -func (x *RateLimit) GetBurst() int32 { - if x != nil { - return x.Burst - } - return 0 -} - -func (x *RateLimit) GetKey() Key { - if x != nil { - return x.Key - } - return Key_KEY_UNSPECIFIED -} - -var file_proto_kit_ratelimit_v1_ratelimit_proto_extTypes = []protoimpl.ExtensionInfo{ - { - ExtendedType: (*descriptorpb.MethodOptions)(nil), - ExtensionType: (*RateLimit)(nil), - Field: 106001, - Name: "kit.ratelimit.v1.rate_limit", - Tag: "bytes,106001,opt,name=rate_limit", - Filename: "proto/kit/ratelimit/v1/ratelimit.proto", - }, -} - -// Extension fields to descriptorpb.MethodOptions. -var ( - // rate_limit configures per-method rate limiting for Kratos middleware. - // - // optional kit.ratelimit.v1.RateLimit rate_limit = 106001; - E_RateLimit = &file_proto_kit_ratelimit_v1_ratelimit_proto_extTypes[0] -) - -var File_proto_kit_ratelimit_v1_ratelimit_proto protoreflect.FileDescriptor - -const file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc = "" + - "\n" + - "&proto/kit/ratelimit/v1/ratelimit.proto\x12\x10kit.ratelimit.v1\x1a google/protobuf/descriptor.proto\x1a\x1egoogle/protobuf/duration.proto\"\x8b\x01\n" + - "\tRateLimit\x12\x12\n" + - "\x04rate\x18\x01 \x01(\x05R\x04rate\x12+\n" + - "\x03per\x18\x02 \x01(\v2\x19.google.protobuf.DurationR\x03per\x12\x14\n" + - "\x05burst\x18\x03 \x01(\x05R\x05burst\x12'\n" + - "\x03key\x18\x04 \x01(\x0e2\x15.kit.ratelimit.v1.KeyR\x03key*]\n" + - "\x03Key\x12\x13\n" + - "\x0fKEY_UNSPECIFIED\x10\x00\x12\x11\n" + - "\rKEY_OPERATION\x10\x01\x12\x11\n" + - "\rKEY_CLIENT_IP\x10\x02\x12\x1b\n" + - "\x17KEY_OPERATION_CLIENT_IP\x10\x03:\\\n" + - "\n" + - "rate_limit\x12\x1e.google.protobuf.MethodOptions\x18\x91\xbc\x06 \x01(\v2\x1b.kit.ratelimit.v1.RateLimitR\trateLimitBBZ@github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1;ratelimitv1b\x06proto3" - -var ( - file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescOnce sync.Once - file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescData []byte -) - -func file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescGZIP() []byte { - file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescOnce.Do(func() { - file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc), len(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc))) - }) - return file_proto_kit_ratelimit_v1_ratelimit_proto_rawDescData -} - -var file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_proto_kit_ratelimit_v1_ratelimit_proto_goTypes = []any{ - (Key)(0), // 0: kit.ratelimit.v1.Key - (*RateLimit)(nil), // 1: kit.ratelimit.v1.RateLimit - (*durationpb.Duration)(nil), // 2: google.protobuf.Duration - (*descriptorpb.MethodOptions)(nil), // 3: google.protobuf.MethodOptions -} -var file_proto_kit_ratelimit_v1_ratelimit_proto_depIdxs = []int32{ - 2, // 0: kit.ratelimit.v1.RateLimit.per:type_name -> google.protobuf.Duration - 0, // 1: kit.ratelimit.v1.RateLimit.key:type_name -> kit.ratelimit.v1.Key - 3, // 2: kit.ratelimit.v1.rate_limit:extendee -> google.protobuf.MethodOptions - 1, // 3: kit.ratelimit.v1.rate_limit:type_name -> kit.ratelimit.v1.RateLimit - 4, // [4:4] is the sub-list for method output_type - 4, // [4:4] is the sub-list for method input_type - 3, // [3:4] is the sub-list for extension type_name - 2, // [2:3] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name -} - -func init() { file_proto_kit_ratelimit_v1_ratelimit_proto_init() } -func file_proto_kit_ratelimit_v1_ratelimit_proto_init() { - if File_proto_kit_ratelimit_v1_ratelimit_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc), len(file_proto_kit_ratelimit_v1_ratelimit_proto_rawDesc)), - NumEnums: 1, - NumMessages: 1, - NumExtensions: 1, - NumServices: 0, - }, - GoTypes: file_proto_kit_ratelimit_v1_ratelimit_proto_goTypes, - DependencyIndexes: file_proto_kit_ratelimit_v1_ratelimit_proto_depIdxs, - EnumInfos: file_proto_kit_ratelimit_v1_ratelimit_proto_enumTypes, - MessageInfos: file_proto_kit_ratelimit_v1_ratelimit_proto_msgTypes, - ExtensionInfos: file_proto_kit_ratelimit_v1_ratelimit_proto_extTypes, - }.Build() - File_proto_kit_ratelimit_v1_ratelimit_proto = out.File - file_proto_kit_ratelimit_v1_ratelimit_proto_goTypes = nil - file_proto_kit_ratelimit_v1_ratelimit_proto_depIdxs = nil -} diff --git a/proto/kit/ratelimit/v1/ratelimit.proto b/proto/kit/ratelimit/v1/ratelimit.proto deleted file mode 100644 index 8b33eef..0000000 --- a/proto/kit/ratelimit/v1/ratelimit.proto +++ /dev/null @@ -1,34 +0,0 @@ -syntax = "proto3"; - -package kit.ratelimit.v1; - -import "google/protobuf/descriptor.proto"; -import "google/protobuf/duration.proto"; - -option go_package = "github.com/crypto-zero/go-kit/proto/kit/ratelimit/v1;ratelimitv1"; - -extend google.protobuf.MethodOptions { - // rate_limit configures per-method rate limiting for Kratos middleware. - RateLimit rate_limit = 106001; -} - -message RateLimit { - // rate is the number of tokens replenished every per duration. - int32 rate = 1; - - // per is the refill window for rate tokens. - google.protobuf.Duration per = 2; - - // burst is the maximum number of tokens a key can accumulate. - int32 burst = 3; - - // key selects how requests are grouped into buckets. - Key key = 4; -} - -enum Key { - KEY_UNSPECIFIED = 0; - KEY_OPERATION = 1; - KEY_CLIENT_IP = 2; - KEY_OPERATION_CLIENT_IP = 3; -} diff --git a/query/convert.go b/query/convert.go index f3ad4b1..577e76d 100644 --- a/query/convert.go +++ b/query/convert.go @@ -1,6 +1,6 @@ package query -// ConvertList convert A list to B list +// ConvertList converts a slice of A values into B values. func ConvertList[A, B any](from []A, convert func(A) B) []B { results := make([]B, 0, len(from)) for _, v := range from { diff --git a/query/paging.go b/query/paging.go index 0d8193e..1beed02 100644 --- a/query/paging.go +++ b/query/paging.go @@ -1,13 +1,13 @@ package query var ( - // DefaultPageSize 默认每页条数 + // DefaultPageSize is the default number of items per page. DefaultPageSize int32 = 10 - // MaxPageSize 最大每页条数 + // MaxPageSize is the maximum number of items per page. MaxPageSize int32 = 1000 ) -// ResizePage 修正分页参数, 页码从0开始. +// ResizePage normalizes page parameters. Page numbers start at zero. func ResizePage(page, pageSize int32) (int32, int32) { if page < 0 { page = 0 diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go deleted file mode 100644 index a1472bf..0000000 --- a/ratelimit/ratelimit.go +++ /dev/null @@ -1,258 +0,0 @@ -// Package ratelimit provides an in-memory token-bucket rate limiter. -package ratelimit - -import ( - "context" - "errors" - "math" - "sync" - "time" -) - -const defaultKey = "global" - -var ( - // ErrInvalidConfig reports a limiter configuration that cannot be applied. - ErrInvalidConfig = errors.New("invalid ratelimit config") - - // DefaultConfig is suitable for service-level protection. - DefaultConfig = Config{ - Rate: 100, - Per: time.Minute, - Burst: 100, - MaxKeys: 10000, - } -) - -// Limit describes the token-bucket parameters used for one Allow call. -type Limit struct { - // Rate is the number of tokens replenished every Per duration. - Rate int - // Per is the refill window for Rate tokens. - Per time.Duration - // Burst is the maximum number of tokens a key can accumulate. - Burst int -} - -// Config controls token-bucket behavior. -type Config struct { - // Rate is the number of tokens replenished every Per duration. - Rate int - // Per is the refill window for Rate tokens. - Per time.Duration - // Burst is the maximum number of tokens a key can accumulate. - Burst int - // MaxKeys bounds the number of tracked buckets. Zero means unbounded. - MaxKeys int -} - -// Result describes the outcome of an Allow call. -type Result struct { - Allowed bool - Remaining int - RetryAfter time.Duration -} - -// Store persists token-bucket state. -// -// Implementations must apply the operation atomically for key. Distributed -// stores such as Redis should use ctx for cancellation and deadlines. -type Store interface { - Take(ctx context.Context, key string, now time.Time, limit Limit, n int) (Result, error) - Len(ctx context.Context) (int, error) -} - -// Option configures a Limiter. -type Option func(*Limiter) - -// WithNow sets the clock used by the limiter. -func WithNow(now func() time.Time) Option { - return func(l *Limiter) { - if now != nil { - l.now = now - } - } -} - -// WithStore sets the storage backend used by the limiter. -func WithStore(store Store) Option { - return func(l *Limiter) { - if store != nil { - l.store = store - } - } -} - -// Limiter applies token-bucket rate limits independently per key. -type Limiter struct { - limit Limit - store Store - now func() time.Time -} - -// New constructs a Limiter. -func New(cfg Config, opts ...Option) (*Limiter, error) { - if cfg.Rate <= 0 || cfg.Per <= 0 || cfg.Burst <= 0 || cfg.MaxKeys < 0 { - return nil, ErrInvalidConfig - } - l := &Limiter{ - limit: Limit{ - Rate: cfg.Rate, - Per: cfg.Per, - Burst: cfg.Burst, - }, - store: NewMemoryStore(cfg.MaxKeys), - now: time.Now, - } - for _, opt := range opts { - opt(l) - } - return l, nil -} - -// NewDefault constructs a Limiter with DefaultConfig. -func NewDefault(opts ...Option) *Limiter { - l, err := New(DefaultConfig, opts...) - if err != nil { - panic(err) - } - return l -} - -// Allow consumes one token for key if capacity is available. -func (l *Limiter) Allow(key string) Result { - return l.AllowN(key, 1) -} - -// AllowContext consumes one token for key if capacity is available. -func (l *Limiter) AllowContext(ctx context.Context, key string) (Result, error) { - return l.AllowNContext(ctx, key, 1) -} - -// AllowN consumes n tokens for key if capacity is available. -func (l *Limiter) AllowN(key string, n int) Result { - res, _ := l.AllowNContext(context.Background(), key, n) - return res -} - -// AllowNContext consumes n tokens for key if capacity is available. -func (l *Limiter) AllowNContext(ctx context.Context, key string, n int) (Result, error) { - if n <= 0 { - return Result{Allowed: true}, nil - } - if n > l.limit.Burst { - return Result{RetryAfter: retryAfter(l.limit, float64(n))}, nil - } - if key == "" { - key = defaultKey - } - return l.store.Take(ctx, key, l.now(), l.limit, n) -} - -// Len returns the number of tracked buckets when the store supports counting. -func (l *Limiter) Len() int { - n, _ := l.LenContext(context.Background()) - return n -} - -// LenContext returns the number of tracked buckets when the store supports counting. -func (l *Limiter) LenContext(ctx context.Context) (int, error) { - return l.store.Len(ctx) -} - -// MemoryStore stores token buckets in memory. -type MemoryStore struct { - mu sync.Mutex - buckets map[string]*bucket - maxKeys int -} - -type bucket struct { - tokens float64 - seen time.Time -} - -// NewMemoryStore constructs an in-memory Store. -func NewMemoryStore(maxKeys int) *MemoryStore { - return &MemoryStore{ - buckets: make(map[string]*bucket), - maxKeys: maxKeys, - } -} - -// Take consumes n tokens from key if capacity is available. -func (s *MemoryStore) Take(_ context.Context, key string, now time.Time, limit Limit, n int) (Result, error) { - s.mu.Lock() - defer s.mu.Unlock() - - b := s.bucketFor(key, now, limit.Burst) - refill(b, now, limit) - need := float64(n) - if b.tokens < need { - return Result{ - Remaining: int(math.Floor(b.tokens)), - RetryAfter: retryAfter(limit, need-b.tokens), - }, nil - } - b.tokens -= need - return Result{ - Allowed: true, - Remaining: int(math.Floor(b.tokens)), - }, nil -} - -// Len returns the number of tracked buckets. -func (s *MemoryStore) Len(_ context.Context) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - return len(s.buckets), nil -} - -func (s *MemoryStore) bucketFor(key string, now time.Time, burst int) *bucket { - if b, ok := s.buckets[key]; ok { - return b - } - if s.maxKeys > 0 && len(s.buckets) >= s.maxKeys { - s.evictOldest() - } - b := &bucket{ - tokens: float64(burst), - seen: now, - } - s.buckets[key] = b - return b -} - -func refill(b *bucket, now time.Time, limit Limit) { - elapsed := now.Sub(b.seen) - if elapsed <= 0 { - b.seen = now - return - } - rate := float64(limit.Rate) / limit.Per.Seconds() - b.tokens = math.Min(float64(limit.Burst), b.tokens+elapsed.Seconds()*rate) - b.seen = now -} - -func retryAfter(limit Limit, tokens float64) time.Duration { - perToken := time.Duration(float64(limit.Per) / float64(limit.Rate)) - d := time.Duration(math.Ceil(float64(perToken) * tokens)) - if d < 0 { - return 0 - } - return d -} - -func (s *MemoryStore) evictOldest() { - var ( - oldestKey string - oldest time.Time - ) - for key, b := range s.buckets { - if oldestKey == "" || b.seen.Before(oldest) { - oldestKey = key - oldest = b.seen - } - } - delete(s.buckets, oldestKey) -} diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go deleted file mode 100644 index 90e53f2..0000000 --- a/ratelimit/ratelimit_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package ratelimit - -import ( - "context" - "errors" - "testing" - "time" -) - -type recordingStore struct { - key string - n int - limit Limit -} - -func (s *recordingStore) Take(_ context.Context, key string, _ time.Time, limit Limit, n int) (Result, error) { - s.key = key - s.n = n - s.limit = limit - return Result{Allowed: true, Remaining: limit.Burst - n}, nil -} - -func (s *recordingStore) Len(context.Context) (int, error) { - return 0, nil -} - -func TestLimiterAllowsBurstThenRejects(t *testing.T) { - now := time.Unix(0, 0) - limiter, err := New(Config{Rate: 2, Per: time.Second, Burst: 2}, WithNow(func() time.Time { - return now - })) - if err != nil { - t.Fatalf("New: %v", err) - } - - if res := limiter.Allow("user-1"); !res.Allowed || res.Remaining != 1 { - t.Fatalf("first request = %+v, want allowed with one remaining", res) - } - if res := limiter.Allow("user-1"); !res.Allowed || res.Remaining != 0 { - t.Fatalf("second request = %+v, want allowed with zero remaining", res) - } - if res := limiter.Allow("user-1"); res.Allowed || res.RetryAfter != 500*time.Millisecond { - t.Fatalf("third request = %+v, want rejected with 500ms retry", res) - } -} - -func TestLimiterRefillsByElapsedTime(t *testing.T) { - now := time.Unix(0, 0) - limiter, err := New(Config{Rate: 2, Per: time.Second, Burst: 2}, WithNow(func() time.Time { - return now - })) - if err != nil { - t.Fatalf("New: %v", err) - } - - limiter.Allow("user-1") - limiter.Allow("user-1") - now = now.Add(500 * time.Millisecond) - - if res := limiter.Allow("user-1"); !res.Allowed || res.Remaining != 0 { - t.Fatalf("refilled request = %+v, want allowed with zero remaining", res) - } -} - -func TestLimiterSeparatesKeys(t *testing.T) { - limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}) - if err != nil { - t.Fatalf("New: %v", err) - } - - if res := limiter.Allow("user-1"); !res.Allowed { - t.Fatalf("user-1 first request = %+v, want allowed", res) - } - if res := limiter.Allow("user-1"); res.Allowed { - t.Fatalf("user-1 second request = %+v, want rejected", res) - } - if res := limiter.Allow("user-2"); !res.Allowed { - t.Fatalf("user-2 first request = %+v, want allowed", res) - } -} - -func TestLimiterEvictsOldestWhenMaxKeysReached(t *testing.T) { - now := time.Unix(0, 0) - limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1, MaxKeys: 2}, WithNow(func() time.Time { - return now - })) - if err != nil { - t.Fatalf("New: %v", err) - } - - limiter.Allow("a") - now = now.Add(time.Millisecond) - limiter.Allow("b") - now = now.Add(time.Millisecond) - limiter.Allow("c") - - if got := limiter.Len(); got != 2 { - t.Fatalf("Len() = %d, want 2", got) - } - if res := limiter.Allow("a"); !res.Allowed { - t.Fatalf("a should have been evicted and recreated with full burst, got %+v", res) - } -} - -func TestLimiterRejectsInvalidConfig(t *testing.T) { - _, err := New(Config{Rate: 0, Per: time.Second, Burst: 1}) - if !errors.Is(err, ErrInvalidConfig) { - t.Fatalf("New error = %v, want ErrInvalidConfig", err) - } -} - -func TestLimiterUsesStore(t *testing.T) { - store := &recordingStore{} - limiter, err := New(Config{Rate: 5, Per: time.Second, Burst: 10}, WithStore(store)) - if err != nil { - t.Fatalf("New: %v", err) - } - - res, err := limiter.AllowNContext(context.Background(), "tenant-1", 3) - if err != nil { - t.Fatalf("AllowNContext: %v", err) - } - if !res.Allowed || res.Remaining != 7 { - t.Fatalf("result = %+v, want allowed with 7 remaining", res) - } - if store.key != "tenant-1" || store.n != 3 { - t.Fatalf("store saw key=%q n=%d, want tenant-1 and 3", store.key, store.n) - } - if store.limit.Rate != 5 || store.limit.Per != time.Second || store.limit.Burst != 10 { - t.Fatalf("store limit = %+v, want configured limit", store.limit) - } -} diff --git a/s3/event.go b/s3/event.go index 09c0c6f..aa66022 100644 --- a/s3/event.go +++ b/s3/event.go @@ -6,7 +6,7 @@ import ( "time" ) -// EventName that wraps the event name +// EventName wraps an S3 event name. type EventName string const ( @@ -39,14 +39,14 @@ const ( EventS3ObjectTaggingDelete EventName = "s3:ObjectTagging:Delete" ) -// Event that wraps an array of EventRecord +// Event wraps a batch of S3 event records. type Event struct { EventName EventName `json:"EventName"` Key string `json:"Key"` Records []EventRecord `json:"Records"` } -// EventRecord which wrap record data +// EventRecord wraps S3 event record data. type EventRecord struct { EventVersion string `json:"eventVersion"` EventSource string `json:"eventSource"` @@ -60,19 +60,19 @@ type EventRecord struct { Source Source `json:"source"` } -// UserIdentity that wraps the principal ID +// UserIdentity wraps the principal ID. type UserIdentity struct { PrincipalID string `json:"principalId"` } -// RequestParameters that wraps the principal ID, region, and source IP address +// RequestParameters wraps the principal ID, region, and source IP address. type RequestParameters struct { PrincipalID string `json:"principalId"` Region string `json:"region"` SourceIPAddress string `json:"sourceIPAddress"` } -// Entity that wraps the bucket and object +// Entity wraps the bucket and object. type Entity struct { SchemaVersion string `json:"s3SchemaVersion"` ConfigurationID string `json:"configurationId"` @@ -80,7 +80,7 @@ type Entity struct { Object Object `json:"object"` } -// Bucket that wraps the bucket name, owner identity, and ARN +// Bucket wraps the bucket name, owner identity, and ARN. type Bucket struct { Name string `json:"name"` OwnerIdentity UserIdentity `json:"ownerIdentity"` @@ -88,7 +88,7 @@ type Bucket struct { } // Object that wraps the object key, size, ETag, content type, user metadata, -// version ID, sequencer, and URL-decoded key +// version ID, sequencer, and URL-decoded key. type Object struct { Key string `json:"key"` Size int64 `json:"size,omitempty"` @@ -113,7 +113,7 @@ func (o *Object) UnmarshalJSON(data []byte) error { return nil } -// Source that wraps the source IP address and user agent +// Source wraps the source IP address and user agent. type Source struct { Host string `json:"host"` Port string `json:"port"` diff --git a/s3/go.sum b/s3/go.sum index d78ded4..5af9eb9 100644 --- a/s3/go.sum +++ b/s3/go.sum @@ -8,19 +8,13 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= -github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM= github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw= github.com/minio/crc64nvme v1.1.0 h1:e/tAguZ+4cw32D+IO/8GSf5UVr9y+3eJcxZI2WOO/7Q= github.com/minio/crc64nvme v1.1.0/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= -github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI= -github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= github.com/minio/minio-go/v7 v7.0.97 h1:lqhREPyfgHTB/ciX8k2r8k0D93WaFqxbJX36UZq5occ= @@ -35,30 +29,14 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= -github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY= -github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/s3/s3.go b/s3/s3.go index e6f130a..2b3eccc 100644 --- a/s3/s3.go +++ b/s3/s3.go @@ -16,78 +16,116 @@ import ( "github.com/minio/minio-go/v7/pkg/credentials" ) -// S3 provides operations on s3 bucket +const ( + serviceAccountCAPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + serviceAccountTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" +) + +// S3 provides operations on an S3-compatible bucket. +// +// Deprecated: Consumers should prefer defining a narrow interface in the +// package that consumes S3 behavior. This broad interface remains for +// compatibility with existing go-kit users. type S3 interface { - // PresignGetURL returns a presigned url for get object operation - PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration, - ) (*url.URL, error) - // PresignPutURL returns a presigned url for put object operation - PresignPutURL(ctx context.Context, bucket, key, contentType, sha256 string, - size int, expire time.Duration) (*url.URL, http.Header, error) - // GetObject gets an object from bucket - GetObject(ctx context.Context, bucket, key string, opt minio.GetObjectOptions) ( - *minio.Object, error) - // PutObject uploads an object to bucket - PutObject(ctx context.Context, bucket, key, contentType string, size int, - body io.Reader, opts minio.PutObjectOptions) (minio.UploadInfo, error) - // CopyObject copies an object from srcKey to destKey - CopyObject(ctx context.Context, bucket, srcKey, destKey string) (out minio.UploadInfo, - err error) - // DeleteObject deletes an object from bucket + // PresignGetURL returns a presigned URL for a get-object operation. + PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration) (*url.URL, error) + // PresignPutURL returns a presigned URL and headers for a put-object operation. + PresignPutURL( + ctx context.Context, + bucket string, + key string, + contentType string, + sha256 string, + size int, + expire time.Duration, + ) (*url.URL, http.Header, error) + // GetObject gets an object from bucket. + GetObject(ctx context.Context, bucket, key string, opts minio.GetObjectOptions) (*minio.Object, error) + // PutObject uploads an object to bucket. + PutObject( + ctx context.Context, + bucket string, + key string, + contentType string, + size int, + body io.Reader, + opts minio.PutObjectOptions, + ) (minio.UploadInfo, error) + // CopyObject copies an object from srcKey to destKey in bucket. + CopyObject(ctx context.Context, bucket, srcKey, destKey string) (minio.UploadInfo, error) + // DeleteObject deletes an object from bucket. DeleteObject(ctx context.Context, bucket, key string) error - // StatObject stats an object in bucket + // StatObject stats an object in bucket. StatObject(ctx context.Context, bucket, key string) (minio.ObjectInfo, error) } -// MinioS3Impl provides operations on AWS/s3 and minio for implementing S3 interface +// MinioS3Impl provides operations on AWS S3 and MinIO. type MinioS3Impl struct { client *minio.Client } -func (m *MinioS3Impl) PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration, -) (out *url.URL, err error) { - if out, err = m.client.PresignedGetObject(ctx, bucket, key, expire, nil); err != nil { +var _ S3 = (*MinioS3Impl)(nil) + +// PresignGetURL returns a presigned URL for a get-object operation. +func (m *MinioS3Impl) PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration) (*url.URL, error) { + out, err := m.client.PresignedGetObject(ctx, bucket, key, expire, nil) + if err != nil { return nil, fmt.Errorf("failed to presign get object: %w", err) } - return + return out, nil } -func (m *MinioS3Impl) PresignPutURL(ctx context.Context, bucket, key, contentType, - sha256 string, size int, expire time.Duration, -) (out *url.URL, headers http.Header, err error) { - headers = http.Header{ +// PresignPutURL returns a presigned URL and headers for a put-object operation. +func (m *MinioS3Impl) PresignPutURL( + ctx context.Context, + bucket string, + key string, + contentType string, + sha256 string, + size int, + expire time.Duration, +) (*url.URL, http.Header, error) { + headers := http.Header{ "Content-Type": []string{contentType}, "Content-Length": []string{fmt.Sprint(size)}, "x-amz-checksum-sha256": []string{sha256}, } - out, err = m.client.PresignHeader(ctx, http.MethodPut, bucket, key, expire, nil, headers) + out, err := m.client.PresignHeader(ctx, http.MethodPut, bucket, key, expire, nil, headers) if err != nil { return nil, nil, fmt.Errorf("failed to presign put object: %w", err) } - return + return out, headers, nil } -func (m *MinioS3Impl) GetObject(ctx context.Context, bucket, key string, opts minio.GetObjectOptions, -) (out *minio.Object, err error) { - if out, err = m.client.GetObject(ctx, bucket, key, opts); err != nil { +// GetObject gets an object from bucket. +func (m *MinioS3Impl) GetObject(ctx context.Context, bucket, key string, opts minio.GetObjectOptions) (*minio.Object, error) { + out, err := m.client.GetObject(ctx, bucket, key, opts) + if err != nil { return nil, fmt.Errorf("failed to get object: %w", err) } - return + return out, nil } -func (m *MinioS3Impl) PutObject(ctx context.Context, bucket, key, contentType string, - size int, body io.Reader, opts minio.PutObjectOptions, -) (out minio.UploadInfo, err error) { +// PutObject uploads an object to bucket. +func (m *MinioS3Impl) PutObject( + ctx context.Context, + bucket string, + key string, + contentType string, + size int, + body io.Reader, + opts minio.PutObjectOptions, +) (minio.UploadInfo, error) { opts.ContentType = contentType - if out, err = m.client.PutObject(ctx, bucket, key, body, int64(size), opts); err != nil { + out, err := m.client.PutObject(ctx, bucket, key, body, int64(size), opts) + if err != nil { return out, fmt.Errorf("failed to put object: %w", err) } - return + return out, nil } -func (m *MinioS3Impl) CopyObject(ctx context.Context, bucket, srcKey, destKey string) ( - out minio.UploadInfo, err error, -) { +// CopyObject copies an object from srcKey to destKey in bucket. +func (m *MinioS3Impl) CopyObject(ctx context.Context, bucket, srcKey, destKey string) (minio.UploadInfo, error) { copySourceOpts := minio.CopySrcOptions{ Bucket: bucket, Object: srcKey, @@ -96,12 +134,14 @@ func (m *MinioS3Impl) CopyObject(ctx context.Context, bucket, srcKey, destKey st Bucket: bucket, Object: destKey, } - if out, err = m.client.CopyObject(ctx, copyDestOpts, copySourceOpts); err != nil { + out, err := m.client.CopyObject(ctx, copyDestOpts, copySourceOpts) + if err != nil { return out, fmt.Errorf("failed to copy object: %w", err) } - return + return out, nil } +// DeleteObject deletes an object from bucket. func (m *MinioS3Impl) DeleteObject(ctx context.Context, bucket, key string) error { if err := m.client.RemoveObject(ctx, bucket, key, minio.RemoveObjectOptions{}); err != nil { return fmt.Errorf("failed to delete object: %w", err) @@ -109,6 +149,7 @@ func (m *MinioS3Impl) DeleteObject(ctx context.Context, bucket, key string) erro return nil } +// StatObject stats an object in bucket. func (m *MinioS3Impl) StatObject(ctx context.Context, bucket, key string) (minio.ObjectInfo, error) { info, err := m.client.StatObject(ctx, bucket, key, minio.StatObjectOptions{}) if IsNoSuchKeyErr(err) { @@ -120,7 +161,9 @@ func (m *MinioS3Impl) StatObject(ctx context.Context, bucket, key string) (minio return info, nil } -// NewMinioS3Impl creates a new MinioS3Impl +// NewMinioS3Impl creates a new MinioS3Impl. +// +// It returns S3 for backward compatibility with earlier releases. func NewMinioS3Impl(endpoint, accessKeyID, secretAccessKey, sessionToken string) (S3, error) { return NewMinioS3ImplWithSTS(endpoint, &credentials.Static{ Value: credentials.Value{ @@ -132,7 +175,9 @@ func NewMinioS3Impl(endpoint, accessKeyID, secretAccessKey, sessionToken string) }) } -// NewMinioS3ImplWithSTS creates a new MinioS3Impl with STSProvider +// NewMinioS3ImplWithSTS creates a new MinioS3Impl with STSProvider. +// +// It returns S3 for backward compatibility with earlier releases. func NewMinioS3ImplWithSTS(endpoint string, sts STSProvider) (S3, error) { uri, err := url.Parse(endpoint) if err != nil { @@ -155,19 +200,19 @@ func NewMinioS3ImplWithSTS(endpoint string, sts STSProvider) (S3, error) { return &MinioS3Impl{client: c}, nil } -// DefaultSTSTokenExpirySeconds is the default expiry duration for STS token +// DefaultSTSTokenExpirySeconds is the default expiry duration for an STS token. const DefaultSTSTokenExpirySeconds = 3 * 24 * 60 * 60 // Three days -// STSProvider provides temporary credentials +// STSProvider provides temporary credentials. type STSProvider = credentials.Provider -// WindowedSTSIdentityProvider provides temporary credentials with a windowed expiry +// WindowedSTSIdentityProvider provides temporary credentials with a windowed expiry. type WindowedSTSIdentityProvider struct { Window time.Duration *credentials.STSWebIdentity } -// Retrieve returns the credential value +// Retrieve returns the credential value. func (w *WindowedSTSIdentityProvider) Retrieve() (credentials.Value, error) { value, err := w.STSWebIdentity.Retrieve() if err != nil { @@ -180,18 +225,18 @@ func (w *WindowedSTSIdentityProvider) Retrieve() (credentials.Value, error) { // NewMinioSTSProviderImpl creates a new instance of the STSProvider. func NewMinioSTSProviderImpl(endpoint string, expirySeconds int, expiryWindow time.Duration, ) (STSProvider, error) { - // Read kubernetes service account ca certificate file - caCert, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/ca.crt") + if expirySeconds <= 0 { + return nil, fmt.Errorf("sts token expiry seconds must be positive") + } + caCert, err := os.ReadFile(serviceAccountCAPath) if err != nil { return nil, fmt.Errorf("failed to read service account ca certificate: %w", err) } - // Read kubernetes service account token file - token, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/token") + token, err := os.ReadFile(serviceAccountTokenPath) if err != nil { return nil, fmt.Errorf("failed to read service account token: %w", err) } - // Create an HttpTransport with the service account token and ca certificate transport, err := minio.DefaultTransport(true) if err != nil { return nil, fmt.Errorf("failed to create minio transport: %w", err) @@ -207,7 +252,6 @@ func NewMinioSTSProviderImpl(endpoint string, expirySeconds int, expiryWindow ti return nil, fmt.Errorf("failed to append kubernetes service account ca certificate") } - // Create sts credentials credential := &credentials.STSWebIdentity{ Client: &http.Client{Transport: transport}, STSEndpoint: endpoint, @@ -222,7 +266,7 @@ func NewMinioSTSProviderImpl(endpoint string, expirySeconds int, expiryWindow ti return &WindowedSTSIdentityProvider{Window: expiryWindow, STSWebIdentity: credential}, nil } -// IsNoSuchKeyErr checks if the error is a NoSuchKey error +// IsNoSuchKeyErr checks if the error is a NoSuchKey error. func IsNoSuchKeyErr(err error) bool { if minioError := minio.ToErrorResponse(err); minioError.Code == "NoSuchKey" { return true @@ -230,4 +274,5 @@ func IsNoSuchKeyErr(err error) bool { return false } +// ErrNoSuchKey reports a missing S3 object. var ErrNoSuchKey = errors.New("no such key") diff --git a/s3/s3_minio_test.go b/s3/s3_minio_test.go index 5aad2e8..50984ac 100644 --- a/s3/s3_minio_test.go +++ b/s3/s3_minio_test.go @@ -33,6 +33,13 @@ func TestMinio(t *testing.T) { }) } +func TestNewMinioSTSProviderImplRejectsMissingExpiry(t *testing.T) { + _, err := NewMinioSTSProviderImpl("https://sts.example.com", 0, time.Minute) + if err == nil { + t.Fatal("NewMinioSTSProviderImpl error = nil, want error") + } +} + type TestMinioSuite struct { suite.Suite diff --git a/snowflake/snowflake.go b/snowflake/snowflake.go index cc075c3..cb2374e 100644 --- a/snowflake/snowflake.go +++ b/snowflake/snowflake.go @@ -7,17 +7,16 @@ import ( ) func init() { - // change epoch from 2024-01-01 and 42 time bits - // approximately 139 years 5 months 18 days + // Use a 2024-01-01 epoch with 42 time bits, covering roughly 139 years. snowflake.Epoch = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli() - // node used 9 bits approximately 512 nodes + // Use 9 node bits, supporting roughly 512 nodes. snowflake.NodeBits = 9 } -// Node is a snowflake Node +// Node is a snowflake node. type Node = snowflake.Node -// NewNode returns a new snowflake Node +// NewNode returns a new snowflake node. func NewNode(node int) *snowflake.Node { n, err := snowflake.NewNode(int64(node)) if err != nil { diff --git a/text/text.go b/text/text.go index 603856b..95b9953 100644 --- a/text/text.go +++ b/text/text.go @@ -7,31 +7,26 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "io" "math" "math/big" mrand "math/rand" - "os" "strings" "unicode" "golang.org/x/text/width" ) -// GeneratePassword generate password from /dev/urandom. +// GeneratePassword generates a password from cryptographically secure random bytes. func GeneratePassword(size int, accept func(byte) bool) (string, error) { - f, err := os.Open("/dev/urandom") - if err != nil { - return "", fmt.Errorf("open /dev/urandom: %w", err) - } password := make([]byte, 0, size) for len(password) < size { buf := make([]byte, size*2) - n, err := f.Read(buf) + n, err := io.ReadFull(crand.Reader, buf) if err != nil { - return "", fmt.Errorf("read /dev/urandom: %w", err) + return "", fmt.Errorf("read random bytes: %w", err) } for idx := 0; idx < n; idx++ { - // Ascii printable characters if accept(buf[idx]) && len(password) < size { password = append(password, buf[idx]) } @@ -40,7 +35,7 @@ func GeneratePassword(size int, accept func(byte) bool) (string, error) { return string(password), nil } -// GeneratePasswordLitterNumbers generate password with litter and numbers. +// GeneratePasswordLitterNumbers generates a password with letters and numbers. func GeneratePasswordLitterNumbers(size int) (string, error) { return GeneratePassword(size, func(b byte) bool { return b >= '0' && b <= '9' || b >= 'A' && b <= 'Z' || b >= 'a' && b <= 'z' @@ -64,8 +59,8 @@ func RandString(length int) string { // maxInt64 is the maximum value of int64. var maxInt64 = big.NewInt(math.MaxInt64) -// RandStringWithCharset returns a random string with given length and charset. -// it uses crypto/rand to generate random string. +// RandStringWithCharset returns a random string with the given length and charset. +// It uses crypto/rand when available. func RandStringWithCharset(length int, charset string) string { var seed int64 if err := binary.Read(crand.Reader, binary.BigEndian, &seed); err != nil { @@ -99,17 +94,17 @@ func CleanAllSpace(s string) string { }, s) } -// NarrowString 全角转半角 +// NarrowString converts full-width characters to half-width characters. func NarrowString(s string) string { return width.Narrow.String(s) } -// CleanString clean all space and narrow string +// CleanString removes all spaces and narrows full-width characters. func CleanString(s string) string { return CleanAllSpace(NarrowString(strings.ToValidUTF8(s, ""))) } -// TrimString trim prefix and suffix space and narrow string +// TrimString trims surrounding spaces and narrows full-width characters. func TrimString(s string) string { return strings.TrimSpace(NarrowString(strings.ToValidUTF8(s, ""))) } From 722303f028fa0f625eee5ead8b58078c91712a01 Mon Sep 17 00:00:00 2001 From: skinny <39215611+crazyskinny@users.noreply.github.com> Date: Thu, 28 May 2026 17:27:21 +0800 Subject: [PATCH 11/11] refactor: modernize Go idioms --- .../internal/redact/benchmark_test.go | 30 +++++++------------ .../internal/redact/integration_test.go | 6 ++-- cmd/protoc-gen-go-sse/main_test.go | 20 ++++++------- kent/encrypt_test.go | 6 ++-- kent/ordering.go | 3 +- kratos/auth/policy_test.go | 28 ++++++++--------- kratos/internal/protoop/protoop.go | 4 +-- kratos/ratelimit/ratelimit_test.go | 6 ++-- kratos/ratelimit/redis/store.go | 11 ++----- otel/trace_provider.go | 2 +- pgx/pgx.go | 2 +- sse/sse_test.go | 2 +- text/text.go | 2 +- 13 files changed, 53 insertions(+), 69 deletions(-) diff --git a/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go b/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go index 360e6e2..e742a90 100644 --- a/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go +++ b/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go @@ -22,8 +22,7 @@ func BenchmarkRedact_SimpleMessage(b *testing.B) { Age: 30, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = user.Redact() } } @@ -49,8 +48,7 @@ func BenchmarkRedact_NestedMessage(b *testing.B) { }, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = account.Redact() } } @@ -69,8 +67,7 @@ func BenchmarkRedact_DeepNested(b *testing.B) { }, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = level1.Redact() } } @@ -110,15 +107,14 @@ func BenchmarkRedact_AllScalarTypes(b *testing.B) { RedactBytes: []byte("secret-bytes"), } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = scalar.Redact() } } func BenchmarkRedact_RepeatedMessages(b *testing.B) { users := make([]*testdata.User, 100) - for i := 0; i < 100; i++ { + for i := range 100 { users[i] = &testdata.User{ Id: "user-" + string(rune(i)), Name: "User " + string(rune(i)), @@ -131,20 +127,19 @@ func BenchmarkRedact_RepeatedMessages(b *testing.B) { Users: users, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = repeated.Redact() } } func BenchmarkRedact_MapWithStringKey(b *testing.B) { stringMap := make(map[string]string, 100) - for i := 0; i < 100; i++ { + for i := range 100 { stringMap["key"+string(rune(i))] = "value" + string(rune(i)) } userMap := make(map[string]*testdata.User, 10) - for i := 0; i < 10; i++ { + for i := range 10 { userMap["user"+string(rune(i))] = &testdata.User{ Id: "id-" + string(rune(i)), Email: "email@test.com", @@ -156,8 +151,7 @@ func BenchmarkRedact_MapWithStringKey(b *testing.B) { UserMap: userMap, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = m.Redact() } } @@ -195,8 +189,7 @@ func BenchmarkRedact_ComplexMessage(b *testing.B) { SecretExtra: &testdata.ComplexMessage_SecretNote{SecretNote: "secret note"}, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = complex.Redact() } } @@ -204,8 +197,7 @@ func BenchmarkRedact_ComplexMessage(b *testing.B) { func BenchmarkRedact_NilMessage(b *testing.B) { var user *testdata.User - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = user.Redact() } } diff --git a/cmd/protoc-gen-go-redact/internal/redact/integration_test.go b/cmd/protoc-gen-go-redact/internal/redact/integration_test.go index 9a0dfdb..6f43127 100644 --- a/cmd/protoc-gen-go-redact/internal/redact/integration_test.go +++ b/cmd/protoc-gen-go-redact/internal/redact/integration_test.go @@ -639,7 +639,7 @@ func TestSpecialCharacters_Redact(t *testing.T) { t.Logf("SpecialCharacters.Redact(): %s", result) // Verify the output is valid JSON - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal([]byte(result), &parsed); err != nil { t.Errorf("Result should be valid JSON: %v", err) } @@ -853,7 +853,7 @@ func TestCustomMaskTypes_AllFields(t *testing.T) { assertContains(t, result, "publicStatus", "2") // Verify it's valid JSON - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal([]byte(result), &parsed); err != nil { t.Errorf("Result should be valid JSON: %v", err) } @@ -879,7 +879,7 @@ func TestAllMessages_ValidJSON(t *testing.T) { for i, msg := range messages { result := msg.Redact() - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal([]byte(result), &parsed); err != nil { t.Errorf("Message %d produced invalid JSON: %v\nResult: %s", i, err, result) } diff --git a/cmd/protoc-gen-go-sse/main_test.go b/cmd/protoc-gen-go-sse/main_test.go index cbb96a0..1f5dc47 100644 --- a/cmd/protoc-gen-go-sse/main_test.go +++ b/cmd/protoc-gen-go-sse/main_test.go @@ -85,23 +85,23 @@ func testCodeGeneratorRequest(t *testing.T) *pluginpb.CodeGeneratorRequest { return &pluginpb.CodeGeneratorRequest{ FileToGenerate: []string{"test/v1/live.proto"}, ProtoFile: []*descriptorpb.FileDescriptorProto{{ - Syntax: proto.String("proto3"), - Name: proto.String("test/v1/live.proto"), - Package: proto.String("test.v1"), + Syntax: new("proto3"), + Name: new("test/v1/live.proto"), + Package: new("test.v1"), Options: &descriptorpb.FileOptions{ - GoPackage: proto.String("example.com/test/v1;testv1"), + GoPackage: new("example.com/test/v1;testv1"), }, MessageType: []*descriptorpb.DescriptorProto{{ - Name: proto.String("WatchRequest"), + Name: new("WatchRequest"), }, { - Name: proto.String("WatchResponse"), + Name: new("WatchResponse"), }}, Service: []*descriptorpb.ServiceDescriptorProto{{ - Name: proto.String("LiveService"), + Name: new("LiveService"), Method: []*descriptorpb.MethodDescriptorProto{{ - Name: proto.String("Watch"), - InputType: proto.String(".test.v1.WatchRequest"), - OutputType: proto.String(".test.v1.WatchResponse"), + Name: new("Watch"), + InputType: new(".test.v1.WatchRequest"), + OutputType: new(".test.v1.WatchResponse"), Options: opts, }}, }}, diff --git a/kent/encrypt_test.go b/kent/encrypt_test.go index cd8e0d7..8323922 100644 --- a/kent/encrypt_test.go +++ b/kent/encrypt_test.go @@ -216,7 +216,7 @@ func TestEncryptDecrypt_MultipleRoundTrips(t *testing.T) { plaintext := "test message" // Perform multiple encrypt-decrypt cycles - for i := 0; i < 100; i++ { + for i := range 100 { ciphertext, err := encryptor.Encrypt(plaintext) if err != nil { t.Fatalf("Encrypt() iteration %d error = %v", i, err) @@ -302,7 +302,7 @@ func TestEncryptDecrypt_Concurrent(t *testing.T) { done := make(chan bool, iterations) // Concurrent encryption - for i := 0; i < iterations; i++ { + for i := range iterations { go func(id int) { ciphertext, err := encryptor.Encrypt(plaintext) if err != nil { @@ -330,7 +330,7 @@ func TestEncryptDecrypt_Concurrent(t *testing.T) { // Wait for all goroutines successCount := 0 - for i := 0; i < iterations; i++ { + for range iterations { if <-done { successCount++ } diff --git a/kent/ordering.go b/kent/ordering.go index 4ae453d..db0eef8 100644 --- a/kent/ordering.go +++ b/kent/ordering.go @@ -19,8 +19,7 @@ func ProcessOrdering(orderBy string, fieldMap map[string]string, defaultOrdering return defaultOrdering } - orderByTerms := strings.Split(orderBy, ",") - for _, term := range orderByTerms { + for term := range strings.SplitSeq(orderBy, ",") { if term == "" { continue } diff --git a/kratos/auth/policy_test.go b/kratos/auth/policy_test.go index 80b6a02..df2e165 100644 --- a/kratos/auth/policy_test.go +++ b/kratos/auth/policy_test.go @@ -13,30 +13,30 @@ func TestOperationPolicyRegistersPublicProtoMethods(t *testing.T) { publicOpts := &descriptorpb.MethodOptions{} proto.SetExtension(publicOpts, authv1.E_Public, true) fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ - Syntax: proto.String("proto3"), - Name: proto.String("test/auth/v1/service.proto"), - Package: proto.String("test.auth.v1"), + Syntax: new("proto3"), + Name: new("test/auth/v1/service.proto"), + Package: new("test.auth.v1"), Service: []*descriptorpb.ServiceDescriptorProto{{ - Name: proto.String("AuthService"), + Name: new("AuthService"), Method: []*descriptorpb.MethodDescriptorProto{ { - Name: proto.String("Login"), - InputType: proto.String(".test.auth.v1.LoginRequest"), - OutputType: proto.String(".test.auth.v1.LoginResponse"), + Name: new("Login"), + InputType: new(".test.auth.v1.LoginRequest"), + OutputType: new(".test.auth.v1.LoginResponse"), Options: publicOpts, }, { - Name: proto.String("Profile"), - InputType: proto.String(".test.auth.v1.ProfileRequest"), - OutputType: proto.String(".test.auth.v1.ProfileResponse"), + Name: new("Profile"), + InputType: new(".test.auth.v1.ProfileRequest"), + OutputType: new(".test.auth.v1.ProfileResponse"), }, }, }}, MessageType: []*descriptorpb.DescriptorProto{ - {Name: proto.String("LoginRequest")}, - {Name: proto.String("LoginResponse")}, - {Name: proto.String("ProfileRequest")}, - {Name: proto.String("ProfileResponse")}, + {Name: new("LoginRequest")}, + {Name: new("LoginResponse")}, + {Name: new("ProfileRequest")}, + {Name: new("ProfileResponse")}, }, }, nil) if err != nil { diff --git a/kratos/internal/protoop/protoop.go b/kratos/internal/protoop/protoop.go index b8a9107..636bf01 100644 --- a/kratos/internal/protoop/protoop.go +++ b/kratos/internal/protoop/protoop.go @@ -12,9 +12,9 @@ import ( func WalkMethods(files []protoreflect.FileDescriptor, fn func(protoreflect.MethodDescriptor)) { for _, fd := range files { services := fd.Services() - for i := 0; i < services.Len(); i++ { + for i := range services.Len() { methods := services.Get(i).Methods() - for j := 0; j < methods.Len(); j++ { + for j := range methods.Len() { fn(methods.Get(j)) } } diff --git a/kratos/ratelimit/ratelimit_test.go b/kratos/ratelimit/ratelimit_test.go index 8838366..b8f85d1 100644 --- a/kratos/ratelimit/ratelimit_test.go +++ b/kratos/ratelimit/ratelimit_test.go @@ -222,7 +222,7 @@ func TestServerUsesOperationRules(t *testing.T) { if _, err := wrapped(fastCtx, nil); !errors.Is(err, ErrLimitExceed) { t.Fatalf("fast second request error = %v, want ErrLimitExceed from operation policy", err) } - for i := 0; i < 3; i++ { + for i := range 3 { if _, err := wrapped(slowCtx, nil); err != nil { t.Fatalf("slow request %d error = %v, want default limiter", i+1, err) } @@ -328,13 +328,13 @@ func TestServerMultiRuleConsumptionIsAtomic(t *testing.T) { if _, err := user1(ctx, nil); err != nil { t.Fatalf("first request error = %v, want nil", err) } - for i := 0; i < 200; i++ { + for i := range 200 { if _, err := user1(ctx, nil); !errors.Is(err, ErrLimitExceed) { t.Fatalf("rejected request %d error = %v, want ErrLimitExceed", i+2, err) } } - for i := 0; i < 99; i++ { + for i := range 99 { user := buildServer(fmt.Sprintf("user-%d", i+2)) if _, err := user(ctx, nil); err != nil { t.Fatalf("user-2 request %d error = %v, want IP bucket still has 99 tokens", i+1, err) diff --git a/kratos/ratelimit/redis/store.go b/kratos/ratelimit/redis/store.go index 5e08ac6..8f403cd 100644 --- a/kratos/ratelimit/redis/store.go +++ b/kratos/ratelimit/redis/store.go @@ -186,7 +186,7 @@ func parseResults(values []any, want int) ([]ratelimit.Result, error) { return nil, fmt.Errorf("%w: got %d values, want %d", ErrInvalidScriptResult, len(values), want*scriptResultWidth) } results := make([]ratelimit.Result, want) - for i := 0; i < want; i++ { + for i := range want { res, err := parseResult(values[i*scriptResultWidth : (i+1)*scriptResultWidth]) if err != nil { return nil, err @@ -239,14 +239,7 @@ func int64Value(v any) (int64, error) { func ttl(limit ratelimit.Limit) int64 { perMs := durationMillis(limit.Per) refillFullMs := int64(math.Ceil(float64(limit.Burst) * float64(perMs) / float64(limit.Rate))) - expireMillis := 2 * perMs - if refillFullMs > expireMillis { - expireMillis = refillFullMs - } - if expireMillis < minBucketTTLMillis { - expireMillis = minBucketTTLMillis - } - return expireMillis + return max(max(2*perMs, refillFullMs), minBucketTTLMillis) } func durationMillis(d time.Duration) int64 { diff --git a/otel/trace_provider.go b/otel/trace_provider.go index 24a0f61..4f00c23 100644 --- a/otel/trace_provider.go +++ b/otel/trace_provider.go @@ -84,7 +84,7 @@ func NewTraceProvider(c *TraceProviderConfig) ( semconv.K8SNamespaceName(kubernetes.GetCurrentNamespace()), } if resourceInEnv := os.Getenv("OTEL_RESOURCE_ATTRIBUTES"); resourceInEnv != "" { - for _, attr := range strings.Split(resourceInEnv, ",") { + for attr := range strings.SplitSeq(resourceInEnv, ",") { parts := strings.Split(attr, "=") if len(parts) == 2 { attrs = append(attrs, attribute.String(parts[0], parts[1])) diff --git a/pgx/pgx.go b/pgx/pgx.go index a22532b..d951e67 100644 --- a/pgx/pgx.go +++ b/pgx/pgx.go @@ -181,7 +181,7 @@ func guessingScan[T any](src any) (value T, err error) { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(src)) + bufSrc = fmt.Append(nil, src) } } diff --git a/sse/sse_test.go b/sse/sse_test.go index d021f5b..8441820 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -226,7 +226,7 @@ func TestStream_Heartbeat(t *testing.T) { // Concurrent writes: ensure mutex prevents interleaved frames. done := make(chan struct{}) go func() { - for i := 0; i < 20; i++ { + for range 20 { _ = s.Write("token") time.Sleep(time.Millisecond) } diff --git a/text/text.go b/text/text.go index 95b9953..8449b60 100644 --- a/text/text.go +++ b/text/text.go @@ -26,7 +26,7 @@ func GeneratePassword(size int, accept func(byte) bool) (string, error) { if err != nil { return "", fmt.Errorf("read random bytes: %w", err) } - for idx := 0; idx < n; idx++ { + for idx := range n { if accept(buf[idx]) && len(password) < size { password = append(password, buf[idx]) }