diff --git a/.gitignore b/.gitignore index 0981e25..a432141 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,8 @@ # Dependency directories (remove the comment below to include it) vendor/ -# Go workspace file -go.work +# Go workspace checksum file +go.work.sum # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o @@ -37,4 +37,4 @@ bin/ logging/protoc-gen-go-redact/protoc-gen-go-redact # cursor -.cursor \ No newline at end of file +.cursor diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..fcb1764 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,56 @@ +version: "3" + +vars: + GO_MODULES: + sh: find . -name go.mod -not -path './.git/*' -exec dirname {} \; | sort + +tasks: + default: + desc: Run all repository verification checks. + cmds: + - task: all + + all: + desc: Run all repository verification checks. + deps: + - test + - vet + - proto + + test: + desc: Run tests for all Go modules. + cmds: + - for: { var: GO_MODULES } + task: test:module + vars: + MODULE: "{{.ITEM}}" + + test:module: + internal: true + dir: "{{.MODULE}}" + cmds: + - go test ./... + + vet: + desc: Run go vet for all Go modules. + cmds: + - for: { var: GO_MODULES } + task: vet:module + vars: + MODULE: "{{.ITEM}}" + + vet:module: + internal: true + dir: "{{.MODULE}}" + cmds: + - go vet ./... + + proto: + desc: Lint proto files with buf. + cmds: + - buf lint + + verify: + desc: Run all repository verification checks. + cmds: + - task: all diff --git a/buf.lock b/buf.lock new file mode 100644 index 0000000..8447589 --- /dev/null +++ b/buf.lock @@ -0,0 +1,6 @@ +# Generated by buf. DO NOT EDIT. +version: v2 +deps: + - name: buf.build/googleapis/googleapis + commit: c17df5b2beca46928cc87d5656bd5343 + digest: b5:648a01e0170d4512dea7d564016165decd1ed6e34bef79fe54753e51ad7e27545709ad9157d7551270147d551155c595a2fb0bf5bb33b1c83040ddbce915c604 diff --git a/proto/buf.yaml b/buf.yaml similarity index 70% rename from proto/buf.yaml rename to buf.yaml index c453ea7..442e551 100644 --- a/proto/buf.yaml +++ b/buf.yaml @@ -1,5 +1,7 @@ version: v2 -name: buf.build/crypto-zero/go-kit +modules: + - path: proto + name: buf.build/crypto-zero/go-kit deps: - buf.build/googleapis/googleapis:main lint: @@ -10,4 +12,4 @@ lint: breaking: except: - EXTENSION_NO_DELETE - - FIELD_SAME_DEFAULT \ No newline at end of file + - FIELD_SAME_DEFAULT diff --git a/logging/protoc-gen-go-redact/README.md b/cmd/protoc-gen-go-redact/README.md similarity index 98% rename from logging/protoc-gen-go-redact/README.md rename to cmd/protoc-gen-go-redact/README.md index 8869fec..8c32920 100644 --- a/logging/protoc-gen-go-redact/README.md +++ b/cmd/protoc-gen-go-redact/README.md @@ -16,13 +16,13 @@ A protoc plugin that generates `Redact()` methods for Protocol Buffer messages t ## Installation ```bash -go install github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact@latest +go install github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact@latest ``` Or build from source: ```bash -cd logging/protoc-gen-go-redact +cd cmd/protoc-gen-go-redact go build -o protoc-gen-go-redact . ``` @@ -273,7 +273,7 @@ type Redacter interface { Use with the logging middleware: ```go -import "github.com/crypto-zero/go-kit/logging/kratos" +import "github.com/crypto-zero/go-kit/kratos/logging" // Server middleware srv := http.NewServer( diff --git a/logging/protoc-gen-go-redact/internal/redact/benchmark_test.go b/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go similarity index 89% rename from logging/protoc-gen-go-redact/internal/redact/benchmark_test.go rename to cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go index db8a90d..e742a90 100644 --- a/logging/protoc-gen-go-redact/internal/redact/benchmark_test.go +++ b/cmd/protoc-gen-go-redact/internal/redact/benchmark_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata" + "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -22,8 +22,7 @@ func BenchmarkRedact_SimpleMessage(b *testing.B) { Age: 30, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = user.Redact() } } @@ -49,8 +48,7 @@ func BenchmarkRedact_NestedMessage(b *testing.B) { }, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = account.Redact() } } @@ -69,8 +67,7 @@ func BenchmarkRedact_DeepNested(b *testing.B) { }, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = level1.Redact() } } @@ -110,15 +107,14 @@ func BenchmarkRedact_AllScalarTypes(b *testing.B) { RedactBytes: []byte("secret-bytes"), } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = scalar.Redact() } } func BenchmarkRedact_RepeatedMessages(b *testing.B) { users := make([]*testdata.User, 100) - for i := 0; i < 100; i++ { + for i := range 100 { users[i] = &testdata.User{ Id: "user-" + string(rune(i)), Name: "User " + string(rune(i)), @@ -131,20 +127,19 @@ func BenchmarkRedact_RepeatedMessages(b *testing.B) { Users: users, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = repeated.Redact() } } func BenchmarkRedact_MapWithStringKey(b *testing.B) { stringMap := make(map[string]string, 100) - for i := 0; i < 100; i++ { + for i := range 100 { stringMap["key"+string(rune(i))] = "value" + string(rune(i)) } userMap := make(map[string]*testdata.User, 10) - for i := 0; i < 10; i++ { + for i := range 10 { userMap["user"+string(rune(i))] = &testdata.User{ Id: "id-" + string(rune(i)), Email: "email@test.com", @@ -156,8 +151,7 @@ func BenchmarkRedact_MapWithStringKey(b *testing.B) { UserMap: userMap, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = m.Redact() } } @@ -195,8 +189,7 @@ func BenchmarkRedact_ComplexMessage(b *testing.B) { SecretExtra: &testdata.ComplexMessage_SecretNote{SecretNote: "secret note"}, } - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = complex.Redact() } } @@ -204,8 +197,7 @@ func BenchmarkRedact_ComplexMessage(b *testing.B) { func BenchmarkRedact_NilMessage(b *testing.B) { var user *testdata.User - b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = user.Redact() } } diff --git a/logging/protoc-gen-go-redact/internal/redact/integration_test.go b/cmd/protoc-gen-go-redact/internal/redact/integration_test.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/integration_test.go rename to cmd/protoc-gen-go-redact/internal/redact/integration_test.go index e7fbff2..6f43127 100644 --- a/logging/protoc-gen-go-redact/internal/redact/integration_test.go +++ b/cmd/protoc-gen-go-redact/internal/redact/integration_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata" + "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -639,7 +639,7 @@ func TestSpecialCharacters_Redact(t *testing.T) { t.Logf("SpecialCharacters.Redact(): %s", result) // Verify the output is valid JSON - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal([]byte(result), &parsed); err != nil { t.Errorf("Result should be valid JSON: %v", err) } @@ -853,7 +853,7 @@ func TestCustomMaskTypes_AllFields(t *testing.T) { assertContains(t, result, "publicStatus", "2") // Verify it's valid JSON - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal([]byte(result), &parsed); err != nil { t.Errorf("Result should be valid JSON: %v", err) } @@ -879,7 +879,7 @@ func TestAllMessages_ValidJSON(t *testing.T) { for i, msg := range messages { result := msg.Redact() - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal([]byte(result), &parsed); err != nil { t.Errorf("Message %d produced invalid JSON: %v\nResult: %s", i, err, result) } diff --git a/logging/protoc-gen-go-redact/internal/redact/redact.go b/cmd/protoc-gen-go-redact/internal/redact/redact.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/redact.go rename to cmd/protoc-gen-go-redact/internal/redact/redact.go diff --git a/logging/protoc-gen-go-redact/internal/redact/redact_test.go b/cmd/protoc-gen-go-redact/internal/redact/redact_test.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/redact_test.go rename to cmd/protoc-gen-go-redact/internal/redact/redact_test.go diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go similarity index 98% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go index 48d2c7b..7ec6c0d 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.1 // source: crossfile/container.proto package crossfile @@ -238,7 +238,7 @@ const file_crossfile_container_proto_rawDesc = "" + "\bdata_map\x18\x01 \x03(\v2-.testdata.crossfile.MapContainer.DataMapEntryR\adataMap\x1a]\n" + "\fDataMapEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x127\n" + - "\x05value\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x05value:\x028\x01B_Z]github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" + "\x05value\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x05value:\x028\x01B[ZYgithub.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" var ( file_crossfile_container_proto_rawDescOnce sync.Once diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto similarity index 88% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto index b3c96ed..e4e9552 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container.proto @@ -3,7 +3,7 @@ syntax = "proto3"; package testdata.crossfile; -option go_package = "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile"; +option go_package = "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile"; import "crossfile/sensitive.proto"; diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go index e401d68..8d8ce9e 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/container_redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-redact. DO NOT EDIT. // versions: // - protoc-gen-go-redact v1.1.0 -// - protoc v6.33.2 +// - protoc v7.34.1 // source: crossfile/container.proto package crossfile diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/crossfile_test.go diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go similarity index 97% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go index 78c8fa6..3334f37 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.pb.go @@ -3,7 +3,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.1 // source: crossfile/sensitive.proto package crossfile @@ -150,7 +150,7 @@ const file_crossfile_sensitive_proto_rawDesc = "" + "\x05phone\x18\x03 \x01(\tB\x0f¢3\v\b\x01\x12\a[PHONE]R\x05phone\"\\\n" + "\x0fNestedSensitive\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x125\n" + - "\x04data\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x04dataB_Z]github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" + "\x04data\x18\x02 \x01(\v2!.testdata.crossfile.SensitiveDataR\x04dataB[ZYgithub.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfileb\x06proto3" var ( file_crossfile_sensitive_proto_rawDescOnce sync.Once diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto similarity index 83% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto index 9720665..816ee68 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package testdata.crossfile; -option go_package = "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile"; +option go_package = "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile"; import "kit/redact/v1/redact.proto"; diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go similarity index 97% rename from logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go index 27cbca1..52abbfb 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/crossfile/sensitive_redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-redact. DO NOT EDIT. // versions: // - protoc-gen-go-redact v1.1.0 -// - protoc v6.33.2 +// - protoc v7.34.1 // source: crossfile/sensitive.proto package crossfile diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/example.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.pb.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/example.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/example.pb.go index 0c99200..615f113 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/example.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.1 // source: example.proto package testdata @@ -3924,7 +3924,7 @@ const file_example_proto_rawDesc = "" + "\fPRIORITY_LOW\x10\x01\x12\x13\n" + "\x0fPRIORITY_MEDIUM\x10\x02\x12\x11\n" + "\rPRIORITY_HIGH\x10\x03\x12\x15\n" + - "\x11PRIORITY_CRITICAL\x10\x04BEZCgithub.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/testdatab\x06proto3" + "\x11PRIORITY_CRITICAL\x10\x04BAZ?github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/testdatab\x06proto3" var ( file_example_proto_rawDescOnce sync.Once diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/example.proto b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.proto similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/example.proto rename to cmd/protoc-gen-go-redact/internal/redact/testdata/example.proto index 45e9035..b5e25b4 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/example.proto +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/example.proto @@ -9,7 +9,7 @@ import "google/protobuf/any.proto"; import "google/protobuf/struct.proto"; import "kit/redact/v1/redact.proto"; -option go_package = "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/testdata"; +option go_package = "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/testdata"; // ============================================================================ // Enum Definitions diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go b/cmd/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go similarity index 99% rename from logging/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go rename to cmd/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go index 3451665..1f60ff0 100644 --- a/logging/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go +++ b/cmd/protoc-gen-go-redact/internal/redact/testdata/example_redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-redact. DO NOT EDIT. // versions: // - protoc-gen-go-redact v1.1.0 -// - protoc v6.33.2 +// - protoc v7.34.1 // source: example.proto package testdata diff --git a/logging/protoc-gen-go-redact/internal/redact/testdata/generate.sh b/cmd/protoc-gen-go-redact/internal/redact/testdata/generate.sh similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/testdata/generate.sh rename to cmd/protoc-gen-go-redact/internal/redact/testdata/generate.sh diff --git a/logging/protoc-gen-go-redact/internal/redact/version.go b/cmd/protoc-gen-go-redact/internal/redact/version.go similarity index 100% rename from logging/protoc-gen-go-redact/internal/redact/version.go rename to cmd/protoc-gen-go-redact/internal/redact/version.go diff --git a/logging/protoc-gen-go-redact/main.go b/cmd/protoc-gen-go-redact/main.go similarity index 92% rename from logging/protoc-gen-go-redact/main.go rename to cmd/protoc-gen-go-redact/main.go index 51fccf9..0cb428f 100644 --- a/logging/protoc-gen-go-redact/main.go +++ b/cmd/protoc-gen-go-redact/main.go @@ -4,7 +4,7 @@ import ( "flag" "fmt" - "github.com/crypto-zero/go-kit/logging/protoc-gen-go-redact/internal/redact" + "github.com/crypto-zero/go-kit/cmd/protoc-gen-go-redact/internal/redact" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/types/pluginpb" ) diff --git a/cmd/protoc-gen-go-sse/main.go b/cmd/protoc-gen-go-sse/main.go new file mode 100644 index 0000000..913e0e3 --- /dev/null +++ b/cmd/protoc-gen-go-sse/main.go @@ -0,0 +1,299 @@ +package main + +import ( + "flag" + "fmt" + "strconv" + "strings" + + ssev1 "github.com/crypto-zero/go-kit/proto/kit/sse/v1" + "google.golang.org/genproto/googleapis/api/annotations" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/pluginpb" +) + +const version = "v0.1.0" + +var ( + contextPackage = protogen.GoImportPath("context") + khttpPackage = protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http") + ssePackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/sse") + kratosPackage = protogen.GoImportPath("github.com/crypto-zero/go-kit/kratos/sse") +) + +func main() { + var flags flag.FlagSet + protogen.Options{ParamFunc: flags.Set}.Run(func(plugin *protogen.Plugin) error { + plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) + for _, file := range plugin.Files { + if file.Generate { + generateFile(plugin, file) + } + } + return nil + }) +} + +type sseMethod struct { + service *protogen.Service + method *protogen.Method + routes []sseRoute +} + +type sseRoute struct { + verb string + path string + body string +} + +func generateFile(plugin *protogen.Plugin, file *protogen.File) { + methods := collectSSEMethods(file) + if len(methods) == 0 { + return + } + + filename := file.GeneratedFilenamePrefix + "_sse.pb.go" + g := plugin.NewGeneratedFile(filename, file.GoImportPath) + g.P("// Code generated by protoc-gen-go-sse. DO NOT EDIT.") + g.P("// versions:") + g.P("// - protoc-gen-go-sse ", version) + g.P("// - protoc ", protocVersion(plugin)) + g.P("// source: ", file.Desc.Path()) + g.P() + g.P("package ", file.GoPackageName) + g.P() + + grouped := map[*protogen.Service][]sseMethod{} + for _, m := range methods { + grouped[m.service] = append(grouped[m.service], m) + } + for _, service := range file.Services { + items := grouped[service] + if len(items) == 0 { + continue + } + genService(g, service, items) + } +} + +func collectSSEMethods(file *protogen.File) []sseMethod { + var out []sseMethod + for _, service := range file.Services { + for _, method := range service.Methods { + opts := method.Desc.Options() + if !proto.HasExtension(opts, ssev1.E_ServerSentEvent) { + continue + } + enabled, ok := proto.GetExtension(opts, ssev1.E_ServerSentEvent).(bool) + if !ok || !enabled { + continue + } + if !proto.HasExtension(opts, annotations.E_Http) { + continue + } + rule, ok := proto.GetExtension(opts, annotations.E_Http).(*annotations.HttpRule) + if !ok || rule == nil { + continue + } + var routes []sseRoute + for _, binding := range httpRuleBindings(rule) { + verb, path := httpRuleHTTP(binding) + if verb == "" || path == "" { + continue + } + routes = append(routes, sseRoute{verb: verb, path: path, body: binding.GetBody()}) + } + if len(routes) == 0 { + continue + } + out = append(out, sseMethod{service: service, method: method, routes: routes}) + } + } + return out +} + +func httpRuleBindings(rule *annotations.HttpRule) []*annotations.HttpRule { + out := []*annotations.HttpRule{rule} + out = append(out, rule.GetAdditionalBindings()...) + return out +} + +func httpRuleHTTP(rule *annotations.HttpRule) (string, string) { + switch pattern := rule.GetPattern().(type) { + case *annotations.HttpRule_Get: + return "GET", pattern.Get + case *annotations.HttpRule_Put: + return "PUT", pattern.Put + case *annotations.HttpRule_Post: + return "POST", pattern.Post + case *annotations.HttpRule_Delete: + return "DELETE", pattern.Delete + case *annotations.HttpRule_Patch: + return "PATCH", pattern.Patch + case *annotations.HttpRule_Custom: + return pattern.Custom.Kind, pattern.Custom.Path + default: + return "", "" + } +} + +func genService(g *protogen.GeneratedFile, service *protogen.Service, methods []sseMethod) { + serviceName := service.GoName + contextIdent := g.QualifiedGoIdent(contextPackage.Ident("Context")) + streamIdent := g.QualifiedGoIdent(ssePackage.Ident("Stream")) + readerIdent := g.QualifiedGoIdent(ssePackage.Ident("Reader")) + serverIdent := g.QualifiedGoIdent(khttpPackage.Ident("Server")) + optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) + clientIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPClient")) + callOptionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamCallOption")) + clientImpl := unexport(serviceName) + "SSEClient" + + g.P("type ", serviceName, "SSEServer interface {") + for _, item := range methods { + g.P(item.method.GoName, "(", contextIdent, ", *", item.method.Input.GoIdent, ", *", streamIdent, ") error") + } + g.P("}") + g.P() + + g.P("func Register", serviceName, "SSEServer(s *", serverIdent, ", srv ", serviceName, "SSEServer, opts ...", optionIdent, ") {") + for _, item := range methods { + g.P("_", serviceName, "_", item.method.GoName, "_SSE_Register(s, srv, opts...)") + } + g.P("}") + g.P() + + g.P("type ", serviceName, "SSEClient interface {") + for _, item := range methods { + g.P(item.method.GoName, "(", contextIdent, ", *", item.method.Input.GoIdent, ", ...", callOptionIdent, ") (*", readerIdent, ", error)") + } + g.P("}") + g.P() + + g.P("type ", clientImpl, " struct {") + g.P("cc *", clientIdent) + g.P("}") + g.P() + + g.P("func New", serviceName, "SSEClient(client *", clientIdent, ") ", serviceName, "SSEClient {") + g.P("return &", clientImpl, "{cc: client}") + g.P("}") + g.P() + + for _, item := range methods { + genMethod(g, item) + } +} + +func genMethod(g *protogen.GeneratedFile, item sseMethod) { + serviceName := item.service.GoName + methodName := item.method.GoName + operationConst := "Operation" + serviceName + methodName + "SSE" + + serverIdent := g.QualifiedGoIdent(khttpPackage.Ident("Server")) + optionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamOption")) + registerIdent := g.QualifiedGoIdent(kratosPackage.Ident("RegisterHTTPStream")) + + g.P("const ", operationConst, ` = "/`, item.method.Desc.Parent().FullName(), `/`, item.method.Desc.Name(), `"`) + g.P() + g.P("func _", serviceName, "_", methodName, "_SSE_Register(s *", serverIdent, ", srv ", serviceName, "SSEServer, opts ...", optionIdent, ") {") + for _, route := range item.routes { + g.P(registerIdent, "(s, ", strconv.Quote(route.verb), ", ", strconv.Quote(route.path), ", ", operationConst, ", srv.", methodName, ", opts...)") + } + g.P("}") + g.P() + + callOptionIdent := g.QualifiedGoIdent(kratosPackage.Ident("HTTPStreamCallOption")) + readerIdent := g.QualifiedGoIdent(ssePackage.Ident("Reader")) + bindingIdent := g.QualifiedGoIdent(protogen.GoImportPath("github.com/go-kratos/kratos/v2/transport/http/binding").Ident("EncodeURL")) + clientImpl := unexport(serviceName) + "SSEClient" + // Match protoc-gen-go-http: generated clients call the primary + // google.api.http binding; additional_bindings are server aliases. + route := item.routes[0] + bodyArg := "nil" + needQuery := "true" + if route.body != "" { + bodyArg = bodyExpr(route.body) + needQuery = "false" + } + g.P("func (c *", clientImpl, ") ", methodName, "(ctx ", g.QualifiedGoIdent(contextPackage.Ident("Context")), ", in *", item.method.Input.GoIdent, ", opts ...", callOptionIdent, ") (*", readerIdent, ", error) {") + g.P("pattern := ", strconv.Quote(route.path)) + g.P("path := ", bindingIdent, "(pattern, in, ", needQuery, ")") + g.P("return c.cc.Open(ctx, ", strconv.Quote(route.verb), ", path, ", bodyArg, ", opts...)") + g.P("}") + g.P() +} + +func unexport(s string) string { + if s == "" { + return "" + } + return strings.ToLower(s[:1]) + s[1:] +} + +func bodyExpr(body string) string { + if body == "*" { + return "in" + } + return "in." + camelCaseVars(body) +} + +func camelCaseVars(s string) string { + subs := strings.Split(s, ".") + vars := make([]string, 0, len(subs)) + for _, sub := range subs { + vars = append(vars, camelCase(sub)) + } + return strings.Join(vars, ".") +} + +func camelCase(s string) string { + if s == "" { + return "" + } + t := make([]byte, 0, 32) + i := 0 + if s[0] == '_' { + t = append(t, 'X') + i++ + } + for ; i < len(s); i++ { + c := s[i] + if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) { + continue + } + if isASCIIDigit(c) { + t = append(t, c) + continue + } + if isASCIILower(c) { + c ^= ' ' + } + t = append(t, c) + for i+1 < len(s) && isASCIILower(s[i+1]) { + i++ + t = append(t, s[i]) + } + } + return string(t) +} + +func isASCIILower(c byte) bool { + return 'a' <= c && c <= 'z' +} + +func isASCIIDigit(c byte) bool { + return '0' <= c && c <= '9' +} + +func protocVersion(plugin *protogen.Plugin) string { + v := plugin.Request.GetCompilerVersion() + if v == nil { + return "(unknown)" + } + var suffix string + if v.GetSuffix() != "" { + suffix = "-" + v.GetSuffix() + } + return strings.TrimPrefix(fmt.Sprintf("v%d.%d.%d%s", v.GetMajor(), v.GetMinor(), v.GetPatch(), suffix), "v0.0.0") +} diff --git a/cmd/protoc-gen-go-sse/main_test.go b/cmd/protoc-gen-go-sse/main_test.go new file mode 100644 index 0000000..1f5dc47 --- /dev/null +++ b/cmd/protoc-gen-go-sse/main_test.go @@ -0,0 +1,110 @@ +package main + +import ( + "strings" + "testing" + + ssev1 "github.com/crypto-zero/go-kit/proto/kit/sse/v1" + "google.golang.org/genproto/googleapis/api/annotations" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/pluginpb" +) + +func TestGenerateFileUsesDefaultHTTPStreamBinding(t *testing.T) { + req := testCodeGeneratorRequest(t) + plugin, err := protogen.Options{}.New(req) + if err != nil { + t.Fatalf("protogen.New: %v", err) + } + + for _, file := range plugin.Files { + if file.Generate { + generateFile(plugin, file) + } + } + + resp := plugin.Response() + if len(resp.File) != 1 { + t.Fatalf("generated files = %d, want 1", len(resp.File)) + } + got := resp.File[0].GetContent() + if strings.Contains(got, "RegisterHTTPStreamBound") { + t.Fatalf("generated code uses bound registration:\n%s", got) + } + if count := strings.Count(got, "const OperationLiveServiceWatchSSE"); count != 1 { + t.Fatalf("operation const count = %d, want 1:\n%s", count, got) + } + if count := strings.Count(got, "func _LiveService_Watch_SSE_Register"); count != 1 { + t.Fatalf("register function count = %d, want 1:\n%s", count, got) + } + for _, want := range []string{ + `RegisterHTTPStream(s, "POST", "/v1/watch", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + `RegisterHTTPStream(s, "GET", "/v1/watch:tail", OperationLiveServiceWatchSSE, srv.Watch, opts...)`, + `type LiveServiceSSEClient interface`, + `func NewLiveServiceSSEClient(client *sse1.HTTPClient) LiveServiceSSEClient`, + `func (c *liveServiceSSEClient) Watch(ctx context.Context, in *WatchRequest, opts ...sse1.HTTPStreamCallOption) (*sse.Reader, error)`, + `path := binding.EncodeURL(pattern, in, false)`, + `return c.cc.Open(ctx, "POST", path, in, opts...)`, + } { + if !strings.Contains(got, want) { + t.Fatalf("generated code missing %q:\n%s", want, got) + } + } +} + +func TestBodyExpr(t *testing.T) { + for _, tc := range []struct { + body string + want string + }{ + {body: "*", want: "in"}, + {body: "payload", want: "in.Payload"}, + {body: "filter_box.zoom_level", want: "in.FilterBox.ZoomLevel"}, + } { + if got := bodyExpr(tc.body); got != tc.want { + t.Fatalf("bodyExpr(%q) = %q, want %q", tc.body, got, tc.want) + } + } +} + +func testCodeGeneratorRequest(t *testing.T) *pluginpb.CodeGeneratorRequest { + t.Helper() + + opts := &descriptorpb.MethodOptions{} + proto.SetExtension(opts, ssev1.E_ServerSentEvent, true) + proto.SetExtension(opts, annotations.E_Http, &annotations.HttpRule{ + Pattern: &annotations.HttpRule_Post{Post: "/v1/watch"}, + Body: "*", + AdditionalBindings: []*annotations.HttpRule{{ + Pattern: &annotations.HttpRule_Get{Get: "/v1/watch:tail"}, + }}, + }) + + return &pluginpb.CodeGeneratorRequest{ + FileToGenerate: []string{"test/v1/live.proto"}, + ProtoFile: []*descriptorpb.FileDescriptorProto{{ + Syntax: new("proto3"), + Name: new("test/v1/live.proto"), + Package: new("test.v1"), + Options: &descriptorpb.FileOptions{ + GoPackage: new("example.com/test/v1;testv1"), + }, + MessageType: []*descriptorpb.DescriptorProto{{ + Name: new("WatchRequest"), + }, { + Name: new("WatchResponse"), + }}, + Service: []*descriptorpb.ServiceDescriptorProto{{ + Name: new("LiveService"), + Method: []*descriptorpb.MethodDescriptorProto{{ + Name: new("Watch"), + InputType: new(".test.v1.WatchRequest"), + OutputType: new(".test.v1.WatchResponse"), + Options: opts, + }}, + }}, + }}, + } +} diff --git a/errors/protoc-gen-kit-errors/main.go b/cmd/protoc-gen-kit-errors/main.go similarity index 96% rename from errors/protoc-gen-kit-errors/main.go rename to cmd/protoc-gen-kit-errors/main.go index 3415b87..5d14584 100644 --- a/errors/protoc-gen-kit-errors/main.go +++ b/cmd/protoc-gen-kit-errors/main.go @@ -107,9 +107,7 @@ func (g *GenerateErrorDeclare) generateFile( sinkVarName := varName lowSinkName, lowParentName := strings.ToLower(sinkVarName), strings.ToLower(parentDescName) - if strings.HasPrefix(lowSinkName, lowParentName) { - lowSinkName = strings.TrimPrefix(lowSinkName, lowParentName) - } + lowSinkName = strings.TrimPrefix(lowSinkName, lowParentName) sinkVarName = "Err" + strcase.ToCamel(lowSinkName) gf.P(strings.TrimSpace(ev.Comments.Leading.String())) diff --git a/errors/client.go b/errors/client.go index f2b30c5..c70678a 100644 --- a/errors/client.go +++ b/errors/client.go @@ -22,10 +22,10 @@ func HttpServerErrorEncoder( _, _ = w.Write(body) } -// RPCHandler is a rpc handler for grpc/http client. +// RPCHandler is an RPC handler for gRPC/HTTP clients. type RPCHandler func(ctx context.Context, req any) (any, error) -// RPCClientErrorParser is a rpc client error parser. +// RPCClientErrorParser converts client errors into this package's Error type. func RPCClientErrorParser(handler RPCHandler) RPCHandler { return func(ctx context.Context, req any) (any, error) { reply, err := handler(ctx, req) diff --git a/errors/errors.go b/errors/errors.go index 84e5701..d51bda7 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -3,6 +3,7 @@ package errors import ( "errors" "fmt" + "maps" "google.golang.org/genproto/googleapis/rpc/errdetails" spb "google.golang.org/genproto/googleapis/rpc/status" @@ -17,9 +18,10 @@ import ( type PBError = pberrors.Error +// Error is the concrete error type used by this package. type Error PBError -// Error return text message and http status code. +// Error returns a text representation of e. func (e *Error) Error() string { return fmt.Sprintf("error: code = %d reason = %s message = %s", e.Status, e.Info.Reason, e.Message) } @@ -32,7 +34,7 @@ func (e *Error) Is(err error) bool { return false } -// GRPCStatus returns the Status represented by se. +// GRPCStatus returns the gRPC status represented by e. func (e *Error) GRPCStatus() *status.Status { s := &spb.Status{Code: int32(ToGRPCCode(int(e.Status))), Message: e.Message} if codes.Code(s.Code) == codes.OK { @@ -47,13 +49,13 @@ func (e *Error) GRPCStatus() *status.Status { return status.FromProto(s) } -// MarshalJSON marshals se to JSON. +// MarshalJSON marshals e to JSON. func (e *Error) MarshalJSON() ([]byte, error) { pbErr := (*PBError)(e) return protojson.Marshal(pbErr) } -// Clone returns a deep copy of se. +// Clone returns a deep copy of e. func (e *Error) Clone() *Error { if e == nil { return nil @@ -63,14 +65,17 @@ func (e *Error) Clone() *Error { return (*Error)(pbErr) } -// SetMetadata set metadata for error info. +// SetMetadata returns a copy of e with an ErrorInfo metadata value set. func (e *Error) SetMetadata(key, value string) *Error { copied := e.Clone() + if copied.Info.Metadata == nil { + copied.Info.Metadata = make(map[string]string) + } copied.Info.Metadata[key] = value return copied } -// SetCause set cause for error info. +// SetCause returns a copy of e with err recorded as ErrorInfo metadata. func (e *Error) SetCause(err error) *Error { if err == nil { return e @@ -78,7 +83,7 @@ func (e *Error) SetCause(err error) *Error { return e.SetMetadata("cause", err.Error()) } -// SetDomainAndCode set domain and code for info without clone. +// SetDomainAndCode sets the ErrorInfo domain and numeric code in place. func (e *Error) SetDomainAndCode(domain string, code int) *Error { if e.Info.Metadata == nil { e.Info.Metadata = make(map[string]string) @@ -95,7 +100,7 @@ const ( UnknownReason = "" ) -// New returns an error object for the code, message. +// New returns an error object for code, reason, and message. func New(code int, reason, message string) *Error { return &Error{ Status: int32(code), @@ -104,17 +109,17 @@ func New(code int, reason, message string) *Error { } } -// Newf New(code fmt.Sprintf(format, a...)) +// Newf returns an error object with a formatted message. func Newf(code int, reason, format string, a ...any) *Error { return New(code, reason, fmt.Sprintf(format, a...)) } -// Errorf returns an error object for the code, message and error info. +// Errorf returns an error object for code, reason, and a formatted message. func Errorf(code int, reason, format string, a ...any) error { return New(code, reason, fmt.Sprintf(format, a...)) } -// Code returns the http code for an error. +// Code returns the HTTP status code for an error. // It supports wrapped errors. func Code(err error) int { if err == nil { @@ -160,9 +165,7 @@ func FromError(err error) *Error { if info, ok := first.(*errdetails.ErrorInfo); ok { ret.Info.Reason, ret.Info.Domain = info.Reason, info.Domain ret.Info.Metadata = make(map[string]string, len(info.Metadata)) - for k, v := range info.Metadata { - ret.Info.Metadata[k] = v - } + maps.Copy(ret.Info.Metadata, info.Metadata) ret.Details = ret.Details[1:] } } diff --git a/errors/errors_test.go b/errors/errors_test.go index f80d746..dd18086 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -1,6 +1,7 @@ package errors import ( + stderrors "errors" "reflect" "testing" @@ -54,3 +55,26 @@ func TestError_Clone(t *testing.T) { }) } } + +func TestError_SetMetadataInitializesMap(t *testing.T) { + err := New(400, "bad_request", "bad request") + + got := err.SetMetadata("key", "value") + + if got.Info.Metadata["key"] != "value" { + t.Fatalf("SetMetadata() metadata = %v; want key=value", got.Info.Metadata) + } + if err.Info.Metadata != nil { + t.Fatalf("SetMetadata() mutated original metadata = %v; want nil", err.Info.Metadata) + } +} + +func TestError_SetCauseInitializesMap(t *testing.T) { + err := New(400, "bad_request", "bad request") + + got := err.SetCause(stderrors.New("root cause")) + + if got.Info.Metadata["cause"] != "root cause" { + t.Fatalf("SetCause() metadata = %v; want cause=root cause", got.Info.Metadata) + } +} diff --git a/errors/status.go b/errors/status.go index a9d91c7..e43bd4c 100644 --- a/errors/status.go +++ b/errors/status.go @@ -7,8 +7,8 @@ import ( ) const ( - // HttpCodeClientClosed is non-standard http status code, - // which defined by nginx. + // HttpCodeClientClosed is the non-standard HTTP status code used by nginx + // when a client closes the request. // https://httpstatus.in/499/ HttpCodeClientClosed = 499 ) @@ -24,7 +24,7 @@ type Converter interface { type statusConverter struct{} -// DefaultConverter default converter. +// DefaultConverter is the default status converter. var DefaultConverter Converter = statusConverter{} // ToGRPCCode converts an HTTP error code into the corresponding gRPC response status. diff --git a/go.mod b/go.mod index ad74865..6ce7545 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit -go 1.25.5 +go 1.26.3 require ( entgo.io/ent v0.14.5 @@ -19,6 +19,7 @@ require ( go.uber.org/zap v1.27.1 go.uber.org/zap/exp v0.3.0 golang.org/x/text v0.32.0 + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 google.golang.org/grpc v1.77.0 google.golang.org/protobuf v1.36.11 @@ -71,6 +72,5 @@ require ( golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.work b/go.work new file mode 100644 index 0000000..7e30e87 --- /dev/null +++ b/go.work @@ -0,0 +1,8 @@ +go 1.26.3 + +use ( + . + ./kratos + ./kubernetes/election + ./s3 +) diff --git a/kent/encrypt_test.go b/kent/encrypt_test.go index cd8e0d7..8323922 100644 --- a/kent/encrypt_test.go +++ b/kent/encrypt_test.go @@ -216,7 +216,7 @@ func TestEncryptDecrypt_MultipleRoundTrips(t *testing.T) { plaintext := "test message" // Perform multiple encrypt-decrypt cycles - for i := 0; i < 100; i++ { + for i := range 100 { ciphertext, err := encryptor.Encrypt(plaintext) if err != nil { t.Fatalf("Encrypt() iteration %d error = %v", i, err) @@ -302,7 +302,7 @@ func TestEncryptDecrypt_Concurrent(t *testing.T) { done := make(chan bool, iterations) // Concurrent encryption - for i := 0; i < iterations; i++ { + for i := range iterations { go func(id int) { ciphertext, err := encryptor.Encrypt(plaintext) if err != nil { @@ -330,7 +330,7 @@ func TestEncryptDecrypt_Concurrent(t *testing.T) { // Wait for all goroutines successCount := 0 - for i := 0; i < iterations; i++ { + for range iterations { if <-done { successCount++ } diff --git a/kent/ordering.go b/kent/ordering.go index 4ae453d..db0eef8 100644 --- a/kent/ordering.go +++ b/kent/ordering.go @@ -19,8 +19,7 @@ func ProcessOrdering(orderBy string, fieldMap map[string]string, defaultOrdering return defaultOrdering } - orderByTerms := strings.Split(orderBy, ",") - for _, term := range orderByTerms { + for term := range strings.SplitSeq(orderBy, ",") { if term == "" { continue } diff --git a/kratos/auth/policy.go b/kratos/auth/policy.go new file mode 100644 index 0000000..56087db --- /dev/null +++ b/kratos/auth/policy.go @@ -0,0 +1,69 @@ +// Package auth provides auth helpers for Kratos operation selectors. +package auth + +import ( + "github.com/crypto-zero/go-kit/kratos/internal/protoop" + authv1 "github.com/crypto-zero/go-kit/proto/kit/auth/v1" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// OperationPolicy reports whether a Kratos operation should run through +// authentication middleware. +type OperationPolicy struct { + public map[string]struct{} +} + +// OperationPolicyOption configures an OperationPolicy. +type OperationPolicyOption func(*OperationPolicy) + +// NewOperationPolicy constructs an auth policy from proto descriptors and +// optional manually registered operations. +func NewOperationPolicy(opts ...OperationPolicyOption) *OperationPolicy { + p := &OperationPolicy{public: make(map[string]struct{})} + for _, opt := range opts { + opt(p) + } + return p +} + +// WithPublicOperations marks explicit Kratos operations as public. +func WithPublicOperations(ops ...string) OperationPolicyOption { + return func(p *OperationPolicy) { + for _, op := range ops { + p.public[op] = struct{}{} + } + } +} + +// WithPublicFromProtoFiles scans file descriptors for methods tagged with +// `(kit.auth.v1.public) = true`. +func WithPublicFromProtoFiles(files ...protoreflect.FileDescriptor) OperationPolicyOption { + return func(p *OperationPolicy) { + for _, fd := range files { + registerPublicFromFile(p, fd) + } + } +} + +// RequiresAuth reports whether operation should run through authentication. +func (p *OperationPolicy) RequiresAuth(operation string) bool { + _, ok := p.public[operation] + return !ok +} + +// OperationName returns the Kratos operation string for a proto method. +func OperationName(m protoreflect.MethodDescriptor) string { + return protoop.OperationName(m) +} + +func registerPublicFromFile(p *OperationPolicy, fd protoreflect.FileDescriptor) { + protoop.WalkMethods([]protoreflect.FileDescriptor{fd}, func(m protoreflect.MethodDescriptor) { + if methodIsPublic(m) { + p.public[protoop.OperationName(m)] = struct{}{} + } + }) +} + +func methodIsPublic(m protoreflect.MethodDescriptor) bool { + return protoop.BoolExtension(m, authv1.E_Public) +} diff --git a/kratos/auth/policy_test.go b/kratos/auth/policy_test.go new file mode 100644 index 0000000..df2e165 --- /dev/null +++ b/kratos/auth/policy_test.go @@ -0,0 +1,65 @@ +package auth + +import ( + "testing" + + authv1 "github.com/crypto-zero/go-kit/proto/kit/auth/v1" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/types/descriptorpb" +) + +func TestOperationPolicyRegistersPublicProtoMethods(t *testing.T) { + publicOpts := &descriptorpb.MethodOptions{} + proto.SetExtension(publicOpts, authv1.E_Public, true) + fd, err := protodesc.NewFile(&descriptorpb.FileDescriptorProto{ + Syntax: new("proto3"), + Name: new("test/auth/v1/service.proto"), + Package: new("test.auth.v1"), + Service: []*descriptorpb.ServiceDescriptorProto{{ + Name: new("AuthService"), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: new("Login"), + InputType: new(".test.auth.v1.LoginRequest"), + OutputType: new(".test.auth.v1.LoginResponse"), + Options: publicOpts, + }, + { + Name: new("Profile"), + InputType: new(".test.auth.v1.ProfileRequest"), + OutputType: new(".test.auth.v1.ProfileResponse"), + }, + }, + }}, + MessageType: []*descriptorpb.DescriptorProto{ + {Name: new("LoginRequest")}, + {Name: new("LoginResponse")}, + {Name: new("ProfileRequest")}, + {Name: new("ProfileResponse")}, + }, + }, nil) + if err != nil { + t.Fatalf("NewFile: %v", err) + } + + policy := NewOperationPolicy(WithPublicFromProtoFiles(fd)) + + if policy.RequiresAuth("/test.auth.v1.AuthService/Login") { + t.Fatal("Login should be public from proto option") + } + if !policy.RequiresAuth("/test.auth.v1.AuthService/Profile") { + t.Fatal("Profile should require auth when untagged") + } +} + +func TestOperationPolicyManualPublicOperations(t *testing.T) { + policy := NewOperationPolicy(WithPublicOperations("/healthz", "/readyz")) + + if policy.RequiresAuth("/healthz") { + t.Fatal("/healthz should be public") + } + if !policy.RequiresAuth("/v1/private") { + t.Fatal("/v1/private should require auth") + } +} diff --git a/kratos/clientip/clientip.go b/kratos/clientip/clientip.go new file mode 100644 index 0000000..2cab7e3 --- /dev/null +++ b/kratos/clientip/clientip.go @@ -0,0 +1,97 @@ +// Package clientip extracts client IP addresses from Kratos request contexts. +package clientip + +import ( + "context" + "net" + "strings" + + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +// FromContext extracts the client IP address from a Kratos server context. +// +// Priority: X-Forwarded-For, X-Real-IP, then remote peer address. Returned +// values are normalized: IPv6 zone IDs are removed and IPv4-mapped IPv6 +// addresses are converted to IPv4. +func FromContext(ctx context.Context) string { + tr, ok := transport.FromServerContext(ctx) + if !ok { + return "" + } + if httpTr, ok := tr.(*kratoshttp.Transport); ok { + return fromHTTP(httpTr) + } + return fromGRPC(ctx) +} + +func fromHTTP(httpTr *kratoshttp.Transport) string { + req := httpTr.Request() + if req == nil { + return "" + } + if ip := extract(req.Header.Get("X-Forwarded-For")); ip != "" { + return ip + } + if ip := extract(req.Header.Get("X-Real-IP")); ip != "" { + return ip + } + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + return normalize(req.RemoteAddr) + } + return normalize(host) +} + +func fromGRPC(ctx context.Context) string { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if xff := md.Get("x-forwarded-for"); len(xff) > 0 { + if ip := extract(xff[0]); ip != "" { + return ip + } + } + if xrip := md.Get("x-real-ip"); len(xrip) > 0 { + if ip := extract(xrip[0]); ip != "" { + return ip + } + } + } + if p, ok := peer.FromContext(ctx); ok && p.Addr != nil { + host, _, err := net.SplitHostPort(p.Addr.String()) + if err != nil { + return normalize(p.Addr.String()) + } + return normalize(host) + } + return "" +} + +func extract(headerVal string) string { + if headerVal == "" { + return "" + } + if idx := strings.IndexByte(headerVal, ','); idx != -1 { + headerVal = headerVal[:idx] + } + return normalize(strings.TrimSpace(headerVal)) +} + +func normalize(ip string) string { + if ip == "" { + return "" + } + if idx := strings.IndexByte(ip, '%'); idx != -1 { + ip = ip[:idx] + } + parsed := net.ParseIP(ip) + if parsed == nil { + return "" + } + if ipv4 := parsed.To4(); ipv4 != nil { + return ipv4.String() + } + return parsed.String() +} diff --git a/kratos/clientip/clientip_test.go b/kratos/clientip/clientip_test.go new file mode 100644 index 0000000..e87d557 --- /dev/null +++ b/kratos/clientip/clientip_test.go @@ -0,0 +1,66 @@ +package clientip + +import ( + "context" + "net" + "testing" + + "github.com/go-kratos/kratos/v2/transport" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +type mockTransport struct{} + +func (m *mockTransport) Kind() transport.Kind { return transport.KindGRPC } +func (m *mockTransport) Endpoint() string { return "localhost:9000" } +func (m *mockTransport) Operation() string { return "/test.Service/Method" } +func (m *mockTransport) RequestHeader() transport.Header { + return &mockHeader{} +} +func (m *mockTransport) ReplyHeader() transport.Header { + return &mockHeader{} +} + +type mockHeader struct{} + +func (m *mockHeader) Get(string) string { return "" } +func (m *mockHeader) Set(string, string) {} +func (m *mockHeader) Add(string, string) {} +func (m *mockHeader) Keys() []string { return nil } +func (m *mockHeader) Values(string) []string { return nil } + +func TestFromContextUsesForwardedHeader(t *testing.T) { + ctx := transport.NewServerContext(context.Background(), &mockTransport{}) + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs( + "x-forwarded-for", " ::ffff:192.168.1.10, 10.0.0.1", + "x-real-ip", "192.168.1.11", + )) + + if got := FromContext(ctx); got != "192.168.1.10" { + t.Fatalf("FromContext() = %q, want forwarded IP", got) + } +} + +func TestFromContextFallsBackToPeer(t *testing.T) { + ctx := transport.NewServerContext(context.Background(), &mockTransport{}) + ctx = peer.NewContext(ctx, &peer.Peer{Addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::1"), Port: 443}}) + + if got := FromContext(ctx); got != "2001:db8::1" { + t.Fatalf("FromContext() = %q, want peer IP", got) + } +} + +func TestFromContextRequiresServerTransport(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-real-ip", "192.168.1.1")) + + if got := FromContext(ctx); got != "" { + t.Fatalf("FromContext() = %q, want empty without server transport", got) + } +} + +func TestNormalizeRejectsInvalidIP(t *testing.T) { + if got := normalize("not-an-ip"); got != "" { + t.Fatalf("normalize() = %q, want empty", got) + } +} diff --git a/kratos/go.mod b/kratos/go.mod new file mode 100644 index 0000000..6c12dbd --- /dev/null +++ b/kratos/go.mod @@ -0,0 +1,28 @@ +module github.com/crypto-zero/go-kit/kratos + +go 1.26.3 + +replace github.com/crypto-zero/go-kit => .. + +require ( + github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac + github.com/go-kratos/kratos/v2 v2.9.2 + github.com/redis/go-redis/v9 v9.19.0 + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 + google.golang.org/grpc v1.78.0 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/go-kratos/aegis v0.2.0 // indirect + github.com/go-playground/form/v4 v4.2.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/mux v1.8.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/sys v0.39.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/logging/kratos/go.sum b/kratos/go.sum similarity index 60% rename from logging/kratos/go.sum rename to kratos/go.sum index 0229ab0..e380f83 100644 --- a/logging/kratos/go.sum +++ b/kratos/go.sum @@ -1,4 +1,8 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac h1:mAUwaMQHRclziH/e4xbuNEHyGrr5DBalMqFtw8wu2uw= +github.com/crypto-zero/go-kit v0.0.0-20260128101518-0545cf5a3fac/go.mod h1:2V6ihklJZoXNZzlmeX+sk5RMYXcwJwq5kIrubxbsUjo= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-kratos/aegis v0.2.0 h1:dObzCDWn3XVjUkgxyBp6ZeWtx/do0DPZ7LY3yNSJLUQ= @@ -23,20 +27,26 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda h1:i/Q+bfisr7gq6feoJnS/DlpdwEL4ihp41fvRiM3Ork0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2 h1:2I6GHUeJ/4shcDpoUlLs/2WPnhg7yJwvXtqcMJt9liA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= diff --git a/kratos/internal/protoop/protoop.go b/kratos/internal/protoop/protoop.go new file mode 100644 index 0000000..636bf01 --- /dev/null +++ b/kratos/internal/protoop/protoop.go @@ -0,0 +1,79 @@ +// Package protoop contains helpers for Kratos proto operation descriptors. +package protoop + +import ( + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +// WalkMethods calls fn for every method in files. +func WalkMethods(files []protoreflect.FileDescriptor, fn func(protoreflect.MethodDescriptor)) { + for _, fd := range files { + services := fd.Services() + for i := range services.Len() { + methods := services.Get(i).Methods() + for j := range methods.Len() { + fn(methods.Get(j)) + } + } + } +} + +// OperationName returns the Kratos operation string for a proto method. +func OperationName(m protoreflect.MethodDescriptor) string { + return "/" + string(m.Parent().FullName()) + "/" + string(m.Name()) +} + +// Extension returns the extension value from a proto method option. +func Extension(m protoreflect.MethodDescriptor, ext protoreflect.ExtensionType) (any, bool) { + opts, ok := m.Options().(*descriptorpb.MethodOptions) + if !ok || opts == nil { + return nil, false + } + return proto.GetExtension(opts, ext), true +} + +// BoolExtension reports whether a boolean extension is set to true. +func BoolExtension(m protoreflect.MethodDescriptor, ext protoreflect.ExtensionType) bool { + v, ok := Extension(m, ext) + if !ok { + return false + } + opts := m.Options().(*descriptorpb.MethodOptions) + number := ext.TypeDescriptor().Number() + switch value := v.(type) { + case bool: + return value || unknownBool(opts, number) + case *bool: + return (value != nil && *value) || unknownBool(opts, number) + default: + return unknownBool(opts, number) + } +} + +func unknownBool(opts *descriptorpb.MethodOptions, number protoreflect.FieldNumber) bool { + raw := opts.ProtoReflect().GetUnknown() + for len(raw) > 0 { + num, typ, n := protowire.ConsumeTag(raw) + if n < 0 { + return false + } + raw = raw[n:] + if num != protowire.Number(number) { + n = protowire.ConsumeFieldValue(num, typ, raw) + if n < 0 { + return false + } + raw = raw[n:] + continue + } + if typ != protowire.VarintType { + return false + } + v, n := protowire.ConsumeVarint(raw) + return n >= 0 && v != 0 + } + return false +} diff --git a/logging/kratos/logging.go b/kratos/logging/logging.go similarity index 69% rename from logging/kratos/logging.go rename to kratos/logging/logging.go index ac60472..192f1b1 100644 --- a/logging/kratos/logging.go +++ b/kratos/logging/logging.go @@ -5,17 +5,16 @@ import ( "encoding/json" "fmt" "log/slog" - "net" "reflect" "strings" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/peer" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "github.com/crypto-zero/go-kit/kratos/clientip" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" @@ -204,113 +203,7 @@ func extractError(err error) (slog.Level, string) { // Supports both HTTP and gRPC transports. // Returns normalized IP address (IPv4-mapped IPv6 converted to IPv4, Zone ID removed). func GetClientIP(ctx context.Context) string { - tr, ok := transport.FromServerContext(ctx) - if !ok { - return "" - } - - // Handle HTTP transport - if httpTr, ok := tr.(*kratoshttp.Transport); ok { - return getClientIPFromHTTP(httpTr) - } - - // Handle gRPC transport - return getClientIPFromGRPC(ctx) -} - -// getClientIPFromHTTP extracts client IP from HTTP request. -func getClientIPFromHTTP(httpTr *kratoshttp.Transport) string { - req := httpTr.Request() - if req == nil { - return "" - } - - if ip := extractIP(req.Header.Get("X-Forwarded-For")); ip != "" { - return ip - } - - if ip := extractIP(req.Header.Get("X-Real-IP")); ip != "" { - return ip - } - - host, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - return normalizeIP(req.RemoteAddr) - } - return normalizeIP(host) -} - -// getClientIPFromGRPC extracts client IP from gRPC context. -func getClientIPFromGRPC(ctx context.Context) string { - // Try to get forwarded headers from gRPC metadata first - if md, ok := metadata.FromIncomingContext(ctx); ok { - // Check X-Forwarded-For - if xff := md.Get("x-forwarded-for"); len(xff) > 0 { - if ip := extractIP(xff[0]); ip != "" { - return ip - } - } - // Check X-Real-IP - if xrip := md.Get("x-real-ip"); len(xrip) > 0 { - if ip := extractIP(xrip[0]); ip != "" { - return ip - } - } - } - - // Fall back to peer address - if p, ok := peer.FromContext(ctx); ok && p.Addr != nil { - host, _, err := net.SplitHostPort(p.Addr.String()) - if err != nil { - return normalizeIP(p.Addr.String()) - } - return normalizeIP(host) - } - - return "" -} - -// extractIP extracts the first valid IP from a header value (e.g., X-Forwarded-For). -// It handles comma-separated values (taking the first one) and applies normalization. -func extractIP(headerVal string) string { - if headerVal == "" { - return "" - } - // X-Forwarded-For can contain multiple IPs, the first one is the client IP. - // Use IndexByte instead of Split to avoid memory allocation. - if idx := strings.IndexByte(headerVal, ','); idx != -1 { - headerVal = headerVal[:idx] - } - return normalizeIP(strings.TrimSpace(headerVal)) -} - -// normalizeIP validates and normalizes an IP address string. -// It performs the following normalizations: -// - Removes IPv6 zone ID (e.g., "fe80::1%eth0" -> "fe80::1") -// - Converts IPv4-mapped IPv6 to IPv4 (e.g., "::ffff:192.168.1.1" -> "192.168.1.1") -// -// Returns empty string if the input is not a valid IP address. -func normalizeIP(ip string) string { - if ip == "" { - return "" - } - - // Remove IPv6 zone ID (e.g., fe80::1%eth0 -> fe80::1) - if idx := strings.IndexByte(ip, '%'); idx != -1 { - ip = ip[:idx] - } - - parsed := net.ParseIP(ip) - if parsed == nil { - return "" - } - - // Convert IPv4-mapped IPv6 to IPv4 (e.g., ::ffff:192.168.1.1 -> 192.168.1.1) - if ipv4 := parsed.To4(); ipv4 != nil { - return ipv4.String() - } - - return parsed.String() + return clientip.FromContext(ctx) } // getClientDevice extracts the client device info (User-Agent or custom header) from the request context. diff --git a/logging/kratos/logging_test.go b/kratos/logging/logging_test.go similarity index 96% rename from logging/kratos/logging_test.go rename to kratos/logging/logging_test.go index 9adf027..6926c63 100644 --- a/logging/kratos/logging_test.go +++ b/kratos/logging/logging_test.go @@ -158,12 +158,12 @@ func TestExtractArgs_PlainStruct(t *testing.T) { func TestExtractArgs_Nil(t *testing.T) { result := extractArgs(nil, false) - str, ok := result.(string) + raw, ok := result.(json.RawMessage) if !ok { - t.Fatalf("Expected string for nil, got %T", result) + t.Fatalf("Expected json.RawMessage for nil, got %T", result) } - if str != "" { - t.Errorf("Expected '', got: %s", str) + if string(raw) != "{}" { + t.Errorf("Expected '{}', got: %s", raw) } } @@ -240,8 +240,8 @@ func TestServer_Success(t *testing.T) { // Check log output logOutput := buf.String() - if !strings.Contains(logOutput, "server request") { - t.Errorf("Expected 'server request' in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"msg":"server"`) { + t.Errorf("Expected server log message, got: %s", logOutput) } if !strings.Contains(logOutput, "/api/v1/test") { t.Errorf("Expected operation in log, got: %s", logOutput) @@ -314,8 +314,8 @@ func TestServer_WithoutTransport(t *testing.T) { } logOutput := buf.String() - if !strings.Contains(logOutput, "server request") { - t.Errorf("Expected 'server request' in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"msg":"server"`) { + t.Errorf("Expected server log message, got: %s", logOutput) } } @@ -350,8 +350,8 @@ func TestClient_Success(t *testing.T) { } logOutput := buf.String() - if !strings.Contains(logOutput, "client request") { - t.Errorf("Expected 'client request' in log, got: %s", logOutput) + if !strings.Contains(logOutput, `"msg":"client"`) { + t.Errorf("Expected client log message, got: %s", logOutput) } if !strings.Contains(logOutput, "/api/external/call") { t.Errorf("Expected operation in log, got: %s", logOutput) diff --git a/kratos/ratelimit/example/README.md b/kratos/ratelimit/example/README.md new file mode 100644 index 0000000..8d78eaf --- /dev/null +++ b/kratos/ratelimit/example/README.md @@ -0,0 +1,101 @@ +# Kratos Rate Limit Example + +This example shows operation-level rate limiting driven only by external config. + +Kratos already provides the current operation name at request time. The +middleware uses that operation as the namespace, then applies the configured +business key parts: + +- `CreateOrder`: limit by `user_id`, and separately by `client_ip` +- `GetOrder`: limit by `user_id` + +In a real Kratos service, define this config shape in that service's own +`internal/conf/conf.proto`, then load it from YAML and convert it to +`ratelimit.OperationRules`. The config proto should live with the service +because operations and business dimensions are service-owned. + +```proto +message Bootstrap { + Server server = 1; + Data data = 2; + RateLimit rate_limit = 3; +} + +message RateLimit { + map operations = 1; + + message Operation { + repeated Rule rules = 1; + } + + message Rule { + repeated string key_parts = 1; + int32 rate = 2; + google.protobuf.Duration per = 3; + int32 burst = 4; + } +} +``` + +```yaml +rate_limit: + operations: + /ratelimit.example.order.v1.OrderService/CreateOrder: + rules: + - key_parts: [user_id] + rate: 10 + per: 60s + burst: 10 + - key_parts: [client_ip] + rate: 30 + per: 60s + burst: 30 + /ratelimit.example.order.v1.OrderService/GetOrder: + rules: + - key_parts: [user_id] + rate: 100 + per: 60s + burst: 100 +``` + +Convert the generated service config before wiring the middleware: + +```go +operationRules, err := convertRateLimitConfig(conf.RateLimit) +if err != nil { + return err +} +``` + +The server wires external rules into the middleware: + +```go +mw, err := ratelimit.Server( + ratelimit.WithRuleStore(store), + ratelimit.WithOperationRules(operationRules), + ratelimit.WithClientIPKeyFunc(ratelimit.ClientIPKey), + ratelimit.WithUserKeyFunc(userIDFromHeader), +) +if err != nil { + return err +} +``` + +Run Redis locally, then start the example: + +```bash +go run ./kratos/ratelimit/example +``` + +Try the limited endpoints: + +```bash +curl -H 'X-User-ID: user_123' \ + -H 'X-Real-IP: 127.0.0.1' \ + -H 'Content-Type: application/json' \ + -d '{"sku":"book"}' \ + http://127.0.0.1:8000/v1/orders + +curl -H 'X-User-ID: user_123' \ + http://127.0.0.1:8000/v1/orders/order_for_book +``` diff --git a/kratos/ratelimit/example/main.go b/kratos/ratelimit/example/main.go new file mode 100644 index 0000000..9a92955 --- /dev/null +++ b/kratos/ratelimit/example/main.go @@ -0,0 +1,269 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + redisstore "github.com/crypto-zero/go-kit/kratos/ratelimit/redis" + khttp "github.com/go-kratos/kratos/v2/transport/http" + goredis "github.com/redis/go-redis/v9" + "google.golang.org/protobuf/types/known/durationpb" +) + +const ( + createOrderOperation = "/ratelimit.example.order.v1.OrderService/CreateOrder" + getOrderOperation = "/ratelimit.example.order.v1.OrderService/GetOrder" +) + +type config struct { + HTTPAddr string + RedisAddr string + RedisPrefix string + + RateLimit *rateLimitConfig +} + +// These config structs mirror the shape a business service would define in +// internal/conf/conf.proto and receive from Kratos config loading. +type rateLimitConfig struct { + Operations map[string]*rateLimitOperation +} + +type rateLimitOperation struct { + Rules []*rateLimitRule +} + +type rateLimitRule struct { + KeyParts []string + Rate int32 + Per *durationpb.Duration + Burst int32 +} + +func main() { + cfg := config{ + HTTPAddr: ":8000", + RedisAddr: "127.0.0.1:6379", + RedisPrefix: "ratelimit-example:order-api", + RateLimit: &rateLimitConfig{ + Operations: map[string]*rateLimitOperation{ + createOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"user_id"}, + Rate: 10, + Per: durationpb.New(time.Minute), + Burst: 10, + }, + { + KeyParts: []string{"client_ip"}, + Rate: 30, + Per: durationpb.New(time.Minute), + Burst: 30, + }, + }, + }, + getOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"user_id"}, + Rate: 100, + Per: durationpb.New(time.Minute), + Burst: 100, + }, + }, + }, + }, + }, + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := run(ctx, cfg); err != nil { + log.Fatal(err) + } +} + +func run(ctx context.Context, cfg config) error { + redisClient := goredis.NewClient(&goredis.Options{Addr: cfg.RedisAddr}) + defer redisClient.Close() + + store, err := redisstore.NewStore(redisClient, cfg.RedisPrefix) + if err != nil { + return err + } + operationRules, err := convertRateLimitConfig(cfg.RateLimit) + if err != nil { + return err + } + mw, err := ratelimit.Server( + ratelimit.WithRuleStore(store), + ratelimit.WithOperationRules(operationRules), + ratelimit.WithClientIPKeyFunc(ratelimit.ClientIPKey), + ratelimit.WithUserKeyFunc(userIDFromHeader), + ) + if err != nil { + return err + } + srv := khttp.NewServer( + khttp.Address(cfg.HTTPAddr), + khttp.Middleware(mw), + ) + registerOrderHTTPServer(srv) + + errc := make(chan error, 1) + go func() { errc <- srv.Start(ctx) }() + + select { + case <-ctx.Done(): + stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return srv.Stop(stopCtx) + case err := <-errc: + return err + } +} + +func convertRateLimitConfig(cfg *rateLimitConfig) (ratelimit.OperationRules, error) { + if cfg == nil || len(cfg.Operations) == 0 { + return nil, nil + } + rules := make(ratelimit.OperationRules, len(cfg.Operations)) + for operation, op := range cfg.Operations { + if operation == "" { + return nil, fmt.Errorf("rate_limit operation must not be empty") + } + if op == nil || len(op.Rules) == 0 { + return nil, fmt.Errorf("%s: rate_limit rules must not be empty", operation) + } + for i, rule := range op.Rules { + converted, err := convertRateLimitRule(rule) + if err != nil { + return nil, fmt.Errorf("%s rules[%d]: %w", operation, i, err) + } + rules[operation] = append(rules[operation], converted) + } + } + return rules, nil +} + +func convertRateLimitRule(rule *rateLimitRule) (ratelimit.RuleConfig, error) { + if rule == nil { + return ratelimit.RuleConfig{}, fmt.Errorf("rule must not be nil") + } + if rule.Rate <= 0 || rule.Per == nil || rule.Per.AsDuration() <= 0 || rule.Burst <= 0 { + return ratelimit.RuleConfig{}, fmt.Errorf("rate, per, and burst must be explicitly positive") + } + keyParts, err := convertKeyParts(rule.KeyParts) + if err != nil { + return ratelimit.RuleConfig{}, err + } + return ratelimit.RuleConfig{ + Config: ratelimit.Config{ + Rate: int(rule.Rate), + Per: rule.Per.AsDuration(), + Burst: int(rule.Burst), + }, + KeyParts: keyParts, + }, nil +} + +func convertKeyParts(parts []string) ([]ratelimit.KeyPart, error) { + if len(parts) == 0 { + return nil, fmt.Errorf("key_parts must not be empty") + } + keyParts := make([]ratelimit.KeyPart, 0, len(parts)) + for _, part := range parts { + kp, err := ratelimit.ParseKeyPart(part) + if err != nil { + return nil, err + } + keyParts = append(keyParts, kp) + } + return keyParts, nil +} + +func userIDFromHeader(ctx context.Context, _ any) string { + if req, ok := khttp.RequestFromServerContext(ctx); ok { + return req.Header.Get("X-User-ID") + } + return "" +} + +func registerOrderHTTPServer(srv *khttp.Server) { + route := srv.Route("/") + route.POST("/v1/orders", createOrder) + route.GET("/v1/orders/{order_id}", getOrder) +} + +func createOrder(ctx khttp.Context) error { + khttp.SetOperation(ctx, createOrderOperation) + + req := new(createOrderRequest) + if err := ctx.Bind(req); err != nil { + return err + } + handler := ctx.Middleware(func(context.Context, any) (any, error) { + return &createOrderResponse{ + OrderId: fmt.Sprintf("order_for_%s", req.GetSku()), + }, nil + }) + return ctx.Returns(handler(ctx, req)) +} + +func getOrder(ctx khttp.Context) error { + khttp.SetOperation(ctx, getOrderOperation) + + req := new(getOrderRequest) + if err := ctx.BindVars(req); err != nil { + return err + } + handler := ctx.Middleware(func(context.Context, any) (any, error) { + return &getOrderResponse{ + OrderId: req.GetOrderId(), + Status: "created", + }, nil + }) + return ctx.Returns(handler(ctx, req)) +} + +var _ http.Handler = (*khttp.Server)(nil) + +type createOrderRequest struct { + Sku string `json:"sku"` +} + +func (r *createOrderRequest) GetSku() string { + if r == nil { + return "" + } + return r.Sku +} + +type createOrderResponse struct { + OrderId string `json:"order_id"` +} + +type getOrderRequest struct { + OrderId string `json:"order_id" form:"order_id"` +} + +func (r *getOrderRequest) GetOrderId() string { + if r == nil { + return "" + } + return r.OrderId +} + +type getOrderResponse struct { + OrderId string `json:"order_id"` + Status string `json:"status"` +} diff --git a/kratos/ratelimit/example/main_test.go b/kratos/ratelimit/example/main_test.go new file mode 100644 index 0000000..8b88746 --- /dev/null +++ b/kratos/ratelimit/example/main_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "testing" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + "google.golang.org/protobuf/types/known/durationpb" +) + +func TestConvertRateLimitConfig(t *testing.T) { + rules, err := convertRateLimitConfig(&rateLimitConfig{ + Operations: map[string]*rateLimitOperation{ + createOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"user_id", "client_ip"}, + Rate: 10, + Per: durationpb.New(time.Minute), + Burst: 20, + }, + }, + }, + }, + }) + if err != nil { + t.Fatalf("convertRateLimitConfig error = %v, want nil", err) + } + if got := len(rules[createOrderOperation]); got != 1 { + t.Fatalf("len(rules[%q]) = %d, want 1", createOrderOperation, got) + } + rule := rules[createOrderOperation][0] + if rule.Config.Rate != 10 || rule.Config.Per != time.Minute || rule.Config.Burst != 20 { + t.Fatalf("rule.Config = %+v, want rate 10 per 1m burst 20", rule.Config) + } + wantKeyParts := []ratelimit.KeyPart{ratelimit.KeyPartUserID, ratelimit.KeyPartClientIP} + for i, want := range wantKeyParts { + if rule.KeyParts[i] != want { + t.Fatalf("rule.KeyParts[%d] = %q, want %q", i, rule.KeyParts[i], want) + } + } +} + +func TestConvertRateLimitConfigRejectsUnknownKeyPart(t *testing.T) { + _, err := convertRateLimitConfig(&rateLimitConfig{ + Operations: map[string]*rateLimitOperation{ + createOrderOperation: { + Rules: []*rateLimitRule{ + { + KeyParts: []string{"device_id"}, + Rate: 10, + Per: durationpb.New(time.Minute), + Burst: 20, + }, + }, + }, + }, + }) + if err == nil { + t.Fatal("convertRateLimitConfig error = nil, want error") + } +} diff --git a/kratos/ratelimit/inmem_store_test.go b/kratos/ratelimit/inmem_store_test.go new file mode 100644 index 0000000..b822fa3 --- /dev/null +++ b/kratos/ratelimit/inmem_store_test.go @@ -0,0 +1,126 @@ +package ratelimit + +import ( + "context" + "fmt" + "math" + "sync" + "time" +) + +// inMemStore is the token-bucket reference implementation shared by tests. +// It is the only Store fixture in the package; Take and TakeMany honor the +// atomicity guarantees in the Store interface doc. +type inMemStore struct { + mu sync.Mutex + buckets map[string]*inMemBucket +} + +type inMemBucket struct { + tokens float64 + seen time.Time +} + +func newInMemStore() *inMemStore { + return &inMemStore{buckets: make(map[string]*inMemBucket)} +} + +func (s *inMemStore) Take(ctx context.Context, key string, now time.Time, limit Limit, n int) (Result, error) { + results, err := s.TakeMany(ctx, []string{key}, now, []Limit{limit}, n) + if err != nil { + return Result{}, err + } + return results[0], nil +} + +func (s *inMemStore) TakeMany(_ context.Context, keys []string, now time.Time, limits []Limit, n int) ([]Result, error) { + if len(keys) != len(limits) { + return nil, fmt.Errorf("keys/limits length mismatch") + } + if n <= 0 { + results := make([]Result, len(keys)) + for i := range results { + results[i] = Result{Allowed: true} + } + return results, nil + } + for i, key := range keys { + if key == "" { + return nil, ErrMissingKey + } + if err := limits[i].Validate(); err != nil { + return nil, err + } + } + s.mu.Lock() + defer s.mu.Unlock() + + tokens := make([]float64, len(keys)) + buckets := make([]*inMemBucket, len(keys)) + for i, key := range keys { + b := s.bucketFor(key, now, limits[i].Burst) + refillInMemBucket(b, now, limits[i]) + buckets[i] = b + tokens[i] = b.tokens + } + + need := float64(n) + allow := true + for i := range buckets { + if need > float64(limits[i].Burst) || tokens[i] < need { + allow = false + break + } + } + + results := make([]Result, len(keys)) + for i, b := range buckets { + if allow { + b.tokens -= need + results[i] = Result{Allowed: true, Remaining: int(math.Floor(b.tokens))} + continue + } + var retry time.Duration + if need > float64(limits[i].Burst) { + retry = retryAfter(limits[i], need-float64(limits[i].Burst)) + } else if tokens[i] < need { + retry = retryAfter(limits[i], need-tokens[i]) + } + results[i] = Result{ + Remaining: int(math.Floor(tokens[i])), + RetryAfter: retry, + } + } + return results, nil +} + +func (s *inMemStore) bucketFor(key string, now time.Time, burst int) *inMemBucket { + if b, ok := s.buckets[key]; ok { + return b + } + b := &inMemBucket{ + tokens: float64(burst), + seen: now, + } + s.buckets[key] = b + return b +} + +func refillInMemBucket(b *inMemBucket, now time.Time, limit Limit) { + elapsed := now.Sub(b.seen) + if elapsed <= 0 { + return + } + per := time.Duration(inMemDurationMillis(limit.Per)) * time.Millisecond + rate := float64(limit.Rate) / per.Seconds() + b.tokens = math.Min(float64(limit.Burst), b.tokens+elapsed.Seconds()*rate) + b.seen = now +} + +func inMemDurationMillis(d time.Duration) int64 { + ms := d.Milliseconds() + if ms <= 0 { + return 1 + } + return ms +} diff --git a/kratos/ratelimit/limiter.go b/kratos/ratelimit/limiter.go new file mode 100644 index 0000000..1d0da13 --- /dev/null +++ b/kratos/ratelimit/limiter.go @@ -0,0 +1,133 @@ +package ratelimit + +import ( + "context" + "errors" + "math" + "time" +) + +var ( + // ErrInvalidConfig reports a limiter configuration that cannot be applied. + ErrInvalidConfig = errors.New("invalid ratelimit config") + // ErrMissingKey reports a request without a rate-limit key. + ErrMissingKey = errors.New("missing ratelimit key") + // ErrMissingStore reports a limiter constructed without a storage backend. + ErrMissingStore = errors.New("missing ratelimit store") +) + +// Config controls token-bucket behavior. It is also the limit type Store +// implementations consume; the Limit alias below preserves the name used on +// the Store interface. +type Config struct { + // Rate is the number of tokens replenished every Per duration. + Rate int + // Per is the refill window for Rate tokens. + Per time.Duration + // Burst is the maximum number of tokens a key can accumulate. + Burst int +} + +// Validate reports whether c carries usable token-bucket parameters. +func (c Config) Validate() error { + if c.Rate <= 0 || c.Per < time.Millisecond || c.Burst <= 0 { + return ErrInvalidConfig + } + return nil +} + +// Limit is the configuration carried into the Store. It is an alias of Config +// to keep the Store signature readable without introducing a second type. +type Limit = Config + +// Result describes the outcome of an Allow call. +type Result struct { + Allowed bool + Remaining int + RetryAfter time.Duration +} + +// Store persists token-bucket state. +// +// Implementations must be safe for concurrent use. Take must apply atomically +// to one key. TakeMany must atomically check every key and either consume n +// tokens from all of them or from none of them, returning one Result per key +// in input order. +type Store interface { + Take(ctx context.Context, key string, now time.Time, limit Limit, n int) (Result, error) + TakeMany(ctx context.Context, keys []string, now time.Time, limits []Limit, n int) ([]Result, error) +} + +// LimiterOption configures a Limiter. +type LimiterOption func(*Limiter) + +// WithNow sets the clock used by the limiter. +func WithNow(now func() time.Time) LimiterOption { + return func(l *Limiter) { + if now != nil { + l.now = now + } + } +} + +// WithStore sets the storage backend used by the limiter. +func WithStore(store Store) LimiterOption { + return func(l *Limiter) { + if store != nil { + l.store = store + } + } +} + +// Limiter applies token-bucket rate limits independently per key. +type Limiter struct { + limit Limit + store Store + now func() time.Time +} + +// New constructs a Limiter. +func New(cfg Config, opts ...LimiterOption) (*Limiter, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + l := &Limiter{ + limit: cfg, + now: time.Now, + } + for _, opt := range opts { + opt(l) + } + if l.store == nil { + return nil, ErrMissingStore + } + return l, nil +} + +// AllowContext consumes one token for key if capacity is available. +func (l *Limiter) AllowContext(ctx context.Context, key string) (Result, error) { + return l.AllowNContext(ctx, key, 1) +} + +// AllowNContext consumes n tokens for key if capacity is available. +func (l *Limiter) AllowNContext(ctx context.Context, key string, n int) (Result, error) { + if n <= 0 { + return Result{Allowed: true}, nil + } + if key == "" { + return Result{}, ErrMissingKey + } + if n > l.limit.Burst { + return Result{RetryAfter: retryAfter(l.limit, float64(n))}, nil + } + return l.store.Take(ctx, key, l.now(), l.limit, n) +} + +func retryAfter(limit Limit, tokens float64) time.Duration { + perToken := time.Duration(float64(limit.Per) / float64(limit.Rate)) + d := time.Duration(math.Ceil(float64(perToken) * tokens)) + if d < 0 { + return 0 + } + return d +} diff --git a/kratos/ratelimit/limiter_test.go b/kratos/ratelimit/limiter_test.go new file mode 100644 index 0000000..608f068 --- /dev/null +++ b/kratos/ratelimit/limiter_test.go @@ -0,0 +1,153 @@ +package ratelimit + +import ( + "context" + "errors" + "testing" + "time" +) + +type recordingStore struct { + key string + keys []string + n int + limit Limit +} + +func (s *recordingStore) Take(_ context.Context, key string, _ time.Time, limit Limit, n int) (Result, error) { + s.key = key + s.n = n + s.limit = limit + return Result{Allowed: true, Remaining: limit.Burst - n}, nil +} + +func (s *recordingStore) TakeMany(_ context.Context, keys []string, _ time.Time, limits []Limit, n int) ([]Result, error) { + s.keys = append([]string(nil), keys...) + s.n = n + results := make([]Result, len(keys)) + for i, limit := range limits { + results[i] = Result{Allowed: true, Remaining: limit.Burst - n} + } + return results, nil +} + +func TestLimiterAllowsBurstThenRejects(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New( + Config{Rate: 2, Per: time.Second, Burst: 2}, + WithStore(newInMemStore()), + WithNow(func() time.Time { return now }), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed || res.Remaining != 1 { + t.Fatalf("first request = %+v, want allowed with one remaining", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed || res.Remaining != 0 { + t.Fatalf("second request = %+v, want allowed with zero remaining", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || res.Allowed || res.RetryAfter != 500*time.Millisecond { + t.Fatalf("third request = %+v, want rejected with 500ms retry", res) + } +} + +func TestLimiterRefillsByElapsedTime(t *testing.T) { + now := time.Unix(0, 0) + limiter, err := New( + Config{Rate: 2, Per: time.Second, Burst: 2}, + WithStore(newInMemStore()), + WithNow(func() time.Time { return now }), + ) + if err != nil { + t.Fatalf("New: %v", err) + } + + if _, err := limiter.AllowContext(context.Background(), "user-1"); err != nil { + t.Fatalf("first request: %v", err) + } + if _, err := limiter.AllowContext(context.Background(), "user-1"); err != nil { + t.Fatalf("second request: %v", err) + } + now = now.Add(500 * time.Millisecond) + + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed || res.Remaining != 0 { + t.Fatalf("refilled request = %+v, want allowed with zero remaining", res) + } +} + +func TestLimiterSeparatesKeys(t *testing.T) { + limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}, WithStore(newInMemStore())) + if err != nil { + t.Fatalf("New: %v", err) + } + + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || !res.Allowed { + t.Fatalf("user-1 first request = %+v, want allowed", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-1"); err != nil || res.Allowed { + t.Fatalf("user-1 second request = %+v, want rejected", res) + } + if res, err := limiter.AllowContext(context.Background(), "user-2"); err != nil || !res.Allowed { + t.Fatalf("user-2 first request = %+v, want allowed", res) + } +} + +func TestLimiterRejectsInvalidConfig(t *testing.T) { + _, err := New(Config{Rate: 0, Per: time.Second, Burst: 1}, WithStore(newInMemStore())) + if !errors.Is(err, ErrInvalidConfig) { + t.Fatalf("New error = %v, want ErrInvalidConfig", err) + } + + _, err = New(Config{Rate: 1, Per: time.Nanosecond, Burst: 1}, WithStore(newInMemStore())) + if !errors.Is(err, ErrInvalidConfig) { + t.Fatalf("New sub-ms config error = %v, want ErrInvalidConfig", err) + } +} + +func TestLimiterRejectsMissingStore(t *testing.T) { + _, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}) + if !errors.Is(err, ErrMissingStore) { + t.Fatalf("New error = %v, want ErrMissingStore", err) + } +} + +func TestLimiterRejectsMissingKey(t *testing.T) { + limiter, err := New(Config{Rate: 1, Per: time.Second, Burst: 1}, WithStore(newInMemStore())) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = limiter.AllowContext(context.Background(), "") + if !errors.Is(err, ErrMissingKey) { + t.Fatalf("AllowContext error = %v, want ErrMissingKey", err) + } + + _, err = limiter.AllowNContext(context.Background(), "", 2) + if !errors.Is(err, ErrMissingKey) { + t.Fatalf("AllowNContext over burst error = %v, want ErrMissingKey", err) + } +} + +func TestLimiterUsesStore(t *testing.T) { + store := &recordingStore{} + limiter, err := New(Config{Rate: 5, Per: time.Second, Burst: 10}, WithStore(store)) + if err != nil { + t.Fatalf("New: %v", err) + } + + res, err := limiter.AllowNContext(context.Background(), "tenant-1", 3) + if err != nil { + t.Fatalf("AllowNContext: %v", err) + } + if !res.Allowed || res.Remaining != 7 { + t.Fatalf("result = %+v, want allowed with 7 remaining", res) + } + if store.key != "tenant-1" || store.n != 3 { + t.Fatalf("store saw key=%q n=%d, want tenant-1 and 3", store.key, store.n) + } + if store.limit.Rate != 5 || store.limit.Per != time.Second || store.limit.Burst != 10 { + t.Fatalf("store limit = %+v, want configured limit", store.limit) + } +} diff --git a/kratos/ratelimit/policy.go b/kratos/ratelimit/policy.go new file mode 100644 index 0000000..44146c8 --- /dev/null +++ b/kratos/ratelimit/policy.go @@ -0,0 +1,240 @@ +package ratelimit + +import ( + "context" + "fmt" + "strings" + "time" +) + +// KeyPart identifies one business dimension used to build a rate-limit key. +type KeyPart string + +const ( + KeyPartClientIP KeyPart = "client_ip" + KeyPartUserID KeyPart = "user_id" +) + +// ParseKeyPart converts a canonical key-part name into a KeyPart. Callers that +// load rules from external config can use it to avoid re-implementing the +// allow-list of supported dimensions. +func ParseKeyPart(s string) (KeyPart, error) { + switch KeyPart(s) { + case KeyPartClientIP, KeyPartUserID: + return KeyPart(s), nil + default: + return "", fmt.Errorf("unsupported key part %q", s) + } +} + +// RuleConfig provides the runtime limit for one operation rule. +type RuleConfig struct { + Config Config + KeyParts []KeyPart +} + +// OperationRules maps Kratos operation names to their runtime limit rules. +type OperationRules map[string][]RuleConfig + +// OperationPolicy selects rate-limit behavior for Kratos operations. +type OperationPolicy struct { + operations map[string][]operationLimit + store Store + now func() time.Time + clientIPKeyFunc KeyFunc + userKeyFunc KeyFunc +} + +type operationLimit struct { + keyFunc KeyFunc + limit Limit +} + +// OperationPolicyOption configures an OperationPolicy. +type OperationPolicyOption func(*OperationPolicy) + +// WithPolicyUserKeyFunc sets how operation policies extract the business user id. +func WithPolicyUserKeyFunc(fn KeyFunc) OperationPolicyOption { + return func(p *OperationPolicy) { + p.userKeyFunc = fn + } +} + +// WithPolicyClientIPKeyFunc sets how operation policies extract the client IP. +func WithPolicyClientIPKeyFunc(fn KeyFunc) OperationPolicyOption { + return func(p *OperationPolicy) { + p.clientIPKeyFunc = fn + } +} + +// WithPolicyNow overrides the clock used to stamp Store calls. Mostly useful in tests. +func WithPolicyNow(now func() time.Time) OperationPolicyOption { + return func(p *OperationPolicy) { + if now != nil { + p.now = now + } + } +} + +// NewOperationPolicy constructs a policy from operation rules. All construction +// errors are returned eagerly; the resulting policy never silently disables +// limiting at runtime. +func NewOperationPolicy( + store Store, + rules OperationRules, + opts ...OperationPolicyOption, +) (*OperationPolicy, error) { + if store == nil { + return nil, ErrMissingStore + } + if len(rules) == 0 { + return nil, ErrMissingRules + } + p := &OperationPolicy{ + operations: make(map[string][]operationLimit, len(rules)), + store: store, + now: time.Now, + } + for _, opt := range opts { + opt(p) + } + for operation, configs := range rules { + if operation == "" { + return nil, fmt.Errorf("ratelimit operation must not be empty") + } + if len(configs) == 0 { + return nil, fmt.Errorf("%s: ratelimit rules must not be empty", operation) + } + seen := make(map[string]struct{}, len(configs)) + for _, cfg := range configs { + sig := keyPartsSignature(cfg.KeyParts) + label := ruleLabel(operation, sig) + if err := validateKeyParts(cfg.KeyParts); err != nil { + return nil, fmt.Errorf("%s: %w", label, err) + } + if err := cfg.Config.Validate(); err != nil { + return nil, fmt.Errorf("%s: %w", label, err) + } + if _, dup := seen[sig]; dup { + return nil, fmt.Errorf("%s: duplicate ratelimit rule %s", operation, sig) + } + seen[sig] = struct{}{} + + keyFunc, err := p.keyFuncFromParts(operation, cfg.KeyParts) + if err != nil { + return nil, fmt.Errorf("%s: %w", operation, err) + } + p.operations[operation] = append(p.operations[operation], operationLimit{ + keyFunc: keyFunc, + limit: cfg.Config, + }) + } + } + return p, nil +} + +// allow runs every rule for operation in one atomic Store call and returns the +// per-rule Results. It returns (nil, nil) when no rule matches. +func (p *OperationPolicy) allow(ctx context.Context, operation string, req any) ([]Result, error) { + if p == nil { + return nil, nil + } + rules := p.operations[operation] + if len(rules) == 0 { + return nil, nil + } + keys := make([]string, len(rules)) + limits := make([]Limit, len(rules)) + for i, rule := range rules { + key := rule.keyFunc(ctx, req) + if key == "" { + return nil, ErrMissingKey + } + keys[i] = key + limits[i] = rule.limit + } + return p.store.TakeMany(ctx, keys, p.now(), limits, 1) +} + +func (p *OperationPolicy) validate() error { + // NewOperationPolicy already performs full validation. This method catches + // zero-value policies passed through WithOperationPolicy. + if p == nil || p.store == nil || len(p.operations) == 0 { + return ErrMissingRules + } + return nil +} + +func (p *OperationPolicy) keyFuncFromParts(operation string, parts []KeyPart) (KeyFunc, error) { + fns := make([]KeyFunc, 0, len(parts)) + for _, part := range parts { + fn, err := p.keyFuncForPart(part) + if err != nil { + return nil, fmt.Errorf("%s: %w", keyPartsSignature(parts), err) + } + fns = append(fns, namedKeyPart(part, fn)) + } + return operationScopedKey(operation, CompositeKey(fns...)), nil +} + +func (p *OperationPolicy) keyFuncForPart(part KeyPart) (KeyFunc, error) { + switch part { + case KeyPartClientIP: + if p.clientIPKeyFunc == nil { + return nil, fmt.Errorf("client IP key function is required") + } + return p.clientIPKeyFunc, nil + case KeyPartUserID: + if p.userKeyFunc == nil { + return nil, fmt.Errorf("user key function is required") + } + return p.userKeyFunc, nil + default: + return nil, fmt.Errorf("unsupported key part %s", part) + } +} + +func namedKeyPart(part KeyPart, fn KeyFunc) KeyFunc { + return func(ctx context.Context, req any) string { + value := fn(ctx, req) + if value == "" { + return "" + } + return string(part) + ":" + escapeKeyPartValue(value) + } +} + +var keyPartEscaper = strings.NewReplacer("%", "%25", ":", "%3A") + +func escapeKeyPartValue(value string) string { + return keyPartEscaper.Replace(value) +} + +func validateKeyParts(parts []KeyPart) error { + if len(parts) == 0 { + return fmt.Errorf("key_parts must not be empty") + } + for _, part := range parts { + switch part { + case KeyPartClientIP, KeyPartUserID: + default: + return fmt.Errorf("unsupported key part %s", part) + } + } + return nil +} + +func keyPartsSignature(parts []KeyPart) string { + names := make([]string, 0, len(parts)) + for _, part := range parts { + names = append(names, string(part)) + } + return strings.Join(names, "+") +} + +func ruleLabel(operation, sig string) string { + if sig == "" { + return operation + } + return operation + " " + sig +} diff --git a/kratos/ratelimit/ratelimit.go b/kratos/ratelimit/ratelimit.go new file mode 100644 index 0000000..d9b1849 --- /dev/null +++ b/kratos/ratelimit/ratelimit.go @@ -0,0 +1,228 @@ +// Package ratelimit provides rate-limit middleware for Kratos services. +package ratelimit + +import ( + "context" + "errors" + "strconv" + "strings" + + "github.com/crypto-zero/go-kit/kratos/clientip" + kratoserrors "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" +) + +// Reason is the Kratos error reason emitted when a request is rejected. +const Reason = "RATELIMIT" + +// Metadata keys carried on the Kratos error returned by Server. +const ( + MetadataRemaining = "remaining" + MetadataRetryAfter = "retry_after" +) + +// ErrLimitExceed is returned when a request exceeds its rate limit. +var ErrLimitExceed = kratoserrors.New(429, Reason, "service unavailable due to rate limit exceeded") + +// ErrStoreUnavailable is returned when the backing store cannot evaluate a +// limit decision. +var ErrStoreUnavailable = kratoserrors.New(503, "RATELIMIT_UNAVAILABLE", "service unavailable due to rate limit store unavailable") + +// ErrPolicyConflict reports that incompatible options were combined — most +// commonly passing both WithOperationPolicy and one of WithOperationRules, +// WithRuleStore, WithUserKeyFunc, or WithClientIPKeyFunc. +var ErrPolicyConflict = errors.New("ratelimit: WithOperationPolicy conflicts with rule-building options") + +// ErrMissingRules reports rule-building options that do not include any +// operation rules. +var ErrMissingRules = errors.New("ratelimit: missing operation rules") + +// KeyFunc derives a rate-limit key from a request. +type KeyFunc func(context.Context, any) string + +// Option configures Server. +type Option func(*options) + +type options struct { + err *kratoserrors.Error + policy *OperationPolicy + store Store + rules OperationRules + clientIPKeyFunc KeyFunc + userKeyFunc KeyFunc + storeSet bool + rulesSet bool + clientIPSet bool + userSet bool +} + +// WithRuleStore sets the storage backend used to build operation rules. +func WithRuleStore(store Store) Option { + return func(o *options) { + o.store = store + o.storeSet = true + } +} + +// WithError sets the error returned when a request is rejected. +func WithError(err *kratoserrors.Error) Option { + return func(o *options) { + if err != nil { + o.err = err + } + } +} + +// WithOperationPolicy installs a pre-built policy. Mutually exclusive with +// WithOperationRules / WithRuleStore / WithUserKeyFunc / WithClientIPKeyFunc. +func WithOperationPolicy(policy *OperationPolicy) Option { + return func(o *options) { o.policy = policy } +} + +// WithOperationRules sets per-operation rate-limit rules from external config. +func WithOperationRules(rules OperationRules) Option { + return func(o *options) { + o.rules = rules + o.rulesSet = true + } +} + +// WithUserKeyFunc sets how user_id key parts are extracted from requests. +func WithUserKeyFunc(fn KeyFunc) Option { + return func(o *options) { + o.userKeyFunc = fn + o.userSet = true + } +} + +// WithClientIPKeyFunc sets how client_ip key parts are extracted from requests. +func WithClientIPKeyFunc(fn KeyFunc) Option { + return func(o *options) { + o.clientIPKeyFunc = fn + o.clientIPSet = true + } +} + +// Server returns a Kratos server middleware that enforces rate limits. +// +// Construction errors are returned eagerly so callers fail-fast at startup +// instead of crashing on the first request. HTTP handlers must set a Kratos +// operation with http.SetOperation; requests without an operation are treated +// as unconfigured and are not limited. +func Server(opts ...Option) (middleware.Middleware, error) { + o := &options{err: ErrLimitExceed} + for _, opt := range opts { + opt(o) + } + if o.policy != nil { + if o.rulesSet || o.storeSet || o.userSet || o.clientIPSet { + return nil, ErrPolicyConflict + } + if err := o.policy.validate(); err != nil { + return nil, err + } + } else if o.rulesSet { + if len(o.rules) == 0 { + return nil, ErrMissingRules + } + policy, err := NewOperationPolicy(o.store, o.rules, + WithPolicyClientIPKeyFunc(o.clientIPKeyFunc), + WithPolicyUserKeyFunc(o.userKeyFunc), + ) + if err != nil { + return nil, err + } + o.policy = policy + } else { + return nil, ErrMissingRules + } + policy := o.policy + errResp := o.err + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + results, err := policy.allow(ctx, OperationKey(ctx, req), req) + if err != nil { + if errors.Is(err, ErrMissingKey) { + return nil, errResp.WithMetadata(map[string]string{ + MetadataRemaining: "0", + }).WithCause(err) + } + return nil, ErrStoreUnavailable.WithCause(err) + } + if rejected, ok := rejectedResult(results); ok { + return nil, errResp.WithMetadata(retryMetadata(rejected)) + } + return handler(ctx, req) + } + }, nil +} + +func rejectedResult(results []Result) (Result, bool) { + var rejected Result + var ok bool + for _, res := range results { + if !res.Allowed && (!ok || res.RetryAfter > rejected.RetryAfter) { + rejected = res + ok = true + } + } + return rejected, ok +} + +// OperationKey returns the Kratos operation from the server context. +func OperationKey(ctx context.Context, _ any) string { + if tr, ok := transport.FromServerContext(ctx); ok { + return tr.Operation() + } + return "" +} + +// ClientIPKey returns the client IP from the server context. +func ClientIPKey(ctx context.Context, _ any) string { + return clientip.FromContext(ctx) +} + +// CompositeKey joins multiple key functions into one. If any underlying +// function returns the empty string, the composite returns the empty string — +// callers can distinguish "key fully derived" from "at least one dimension +// missing" without silently collapsing into a narrower bucket. +func CompositeKey(fns ...KeyFunc) KeyFunc { + return func(ctx context.Context, req any) string { + parts := make([]string, 0, len(fns)) + for _, fn := range fns { + if fn == nil { + continue + } + part := fn(ctx, req) + if part == "" { + return "" + } + parts = append(parts, part) + } + return strings.Join(parts, ":") + } +} + +func operationScopedKey(operation string, fn KeyFunc) KeyFunc { + return func(ctx context.Context, req any) string { + if fn == nil { + return operation + } + key := fn(ctx, req) + if key == "" { + return "" + } + return operation + ":" + key + } +} + +func retryMetadata(res Result) map[string]string { + md := map[string]string{ + MetadataRemaining: strconv.Itoa(res.Remaining), + } + if res.RetryAfter > 0 { + md[MetadataRetryAfter] = strconv.FormatFloat(res.RetryAfter.Seconds(), 'f', 3, 64) + } + return md +} diff --git a/kratos/ratelimit/ratelimit_test.go b/kratos/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..b8f85d1 --- /dev/null +++ b/kratos/ratelimit/ratelimit_test.go @@ -0,0 +1,466 @@ +package ratelimit + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + kratoserrors "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/transport" + "google.golang.org/grpc/metadata" +) + +type mockTransport struct { + operation string +} + +func (m *mockTransport) Kind() transport.Kind { return transport.KindHTTP } +func (m *mockTransport) Endpoint() string { return "localhost:8000" } +func (m *mockTransport) Operation() string { return m.operation } +func (m *mockTransport) RequestHeader() transport.Header { return &mockHeader{} } +func (m *mockTransport) ReplyHeader() transport.Header { return &mockHeader{} } + +type mockHeader struct{} + +func (m *mockHeader) Get(string) string { return "" } +func (m *mockHeader) Set(string, string) {} +func (m *mockHeader) Add(string, string) {} +func (m *mockHeader) Keys() []string { return nil } +func (m *mockHeader) Values(string) []string { return nil } + +type errorStore struct { + err error +} + +func (s errorStore) Take(context.Context, string, time.Time, Limit, int) (Result, error) { + return Result{}, s.err +} + +func (s errorStore) TakeMany(context.Context, []string, time.Time, []Limit, int) ([]Result, error) { + return nil, s.err +} + +func mustServer(t *testing.T, opts ...Option) func(context.Context, any) (any, error) { + t.Helper() + mw, err := Server(opts...) + if err != nil { + t.Fatalf("Server: %v", err) + } + return mw(func(context.Context, any) (any, error) { return "ok", nil }) +} + +func TestServerRejectsWhenLimitExceeded(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/svc/Test", "192.168.1.10") + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + _, err := wrapped(ctx, nil) + if !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed", err) + } + se := kratoserrors.FromError(err) + if se.Code != 429 || se.Reason != Reason || se.Metadata[MetadataRetryAfter] == "" { + t.Fatalf("kratos error = %+v, want 429 RATELIMIT with retry_after", se) + } +} + +func TestServerRejectsMissingRules(t *testing.T) { + if _, err := Server(); !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules", err) + } +} + +func TestServerRejectsMissingKey(t *testing.T) { + wrapped := mustServer(t, + WithRuleStore(newInMemStore()), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/Test"}) + + _, err := wrapped(ctx, nil) + if !errors.Is(err, ErrMissingKey) || !errors.Is(err, ErrLimitExceed) { + t.Fatalf("request error = %v, want ErrMissingKey cause on ErrLimitExceed", err) + } + se := kratoserrors.FromError(err) + if se.Code != 429 || se.Reason != Reason { + t.Fatalf("kratos error = %+v, want 429 RATELIMIT", se) + } +} + +func TestServerMapsStoreError(t *testing.T) { + storeErr := errors.New("redis unavailable") + wrapped := mustServer(t, + WithRuleStore(errorStore{err: storeErr}), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/svc/Test", "192.168.1.10") + + _, err := wrapped(ctx, nil) + if !errors.Is(err, ErrStoreUnavailable) || !errors.Is(err, storeErr) { + t.Fatalf("request error = %v, want ErrStoreUnavailable with store cause", err) + } + se := kratoserrors.FromError(err) + if se.Code != 503 || se.Reason != "RATELIMIT_UNAVAILABLE" { + t.Fatalf("kratos error = %+v, want 503 RATELIMIT_UNAVAILABLE", se) + } +} + +func TestServerMultiPartKeyMissingDimensionRejects(t *testing.T) { + // With KeyParts=[user_id, client_ip], a missing user_id must not silently + // degrade into a client_ip-only bucket — the composite key must be empty + // and the request must be rejected with ErrMissingKey. + wrapped := mustServer(t, + WithRuleStore(newInMemStore()), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID, KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + WithUserKeyFunc(func(context.Context, any) string { return "" }), + ) + ctx := clientIPContext("/svc/Test", "192.168.1.10") + + if _, err := wrapped(ctx, nil); !errors.Is(err, ErrMissingKey) { + t.Fatalf("request error = %v, want ErrMissingKey", err) + } +} + +func TestOperationPolicyEscapesKeyPartValues(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/svc/Test": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }}, + }), + WithUserKeyFunc(func(context.Context, any) string { return "u%1:admin" }), + ) + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/svc/Test"}) + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("request error = %v, want nil", err) + } + if _, ok := store.buckets["/svc/Test:user_id:u%251%3Aadmin"]; !ok { + t.Fatalf("store buckets = %#v, want escaped key part value", store.buckets) + } +} + +func TestCompositeKey(t *testing.T) { + key := CompositeKey( + func(context.Context, any) string { return "/svc/A" }, + func(context.Context, any) string { return "127.0.0.1" }, + )(context.Background(), nil) + + if key != "/svc/A:127.0.0.1" { + t.Fatalf("CompositeKey = %q, want joined key", key) + } +} + +func TestCompositeKeyReturnsEmptyWhenAnyPartIsEmpty(t *testing.T) { + key := CompositeKey( + func(context.Context, any) string { return "/svc/A" }, + func(context.Context, any) string { return "" }, + )(context.Background(), nil) + + if key != "" { + t.Fatalf("CompositeKey = %q, want empty when any part is empty", key) + } +} + +func TestServerUsesOperationRules(t *testing.T) { + store := newInMemStore() + policy, err := NewOperationPolicy( + store, + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }, + WithPolicyClientIPKeyFunc(ClientIPKey), + ) + if err != nil { + t.Fatalf("NewOperationPolicy: %v", err) + } + wrapped := mustServer(t, WithOperationPolicy(policy)) + fastCtx := clientIPContext("/test.limit.v1.LimitService/Fast", "192.168.1.10") + slowCtx := transport.NewServerContext(context.Background(), &mockTransport{operation: "/test.limit.v1.LimitService/Slow"}) + + if _, err := wrapped(fastCtx, nil); err != nil { + t.Fatalf("fast first request error = %v, want nil", err) + } + if _, err := wrapped(fastCtx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("fast second request error = %v, want ErrLimitExceed from operation policy", err) + } + for i := range 3 { + if _, err := wrapped(slowCtx, nil); err != nil { + t.Fatalf("slow request %d error = %v, want default limiter", i+1, err) + } + } +} + +func TestOperationPolicyRejectsMissingKeyParts(t *testing.T) { + _, err := NewOperationPolicy( + newInMemStore(), + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: nil, + }}, + }, + ) + if err == nil { + t.Fatal("NewOperationPolicy error = nil, want error") + } +} + +func TestServerUsesOperationRulesOption(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/test.limit.v1.LimitService/Fast", "192.168.1.10") + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + if _, err := wrapped(ctx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed from operation rule", err) + } +} + +func TestServerAppliesMultipleOperationRules(t *testing.T) { + store := newInMemStore() + wrapped := mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/test.limit.v1.LimitService/Fast": { + { + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }, + { + Config: Config{Rate: 100, Per: time.Second, Burst: 100}, + KeyParts: []KeyPart{KeyPartClientIP}, + }, + }, + }), + WithUserKeyFunc(func(context.Context, any) string { return "user-1" }), + WithClientIPKeyFunc(ClientIPKey), + ) + ctx := clientIPContext("/test.limit.v1.LimitService/Fast", "192.168.1.10") + + if _, err := wrapped(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + if _, err := wrapped(ctx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("second request error = %v, want ErrLimitExceed from user rule", err) + } +} + +// Multi-rule consumption must be atomic. Rule[0] is a tight per-user limit; +// rule[1] is a generous per-IP limit. Once the user rule starts rejecting, +// the IP rule must NOT have been silently decremented by the rejected +// attempts — otherwise a different user from the same IP would be wrongly +// throttled. +func TestServerMultiRuleConsumptionIsAtomic(t *testing.T) { + store := newInMemStore() + buildServer := func(userID string) func(context.Context, any) (any, error) { + return mustServer(t, + WithRuleStore(store), + WithOperationRules(OperationRules{ + "/svc/Multi": { + { + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }, + { + Config: Config{Rate: 100, Per: time.Second, Burst: 100}, + KeyParts: []KeyPart{KeyPartClientIP}, + }, + }, + }), + WithUserKeyFunc(func(context.Context, any) string { return userID }), + WithClientIPKeyFunc(ClientIPKey), + ) + } + + user1 := buildServer("user-1") + ctx := clientIPContext("/svc/Multi", "192.168.1.10") + + if _, err := user1(ctx, nil); err != nil { + t.Fatalf("first request error = %v, want nil", err) + } + for i := range 200 { + if _, err := user1(ctx, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("rejected request %d error = %v, want ErrLimitExceed", i+2, err) + } + } + + for i := range 99 { + user := buildServer(fmt.Sprintf("user-%d", i+2)) + if _, err := user(ctx, nil); err != nil { + t.Fatalf("user-2 request %d error = %v, want IP bucket still has 99 tokens", i+1, err) + } + } +} + +func TestOperationPolicyRequiresUserKeyFunc(t *testing.T) { + _, err := NewOperationPolicy( + newInMemStore(), + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartUserID}, + }}, + }, + ) + if err == nil { + t.Fatal("NewOperationPolicy error = nil, want missing user key func error") + } +} + +func TestOperationPolicyRequiresClientIPKeyFunc(t *testing.T) { + _, err := NewOperationPolicy( + newInMemStore(), + OperationRules{ + "/test.limit.v1.LimitService/Fast": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + }, + ) + if err == nil { + t.Fatal("NewOperationPolicy error = nil, want missing client IP key func error") + } +} + +func TestOperationPolicyRequiresStore(t *testing.T) { + _, err := NewOperationPolicy(nil, OperationRules{ + "/svc/A": {{Config: Config{Rate: 1, Per: time.Second, Burst: 1}, KeyParts: []KeyPart{KeyPartClientIP}}}, + }) + if !errors.Is(err, ErrMissingStore) { + t.Fatalf("NewOperationPolicy error = %v, want ErrMissingStore", err) + } +} + +func TestOperationPolicyRejectsMissingRules(t *testing.T) { + _, err := NewOperationPolicy(newInMemStore(), nil) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("NewOperationPolicy error = %v, want ErrMissingRules", err) + } + + _, err = NewOperationPolicy(newInMemStore(), OperationRules{}) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("NewOperationPolicy error = %v, want ErrMissingRules for empty rules", err) + } +} + +func TestServerRejectsConflictingOptions(t *testing.T) { + policy, err := NewOperationPolicy(newInMemStore(), OperationRules{ + "/svc/A": {{Config: Config{Rate: 1, Per: time.Second, Burst: 1}, KeyParts: []KeyPart{KeyPartClientIP}}}, + }, WithPolicyClientIPKeyFunc(ClientIPKey)) + if err != nil { + t.Fatalf("NewOperationPolicy: %v", err) + } + + if _, err := Server(WithOperationPolicy(policy), WithClientIPKeyFunc(ClientIPKey)); !errors.Is(err, ErrPolicyConflict) { + t.Fatalf("Server error = %v, want ErrPolicyConflict when both policy and key func given", err) + } +} + +func TestServerRejectsInvalidOperationPolicy(t *testing.T) { + if _, err := Server(WithOperationPolicy(&OperationPolicy{})); !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules for zero-value policy", err) + } +} + +func TestServerRejectsRuleBuildingOptionsWithoutRules(t *testing.T) { + _, err := Server(WithRuleStore(newInMemStore())) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules", err) + } + + _, err = Server(WithRuleStore(newInMemStore()), WithOperationRules(nil)) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules for nil rules", err) + } + + _, err = Server(WithRuleStore(newInMemStore()), WithOperationRules(OperationRules{})) + if !errors.Is(err, ErrMissingRules) { + t.Fatalf("Server error = %v, want ErrMissingRules for empty rules", err) + } +} + +func TestOperationPolicyScopesClientIPByOperation(t *testing.T) { + store := newInMemStore() + rules := OperationRules{ + "/svc/A": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + "/svc/B": {{ + Config: Config{Rate: 1, Per: time.Second, Burst: 1}, + KeyParts: []KeyPart{KeyPartClientIP}, + }}, + } + policy, err := NewOperationPolicy(store, rules, WithPolicyClientIPKeyFunc(ClientIPKey)) + if err != nil { + t.Fatalf("NewOperationPolicy: %v", err) + } + wrapped := mustServer(t, WithOperationPolicy(policy)) + ctxA := clientIPContext("/svc/A", "192.168.1.10") + ctxB := clientIPContext("/svc/B", "192.168.1.10") + + if _, err := wrapped(ctxA, nil); err != nil { + t.Fatalf("operation A first request error = %v, want nil", err) + } + if _, err := wrapped(ctxA, nil); !errors.Is(err, ErrLimitExceed) { + t.Fatalf("operation A second request error = %v, want ErrLimitExceed", err) + } + if _, err := wrapped(ctxB, nil); err != nil { + t.Fatalf("operation B first request error = %v, want nil with operation-scoped client IP key", err) + } +} + +func clientIPContext(operation, ip string) context.Context { + ctx := transport.NewServerContext(context.Background(), &mockTransport{operation: operation}) + return metadata.NewIncomingContext(ctx, metadata.Pairs("x-real-ip", ip)) +} diff --git a/kratos/ratelimit/redis/store.go b/kratos/ratelimit/redis/store.go new file mode 100644 index 0000000..8f403cd --- /dev/null +++ b/kratos/ratelimit/redis/store.go @@ -0,0 +1,332 @@ +// Package redis provides a Redis-backed ratelimit store. +package redis + +import ( + "context" + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + goredis "github.com/redis/go-redis/v9" +) + +const ( + minBucketTTLMillis = 1000 + scriptArgsPerKey = 4 + scriptResultWidth = 3 +) + +var ( + // ErrMissingClient reports a store constructed without a Redis client. + ErrMissingClient = errors.New("redis ratelimit: missing client") + // ErrMissingPrefix reports a store constructed without an explicit Redis key prefix. + ErrMissingPrefix = errors.New("redis ratelimit: missing prefix") + // ErrInvalidScriptResult reports an unexpected Lua script return value. + ErrInvalidScriptResult = errors.New("redis ratelimit: invalid script result") + // ErrKeyLimitMismatch reports a TakeMany call with mismatched keys/limits. + ErrKeyLimitMismatch = errors.New("redis ratelimit: keys and limits length mismatch") + // ErrKeyGroupMismatch reports keys that cannot be evaluated in one Redis slot. + ErrKeyGroupMismatch = errors.New("redis ratelimit: keys must share a group") +) + +// Store persists rate-limit buckets in Redis. +// +// The script uses Redis-side TIME, so it tolerates client clock skew across +// multiple Kratos instances. It performs a two-phase check-and-commit across +// every key so multi-rule operations consume tokens atomically or not at all. +type Store struct { + client goredis.Scripter + prefix string + script *goredis.Script +} + +// NewStore constructs a Redis-backed store. +func NewStore(client goredis.UniversalClient, prefix string) (*Store, error) { + return newScriptStore(client, prefix) +} + +func newScriptStore(client goredis.Scripter, prefix string) (*Store, error) { + if client == nil { + return nil, ErrMissingClient + } + prefix = strings.TrimSpace(prefix) + if prefix == "" { + return nil, ErrMissingPrefix + } + return &Store{ + client: client, + prefix: prefix, + script: goredis.NewScript(takeScript), + }, nil +} + +// Take consumes n tokens from key if capacity is available. +func (s *Store) Take(ctx context.Context, key string, now time.Time, limit ratelimit.Limit, n int) (ratelimit.Result, error) { + results, err := s.TakeMany(ctx, []string{key}, now, []ratelimit.Limit{limit}, n) + if err != nil { + return ratelimit.Result{}, err + } + return results[0], nil +} + +// TakeMany consumes n tokens from every key atomically: either all keys +// commit, or none do. Order of returned Results matches input order. +// +// The now argument is ignored — the script uses Redis-side TIME. +func (s *Store) TakeMany(ctx context.Context, keys []string, _ time.Time, limits []ratelimit.Limit, n int) ([]ratelimit.Result, error) { + if err := s.validateTake(keys, limits, n); err != nil { + return nil, err + } + if len(keys) == 0 { + return nil, nil + } + if n <= 0 { + return allowedResults(len(keys)), nil + } + + redisKeys, args, err := s.buildScriptCall(keys, limits, n) + if err != nil { + return nil, err + } + values, err := s.script.Run(ctx, s.client, redisKeys, args...).Slice() + if err != nil { + return nil, err + } + return parseResults(values, len(keys)) +} + +func (s *Store) validateTake(keys []string, limits []ratelimit.Limit, n int) error { + if s == nil || s.client == nil { + return ErrMissingClient + } + if s.prefix == "" { + return ErrMissingPrefix + } + if len(keys) != len(limits) { + return ErrKeyLimitMismatch + } + if n <= 0 { + return nil + } + for i, key := range keys { + if key == "" { + return ratelimit.ErrMissingKey + } + if err := limits[i].Validate(); err != nil { + return err + } + } + return nil +} + +func allowedResults(n int) []ratelimit.Result { + results := make([]ratelimit.Result, n) + for i := range results { + results[i] = ratelimit.Result{Allowed: true} + } + return results +} + +func (s *Store) buildScriptCall(keys []string, limits []ratelimit.Limit, n int) ([]string, []any, error) { + slotTag, err := sharedSlotTag(keys) + if err != nil { + return nil, nil, err + } + redisKeys := make([]string, len(keys)) + args := make([]any, 0, 1+len(limits)*scriptArgsPerKey) + args = append(args, n) + for i, key := range keys { + redisKeys[i] = s.redisKey(key, slotTag) + args = appendLimitArgs(args, limits[i]) + } + return redisKeys, args, nil +} + +func sharedSlotTag(keys []string) (string, error) { + group := keyGroup(keys[0]) + for _, key := range keys[1:] { + if keyGroup(key) != group { + return "", ErrKeyGroupMismatch + } + } + return sanitizeHashTag(group), nil +} + +func (s *Store) redisKey(key, slotTag string) string { + return "{" + slotTag + "}:" + s.prefix + ":" + key +} + +func keyGroup(key string) string { + group, _, ok := strings.Cut(key, ":") + if ok && group != "" { + return group + } + return key +} + +func sanitizeHashTag(tag string) string { + return strings.NewReplacer("{", "_", "}", "_").Replace(tag) +} + +func appendLimitArgs(args []any, limit ratelimit.Limit) []any { + return append(args, + limit.Rate, + durationMillis(limit.Per), + limit.Burst, + ttl(limit), + ) +} + +func parseResults(values []any, want int) ([]ratelimit.Result, error) { + if len(values) != want*scriptResultWidth { + return nil, fmt.Errorf("%w: got %d values, want %d", ErrInvalidScriptResult, len(values), want*scriptResultWidth) + } + results := make([]ratelimit.Result, want) + for i := range want { + res, err := parseResult(values[i*scriptResultWidth : (i+1)*scriptResultWidth]) + if err != nil { + return nil, err + } + results[i] = res + } + return results, nil +} + +func parseResult(values []any) (ratelimit.Result, error) { + allowed, err := int64Value(values[0]) + if err != nil { + return ratelimit.Result{}, err + } + remaining, err := int64Value(values[1]) + if err != nil { + return ratelimit.Result{}, err + } + retryMillis, err := int64Value(values[2]) + if err != nil { + return ratelimit.Result{}, err + } + return ratelimit.Result{ + Allowed: allowed == 1, + Remaining: int(remaining), + RetryAfter: time.Duration(retryMillis) * time.Millisecond, + }, nil +} + +func int64Value(v any) (int64, error) { + switch v := v.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case string: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, fmt.Errorf("%w: %q", ErrInvalidScriptResult, v) + } + return n, nil + default: + return 0, fmt.Errorf("%w: %T", ErrInvalidScriptResult, v) + } +} + +// ttl returns the bucket's PEXPIRE in milliseconds. The bucket must outlive a +// full refill from empty so idle keys don't reset to burst and bypass the +// in-progress retry-after. A 1s floor avoids degenerate sub-second TTLs. +func ttl(limit ratelimit.Limit) int64 { + perMs := durationMillis(limit.Per) + refillFullMs := int64(math.Ceil(float64(limit.Burst) * float64(perMs) / float64(limit.Rate))) + return max(max(2*perMs, refillFullMs), minBucketTTLMillis) +} + +func durationMillis(d time.Duration) int64 { + ms := d.Milliseconds() + if ms <= 0 { + return 1 + } + return ms +} + +// takeScript drives one two-phase token consumption across N keys. +// +// ARGV layout: ARGV[1] = n; then 4 args per key starting at ARGV[2]: +// rate, per_ms, burst, ttl_ms. KEYS[i] pairs with ARGV[2+(i-1)*4 .. +3]. +// +// Phase 1 refills every bucket from its stored state and snapshots the tokens. +// Phase 2 checks whether every key has capacity. Phase 3 commits the refilled +// state (and the n-token decrement only when every key passed), then emits +// three values per key: allowed (0|1), floor(remaining_tokens), retry_after_ms. +const takeScript = ` +local n = tonumber(ARGV[1]) +local count = #KEYS + +local t = redis.call('TIME') +local now = t[1] * 1000 + math.floor(t[2] / 1000) + +local snap = {} +for i = 1, count do + local base = 2 + (i - 1) * 4 + local rate = tonumber(ARGV[base]) + local per = tonumber(ARGV[base + 1]) + local burst = tonumber(ARGV[base + 2]) + local ttl = tonumber(ARGV[base + 3]) + + local bucket = redis.call("HMGET", KEYS[i], "tokens", "seen") + local tokens = tonumber(bucket[1]) + local seen = tonumber(bucket[2]) + if tokens == nil or seen == nil then + tokens = burst + seen = now + else + local elapsed = now - seen + if elapsed > 0 then + tokens = math.min(burst, tokens + (elapsed * rate / per)) + seen = now + end + end + + snap[i] = {tokens = tokens, seen = seen, rate = rate, per = per, burst = burst, ttl = ttl} +end + +local allow = true +for i = 1, count do + local s = snap[i] + if n > s.burst or s.tokens < n then + allow = false + break + end +end + +local out = {} +for i = 1, count do + local s = snap[i] + local committed = s.tokens + if allow then + committed = s.tokens - n + end + redis.call("HSET", KEYS[i], "tokens", committed, "seen", s.seen) + redis.call("PEXPIRE", KEYS[i], s.ttl) + + if allow then + table.insert(out, 1) + table.insert(out, math.floor(committed)) + table.insert(out, 0) + else + local retry + if n > s.burst then + retry = math.ceil((n - s.burst) * s.per / s.rate) + elseif s.tokens < n then + retry = math.ceil((n - s.tokens) * s.per / s.rate) + else + retry = 0 + end + table.insert(out, 0) + table.insert(out, math.floor(committed)) + table.insert(out, retry) + end +end +return out +` diff --git a/kratos/ratelimit/redis/store_test.go b/kratos/ratelimit/redis/store_test.go new file mode 100644 index 0000000..6ab1705 --- /dev/null +++ b/kratos/ratelimit/redis/store_test.go @@ -0,0 +1,299 @@ +package redis + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/crypto-zero/go-kit/kratos/ratelimit" + goredis "github.com/redis/go-redis/v9" +) + +// fakeClient implements goredis.Scripter for tests. It forces the EVALSHA +// fast-path to miss so Eval is called, which lets the test inspect the script +// body and arg layout. +type fakeClient struct { + script string + keys []string + args []any + cmd *goredis.Cmd + evals int +} + +func (f *fakeClient) Eval(_ context.Context, script string, keys []string, args ...any) *goredis.Cmd { + f.evals++ + f.script = script + f.keys = append([]string(nil), keys...) + f.args = append([]any(nil), args...) + if f.cmd != nil { + return f.cmd + } + return goredis.NewCmdResult([]any{int64(1), int64(4), int64(0)}, nil) +} + +func (f *fakeClient) EvalRO(ctx context.Context, script string, keys []string, args ...any) *goredis.Cmd { + return f.Eval(ctx, script, keys, args...) +} + +func (f *fakeClient) EvalSha(_ context.Context, _ string, _ []string, _ ...any) *goredis.Cmd { + return goredis.NewCmdResult(nil, goredis.ErrNoScript) +} + +func (f *fakeClient) EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...any) *goredis.Cmd { + return f.EvalSha(ctx, sha1, keys, args...) +} + +func (f *fakeClient) ScriptExists(_ context.Context, _ ...string) *goredis.BoolSliceCmd { + return goredis.NewBoolSliceResult(nil, nil) +} + +func (f *fakeClient) ScriptLoad(_ context.Context, _ string) *goredis.StringCmd { + return goredis.NewStringResult("", nil) +} + +func TestStoreTakeEvaluatesScript(t *testing.T) { + client := &fakeClient{} + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + res, err := store.Take(context.Background(), "tenant-1", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 5, + Per: time.Second, + Burst: 10, + }, 3) + if err != nil { + t.Fatalf("Take: %v", err) + } + if !res.Allowed || res.Remaining != 4 || res.RetryAfter != 0 { + t.Fatalf("result = %+v, want allowed with 4 remaining", res) + } + if client.script != takeScript { + t.Fatal("Eval script mismatch") + } + if len(client.keys) != 1 || client.keys[0] != "{tenant-1}:api:tenant-1" { + t.Fatalf("keys = %#v, want hash-tagged api key", client.keys) + } + wantArgs := []any{3, 5, int64(1000), 10, int64(2000)} + if len(client.args) != len(wantArgs) { + t.Fatalf("args = %#v, want %#v", client.args, wantArgs) + } + for i := range wantArgs { + if client.args[i] != wantArgs[i] { + t.Fatalf("arg %d = %#v, want %#v", i, client.args[i], wantArgs[i]) + } + } +} + +func TestStoreTakeParsesRejectedResult(t *testing.T) { + client := &fakeClient{ + cmd: goredis.NewCmdResult([]any{int64(0), int64(0), int64(500)}, nil), + } + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + res, err := store.Take(context.Background(), "tenant-1", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 2, + Per: time.Second, + Burst: 1, + }, 1) + if err != nil { + t.Fatalf("Take: %v", err) + } + if res.Allowed || res.Remaining != 0 || res.RetryAfter != 500*time.Millisecond { + t.Fatalf("result = %+v, want rejected with 500ms retry", res) + } +} + +func TestStoreTakeManyParsesPerKeyResults(t *testing.T) { + client := &fakeClient{ + cmd: goredis.NewCmdResult([]any{ + int64(1), int64(9), int64(0), + int64(0), int64(0), int64(750), + }, nil), + } + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + results, err := store.TakeMany(context.Background(), + []string{"/svc/A:user_id:user-1", "/svc/A:client_ip:ip-1"}, + time.UnixMilli(1000), + []ratelimit.Limit{ + {Rate: 10, Per: time.Second, Burst: 10}, + {Rate: 2, Per: time.Second, Burst: 1}, + }, + 1, + ) + if err != nil { + t.Fatalf("TakeMany: %v", err) + } + if len(results) != 2 { + t.Fatalf("got %d results, want 2", len(results)) + } + if !results[0].Allowed || results[0].Remaining != 9 || results[0].RetryAfter != 0 { + t.Fatalf("results[0] = %+v, want allowed with 9 remaining", results[0]) + } + if results[1].Allowed || results[1].RetryAfter != 750*time.Millisecond { + t.Fatalf("results[1] = %+v, want rejected with 750ms retry", results[1]) + } + if len(client.keys) != 2 || client.keys[0] != "{/svc/A}:api:/svc/A:user_id:user-1" || client.keys[1] != "{/svc/A}:api:/svc/A:client_ip:ip-1" { + t.Fatalf("keys = %#v, want shared hash tag from first key in order", client.keys) + } +} + +func TestRedisKeyUsesSanitizedGroupHashTag(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api{prod}") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + tag := sanitizeHashTag(keyGroup("/svc/{Order}:Create")) + if got := store.redisKey("/svc/{Order}:Create:user_id:u1", tag); got != "{/svc/_Order_}:api{prod}:/svc/{Order}:Create:user_id:u1" { + t.Fatalf("redisKey = %q, want sanitized group hash tag", got) + } +} + +func TestStoreTakeManyRejectsKeyLimitMismatch(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.TakeMany(context.Background(), + []string{"a", "b"}, + time.UnixMilli(1000), + []ratelimit.Limit{{Rate: 1, Per: time.Second, Burst: 1}}, + 1, + ) + if !errors.Is(err, ErrKeyLimitMismatch) { + t.Fatalf("TakeMany error = %v, want ErrKeyLimitMismatch", err) + } + + _, err = store.TakeMany(context.Background(), + nil, + time.UnixMilli(1000), + []ratelimit.Limit{{Rate: 1, Per: time.Second, Burst: 1}}, + 1, + ) + if !errors.Is(err, ErrKeyLimitMismatch) { + t.Fatalf("TakeMany empty keys mismatch error = %v, want ErrKeyLimitMismatch", err) + } +} + +func TestStoreTakeManyRejectsKeyGroupMismatch(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.TakeMany(context.Background(), + []string{"/svc/A:user_id:u1", "/svc/B:client_ip:127.0.0.1"}, + time.UnixMilli(1000), + []ratelimit.Limit{ + {Rate: 1, Per: time.Second, Burst: 1}, + {Rate: 1, Per: time.Second, Burst: 1}, + }, + 1, + ) + if !errors.Is(err, ErrKeyGroupMismatch) { + t.Fatalf("TakeMany error = %v, want ErrKeyGroupMismatch", err) + } +} + +func TestStoreTakeManyAllowsNonPositiveNWithoutRedis(t *testing.T) { + client := &fakeClient{} + store, err := newScriptStore(client, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + results, err := store.TakeMany(context.Background(), + []string{"tenant-1"}, + time.UnixMilli(1000), + []ratelimit.Limit{{Rate: 1, Per: time.Second, Burst: 1}}, + -1, + ) + if err != nil { + t.Fatalf("TakeMany: %v", err) + } + if len(results) != 1 || !results[0].Allowed { + t.Fatalf("results = %+v, want one allowed result", results) + } + if client.evals != 0 { + t.Fatalf("Eval calls = %d, want 0 for non-positive n", client.evals) + } +} + +func TestStoreTakeReturnsRedisError(t *testing.T) { + redisErr := errors.New("redis unavailable") + store, err := newScriptStore(&fakeClient{cmd: goredis.NewCmdResult(nil, redisErr)}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.Take(context.Background(), "tenant-1", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 1, + Per: time.Second, + Burst: 1, + }, 1) + if !errors.Is(err, redisErr) { + t.Fatalf("Take error = %v, want redis error", err) + } +} + +func TestNewStoreRejectsMissingClient(t *testing.T) { + _, err := NewStore(nil, "api") + if !errors.Is(err, ErrMissingClient) { + t.Fatalf("NewStore error = %v, want ErrMissingClient", err) + } +} + +func TestNewStoreRejectsMissingPrefix(t *testing.T) { + _, err := newScriptStore(&fakeClient{}, "") + if !errors.Is(err, ErrMissingPrefix) { + t.Fatalf("NewStore error = %v, want ErrMissingPrefix", err) + } +} + +func TestStoreTakeRejectsMissingKey(t *testing.T) { + store, err := newScriptStore(&fakeClient{}, "api") + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + _, err = store.Take(context.Background(), "", time.UnixMilli(1000), ratelimit.Limit{ + Rate: 1, + Per: time.Second, + Burst: 1, + }, 1) + if !errors.Is(err, ratelimit.ErrMissingKey) { + t.Fatalf("Take error = %v, want ErrMissingKey", err) + } +} + +func TestParseResultsRejectsInvalidLength(t *testing.T) { + _, err := parseResults([]any{int64(1)}, 1) + if !errors.Is(err, ErrInvalidScriptResult) { + t.Fatalf("parseResults error = %v, want ErrInvalidScriptResult", err) + } +} + +func TestDurationMillisRoundsTinyDurationUp(t *testing.T) { + if got := durationMillis(time.Nanosecond); got != 1 { + t.Fatalf("durationMillis(time.Nanosecond) = %d, want 1", got) + } +} + +func TestTTLCoversFullRefillFromEmpty(t *testing.T) { + got := ttl(ratelimit.Limit{Rate: 1, Per: time.Minute, Burst: 600}) + wantMin := int64(600) * time.Minute.Milliseconds() / 1 // refill-from-empty time + if got < wantMin { + t.Fatalf("ttl = %d ms, want >= %d ms so idle keys do not reset to burst", got, wantMin) + } +} diff --git a/kratos/sse/doc.go b/kratos/sse/doc.go new file mode 100644 index 0000000..a2a12a9 --- /dev/null +++ b/kratos/sse/doc.go @@ -0,0 +1,2 @@ +// Package sse mounts Server-Sent Events endpoints on a Kratos HTTP server. +package sse diff --git a/kratos/sse/http_client.go b/kratos/sse/http_client.go new file mode 100644 index 0000000..8c425b0 --- /dev/null +++ b/kratos/sse/http_client.go @@ -0,0 +1,141 @@ +package sse + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/go-kratos/kratos/v2/encoding" + + "github.com/crypto-zero/go-kit/sse" +) + +// HTTPClient opens SSE streams over HTTP. +type HTTPClient struct { + endpoint string + client *http.Client +} + +// HTTPClientOption configures an HTTPClient. +type HTTPClientOption func(*HTTPClient) + +// WithHTTPClient sets the underlying net/http client. +func WithHTTPClient(client *http.Client) HTTPClientOption { + return func(c *HTTPClient) { + if client == nil { + panic("kratos/sse: nil HTTP client") + } + c.client = client + } +} + +// NewHTTPClient returns an SSE HTTP client for endpoint. +func NewHTTPClient(endpoint string, opts ...HTTPClientOption) *HTTPClient { + c := &HTTPClient{endpoint: strings.TrimRight(endpoint, "/"), client: http.DefaultClient} + for _, opt := range opts { + opt(c) + } + return c +} + +// HTTPStreamCallOption configures one SSE stream request. +type HTTPStreamCallOption func(*httpStreamCallConfig) + +type httpStreamCallConfig struct { + headers http.Header +} + +// WithRequestHeader adds one request header value. +func WithRequestHeader(key, value string) HTTPStreamCallOption { + return func(c *httpStreamCallConfig) { + c.headers.Add(key, value) + } +} + +// WithLastEventID sets the Last-Event-ID header for stream resumption. +func WithLastEventID(id string) HTTPStreamCallOption { + return WithRequestHeader(sse.LastEventIDHeader, id) +} + +// Open opens an SSE stream and returns a reader for response events. Non-nil +// body values are JSON encoded unless body already implements io.Reader. +func (c *HTTPClient) Open(ctx context.Context, method, path string, body any, opts ...HTTPStreamCallOption) (*sse.Reader, error) { + cfg := httpStreamCallConfig{headers: make(http.Header)} + for _, opt := range opts { + opt(&cfg) + } + u, err := joinEndpointPath(c.endpoint, path) + if err != nil { + return nil, err + } + bodyReader, contentType, err := encodeRequestBody(body) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, method, u, bodyReader) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "text/event-stream") + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + for key, values := range cfg.headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode < http.StatusOK || resp.StatusCode > 299 { + defer resp.Body.Close() + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10)) + msg := strings.TrimSpace(string(body)) + if msg == "" { + return nil, fmt.Errorf("sse: unexpected status %d", resp.StatusCode) + } + return nil, fmt.Errorf("sse: unexpected status %d: %s", resp.StatusCode, msg) + } + return sse.NewReader(resp.Body), nil +} + +func encodeRequestBody(body any) (io.Reader, string, error) { + if body == nil { + return nil, "", nil + } + if r, ok := body.(io.Reader); ok { + return r, "", nil + } + b, err := encoding.GetCodec("json").Marshal(body) + if err != nil { + return nil, "", fmt.Errorf("sse: marshal request body: %w", err) + } + return bytes.NewReader(b), "application/json", nil +} + +func joinEndpointPath(endpoint, path string) (string, error) { + if endpoint == "" { + return "", fmt.Errorf("sse: endpoint is empty") + } + base, err := url.Parse(endpoint) + if err != nil { + return "", err + } + ref, err := url.Parse(path) + if err != nil { + return "", err + } + if ref.IsAbs() { + return ref.String(), nil + } + base.Path = strings.TrimRight(base.Path, "/") + "/" + strings.TrimLeft(ref.Path, "/") + base.RawQuery = ref.RawQuery + base.Fragment = ref.Fragment + return base.String(), nil +} diff --git a/kratos/sse/http_client_test.go b/kratos/sse/http_client_test.go new file mode 100644 index 0000000..1e525dc --- /dev/null +++ b/kratos/sse/http_client_test.go @@ -0,0 +1,97 @@ +package sse_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + ssekratos "github.com/crypto-zero/go-kit/kratos/sse" + ksse "github.com/crypto-zero/go-kit/sse" +) + +func TestHTTPClientOpen(t *testing.T) { + var gotLastEventID string + var gotPath string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotLastEventID = r.Header.Get(ksse.LastEventIDHeader) + gotPath = r.URL.RequestURI() + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("id: 7\nevent: point\ndata: {}\n\n")) + })) + defer ts.Close() + + client := ssekratos.NewHTTPClient(ts.URL + "/api") + reader, err := client.Open(context.Background(), http.MethodGet, "/v1/watch?west=1", nil, ssekratos.WithLastEventID("6")) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer reader.Close() + if gotPath != "/api/v1/watch?west=1" { + t.Fatalf("path = %q, want /api/v1/watch?west=1", gotPath) + } + if gotLastEventID != "6" { + t.Fatalf("Last-Event-ID = %q, want 6", gotLastEventID) + } + ev, err := reader.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + if ev.ID != "7" || ev.Event != "point" || ev.Data != "{}" { + t.Fatalf("unexpected event: %#v", ev) + } +} + +func TestHTTPClientOpenSendsJSONBody(t *testing.T) { + var gotBody string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + gotBody = string(body) + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Fatalf("Content-Type = %q, want application/json", got) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: ok\n\n")) + })) + defer ts.Close() + + client := ssekratos.NewHTTPClient(ts.URL) + reader, err := client.Open(context.Background(), http.MethodPost, "/v1/watch", map[string]string{"name": "alice"}) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer reader.Close() + if gotBody != `{"name":"alice"}` { + t.Fatalf("body = %q, want JSON payload", gotBody) + } +} + +func TestHTTPClientOpenStatusErrorIncludesBody(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"bad token"}`, http.StatusUnauthorized) + })) + defer ts.Close() + + client := ssekratos.NewHTTPClient(ts.URL) + _, err := client.Open(context.Background(), http.MethodGet, "/v1/watch", nil) + if err == nil { + t.Fatal("Open err = nil, want error") + } + if got := err.Error(); !strings.Contains(got, "401") || !strings.Contains(got, "bad token") { + t.Fatalf("Open err = %q, want status and body", got) + } +} + +func TestWithHTTPClientNilPanics(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("NewHTTPClient did not panic") + } + }() + _ = ssekratos.NewHTTPClient("http://example.com", ssekratos.WithHTTPClient(nil)) +} diff --git a/kratos/sse/http_handler.go b/kratos/sse/http_handler.go new file mode 100644 index 0000000..7097712 --- /dev/null +++ b/kratos/sse/http_handler.go @@ -0,0 +1,98 @@ +package sse + +import ( + "context" + "net/http" + "time" + + khttp "github.com/go-kratos/kratos/v2/transport/http" + + "github.com/crypto-zero/go-kit/sse" +) + +// HTTPStreamOption configures a Kratos HTTP-attached SSE stream handler. +type HTTPStreamOption func(*httpStreamConfig) + +type httpStreamConfig struct { + heartbeat time.Duration + filters []khttp.FilterFunc +} + +// HTTPHeartbeat enables automatic SSE comment frames at interval for one +// Kratos HTTP stream handler. Set interval <= 0 to disable it. +func HTTPHeartbeat(interval time.Duration) HTTPStreamOption { + return func(c *httpStreamConfig) { c.heartbeat = interval } +} + +// HTTPFilter appends Kratos HTTP route filters around one stream endpoint. +func HTTPFilter(filters ...khttp.FilterFunc) HTTPStreamOption { + return func(c *httpStreamConfig) { c.filters = append(c.filters, filters...) } +} + +// RegisterHTTPStream mounts an SSE endpoint on an existing Kratos HTTP +// server. Unlike Server, it does not own a listener: lifecycle, filters, +// route walking, operation selection and service middleware all come from +// the supplied HTTP server. +// +// GET/HEAD requests decode through ctx.BindQuery; other methods use +// ctx.Bind. Proto messages get the full Kratos form codec handling +// (well-known types, repeated/map fields, nested paths) for free. +func RegisterHTTPStream[Req any]( + srv *khttp.Server, + method string, + path string, + operation string, + do func(ctx context.Context, req *Req, st *sse.Stream) error, + opts ...HTTPStreamOption, +) { + registerHTTPStream(srv, method, path, operation, bindHTTPStreamRequest[Req], do, opts...) +} + +func registerHTTPStream[Req any]( + srv *khttp.Server, + method string, + path string, + operation string, + bind func(khttp.Context, *Req) error, + do func(ctx context.Context, req *Req, st *sse.Stream) error, + opts ...HTTPStreamOption, +) { + cfg := httpStreamConfig{} + for _, opt := range opts { + opt(&cfg) + } + srv.Route("/").Handle(method, path, func(ctx khttp.Context) error { + req := new(Req) + if err := bind(ctx, req); err != nil { + return err + } + if operation != "" { + khttp.SetOperation(ctx, operation) + } + streamCtx, stopStreamCtx := sse.DetachDeadlineContext(ctx) + defer stopStreamCtx() + h := ctx.Middleware(func(mctx context.Context, raw any) (any, error) { + st := sse.NewStream(ctx.Response()) + stopBeat := st.Heartbeat(mctx, cfg.heartbeat) + defer stopBeat() + if err := do(mctx, raw.(*Req), st); err != nil { + _ = st.Error(err.Error()) + } + return nil, nil + }) + _, err := h(streamCtx, req) + return err + }, cfg.filters...) +} + +func bindHTTPStreamRequest[Req any](ctx khttp.Context, target *Req) error { + if err := ctx.BindVars(target); err != nil { + return err + } + switch ctx.Request().Method { + case http.MethodGet, http.MethodHead: + return ctx.BindQuery(target) + default: + return ctx.Bind(target) + } +} diff --git a/kratos/sse/http_handler_test.go b/kratos/sse/http_handler_test.go new file mode 100644 index 0000000..cf32b90 --- /dev/null +++ b/kratos/sse/http_handler_test.go @@ -0,0 +1,135 @@ +package sse_test + +import ( + "bufio" + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-kratos/kratos/v2/middleware" + ktransport "github.com/go-kratos/kratos/v2/transport" + khttp "github.com/go-kratos/kratos/v2/transport/http" + "google.golang.org/protobuf/types/known/durationpb" + + ksse "github.com/crypto-zero/go-kit/kratos/sse" + "github.com/crypto-zero/go-kit/sse" +) + +func TestHTTPStreamHandler_BindsProtoQueryAndStreamsOnKratosHTTP(t *testing.T) { + srv := khttp.NewServer(khttp.Timeout(0)) + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/duration/{seconds}", "/test.Duration/Watch", + func(_ context.Context, req *durationpb.Duration, st *sse.Stream) error { + if req.GetSeconds() != 12 || req.GetNanos() != 34 { + t.Fatalf("request = %ds/%dns, want 12s/34ns", req.GetSeconds(), req.GetNanos()) + } + _ = st.Write("bound") + return st.Done() + }, + ) + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/duration/12?nanos=34") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if got, want := resp.Header.Get("Content-Type"), "text/event-stream"; got != want { + t.Errorf("Content-Type = %q, want %q", got, want) + } + body := readAll(t, resp.Body) + if !strings.Contains(body, "data: bound\n\n") { + t.Errorf("body missing payload: %q", body) + } + if !strings.Contains(body, "data: [DONE]\n\n") { + t.Errorf("body missing done marker: %q", body) + } +} + +func TestHTTPStreamHandler_SetsOperationBeforeMiddleware(t *testing.T) { + const operation = "/test.Live/Watch" + seen := make(chan string, 1) + srv := khttp.NewServer( + khttp.Timeout(0), + khttp.Middleware(func(next middleware.Handler) middleware.Handler { + return func(ctx context.Context, req any) (any, error) { + tr, ok := ktransport.FromServerContext(ctx) + if !ok { + t.Fatal("missing transport") + } + seen <- tr.Operation() + return next(ctx, req) + } + }), + ) + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/live", operation, + func(_ context.Context, _ *durationpb.Duration, st *sse.Stream) error { + return st.Done() + }, + ) + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/live") + if err != nil { + t.Fatalf("GET: %v", err) + } + _ = resp.Body.Close() + + select { + case got := <-seen: + if got != operation { + t.Errorf("operation = %q, want %q", got, operation) + } + case <-time.After(time.Second): + t.Fatal("middleware did not run") + } +} + +func TestHTTPStreamHandler_DetachesKratosHTTPTimeout(t *testing.T) { + srv := khttp.NewServer(khttp.Timeout(20 * time.Millisecond)) + ksse.RegisterHTTPStream(srv, http.MethodGet, "/v1/slow", "/test.Slow/Watch", + func(ctx context.Context, _ *durationpb.Duration, st *sse.Stream) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(80 * time.Millisecond): + } + _ = st.Write("still-open") + return st.Done() + }, + ) + ts := httptest.NewServer(srv) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/slow") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer func() { _ = resp.Body.Close() }() + body := readAll(t, resp.Body) + if !strings.Contains(body, "data: still-open\n\n") { + t.Errorf("body missing delayed payload after HTTP timeout: %q", body) + } + if !strings.Contains(body, "data: [DONE]\n\n") { + t.Errorf("body missing done marker: %q", body) + } +} + +func readAll(t *testing.T, r interface { + Read([]byte) (int, error) +}) string { + t.Helper() + var sb strings.Builder + br := bufio.NewReader(r) + for { + line, err := br.ReadString('\n') + sb.WriteString(line) + if err != nil { + return sb.String() + } + } +} diff --git a/kubernetes/election/election.go b/kubernetes/election/election.go index 1ec10f4..06b64ed 100644 --- a/kubernetes/election/election.go +++ b/kubernetes/election/election.go @@ -3,11 +3,11 @@ package election import ( "context" "fmt" - "io" "log/slog" "os" "strings" "sync" + "sync/atomic" "time" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -17,6 +17,11 @@ import ( "k8s.io/client-go/tools/leaderelection/resourcelock" ) +const ( + loggerNamed = "__LOGGER.NAMED__" + namespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" +) + type ( // StateMachine is the state machine interface. StateMachine interface { @@ -26,15 +31,13 @@ type ( EnsureMaster(ctx context.Context) error // EnsureSlave ensures the state machine is slave. EnsureSlave(ctx context.Context) error - // Do the state machine. + // Do runs one state machine iteration and returns the delay before the next run. Do(ctx context.Context) (after time.Duration) - // Cleanup the state machine. + // Cleanup releases resources held by the state machine. Cleanup() } - // StateMachineRunner is the state machine runner interface + // StateMachineRunner runs leader-elected state machines. StateMachineRunner interface { - // implements kratos.Server - // Start starts the state machine runner. Start(context.Context) error // Stop stops the state machine runner. @@ -45,11 +48,12 @@ type ( } ) -// StateMachiRunnerImpl is the state machine runner implementation. -type StateMachiRunnerImpl struct { - ctx context.Context - cancel func() - closed bool +// StateMachineRunnerImpl is the state machine runner implementation. +type StateMachineRunnerImpl struct { + ctx context.Context + cancel func() + closed atomic.Bool + cleanupOnce sync.Once wg sync.WaitGroup @@ -60,29 +64,35 @@ type StateMachiRunnerImpl struct { logger *slog.Logger } +// StateMachiRunnerImpl is deprecated. +// +// Deprecated: Use StateMachineRunnerImpl. +type StateMachiRunnerImpl = StateMachineRunnerImpl + // Start starts the state machine runner. -func (s *StateMachiRunnerImpl) Start(context.Context) error { return nil } +func (s *StateMachineRunnerImpl) Start(context.Context) error { return nil } // Stop stops the state machine runner. -func (s *StateMachiRunnerImpl) Stop(context.Context) error { return nil } +func (s *StateMachineRunnerImpl) Stop(context.Context) error { + s.cleanup() + return nil +} // cleanup cleans up the state machine runner. -func (s *StateMachiRunnerImpl) cleanup() { - s.closed = true - s.cancel() - s.wg.Wait() +func (s *StateMachineRunnerImpl) cleanup() { + s.cleanupOnce.Do(func() { + s.closed.Store(true) + s.cancel() + s.wg.Wait() + }) } // serveMachine serves the state machine. -func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { - // The logger name is conventionally assigned to the key "__LOGGER.NAMED__" defined in go-kit/zap. - const ( - LoggerNamed = "__LOGGER.NAMED__" - ) +func (s *StateMachineRunnerImpl) serveMachine(machine StateMachine) { + defer s.wg.Done() - // type hint here that can be omitted name := fmt.Sprintf("state-machine-runner-%s", machine.Name()) - logger := s.logger.With(LoggerNamed, name) + logger := s.logger.With(loggerNamed, name) // lease lock name rule: [a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)* isLeaderChan := make(chan bool, 10) @@ -127,7 +137,6 @@ func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { ctx, cancel := context.WithCancel(s.ctx) - defer s.wg.Done() defer func() { logger.Info("stopped") }() defer machine.Cleanup() defer cancel() @@ -161,7 +170,7 @@ func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { } } - for !s.closed { + for !s.closed.Load() { after := machine.Do(ctx) if after <= 0 { return @@ -178,18 +187,13 @@ func (s *StateMachiRunnerImpl) serveMachine(machine StateMachine) { } } -func (s *StateMachiRunnerImpl) AddMachine(machine StateMachine) { +func (s *StateMachineRunnerImpl) AddMachine(machine StateMachine) { s.wg.Add(1) go s.serveMachine(machine) } // NewStateMachineRunnerImpl creates a new StateMachineRunner. func NewStateMachineRunnerImpl(logger *slog.Logger) (StateMachineRunner, func(), error) { - out := &StateMachiRunnerImpl{ - logger: logger, - } - out.ctx, out.cancel = context.WithCancel(context.Background()) - config, err := rest.InClusterConfig() if err != nil { return nil, nil, err @@ -202,17 +206,19 @@ func NewStateMachineRunnerImpl(logger *slog.Logger) (StateMachineRunner, func(), if err != nil { return nil, nil, err } - out.cli, out.namespace, out.pod = cli, GetCurrentNamespace(), pod - return out, sync.OnceFunc(out.cleanup), nil + out := &StateMachineRunnerImpl{ + logger: logger, + cli: cli, + namespace: GetCurrentNamespace(), + pod: pod, + } + out.ctx, out.cancel = context.WithCancel(context.Background()) + return out, out.cleanup, nil } // GetCurrentNamespace returns the current namespace in the kubernetes cluster. func GetCurrentNamespace() (namespace string) { - namespaceFile, err := os.Open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") - if err != nil { - return "" - } - d, err := io.ReadAll(namespaceFile) + d, err := os.ReadFile(namespacePath) if err != nil { return "" } diff --git a/kubernetes/election/election_test.go b/kubernetes/election/election_test.go new file mode 100644 index 0000000..cdf1a2d --- /dev/null +++ b/kubernetes/election/election_test.go @@ -0,0 +1,27 @@ +package election + +import ( + "context" + "testing" +) + +func TestStateMachineRunnerStopIsIdempotent(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + runner := &StateMachineRunnerImpl{ + ctx: ctx, + cancel: cancel, + } + + if err := runner.Stop(context.Background()); err != nil { + t.Fatalf("first Stop: %v", err) + } + if err := runner.Stop(context.Background()); err != nil { + t.Fatalf("second Stop: %v", err) + } + + select { + case <-ctx.Done(): + default: + t.Fatal("Stop did not cancel runner context") + } +} diff --git a/kubernetes/election/go.mod b/kubernetes/election/go.mod index e45236e..54788e3 100644 --- a/kubernetes/election/go.mod +++ b/kubernetes/election/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit/kubernetes/election -go 1.25.5 +go 1.26.3 require ( k8s.io/apimachinery v0.35.0 diff --git a/kubernetes/kubernetes.go b/kubernetes/kubernetes.go index 907fbe2..39fd061 100644 --- a/kubernetes/kubernetes.go +++ b/kubernetes/kubernetes.go @@ -1,18 +1,15 @@ package kubernetes import ( - "io" "os" "strings" ) +const namespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + // GetCurrentNamespace returns the current namespace in the kubernetes cluster. func GetCurrentNamespace() (namespace string) { - namespaceFile, err := os.Open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") - if err != nil { - return "" - } - d, err := io.ReadAll(namespaceFile) + d, err := os.ReadFile(namespacePath) if err != nil { return "" } diff --git a/lifecycle/loop_runner.go b/lifecycle/loop_runner.go new file mode 100644 index 0000000..7fbc415 --- /dev/null +++ b/lifecycle/loop_runner.go @@ -0,0 +1,206 @@ +// Package lifecycle provides helpers for service lifecycles. +package lifecycle + +import ( + "context" + "errors" + "log/slog" + "sync" + "time" +) + +// ErrAlreadyStarted reports a runner configuration change after Start. +var ErrAlreadyStarted = errors.New("lifecycle: runner already started") + +// Service is the common lifecycle interface used by Kratos servers. +type Service interface { + Start(context.Context) error + Stop(context.Context) error +} + +// LoopRunner manages a group of background loops. +type LoopRunner interface { + Service + + // OnStart registers a callback that runs during Start before any loops are + // spawned. Calling OnStart twice overrides the previous callback. + OnStart(func(context.Context) error) LoopRunner + + // Add registers a loop to spawn during Start. + Add(name string, fn func(context.Context)) LoopRunner + + // AddTick registers a periodic loop that invokes fn every interval. + AddTick(name string, interval time.Duration, fn func(context.Context) error) LoopRunner +} + +type namedLoop struct { + name string + fn func(context.Context) +} + +type loopRunner struct { + name string + logger *slog.Logger + + mu sync.Mutex + onStart func(context.Context) error + loops []namedLoop + cancel context.CancelFunc + started bool + + wg sync.WaitGroup +} + +// NewLoopRunner constructs a runner scoped to name. +func NewLoopRunner(name string, logger *slog.Logger) LoopRunner { + if logger == nil { + panic("lifecycle: logger is nil") + } + if name == "" { + panic("lifecycle: service name is empty") + } + return &loopRunner{ + name: name, + logger: logger.With("service", name), + } +} + +func (r *loopRunner) OnStart(fn func(context.Context) error) LoopRunner { + if err := r.configure(func() { + r.onStart = fn + }); err != nil { + panic(err) + } + return r +} + +func (r *loopRunner) Add(name string, fn func(context.Context)) LoopRunner { + if name == "" { + panic("lifecycle: loop name is empty") + } + if fn == nil { + panic("lifecycle: loop function is nil") + } + if err := r.configure(func() { + r.loops = append(r.loops, namedLoop{name: name, fn: fn}) + }); err != nil { + panic(err) + } + return r +} + +func (r *loopRunner) AddTick(name string, interval time.Duration, fn func(context.Context) error) LoopRunner { + if interval <= 0 { + panic("lifecycle: tick interval must be positive") + } + if fn == nil { + panic("lifecycle: tick function is nil") + } + return r.Add(name, func(ctx context.Context) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := fn(ctx); err != nil { + r.logger.ErrorContext(ctx, "loop tick failed", "loop", name, "err", err) + } + } + } + }) +} + +// Start implements Service. +func (r *loopRunner) Start(ctx context.Context) error { + runCtx, onStart, loops, err := r.start() + if err != nil { + return err + } + if onStart != nil { + if err := onStart(ctx); err != nil { + _ = r.Stop(context.Background()) + return err + } + } + for _, l := range loops { + r.spawn(runCtx, l.name, l.fn) + } + r.logger.InfoContext(ctx, "service started", "loops", len(loops)) + return nil +} + +// Stop implements Service. +func (r *loopRunner) Stop(ctx context.Context) error { + cancel := r.stop() + if cancel == nil { + return nil + } + cancel() + + done := make(chan struct{}) + go func() { + r.wg.Wait() + r.finishStop() + close(done) + }() + + select { + case <-done: + r.logger.InfoContext(ctx, "service stopped") + case <-ctx.Done(): + r.logger.WarnContext(ctx, "stop deadline exceeded, goroutines may still be running") + } + return nil +} + +func (r *loopRunner) configure(fn func()) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.started { + return ErrAlreadyStarted + } + fn() + return nil +} + +func (r *loopRunner) start() (context.Context, func(context.Context) error, []namedLoop, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.started { + return nil, nil, nil, ErrAlreadyStarted + } + runCtx, cancel := context.WithCancel(context.Background()) + r.cancel = cancel + r.started = true + return runCtx, r.onStart, append([]namedLoop(nil), r.loops...), nil +} + +func (r *loopRunner) stop() context.CancelFunc { + r.mu.Lock() + defer r.mu.Unlock() + cancel := r.cancel + r.cancel = nil + return cancel +} + +func (r *loopRunner) finishStop() { + r.mu.Lock() + defer r.mu.Unlock() + r.started = false +} + +func (r *loopRunner) spawn(ctx context.Context, name string, fn func(context.Context)) { + r.wg.Go(func() { + defer func() { + if rec := recover(); rec != nil { + r.logger.ErrorContext(ctx, "background goroutine panicked", "loop", name, "panic", rec) + } + }() + if fn == nil { + return + } + fn(ctx) + }) +} diff --git a/lifecycle/loop_runner_test.go b/lifecycle/loop_runner_test.go new file mode 100644 index 0000000..1f32bbf --- /dev/null +++ b/lifecycle/loop_runner_test.go @@ -0,0 +1,153 @@ +package lifecycle + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestLoopRunnerStartsAndStopsLoops(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + var loaded atomic.Bool + started := make(chan struct{}) + + runner := NewLoopRunner("test", logger). + OnStart(func(context.Context) error { + loaded.Store(true) + return nil + }). + Add("main", func(ctx context.Context) { + if !loaded.Load() { + t.Error("loop started before OnStart completed") + } + close(started) + <-ctx.Done() + }) + + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + waitFor(t, started) + + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := runner.Stop(stopCtx); err != nil { + t.Fatalf("Stop: %v", err) + } + if !strings.Contains(buf.String(), "service stopped") { + t.Fatalf("logs = %q, want service stopped", buf.String()) + } +} + +func TestLoopRunnerRejectsConfigurationAfterStart(t *testing.T) { + runner := NewLoopRunner("test", testLogger()).Add("main", func(ctx context.Context) { + <-ctx.Done() + }) + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + defer func() { + _ = runner.Stop(context.Background()) + }() + + err := recoverPanic(func() { + runner.Add("late", func(context.Context) {}) + }) + if !errors.Is(err, ErrAlreadyStarted) { + t.Fatalf("panic = %v, want ErrAlreadyStarted", err) + } +} + +func TestLoopRunnerRecoversLoopPanic(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + runner := NewLoopRunner("test", logger).Add("panic", func(context.Context) { + panic("boom") + }) + + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := runner.Stop(stopCtx); err != nil { + t.Fatalf("Stop: %v", err) + } + if !strings.Contains(buf.String(), "background goroutine panicked") { + t.Fatalf("logs = %q, want panic log", buf.String()) + } +} + +func TestLoopRunnerTickLoop(t *testing.T) { + ticked := make(chan struct{}, 1) + runner := NewLoopRunner("test", testLogger()).AddTick("tick", time.Millisecond, func(context.Context) error { + select { + case ticked <- struct{}{}: + default: + } + return nil + }) + + if err := runner.Start(context.Background()); err != nil { + t.Fatalf("Start: %v", err) + } + waitFor(t, ticked) + + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := runner.Stop(stopCtx); err != nil { + t.Fatalf("Stop: %v", err) + } +} + +func TestNewLoopRunnerRejectsMissingName(t *testing.T) { + err := recoverPanic(func() { + NewLoopRunner("", testLogger()) + }) + if err == nil || err.Error() != "panic" { + t.Fatalf("panic = %v, want panic for missing name", err) + } +} + +func TestNewLoopRunnerRejectsMissingLogger(t *testing.T) { + err := recoverPanic(func() { + NewLoopRunner("test", nil) + }) + if err == nil || err.Error() != "panic" { + t.Fatalf("panic = %v, want panic for missing logger", err) + } +} + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func waitFor(t *testing.T, ch <-chan struct{}) { + t.Helper() + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timed out waiting for channel") + } +} + +func recoverPanic(fn func()) (err error) { + defer func() { + if rec := recover(); rec != nil { + if e, ok := rec.(error); ok { + err = e + return + } + err = errors.New("panic") + } + }() + fn() + return nil +} diff --git a/logging/kratos/go.mod b/logging/kratos/go.mod deleted file mode 100644 index 4387484..0000000 --- a/logging/kratos/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/crypto-zero/go-kit/logging/kratos - -go 1.25.5 - -require ( - github.com/go-kratos/kratos/v2 v2.9.2 - google.golang.org/grpc v1.78.0 - google.golang.org/protobuf v1.36.11 -) - -require ( - github.com/go-kratos/aegis v0.2.0 // indirect - github.com/go-playground/form/v4 v4.2.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/mux v1.8.1 // indirect - github.com/kr/text v0.2.0 // indirect - golang.org/x/sys v0.38.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/maxmind/maxmind.go b/maxmind/maxmind.go index 3b54191..800eb70 100644 --- a/maxmind/maxmind.go +++ b/maxmind/maxmind.go @@ -7,7 +7,7 @@ import ( "github.com/oschwald/maxminddb-golang" ) -// GeoNames is a struct for multiple languages +// GeoNames is a struct for multiple languages. type GeoNames struct { German string `maxminddb:"de"` English string `maxminddb:"en"` @@ -19,7 +19,7 @@ type GeoNames struct { Chinese string `maxminddb:"zh-CN"` } -// GeoCity is a struct for maxminddb city result +// GeoCity is a struct for maxminddb city result. type GeoCity struct { City struct { Name GeoNames `maxminddb:"names"` @@ -53,17 +53,20 @@ type GeoCity struct { var emptyGeoCity = GeoCity{} -// Database is an interface for maxminddb +// Database reads GeoCity records from a MaxMind database. type Database interface { - // Lookup returns GeoCity for given IP + // Lookup returns GeoCity for given IP. Lookup(ip net.IP) (*GeoCity, error) } -// DatabaseImpl is an implementation of Database +// DatabaseImpl implements Database. type DatabaseImpl struct { db *maxminddb.Reader } +var _ Database = (*DatabaseImpl)(nil) + +// Lookup returns GeoCity for given IP. func (d *DatabaseImpl) Lookup(ip net.IP) (*GeoCity, error) { var record GeoCity if err := d.db.Lookup(ip, &record); err != nil { @@ -75,16 +78,16 @@ func (d *DatabaseImpl) Lookup(ip net.IP) (*GeoCity, error) { return &record, nil } -// Path is a type for maxminddb path +// Path is a type for maxminddb path. type Path string -// ContainerPath returns path to maxminddb container +// ContainerPath returns path to maxminddb container. func ContainerPath() Path { return "/app/bin/GeoLite2-City.mmdb" } -// NewDatabaseImpl returns implementation of Database -func NewDatabaseImpl(path Path) (Database, func(), error) { +// NewDatabase opens a MaxMind database. +func NewDatabase(path Path) (*DatabaseImpl, func(), error) { db, err := maxminddb.Open(string(path)) if err != nil { return nil, nil, err @@ -94,7 +97,18 @@ func NewDatabaseImpl(path Path) (Database, func(), error) { }, nil } -// IsEmptyGeoCity checks if GeoCity is empty +// NewDatabaseImpl returns implementation of Database. +// +// Deprecated: Use NewDatabase when a concrete *DatabaseImpl is acceptable. +func NewDatabaseImpl(path Path) (Database, func(), error) { + database, cleanup, err := NewDatabase(path) + if err != nil { + return nil, nil, err + } + return database, cleanup, nil +} + +// IsEmptyGeoCity checks if GeoCity is empty. func IsEmptyGeoCity(geoCity GeoCity) bool { return reflect.DeepEqual(geoCity, emptyGeoCity) } diff --git a/maxmind/maxmind_test.go b/maxmind/maxmind_test.go index 989edb4..217363d 100644 --- a/maxmind/maxmind_test.go +++ b/maxmind/maxmind_test.go @@ -2,7 +2,9 @@ package maxmind import ( "encoding/json" + "errors" "net" + "os" "testing" "github.com/oschwald/maxminddb-golang" @@ -11,14 +13,21 @@ import ( func TestMaxmindRead(t *testing.T) { reader, err := maxminddb.Open("./GeoLite2-City.mmdb") if err != nil { + if errors.Is(err, os.ErrNotExist) { + t.Skip("GeoLite2-City.mmdb is not available") + } t.Fatal(err) } + defer reader.Close() var record GeoCity internalIP := net.ParseIP("81.2.69.142") if err = reader.Lookup(internalIP, &record); err != nil { t.Fatal(err) } - b, _ := json.Marshal(record) + b, err := json.Marshal(record) + if err != nil { + t.Fatal(err) + } t.Log(string(b), IsEmptyGeoCity(record)) } diff --git a/otel/resouce.go b/otel/resouce.go index fcd046b..dd9f946 100644 --- a/otel/resouce.go +++ b/otel/resouce.go @@ -3,22 +3,22 @@ package otel import "go.opentelemetry.io/otel/attribute" const ( - // SigNozSystemDBKey span type database call + // SigNozSystemDBKey is the span attribute key for database calls. // https://signoz.io/docs/userguide/metrics/ SigNozSystemDBKey = attribute.Key("db.system") ) -// SigNozSystemDB return db system attribute +// SigNozSystemDB returns a database system attribute. func SigNozSystemDB(system string) attribute.KeyValue { return SigNozSystemDBKey.String(system) } -// SigNozSystemDBPostgres return db system attribute for postgres +// SigNozSystemDBPostgres returns a database system attribute for PostgreSQL. func SigNozSystemDBPostgres() attribute.KeyValue { return SigNozSystemDB("postgresql") } -// SigNozSystemDBNats return db system attribute for nats +// SigNozSystemDBNats returns a database system attribute for NATS. func SigNozSystemDBNats() attribute.KeyValue { return SigNozSystemDB("nats") } diff --git a/otel/trace_provider.go b/otel/trace_provider.go index a8fa8e7..4f00c23 100644 --- a/otel/trace_provider.go +++ b/otel/trace_provider.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "strings" + "time" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -17,7 +18,15 @@ import ( "github.com/crypto-zero/go-kit/kubernetes" ) -// TraceProviderConfig is an open telemetry trace provider config. +const traceShutdownTimeout = 5 * time.Second + +// TraceProvider is an open telemetry trace service. +// +// Deprecated: This broad compatibility type is an alias-shaped service token. +// Consumers should depend on the behavior they need instead of this type. +type TraceProvider any + +// TraceProviderConfig configures an OpenTelemetry trace provider. type TraceProviderConfig struct { Context context.Context Name string @@ -28,7 +37,7 @@ type TraceProviderConfig struct { SampleFraction float64 } -// FromEnv load config from env. +// FromEnv loads trace provider config from environment variables. func (c *TraceProviderConfig) FromEnv() { value := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") // The env var may contain a scheme, which we need to remove. @@ -39,12 +48,12 @@ func (c *TraceProviderConfig) FromEnv() { } } -// TraceProvider is an open telemetry trace service. -type TraceProvider any - +// TraceProviderImpl is an OpenTelemetry trace service. type TraceProviderImpl struct{} -// NewTraceProvider new an open telemetry trace provider. +// NewTraceProvider creates an OpenTelemetry trace provider. +// +// It returns TraceProvider for backward compatibility with earlier releases. func NewTraceProvider(c *TraceProviderConfig) ( TraceProvider, func(), error, ) { @@ -57,7 +66,11 @@ func NewTraceProvider(c *TraceProviderConfig) ( exportGrpcOptions = append(exportGrpcOptions, otlptracegrpc.WithInsecure()) } exportGrpcOptions = append(exportGrpcOptions, otlptracegrpc.WithEndpoint(c.Endpoint)) - exporter, err := otlptrace.New(c.Context, otlptracegrpc.NewClient(exportGrpcOptions...)) + ctx := c.Context + if ctx == nil { + ctx = context.Background() + } + exporter, err := otlptrace.New(ctx, otlptracegrpc.NewClient(exportGrpcOptions...)) if err != nil { return nil, nil, fmt.Errorf("failed to create the collector exporter: %w", err) } @@ -71,7 +84,7 @@ func NewTraceProvider(c *TraceProviderConfig) ( semconv.K8SNamespaceName(kubernetes.GetCurrentNamespace()), } if resourceInEnv := os.Getenv("OTEL_RESOURCE_ATTRIBUTES"); resourceInEnv != "" { - for _, attr := range strings.Split(resourceInEnv, ",") { + for attr := range strings.SplitSeq(resourceInEnv, ",") { parts := strings.Split(attr, "=") if len(parts) == 2 { attrs = append(attrs, attribute.String(parts[0], parts[1])) @@ -79,12 +92,15 @@ func NewTraceProvider(c *TraceProviderConfig) ( } } - otel.SetTracerProvider( - sdktrace.NewTracerProvider( - sdktrace.WithSampler(sdktrace.TraceIDRatioBased(c.SampleFraction)), - sdktrace.WithBatcher(exporter), - sdktrace.WithResource(resource.NewSchemaless(attrs...)), - ), + provider := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.TraceIDRatioBased(c.SampleFraction)), + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(resource.NewSchemaless(attrs...)), ) - return &TraceProviderImpl{}, func() {}, nil + otel.SetTracerProvider(provider) + return &TraceProviderImpl{}, func() { + ctx, cancel := context.WithTimeout(context.Background(), traceShutdownTimeout) + defer cancel() + _ = provider.Shutdown(ctx) + }, nil } diff --git a/pgx/pgx.go b/pgx/pgx.go index c44d8b4..d951e67 100644 --- a/pgx/pgx.go +++ b/pgx/pgx.go @@ -10,7 +10,8 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -// ent not support for generic types, so we need to declare wrapper types for each type +// Ent does not support generic field types, so concrete wrapper types are +// declared for each supported database value. type ( // CIDRWrapper is a wrapper for pgx standard sql library types. CIDRWrapper StdWrapper[netip.Prefix] @@ -32,23 +33,15 @@ type ( ) // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w CIDRWrapper) Value() (driver.Value, error) { return StdWrapper[netip.Prefix](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *CIDRWrapper) Scan(src any) error { return (*StdWrapper[netip.Prefix])(w).Scan(src) } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w DurationWrapper) Value() (driver.Value, error) { return StdWrapper[time.Duration](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *DurationWrapper) Scan(src any) error { return (*StdWrapper[time.Duration])(w).Scan(src) } @@ -57,52 +50,36 @@ func (w *DurationWrapper) Scan(src any) error { func NewIntsWrapper() IntsWrapper { return IntsWrapper{V: make([]int, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w IntsWrapper) Value() (driver.Value, error) { return SliceWrapper[int](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *IntsWrapper) Scan(src any) error { return (*SliceWrapper[int])(w).Scan(src) } // NewFloatsWrapper returns a new FloatsWrapper. func NewFloatsWrapper() FloatsWrapper { return FloatsWrapper{V: make([]float64, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w FloatsWrapper) Value() (driver.Value, error) { return SliceWrapper[float64](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *FloatsWrapper) Scan(src any) error { return (*SliceWrapper[float64])(w).Scan(src) } // NewStringsWrapper returns a new StringsWrapper. func NewStringsWrapper() StringsWrapper { return StringsWrapper{V: make([]string, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w StringsWrapper) Value() (driver.Value, error) { return SliceWrapper[string](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *StringsWrapper) Scan(src any) error { return (*SliceWrapper[string])(w).Scan(src) } // NewCIDRsWrapper returns a new CIDRsWrapper. func NewCIDRsWrapper() CIDRsWrapper { return CIDRsWrapper{V: make([]netip.Prefix, 0)} } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w CIDRsWrapper) Value() (driver.Value, error) { return SliceWrapper[netip.Prefix](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *CIDRsWrapper) Scan(src any) error { return (*SliceWrapper[netip.Prefix])(w).Scan(src) } // NewDurationsWrapper returns a new DurationsWrapper. @@ -111,15 +88,11 @@ func NewDurationsWrapper() DurationsWrapper { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w DurationsWrapper) Value() (driver.Value, error) { return SliceWrapper[time.Duration](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *DurationsWrapper) Scan(src any) error { return (*SliceWrapper[time.Duration])(w).Scan(src) } @@ -130,15 +103,11 @@ func NewTimestampsWrapper() TimestampsWrapper { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w TimestampsWrapper) Value() (driver.Value, error) { return SliceWrapper[time.Time](w).Value() } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *TimestampsWrapper) Scan(src any) error { return (*SliceWrapper[time.Time])(w).Scan(src) } @@ -156,15 +125,11 @@ type StdWrapper[T any] struct { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w StdWrapper[T]) Value() (driver.Value, error) { return w.V, nil } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *StdWrapper[T]) Scan(src any) (err error) { return typeMapScan(src, &w.V) } @@ -175,10 +140,8 @@ type SliceWrapper[T any] struct { } // Value implements the database/sql/driver Valuer interface. -// -//goland:noinspection GoMixedReceiverTypes func (w SliceWrapper[T]) Value() (driver.Value, error) { - // pgx treats nil slice as NULL, but we want to treat it as an empty array + // pgx treats nil slices as NULL, but callers expect an empty array. if w.V == nil { return make([]T, 0), nil } @@ -186,8 +149,6 @@ func (w SliceWrapper[T]) Value() (driver.Value, error) { } // Scan implements the database/sql Scanner interface. -// -//goland:noinspection GoMixedReceiverTypes func (w *SliceWrapper[T]) Scan(src any) (err error) { return typeMapScan(src, &w.V) } @@ -220,7 +181,7 @@ func guessingScan[T any](src any) (value T, err error) { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(bufSrc)) + bufSrc = fmt.Append(nil, src) } } diff --git a/pgx/pgx_test.go b/pgx/pgx_test.go index c83ebb9..3fe923f 100644 --- a/pgx/pgx_test.go +++ b/pgx/pgx_test.go @@ -16,18 +16,21 @@ import ( ) var db *sql.DB +var skipPGXTests string func TestMain(m *testing.M) { // uses a sensible default on windows (tcp/http) and linux/osx (socket) pool, err := dockertest.NewPool("") if err != nil { - log.Fatalf("Could not construct pool: %s", err) + skipPGXTests = fmt.Sprintf("could not construct Docker pool: %s", err) + os.Exit(m.Run()) } // uses pool to try to connect to Docker err = pool.Client.Ping() if err != nil { - log.Fatalf("Could not connect to Docker: %s", err) + skipPGXTests = fmt.Sprintf("could not connect to Docker: %s", err) + os.Exit(m.Run()) } // pulls an image, creates a container based on it and runs it @@ -47,7 +50,8 @@ func TestMain(m *testing.M) { } }) if err != nil { - log.Fatalf("Could not start resource: %s", err) + skipPGXTests = fmt.Sprintf("could not start PostgreSQL container: %s", err) + os.Exit(m.Run()) } // exponential backoff-retry, because the application in the container might not be ready to accept connections yet @@ -59,12 +63,16 @@ func TestMain(m *testing.M) { } return db.Ping() }); err != nil { - log.Fatalf("Could not connect to database: %s", err) + skipPGXTests = fmt.Sprintf("could not connect to database: %s", err) + os.Exit(m.Run()) } code := m.Run() // You can't defer this because os.Exit doesn't care for defer + if err := db.Close(); err != nil { + log.Fatalf("Could not close database: %s", err) + } if err := pool.Purge(resource); err != nil { log.Fatalf("Could not purge resource: %s", err) } @@ -72,7 +80,31 @@ func TestMain(m *testing.M) { os.Exit(code) } +func requirePGXDB(t *testing.T) *sql.DB { + t.Helper() + if skipPGXTests != "" { + t.Skip(skipPGXTests) + } + return db +} + +func TestGuessingScanUsesFallbackSource(t *testing.T) { + type textJSON string + type person struct { + Name string `json:"name"` + } + + got, err := guessingScan[person](textJSON(`{"name":"John"}`)) + if err != nil { + t.Fatal(err) + } + if got.Name != "John" { + t.Fatalf("guessingScan() = %+v; want name John", got) + } +} + func TestPGXIntArray(t *testing.T) { + db := requirePGXDB(t) input := []int{1, 2, 3} var output StdWrapper[[]int] if err := db.QueryRow("select $1::int[]", input).Scan(&output); err != nil { @@ -84,6 +116,7 @@ func TestPGXIntArray(t *testing.T) { } func TestPGXJSON(t *testing.T) { + db := requirePGXDB(t) type person struct { Name string `json:"name"` Age int `json:"age"` @@ -99,6 +132,7 @@ func TestPGXJSON(t *testing.T) { } func TestPGXJSONArray(t *testing.T) { + db := requirePGXDB(t) type person struct { Name string `json:"name"` Age int `json:"age"` @@ -122,6 +156,7 @@ func TestPGXJSONArray(t *testing.T) { } func TestPGXNetPrefix(t *testing.T) { + db := requirePGXDB(t) input := netip.MustParsePrefix("255.255.255.255/32") var output StdWrapper[netip.Prefix] if err := db.QueryRow("select $1::cidr", input).Scan(&output); err != nil { @@ -133,6 +168,7 @@ func TestPGXNetPrefix(t *testing.T) { } func TestPGXNetPrefixArray(t *testing.T) { + db := requirePGXDB(t) input := []netip.Prefix{ netip.MustParsePrefix("127.0.0.1/32"), netip.MustParsePrefix("10.0.0.0/8"), diff --git a/pprof/pprof.go b/pprof/pprof.go index 2d18a5d..41d7954 100644 --- a/pprof/pprof.go +++ b/pprof/pprof.go @@ -1,36 +1,62 @@ package pprof import ( + "errors" "fmt" - "log" + "log/slog" "net" "net/http" _ "net/http/pprof" + "sync" "github.com/google/gops/agent" ) // Pprof is a pprof service. +// +// Deprecated: This broad compatibility type is an alias-shaped service token. +// Consumers should depend on the behavior they need instead of this type. type Pprof any // PprofImpl is a pprof service implementation. -type PprofImpl struct{} +type PprofImpl struct { + listener net.Listener + once sync.Once +} // NewPProfImpl returns a new PprofImpl. -// it provides gops agent and pprof service. +// It provides gops agent and pprof service. +// +// It returns Pprof for backward compatibility with earlier releases. func NewPProfImpl() (Pprof, func(), error) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, func() {}, fmt.Errorf("start pprof failed: %v", err) + return nil, func() {}, fmt.Errorf("start pprof failed: %w", err) } - log.Println("start pprof service on:", ln.Addr()) + service := &PprofImpl{listener: ln} + cleanup := service.Close + + slog.Info("start pprof service", "addr", ln.Addr().String()) go func() { - _ = http.Serve(ln, nil) + if err := http.Serve(ln, nil); err != nil && !errors.Is(err, net.ErrClosed) { + slog.Error("pprof service stopped", "err", err) + } }() if err := agent.Listen(agent.Options{ShutdownCleanup: false}); err != nil { - return nil, func() {}, fmt.Errorf("start gops agent failed: %v", err) + cleanup() + return nil, func() {}, fmt.Errorf("start gops agent failed: %w", err) } - return &PprofImpl{}, func() {}, nil + return service, cleanup, nil +} + +// Close stops the pprof service and gops agent. +func (p *PprofImpl) Close() { + p.once.Do(func() { + agent.Close() + if p.listener != nil { + _ = p.listener.Close() + } + }) } diff --git a/proto/buf.gen.yaml b/proto/buf.gen.yaml index fa5facd..e9a6ce4 100644 --- a/proto/buf.gen.yaml +++ b/proto/buf.gen.yaml @@ -6,11 +6,13 @@ plugins: - local: - go - run - - ../errors/protoc-gen-kit-errors/main.go + - ../cmd/protoc-gen-kit-errors/main.go out: . + exclude_types: + - kit.errors.v1 - local: - go - run - - ../protoc-gen-go-redact/main.go + - ../cmd/protoc-gen-go-redact/main.go out: . - opt: paths=source_relative \ No newline at end of file + opt: paths=source_relative diff --git a/proto/buf.lock b/proto/buf.lock deleted file mode 100644 index ace68b7..0000000 --- a/proto/buf.lock +++ /dev/null @@ -1,6 +0,0 @@ -# Generated by buf. DO NOT EDIT. -version: v2 -deps: - - name: buf.build/googleapis/googleapis - commit: b30c5775bfb3485d9da2e87b26590ac9 - digest: b5:13f091a467b31c7f734307e6760d864e3319b9c47656f2ada6efa45c643864d9c9e7d5cd372c92cc8e0972deb63f41bc8fc88a5ca21ab2e9ea04d2144752857d diff --git a/proto/kit/auth/v1/auth.pb.go b/proto/kit/auth/v1/auth.pb.go new file mode 100644 index 0000000..5c7797f --- /dev/null +++ b/proto/kit/auth/v1/auth.pb.go @@ -0,0 +1,84 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: kit/auth/v1/auth.proto + +package authv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +var file_kit_auth_v1_auth_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 105001, + Name: "kit.auth.v1.public", + Tag: "varint,105001,opt,name=public", + Filename: "kit/auth/v1/auth.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // public marks an RPC operation as not requiring authenticated user context. + // + // optional bool public = 105001; + E_Public = &file_kit_auth_v1_auth_proto_extTypes[0] +) + +var File_kit_auth_v1_auth_proto protoreflect.FileDescriptor + +const file_kit_auth_v1_auth_proto_rawDesc = "" + + "\n" + + "\x16kit/auth/v1/auth.proto\x12\vkit.auth.v1\x1a google/protobuf/descriptor.proto:8\n" + + "\x06public\x12\x1e.google.protobuf.MethodOptions\x18\xa9\xb4\x06 \x01(\bR\x06publicB8Z6github.com/crypto-zero/go-kit/proto/kit/auth/v1;authv1b\x06proto3" + +var file_kit_auth_v1_auth_proto_goTypes = []any{ + (*descriptorpb.MethodOptions)(nil), // 0: google.protobuf.MethodOptions +} +var file_kit_auth_v1_auth_proto_depIdxs = []int32{ + 0, // 0: kit.auth.v1.public:extendee -> google.protobuf.MethodOptions + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_kit_auth_v1_auth_proto_init() } +func file_kit_auth_v1_auth_proto_init() { + if File_kit_auth_v1_auth_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_kit_auth_v1_auth_proto_rawDesc), len(file_kit_auth_v1_auth_proto_rawDesc)), + NumEnums: 0, + NumMessages: 0, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_kit_auth_v1_auth_proto_goTypes, + DependencyIndexes: file_kit_auth_v1_auth_proto_depIdxs, + ExtensionInfos: file_kit_auth_v1_auth_proto_extTypes, + }.Build() + File_kit_auth_v1_auth_proto = out.File + file_kit_auth_v1_auth_proto_goTypes = nil + file_kit_auth_v1_auth_proto_depIdxs = nil +} diff --git a/proto/kit/auth/v1/auth.proto b/proto/kit/auth/v1/auth.proto new file mode 100644 index 0000000..e296050 --- /dev/null +++ b/proto/kit/auth/v1/auth.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package kit.auth.v1; + +import "google/protobuf/descriptor.proto"; + +option go_package = "github.com/crypto-zero/go-kit/proto/kit/auth/v1;authv1"; + +extend google.protobuf.MethodOptions { + // public marks an RPC operation as not requiring authenticated user context. + bool public = 105001; +} diff --git a/proto/kit/redact/v1/redact.pb.go b/proto/kit/redact/v1/redact.pb.go index 02fcaef..3ec5c95 100644 --- a/proto/kit/redact/v1/redact.pb.go +++ b/proto/kit/redact/v1/redact.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.9 -// protoc v5.29.3 +// protoc-gen-go v1.36.11 +// protoc (unknown) // source: kit/redact/v1/redact.proto package redact diff --git a/proto/kit/sse/v1/sse.pb.go b/proto/kit/sse/v1/sse.pb.go new file mode 100644 index 0000000..e68c870 --- /dev/null +++ b/proto/kit/sse/v1/sse.pb.go @@ -0,0 +1,83 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v7.34.1 +// source: kit/sse/v1/sse.proto + +package ssev1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +var file_kit_sse_v1_sse_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.MethodOptions)(nil), + ExtensionType: (*bool)(nil), + Field: 81001, + Name: "kit.sse.v1.server_sent_event", + Tag: "varint,81001,opt,name=server_sent_event", + Filename: "kit/sse/v1/sse.proto", + }, +} + +// Extension fields to descriptorpb.MethodOptions. +var ( + // optional bool server_sent_event = 81001; + E_ServerSentEvent = &file_kit_sse_v1_sse_proto_extTypes[0] +) + +var File_kit_sse_v1_sse_proto protoreflect.FileDescriptor + +const file_kit_sse_v1_sse_proto_rawDesc = "" + + "\n" + + "\x14kit/sse/v1/sse.proto\x12\n" + + "kit.sse.v1\x1a google/protobuf/descriptor.proto:L\n" + + "\x11server_sent_event\x12\x1e.google.protobuf.MethodOptions\x18\xe9\xf8\x04 \x01(\bR\x0fserverSentEventB6Z4github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1b\x06proto3" + +var file_kit_sse_v1_sse_proto_goTypes = []any{ + (*descriptorpb.MethodOptions)(nil), // 0: google.protobuf.MethodOptions +} +var file_kit_sse_v1_sse_proto_depIdxs = []int32{ + 0, // 0: kit.sse.v1.server_sent_event:extendee -> google.protobuf.MethodOptions + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 0, // [0:1] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_kit_sse_v1_sse_proto_init() } +func file_kit_sse_v1_sse_proto_init() { + if File_kit_sse_v1_sse_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_kit_sse_v1_sse_proto_rawDesc), len(file_kit_sse_v1_sse_proto_rawDesc)), + NumEnums: 0, + NumMessages: 0, + NumExtensions: 1, + NumServices: 0, + }, + GoTypes: file_kit_sse_v1_sse_proto_goTypes, + DependencyIndexes: file_kit_sse_v1_sse_proto_depIdxs, + ExtensionInfos: file_kit_sse_v1_sse_proto_extTypes, + }.Build() + File_kit_sse_v1_sse_proto = out.File + file_kit_sse_v1_sse_proto_goTypes = nil + file_kit_sse_v1_sse_proto_depIdxs = nil +} diff --git a/proto/kit/sse/v1/sse.proto b/proto/kit/sse/v1/sse.proto new file mode 100644 index 0000000..17a5ab6 --- /dev/null +++ b/proto/kit/sse/v1/sse.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package kit.sse.v1; + +import "google/protobuf/descriptor.proto"; + +option go_package = "github.com/crypto-zero/go-kit/proto/kit/sse/v1;ssev1"; + +extend google.protobuf.MethodOptions { + bool server_sent_event = 81001; +} diff --git a/query/convert.go b/query/convert.go index f3ad4b1..577e76d 100644 --- a/query/convert.go +++ b/query/convert.go @@ -1,6 +1,6 @@ package query -// ConvertList convert A list to B list +// ConvertList converts a slice of A values into B values. func ConvertList[A, B any](from []A, convert func(A) B) []B { results := make([]B, 0, len(from)) for _, v := range from { diff --git a/query/paging.go b/query/paging.go index 0d8193e..1beed02 100644 --- a/query/paging.go +++ b/query/paging.go @@ -1,13 +1,13 @@ package query var ( - // DefaultPageSize 默认每页条数 + // DefaultPageSize is the default number of items per page. DefaultPageSize int32 = 10 - // MaxPageSize 最大每页条数 + // MaxPageSize is the maximum number of items per page. MaxPageSize int32 = 1000 ) -// ResizePage 修正分页参数, 页码从0开始. +// ResizePage normalizes page parameters. Page numbers start at zero. func ResizePage(page, pageSize int32) (int32, int32) { if page < 0 { page = 0 diff --git a/s3/event.go b/s3/event.go index 09c0c6f..aa66022 100644 --- a/s3/event.go +++ b/s3/event.go @@ -6,7 +6,7 @@ import ( "time" ) -// EventName that wraps the event name +// EventName wraps an S3 event name. type EventName string const ( @@ -39,14 +39,14 @@ const ( EventS3ObjectTaggingDelete EventName = "s3:ObjectTagging:Delete" ) -// Event that wraps an array of EventRecord +// Event wraps a batch of S3 event records. type Event struct { EventName EventName `json:"EventName"` Key string `json:"Key"` Records []EventRecord `json:"Records"` } -// EventRecord which wrap record data +// EventRecord wraps S3 event record data. type EventRecord struct { EventVersion string `json:"eventVersion"` EventSource string `json:"eventSource"` @@ -60,19 +60,19 @@ type EventRecord struct { Source Source `json:"source"` } -// UserIdentity that wraps the principal ID +// UserIdentity wraps the principal ID. type UserIdentity struct { PrincipalID string `json:"principalId"` } -// RequestParameters that wraps the principal ID, region, and source IP address +// RequestParameters wraps the principal ID, region, and source IP address. type RequestParameters struct { PrincipalID string `json:"principalId"` Region string `json:"region"` SourceIPAddress string `json:"sourceIPAddress"` } -// Entity that wraps the bucket and object +// Entity wraps the bucket and object. type Entity struct { SchemaVersion string `json:"s3SchemaVersion"` ConfigurationID string `json:"configurationId"` @@ -80,7 +80,7 @@ type Entity struct { Object Object `json:"object"` } -// Bucket that wraps the bucket name, owner identity, and ARN +// Bucket wraps the bucket name, owner identity, and ARN. type Bucket struct { Name string `json:"name"` OwnerIdentity UserIdentity `json:"ownerIdentity"` @@ -88,7 +88,7 @@ type Bucket struct { } // Object that wraps the object key, size, ETag, content type, user metadata, -// version ID, sequencer, and URL-decoded key +// version ID, sequencer, and URL-decoded key. type Object struct { Key string `json:"key"` Size int64 `json:"size,omitempty"` @@ -113,7 +113,7 @@ func (o *Object) UnmarshalJSON(data []byte) error { return nil } -// Source that wraps the source IP address and user agent +// Source wraps the source IP address and user agent. type Source struct { Host string `json:"host"` Port string `json:"port"` diff --git a/s3/go.mod b/s3/go.mod index 7a332a8..9cdf048 100644 --- a/s3/go.mod +++ b/s3/go.mod @@ -1,6 +1,6 @@ module github.com/crypto-zero/go-kit/s3 -go 1.25.5 +go 1.26.3 require ( github.com/minio/minio-go/v7 v7.0.97 diff --git a/s3/go.sum b/s3/go.sum index d78ded4..5af9eb9 100644 --- a/s3/go.sum +++ b/s3/go.sum @@ -8,19 +8,13 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= -github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/klauspost/crc32 v1.3.0 h1:sSmTt3gUt81RP655XGZPElI0PelVTZ6YwCRnPSupoFM= github.com/klauspost/crc32 v1.3.0/go.mod h1:D7kQaZhnkX/Y0tstFGf8VUzv2UofNGqCjnC3zdHB0Hw= github.com/minio/crc64nvme v1.1.0 h1:e/tAguZ+4cw32D+IO/8GSf5UVr9y+3eJcxZI2WOO/7Q= github.com/minio/crc64nvme v1.1.0/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= -github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI= -github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= github.com/minio/minio-go/v7 v7.0.97 h1:lqhREPyfgHTB/ciX8k2r8k0D93WaFqxbJX36UZq5occ= @@ -35,30 +29,14 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tinylib/msgp v1.3.0 h1:ULuf7GPooDaIlbyvgAxBV/FI7ynli6LZ1/nVUNu+0ww= github.com/tinylib/msgp v1.3.0/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0= -github.com/tinylib/msgp v1.6.1 h1:ESRv8eL3u+DNHUoSAAQRE50Hm162zqAnBoGv9PzScPY= -github.com/tinylib/msgp v1.6.1/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/s3/s3.go b/s3/s3.go index e6f130a..2b3eccc 100644 --- a/s3/s3.go +++ b/s3/s3.go @@ -16,78 +16,116 @@ import ( "github.com/minio/minio-go/v7/pkg/credentials" ) -// S3 provides operations on s3 bucket +const ( + serviceAccountCAPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + serviceAccountTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" +) + +// S3 provides operations on an S3-compatible bucket. +// +// Deprecated: Consumers should prefer defining a narrow interface in the +// package that consumes S3 behavior. This broad interface remains for +// compatibility with existing go-kit users. type S3 interface { - // PresignGetURL returns a presigned url for get object operation - PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration, - ) (*url.URL, error) - // PresignPutURL returns a presigned url for put object operation - PresignPutURL(ctx context.Context, bucket, key, contentType, sha256 string, - size int, expire time.Duration) (*url.URL, http.Header, error) - // GetObject gets an object from bucket - GetObject(ctx context.Context, bucket, key string, opt minio.GetObjectOptions) ( - *minio.Object, error) - // PutObject uploads an object to bucket - PutObject(ctx context.Context, bucket, key, contentType string, size int, - body io.Reader, opts minio.PutObjectOptions) (minio.UploadInfo, error) - // CopyObject copies an object from srcKey to destKey - CopyObject(ctx context.Context, bucket, srcKey, destKey string) (out minio.UploadInfo, - err error) - // DeleteObject deletes an object from bucket + // PresignGetURL returns a presigned URL for a get-object operation. + PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration) (*url.URL, error) + // PresignPutURL returns a presigned URL and headers for a put-object operation. + PresignPutURL( + ctx context.Context, + bucket string, + key string, + contentType string, + sha256 string, + size int, + expire time.Duration, + ) (*url.URL, http.Header, error) + // GetObject gets an object from bucket. + GetObject(ctx context.Context, bucket, key string, opts minio.GetObjectOptions) (*minio.Object, error) + // PutObject uploads an object to bucket. + PutObject( + ctx context.Context, + bucket string, + key string, + contentType string, + size int, + body io.Reader, + opts minio.PutObjectOptions, + ) (minio.UploadInfo, error) + // CopyObject copies an object from srcKey to destKey in bucket. + CopyObject(ctx context.Context, bucket, srcKey, destKey string) (minio.UploadInfo, error) + // DeleteObject deletes an object from bucket. DeleteObject(ctx context.Context, bucket, key string) error - // StatObject stats an object in bucket + // StatObject stats an object in bucket. StatObject(ctx context.Context, bucket, key string) (minio.ObjectInfo, error) } -// MinioS3Impl provides operations on AWS/s3 and minio for implementing S3 interface +// MinioS3Impl provides operations on AWS S3 and MinIO. type MinioS3Impl struct { client *minio.Client } -func (m *MinioS3Impl) PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration, -) (out *url.URL, err error) { - if out, err = m.client.PresignedGetObject(ctx, bucket, key, expire, nil); err != nil { +var _ S3 = (*MinioS3Impl)(nil) + +// PresignGetURL returns a presigned URL for a get-object operation. +func (m *MinioS3Impl) PresignGetURL(ctx context.Context, bucket, key string, expire time.Duration) (*url.URL, error) { + out, err := m.client.PresignedGetObject(ctx, bucket, key, expire, nil) + if err != nil { return nil, fmt.Errorf("failed to presign get object: %w", err) } - return + return out, nil } -func (m *MinioS3Impl) PresignPutURL(ctx context.Context, bucket, key, contentType, - sha256 string, size int, expire time.Duration, -) (out *url.URL, headers http.Header, err error) { - headers = http.Header{ +// PresignPutURL returns a presigned URL and headers for a put-object operation. +func (m *MinioS3Impl) PresignPutURL( + ctx context.Context, + bucket string, + key string, + contentType string, + sha256 string, + size int, + expire time.Duration, +) (*url.URL, http.Header, error) { + headers := http.Header{ "Content-Type": []string{contentType}, "Content-Length": []string{fmt.Sprint(size)}, "x-amz-checksum-sha256": []string{sha256}, } - out, err = m.client.PresignHeader(ctx, http.MethodPut, bucket, key, expire, nil, headers) + out, err := m.client.PresignHeader(ctx, http.MethodPut, bucket, key, expire, nil, headers) if err != nil { return nil, nil, fmt.Errorf("failed to presign put object: %w", err) } - return + return out, headers, nil } -func (m *MinioS3Impl) GetObject(ctx context.Context, bucket, key string, opts minio.GetObjectOptions, -) (out *minio.Object, err error) { - if out, err = m.client.GetObject(ctx, bucket, key, opts); err != nil { +// GetObject gets an object from bucket. +func (m *MinioS3Impl) GetObject(ctx context.Context, bucket, key string, opts minio.GetObjectOptions) (*minio.Object, error) { + out, err := m.client.GetObject(ctx, bucket, key, opts) + if err != nil { return nil, fmt.Errorf("failed to get object: %w", err) } - return + return out, nil } -func (m *MinioS3Impl) PutObject(ctx context.Context, bucket, key, contentType string, - size int, body io.Reader, opts minio.PutObjectOptions, -) (out minio.UploadInfo, err error) { +// PutObject uploads an object to bucket. +func (m *MinioS3Impl) PutObject( + ctx context.Context, + bucket string, + key string, + contentType string, + size int, + body io.Reader, + opts minio.PutObjectOptions, +) (minio.UploadInfo, error) { opts.ContentType = contentType - if out, err = m.client.PutObject(ctx, bucket, key, body, int64(size), opts); err != nil { + out, err := m.client.PutObject(ctx, bucket, key, body, int64(size), opts) + if err != nil { return out, fmt.Errorf("failed to put object: %w", err) } - return + return out, nil } -func (m *MinioS3Impl) CopyObject(ctx context.Context, bucket, srcKey, destKey string) ( - out minio.UploadInfo, err error, -) { +// CopyObject copies an object from srcKey to destKey in bucket. +func (m *MinioS3Impl) CopyObject(ctx context.Context, bucket, srcKey, destKey string) (minio.UploadInfo, error) { copySourceOpts := minio.CopySrcOptions{ Bucket: bucket, Object: srcKey, @@ -96,12 +134,14 @@ func (m *MinioS3Impl) CopyObject(ctx context.Context, bucket, srcKey, destKey st Bucket: bucket, Object: destKey, } - if out, err = m.client.CopyObject(ctx, copyDestOpts, copySourceOpts); err != nil { + out, err := m.client.CopyObject(ctx, copyDestOpts, copySourceOpts) + if err != nil { return out, fmt.Errorf("failed to copy object: %w", err) } - return + return out, nil } +// DeleteObject deletes an object from bucket. func (m *MinioS3Impl) DeleteObject(ctx context.Context, bucket, key string) error { if err := m.client.RemoveObject(ctx, bucket, key, minio.RemoveObjectOptions{}); err != nil { return fmt.Errorf("failed to delete object: %w", err) @@ -109,6 +149,7 @@ func (m *MinioS3Impl) DeleteObject(ctx context.Context, bucket, key string) erro return nil } +// StatObject stats an object in bucket. func (m *MinioS3Impl) StatObject(ctx context.Context, bucket, key string) (minio.ObjectInfo, error) { info, err := m.client.StatObject(ctx, bucket, key, minio.StatObjectOptions{}) if IsNoSuchKeyErr(err) { @@ -120,7 +161,9 @@ func (m *MinioS3Impl) StatObject(ctx context.Context, bucket, key string) (minio return info, nil } -// NewMinioS3Impl creates a new MinioS3Impl +// NewMinioS3Impl creates a new MinioS3Impl. +// +// It returns S3 for backward compatibility with earlier releases. func NewMinioS3Impl(endpoint, accessKeyID, secretAccessKey, sessionToken string) (S3, error) { return NewMinioS3ImplWithSTS(endpoint, &credentials.Static{ Value: credentials.Value{ @@ -132,7 +175,9 @@ func NewMinioS3Impl(endpoint, accessKeyID, secretAccessKey, sessionToken string) }) } -// NewMinioS3ImplWithSTS creates a new MinioS3Impl with STSProvider +// NewMinioS3ImplWithSTS creates a new MinioS3Impl with STSProvider. +// +// It returns S3 for backward compatibility with earlier releases. func NewMinioS3ImplWithSTS(endpoint string, sts STSProvider) (S3, error) { uri, err := url.Parse(endpoint) if err != nil { @@ -155,19 +200,19 @@ func NewMinioS3ImplWithSTS(endpoint string, sts STSProvider) (S3, error) { return &MinioS3Impl{client: c}, nil } -// DefaultSTSTokenExpirySeconds is the default expiry duration for STS token +// DefaultSTSTokenExpirySeconds is the default expiry duration for an STS token. const DefaultSTSTokenExpirySeconds = 3 * 24 * 60 * 60 // Three days -// STSProvider provides temporary credentials +// STSProvider provides temporary credentials. type STSProvider = credentials.Provider -// WindowedSTSIdentityProvider provides temporary credentials with a windowed expiry +// WindowedSTSIdentityProvider provides temporary credentials with a windowed expiry. type WindowedSTSIdentityProvider struct { Window time.Duration *credentials.STSWebIdentity } -// Retrieve returns the credential value +// Retrieve returns the credential value. func (w *WindowedSTSIdentityProvider) Retrieve() (credentials.Value, error) { value, err := w.STSWebIdentity.Retrieve() if err != nil { @@ -180,18 +225,18 @@ func (w *WindowedSTSIdentityProvider) Retrieve() (credentials.Value, error) { // NewMinioSTSProviderImpl creates a new instance of the STSProvider. func NewMinioSTSProviderImpl(endpoint string, expirySeconds int, expiryWindow time.Duration, ) (STSProvider, error) { - // Read kubernetes service account ca certificate file - caCert, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/ca.crt") + if expirySeconds <= 0 { + return nil, fmt.Errorf("sts token expiry seconds must be positive") + } + caCert, err := os.ReadFile(serviceAccountCAPath) if err != nil { return nil, fmt.Errorf("failed to read service account ca certificate: %w", err) } - // Read kubernetes service account token file - token, err := os.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/token") + token, err := os.ReadFile(serviceAccountTokenPath) if err != nil { return nil, fmt.Errorf("failed to read service account token: %w", err) } - // Create an HttpTransport with the service account token and ca certificate transport, err := minio.DefaultTransport(true) if err != nil { return nil, fmt.Errorf("failed to create minio transport: %w", err) @@ -207,7 +252,6 @@ func NewMinioSTSProviderImpl(endpoint string, expirySeconds int, expiryWindow ti return nil, fmt.Errorf("failed to append kubernetes service account ca certificate") } - // Create sts credentials credential := &credentials.STSWebIdentity{ Client: &http.Client{Transport: transport}, STSEndpoint: endpoint, @@ -222,7 +266,7 @@ func NewMinioSTSProviderImpl(endpoint string, expirySeconds int, expiryWindow ti return &WindowedSTSIdentityProvider{Window: expiryWindow, STSWebIdentity: credential}, nil } -// IsNoSuchKeyErr checks if the error is a NoSuchKey error +// IsNoSuchKeyErr checks if the error is a NoSuchKey error. func IsNoSuchKeyErr(err error) bool { if minioError := minio.ToErrorResponse(err); minioError.Code == "NoSuchKey" { return true @@ -230,4 +274,5 @@ func IsNoSuchKeyErr(err error) bool { return false } +// ErrNoSuchKey reports a missing S3 object. var ErrNoSuchKey = errors.New("no such key") diff --git a/s3/s3_minio_test.go b/s3/s3_minio_test.go index 5aad2e8..50984ac 100644 --- a/s3/s3_minio_test.go +++ b/s3/s3_minio_test.go @@ -33,6 +33,13 @@ func TestMinio(t *testing.T) { }) } +func TestNewMinioSTSProviderImplRejectsMissingExpiry(t *testing.T) { + _, err := NewMinioSTSProviderImpl("https://sts.example.com", 0, time.Minute) + if err == nil { + t.Fatal("NewMinioSTSProviderImpl error = nil, want error") + } +} + type TestMinioSuite struct { suite.Suite diff --git a/snowflake/snowflake.go b/snowflake/snowflake.go index cc075c3..cb2374e 100644 --- a/snowflake/snowflake.go +++ b/snowflake/snowflake.go @@ -7,17 +7,16 @@ import ( ) func init() { - // change epoch from 2024-01-01 and 42 time bits - // approximately 139 years 5 months 18 days + // Use a 2024-01-01 epoch with 42 time bits, covering roughly 139 years. snowflake.Epoch = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli() - // node used 9 bits approximately 512 nodes + // Use 9 node bits, supporting roughly 512 nodes. snowflake.NodeBits = 9 } -// Node is a snowflake Node +// Node is a snowflake node. type Node = snowflake.Node -// NewNode returns a new snowflake Node +// NewNode returns a new snowflake node. func NewNode(node int) *snowflake.Node { n, err := snowflake.NewNode(int64(node)) if err != nil { diff --git a/sse/detach.go b/sse/detach.go new file mode 100644 index 0000000..1d3e45d --- /dev/null +++ b/sse/detach.go @@ -0,0 +1,63 @@ +package sse + +import ( + "context" + "errors" + "net/http" + "sync" + "time" +) + +// DetachWriteTimeout overrides the underlying connection's write deadline +// and returns a request whose context is detached from any server-injected +// deadline. +// +// It addresses a common collision between SSE handlers and HTTP servers that +// install a global write timeout for unary requests (Kratos' http.Timeout +// option is the motivating case). Two things happen: +// +// 1. The connection's write deadline is reset to writeTimeout from now via +// http.ResponseController, replacing the server-wide value for this +// connection. +// 2. The request context is replaced with one that cancels only on genuine +// client disconnect; if the original context's Err is DeadlineExceeded +// (i.e. the server-wide deadline fired), the replacement is not +// cancelled. +// +// The replacement context is cancelled when the original context completes +// for any non-deadline reason, so downstream goroutines observing +// ctx.Done() still see client disconnects. +func DetachWriteTimeout(w http.ResponseWriter, r *http.Request, writeTimeout time.Duration) *http.Request { + rc := http.NewResponseController(w) + deadline := time.Time{} + if writeTimeout > 0 { + deadline = time.Now().Add(writeTimeout) + } + _ = rc.SetWriteDeadline(deadline) + + ctx, _ := DetachDeadlineContext(r.Context()) + return r.WithContext(ctx) +} + +// DetachDeadlineContext returns a context that keeps parent values but ignores +// parent DeadlineExceeded cancellation. Other parent cancellations, such as a +// client disconnect, are still forwarded. +// +// The returned stop function unregisters the parent callback and cancels the +// detached context. Callers that own a bounded streaming lifecycle should defer +// stop when the stream ends. +func DetachDeadlineContext(parent context.Context) (context.Context, func()) { + ctx, cancel := context.WithCancel(context.WithoutCancel(parent)) + stopParent := context.AfterFunc(parent, func() { + if !errors.Is(parent.Err(), context.DeadlineExceeded) { + cancel() + } + }) + var once sync.Once + return ctx, func() { + once.Do(func() { + stopParent() + cancel() + }) + } +} diff --git a/sse/example_test.go b/sse/example_test.go new file mode 100644 index 0000000..1d2885b --- /dev/null +++ b/sse/example_test.go @@ -0,0 +1,59 @@ +package sse_test + +import ( + "context" + "net/http" + "time" + + "github.com/crypto-zero/go-kit/sse" +) + +// ExampleStream_Pump shows the typical streaming handler for a Kratos HTTP +// server. The handler is mounted as a raw net/http handler via srv.Handle, +// bypassing protobuf transcoding so tokens reach the client as soon as the +// upstream goroutine emits them. +// +// Kratos installs a server-wide write deadline through http.Timeout; the +// DetachWriteTimeout call replaces it with a per-connection deadline long +// enough for the SSE stream and detaches the request context from the +// server-injected DeadlineExceeded signal. +func ExampleStream_Pump() { + const streamTimeout = 300 * time.Second + + handler := func(w http.ResponseWriter, r *http.Request) { + r = sse.DetachWriteTimeout(w, r, streamTimeout) + + // produceTokens stands in for a biz-layer call that returns the + // token channel and an error channel. + chunks, errs := produceTokens(r.Context()) + + s := sse.NewStream(w) + _ = s.Pump(r.Context(), chunks, errs) + } + _ = handler +} + +// ExampleStream_WriteJSON shows the unary-over-SSE pattern: the handler +// computes a single response, writes it as one data frame, and terminates +// with [DONE]. Useful when the response is structured JSON but the client +// already expects an SSE-formatted endpoint. +func ExampleStream_WriteJSON() { + const callTimeout = 60 * time.Second + + handler := func(w http.ResponseWriter, r *http.Request) { + r = sse.DetachWriteTimeout(w, r, callTimeout) + + result, err := compute(r.Context()) + s := sse.NewStream(w) + if err != nil { + _ = s.Error(err.Error()) + return + } + _ = s.WriteJSON(result) + _ = s.Done() + } + _ = handler +} + +func produceTokens(context.Context) (<-chan string, <-chan error) { return nil, nil } +func compute(context.Context) (any, error) { return nil, nil } diff --git a/sse/reader.go b/sse/reader.go new file mode 100644 index 0000000..55000df --- /dev/null +++ b/sse/reader.go @@ -0,0 +1,91 @@ +package sse + +import ( + "bufio" + "errors" + "io" + "strconv" + "strings" + "time" +) + +// Reader reads Server-Sent Events from an HTTP response body. +type Reader struct { + r *bufio.Reader + c io.Closer + lastID string +} + +// NewReader returns an SSE event reader for r. +func NewReader(r io.Reader) *Reader { + er := &Reader{r: bufio.NewReader(r)} + if c, ok := r.(io.Closer); ok { + er.c = c + } + return er +} + +// Next blocks until the next complete event frame is read. +func (r *Reader) Next() (*Event, error) { + var ev Event + var data []string + seen := false + idSeen := false + for { + line, err := r.r.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) && seen { + return r.finishEvent(ev, data, idSeen), nil + } + return nil, err + } + line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r") + if line == "" { + if !seen { + continue + } + return r.finishEvent(ev, data, idSeen), nil + } + if strings.HasPrefix(line, ":") { + continue + } + seen = true + field, value, ok := strings.Cut(line, ":") + if !ok { + value = "" + } else if strings.HasPrefix(value, " ") { + value = value[1:] + } + switch field { + case "event": + ev.Event = value + case "id": + ev.ID = value + idSeen = true + case "data": + data = append(data, value) + case "retry": + if ms, err := strconv.ParseInt(value, 10, 64); err == nil { + ev.Retry = time.Duration(ms) * time.Millisecond + } + } + } +} + +func (r *Reader) finishEvent(ev Event, data []string, idSeen bool) *Event { + if idSeen { + r.lastID = ev.ID + } else { + ev.ID = r.lastID + } + ev.Data = strings.Join(data, "\n") + return &ev +} + +// Close closes the underlying reader when it implements io.Closer. +func (r *Reader) Close() error { + if r.c == nil { + return nil + } + return r.c.Close() +} diff --git a/sse/reader_test.go b/sse/reader_test.go new file mode 100644 index 0000000..e20981f --- /dev/null +++ b/sse/reader_test.go @@ -0,0 +1,39 @@ +package sse + +import ( + "errors" + "io" + "strings" + "testing" + "time" +) + +func TestReaderNext(t *testing.T) { + r := NewReader(strings.NewReader(": keepalive\nid: 42\nevent: point\nretry: 1500\ndata: {\"ok\":true}\ndata: tail\n\n" + + "event: next\ndata\n\n" + + "id:\nevent: reset\ndata: done\n\n")) + ev, err := r.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + if ev.ID != "42" || ev.Event != "point" || ev.Retry != 1500*time.Millisecond || ev.Data != "{\"ok\":true}\ntail" { + t.Fatalf("unexpected event: %#v", ev) + } + ev, err = r.Next() + if err != nil { + t.Fatalf("Next sticky ID: %v", err) + } + if ev.ID != "42" || ev.Event != "next" || ev.Data != "" { + t.Fatalf("unexpected sticky/no-colon event: %#v", ev) + } + ev, err = r.Next() + if err != nil { + t.Fatalf("Next reset ID: %v", err) + } + if ev.ID != "" || ev.Event != "reset" || ev.Data != "done" { + t.Fatalf("unexpected reset event: %#v", ev) + } + if _, err := r.Next(); !errors.Is(err, io.EOF) { + t.Fatalf("Next EOF = %v, want io.EOF", err) + } +} diff --git a/sse/sse.go b/sse/sse.go new file mode 100644 index 0000000..4f03d72 --- /dev/null +++ b/sse/sse.go @@ -0,0 +1,283 @@ +// Package sse implements a minimal Server-Sent Events writer for net/http. +// +// The package is transport-only: it owns the wire format (headers, framing, +// flushing, the [DONE] terminator) and a small helper for draining streaming +// channels. Authentication, request decoding, validation and routing are the +// caller's responsibility. +// +// For helpers that mount SSE endpoints on a Kratos HTTP server, see the +// module package github.com/crypto-zero/go-kit/kratos/sse. +package sse + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// DoneMarker is the conventional terminator sent as the final data frame of +// an SSE stream. It matches the OpenAI-style protocol that most browser and +// CLI clients already understand. +const DoneMarker = "[DONE]" + +// LastEventIDHeader is the HTTP header browsers send on reconnect to resume +// from the last received event. +const LastEventIDHeader = "Last-Event-ID" + +// Event is a single SSE event. All fields are optional; an event with only +// Data set produces the common "data: \n\n" frame. +type Event struct { + // Event is the event name. When empty, no "event:" field is written and + // browsers dispatch the frame as the default "message" event. + Event string + // ID populates the "id:" field, allowing clients to resume via the + // Last-Event-ID header on reconnect. + ID string + // Data is the event payload. It may contain newlines: each line is + // emitted as a separate "data:" field per the SSE spec. + Data string + // Retry, when non-zero, sets the client's reconnection delay in + // milliseconds via the "retry:" field. + Retry time.Duration +} + +// Stream writes Server-Sent Events to an HTTP response. Methods are safe for +// concurrent use, allowing a Heartbeat goroutine to coexist with the main +// writer. +type Stream struct { + mu sync.Mutex + w http.ResponseWriter + rc *http.ResponseController + started bool +} + +// NewStream wraps w for SSE output. Headers are not written until the first +// frame is sent (or Start is called explicitly), so callers may still return +// a non-SSE HTTP error after constructing the Stream. +// +// Flushes go through http.ResponseController, which walks any Unwrap() +// chain installed by middleware (metrics wrappers, response recorders) to +// reach the underlying Flusher. +func NewStream(w http.ResponseWriter) *Stream { + return &Stream{w: w, rc: http.NewResponseController(w)} +} + +// Start writes the SSE response headers and the 200 status line. It is safe +// to call multiple times; subsequent calls are no-ops. Callers normally do +// not need to invoke Start directly — any of the Write methods will trigger +// it on first use. +func (s *Stream) Start() { + s.mu.Lock() + defer s.mu.Unlock() + s.startLocked() +} + +func (s *Stream) startLocked() { + if s.started { + return + } + h := s.w.Header() + h.Set("Content-Type", "text/event-stream") + h.Set("Cache-Control", "no-cache") + h.Set("Connection", "keep-alive") + // Disable proxy buffering (nginx-specific but harmless elsewhere) so + // frames reach the client as soon as they are flushed. + h.Set("X-Accel-Buffering", "no") + s.w.WriteHeader(http.StatusOK) + s.started = true +} + +// Write sends a single default-event data frame and flushes immediately. +func (s *Stream) Write(data string) error { + return s.WriteEvent(Event{Data: data}) +} + +// WriteEvent sends a fully specified event and flushes immediately. +func (s *Stream) WriteEvent(e Event) error { + s.mu.Lock() + defer s.mu.Unlock() + s.startLocked() + + var b strings.Builder + if e.Event != "" { + b.WriteString("event: ") + b.WriteString(e.Event) + b.WriteByte('\n') + } + if e.ID != "" { + b.WriteString("id: ") + b.WriteString(e.ID) + b.WriteByte('\n') + } + if e.Retry > 0 { + fmt.Fprintf(&b, "retry: %d\n", e.Retry.Milliseconds()) + } + // Per spec, every newline in the payload starts a new "data:" field; + // the frame is terminated by a blank line. + for line := range strings.SplitSeq(e.Data, "\n") { + b.WriteString("data: ") + b.WriteString(line) + b.WriteByte('\n') + } + b.WriteByte('\n') + if _, err := io.WriteString(s.w, b.String()); err != nil { + return err + } + s.flushLocked() + return nil +} + +// WriteJSON marshals v and writes it as a single data frame. The caller is +// responsible for any terminating Done frame. +func (s *Stream) WriteJSON(v any) error { + return s.WriteJSONEvent(Event{}, v) +} + +// WriteJSONEvent marshals v and writes it as a named/id/retry event. +func (s *Stream) WriteJSONEvent(e Event, v any) error { + data, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("sse: marshal json: %w", err) + } + e.Data = string(data) + return s.WriteEvent(e) +} + +// Error writes an "error"-named event whose data payload is the JSON object +// {"error": msg}, matching the convention used by most JS EventSource +// consumers. +func (s *Stream) Error(msg string) error { + return s.WriteJSONEvent(Event{Event: "error"}, map[string]string{"error": msg}) +} + +// Done sends the DoneMarker as a final data frame, signaling end-of-stream +// to clients that follow the OpenAI-style protocol. +func (s *Stream) Done() error { + return s.Write(DoneMarker) +} + +// Comment writes an SSE comment frame (": text\n\n"). Comments are ignored +// by clients and useful as keepalive packets through proxies that close +// idle connections (nginx, ALB, CloudFlare). +// +// An empty text writes a bare ":\n\n" — the minimal valid keepalive. +func (s *Stream) Comment(text string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.startLocked() + var b strings.Builder + for line := range strings.SplitSeq(text, "\n") { + b.WriteByte(':') + if line != "" { + b.WriteByte(' ') + b.WriteString(line) + } + b.WriteByte('\n') + } + b.WriteByte('\n') + if _, err := io.WriteString(s.w, b.String()); err != nil { + return err + } + s.flushLocked() + return nil +} + +// Heartbeat starts a goroutine that emits a Comment frame every interval +// until ctx is cancelled or the returned stop function is called. +// +// stop is synchronous: it cancels the ticker and blocks until the +// goroutine has finished its current iteration. Callers MUST invoke +// stop before the underlying http.ResponseWriter is recycled (typically +// by deferring it in the handler) — writing a comment to a reclaimed +// response panics. stop is safe to call multiple times. +// +// Use this for long-lived streams that sit behind proxies with idle +// connection timeouts. +func (s *Stream) Heartbeat(ctx context.Context, interval time.Duration) (stop func()) { + if interval <= 0 { + return func() {} + } + ctx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + go func() { + defer close(done) + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if err := s.Comment(""); err != nil { + return + } + } + } + }() + var once sync.Once + return func() { + once.Do(func() { + cancel() + <-done + }) + } +} + +// Pump drains chunks into the stream as default data events. It returns +// when: +// +// - chunks is closed: Done is sent and nil is returned; +// - errs delivers a non-nil error: that error is forwarded via Error and +// returned (no Done is sent); +// - errs delivers nil: returns nil silently (no Done); +// - errs is closed: returns nil silently (no Done); +// - ctx is cancelled: returns ctx.Err() silently (no Done). +// +// Empty chunks are skipped. Callers typically use Pump to relay a +// (<-chan string, <-chan error) pair produced by a streaming biz call. +func (s *Stream) Pump(ctx context.Context, chunks <-chan string, errs <-chan error) error { + for { + select { + case chunk, ok := <-chunks: + if !ok { + return s.Done() + } + if chunk == "" { + continue + } + if err := s.Write(chunk); err != nil { + return err + } + case err, ok := <-errs: + if !ok { + return nil + } + if err == nil { + return nil + } + _ = s.Error(err.Error()) + return err + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (s *Stream) flushLocked() { + // Best-effort: ResponseController.Flush returns http.ErrNotSupported + // when no Flusher is reachable through the Unwrap chain. SSE without + // flushing degrades to "client receives nothing until close" — bad, + // but not something this layer can recover from. Silently ignore. + _ = s.rc.Flush() +} + +// LastEventID returns the value of the Last-Event-ID HTTP header sent by +// EventSource clients on reconnect. Returns "" when absent. +func LastEventID(r *http.Request) string { + return r.Header.Get(LastEventIDHeader) +} diff --git a/sse/sse_test.go b/sse/sse_test.go new file mode 100644 index 0000000..8441820 --- /dev/null +++ b/sse/sse_test.go @@ -0,0 +1,347 @@ +package sse + +import ( + "bufio" + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestStream_HeadersAreLazy(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if got := rec.Header().Get("Content-Type"); got != "" { + t.Errorf("Content-Type before first write = %q, want empty", got) + } + if err := s.Write("hello"); err != nil { + t.Fatalf("Write: %v", err) + } + if got, want := rec.Header().Get("Content-Type"), "text/event-stream"; got != want { + t.Errorf("Content-Type after Write = %q, want %q", got, want) + } + if got, want := rec.Header().Get("Cache-Control"), "no-cache"; got != want { + t.Errorf("Cache-Control = %q, want %q", got, want) + } + if got, want := rec.Header().Get("X-Accel-Buffering"), "no"; got != want { + t.Errorf("X-Accel-Buffering = %q, want %q", got, want) + } +} + +func TestStream_StartIdempotent(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + s.Start() + s.Start() + if got, want := rec.Code, 200; got != want { + t.Errorf("status = %d, want %d", got, want) + } +} + +func TestStream_WriteFormat(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Write("hello"); err != nil { + t.Fatalf("Write: %v", err) + } + if got, want := rec.Body.String(), "data: hello\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_WriteEventFull(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + err := s.WriteEvent(Event{ + Event: "delta", + ID: "42", + Data: "line1\nline2", + Retry: 1500 * time.Millisecond, + }) + if err != nil { + t.Fatalf("WriteEvent: %v", err) + } + want := "event: delta\nid: 42\nretry: 1500\ndata: line1\ndata: line2\n\n" + if got := rec.Body.String(); got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_WriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.WriteJSON(map[string]int{"n": 1}); err != nil { + t.Fatalf("WriteJSON: %v", err) + } + if got, want := rec.Body.String(), "data: {\"n\":1}\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_WriteJSONEvent(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.WriteJSONEvent(Event{Event: "point", ID: "42"}, map[string]int{"n": 1}); err != nil { + t.Fatalf("WriteJSONEvent: %v", err) + } + want := "event: point\nid: 42\ndata: {\"n\":1}\n\n" + if got := rec.Body.String(); got != want { + t.Fatalf("body = %q, want %q", got, want) + } +} + +func TestStream_Error(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Error("boom"); err != nil { + t.Fatalf("Error: %v", err) + } + want := "event: error\ndata: {\"error\":\"boom\"}\n\n" + if got := rec.Body.String(); got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_Done(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Done(); err != nil { + t.Fatalf("Done: %v", err) + } + if got, want := rec.Body.String(), "data: [DONE]\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestPump_CleanFinish(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string, 3) + errs := make(chan error) + chunks <- "a" + chunks <- "" + chunks <- "b" + close(chunks) + if err := s.Pump(context.Background(), chunks, errs); err != nil { + t.Fatalf("Pump: %v", err) + } + got := rec.Body.String() + want := "data: a\n\ndata: b\n\ndata: [DONE]\n\n" + if got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestPump_ErrorFromErrCh(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string) + errs := make(chan error, 1) + sentinel := errors.New("upstream failure") + errs <- sentinel + err := s.Pump(context.Background(), chunks, errs) + if !errors.Is(err, sentinel) { + t.Errorf("Pump err = %v, want %v", err, sentinel) + } + if !strings.Contains(rec.Body.String(), `"error":"upstream failure"`) { + t.Errorf("body missing error payload: %q", rec.Body.String()) + } + if strings.Contains(rec.Body.String(), DoneMarker) { + t.Errorf("body should not contain DONE on error, got %q", rec.Body.String()) + } +} + +func TestPump_ErrChClosedSilently(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string) + errs := make(chan error) + close(errs) + if err := s.Pump(context.Background(), chunks, errs); err != nil { + t.Errorf("Pump on closed errCh = %v, want nil", err) + } + if got := rec.Body.String(); got != "" { + t.Errorf("body = %q, want empty (no headers written either)", got) + } +} + +func TestPump_NilErrReturnsSilently(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + chunks := make(chan string) + errs := make(chan error, 1) + errs <- nil + if err := s.Pump(context.Background(), chunks, errs); err != nil { + t.Errorf("Pump on nil err = %v, want nil", err) + } + if got := rec.Body.String(); got != "" { + t.Errorf("body = %q, want empty", got) + } +} + +func TestPump_ContextCancel(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + ctx, cancel := context.WithCancel(context.Background()) + chunks := make(chan string) + errs := make(chan error) + cancel() + if err := s.Pump(ctx, chunks, errs); !errors.Is(err, context.Canceled) { + t.Errorf("Pump err = %v, want context.Canceled", err) + } +} + +func TestStream_Comment(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Comment("hello"); err != nil { + t.Fatalf("Comment: %v", err) + } + if got, want := rec.Body.String(), ": hello\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_CommentEmpty(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + if err := s.Comment(""); err != nil { + t.Fatalf("Comment: %v", err) + } + if got, want := rec.Body.String(), ":\n\n"; got != want { + t.Errorf("body = %q, want %q", got, want) + } +} + +func TestStream_Heartbeat(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + stop := s.Heartbeat(context.Background(), 5*time.Millisecond) + defer stop() + + // Concurrent writes: ensure mutex prevents interleaved frames. + done := make(chan struct{}) + go func() { + for range 20 { + _ = s.Write("token") + time.Sleep(time.Millisecond) + } + close(done) + }() + <-done + stop() + + body := rec.Body.String() + if !strings.Contains(body, ":\n\n") { + t.Errorf("expected at least one comment frame, got %q", body) + } + if !strings.Contains(body, "data: token\n\n") { + t.Errorf("expected data frames, got %q", body) + } +} + +func TestStream_HeartbeatNonPositiveIntervalIsNoop(t *testing.T) { + rec := httptest.NewRecorder() + s := NewStream(rec) + + stop := s.Heartbeat(context.Background(), 0) + stop() + stop() + + if got := rec.Body.String(); got != "" { + t.Errorf("body = %q, want empty", got) + } +} + +func TestLastEventID(t *testing.T) { + r := httptest.NewRequest("GET", "/x", nil) + r.Header.Set(LastEventIDHeader, "42") + if got := LastEventID(r); got != "42" { + t.Errorf("LastEventID(header) = %q, want %q", got, "42") + } + + r = httptest.NewRequest("GET", "/x", nil) + if got := LastEventID(r); got != "" { + t.Errorf("LastEventID(absent) = %q, want empty", got) + } +} + +func TestDetachWriteTimeout_DeadlineDoesNotPropagate(t *testing.T) { + rec := httptest.NewRecorder() + parent, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) + defer cancel() + req := httptest.NewRequest("POST", "/x", nil).WithContext(parent) + + r := DetachWriteTimeout(rec, req, time.Second) + <-parent.Done() + // Give the detach goroutine a moment to observe parent.Err(). + time.Sleep(20 * time.Millisecond) + + select { + case <-r.Context().Done(): + t.Errorf("detached context cancelled on parent DeadlineExceeded; want still alive") + default: + } +} + +func TestDetachWriteTimeout_PropagatesCancel(t *testing.T) { + rec := httptest.NewRecorder() + parent, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("POST", "/x", nil).WithContext(parent) + + r := DetachWriteTimeout(rec, req, time.Second) + cancel() + + select { + case <-r.Context().Done(): + case <-time.After(100 * time.Millisecond): + t.Errorf("detached context not cancelled when parent was Canceled") + } +} + +func TestDetachWriteTimeout_NonPositiveClearsWriteDeadline(t *testing.T) { + var got time.Time + w := deadlineResponseWriter{ + ResponseWriter: httptest.NewRecorder(), + setWriteDeadline: func(t time.Time) error { + got = t + return nil + }, + } + req := httptest.NewRequest("POST", "/x", nil) + + _ = DetachWriteTimeout(w, req, 0) + + if !got.IsZero() { + t.Errorf("write deadline = %v, want zero time", got) + } +} + +type deadlineResponseWriter struct { + http.ResponseWriter + setWriteDeadline func(time.Time) error +} + +func (w deadlineResponseWriter) SetWriteDeadline(t time.Time) error { + return w.setWriteDeadline(t) +} + +func (w deadlineResponseWriter) SetReadDeadline(time.Time) error { + return nil +} + +func (w deadlineResponseWriter) EnableFullDuplex() error { + return nil +} + +func (w deadlineResponseWriter) Flush() error { + return nil +} + +func (w deadlineResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, http.ErrNotSupported +} diff --git a/text/text.go b/text/text.go index 603856b..8449b60 100644 --- a/text/text.go +++ b/text/text.go @@ -7,31 +7,26 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "io" "math" "math/big" mrand "math/rand" - "os" "strings" "unicode" "golang.org/x/text/width" ) -// GeneratePassword generate password from /dev/urandom. +// GeneratePassword generates a password from cryptographically secure random bytes. func GeneratePassword(size int, accept func(byte) bool) (string, error) { - f, err := os.Open("/dev/urandom") - if err != nil { - return "", fmt.Errorf("open /dev/urandom: %w", err) - } password := make([]byte, 0, size) for len(password) < size { buf := make([]byte, size*2) - n, err := f.Read(buf) + n, err := io.ReadFull(crand.Reader, buf) if err != nil { - return "", fmt.Errorf("read /dev/urandom: %w", err) + return "", fmt.Errorf("read random bytes: %w", err) } - for idx := 0; idx < n; idx++ { - // Ascii printable characters + for idx := range n { if accept(buf[idx]) && len(password) < size { password = append(password, buf[idx]) } @@ -40,7 +35,7 @@ func GeneratePassword(size int, accept func(byte) bool) (string, error) { return string(password), nil } -// GeneratePasswordLitterNumbers generate password with litter and numbers. +// GeneratePasswordLitterNumbers generates a password with letters and numbers. func GeneratePasswordLitterNumbers(size int) (string, error) { return GeneratePassword(size, func(b byte) bool { return b >= '0' && b <= '9' || b >= 'A' && b <= 'Z' || b >= 'a' && b <= 'z' @@ -64,8 +59,8 @@ func RandString(length int) string { // maxInt64 is the maximum value of int64. var maxInt64 = big.NewInt(math.MaxInt64) -// RandStringWithCharset returns a random string with given length and charset. -// it uses crypto/rand to generate random string. +// RandStringWithCharset returns a random string with the given length and charset. +// It uses crypto/rand when available. func RandStringWithCharset(length int, charset string) string { var seed int64 if err := binary.Read(crand.Reader, binary.BigEndian, &seed); err != nil { @@ -99,17 +94,17 @@ func CleanAllSpace(s string) string { }, s) } -// NarrowString 全角转半角 +// NarrowString converts full-width characters to half-width characters. func NarrowString(s string) string { return width.Narrow.String(s) } -// CleanString clean all space and narrow string +// CleanString removes all spaces and narrows full-width characters. func CleanString(s string) string { return CleanAllSpace(NarrowString(strings.ToValidUTF8(s, ""))) } -// TrimString trim prefix and suffix space and narrow string +// TrimString trims surrounding spaces and narrows full-width characters. func TrimString(s string) string { return strings.TrimSpace(NarrowString(strings.ToValidUTF8(s, ""))) }