Skip to content

Commit 079730c

Browse files
committed
Add helpers to replace hand-made changes to dapr/dapr protos
1 parent 7d96902 commit 079730c

File tree

2 files changed

+218
-0
lines changed

2 files changed

+218
-0
lines changed

client/internal/crypto/helpers.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package crypto
2+
3+
import (
4+
commonv1pb "github.com/dapr/go-sdk/internal/proto/dapr/proto/common/v1"
5+
runtimev1 "github.com/dapr/go-sdk/internal/proto/dapr/proto/runtime/v1"
6+
"google.golang.org/protobuf/proto"
7+
)
8+
9+
func GetPayload[T runtimev1.DecryptRequest | runtimev1.EncryptRequest | runtimev1.DecryptResponse | runtimev1.EncryptResponse](req *T) *commonv1pb.StreamPayload {
10+
if req == nil {
11+
return nil
12+
}
13+
14+
switch r := any(req).(type) {
15+
case *runtimev1.EncryptRequest:
16+
return r.Payload
17+
case *runtimev1.DecryptRequest:
18+
return r.Payload
19+
case *runtimev1.EncryptResponse:
20+
return r.Payload
21+
case *runtimev1.DecryptResponse:
22+
return r.Payload
23+
}
24+
25+
return nil
26+
}
27+
28+
func SetPayload[T runtimev1.DecryptRequest | runtimev1.EncryptRequest | runtimev1.DecryptResponse | runtimev1.EncryptResponse](req *T, payload *commonv1pb.StreamPayload) {
29+
if req == nil {
30+
return
31+
}
32+
33+
switch r := any(req).(type) {
34+
case *runtimev1.EncryptRequest:
35+
r.Payload = payload
36+
case *runtimev1.DecryptRequest:
37+
r.Payload = payload
38+
case *runtimev1.EncryptResponse:
39+
r.Payload = payload
40+
case *runtimev1.DecryptResponse:
41+
r.Payload = payload
42+
}
43+
}
44+
45+
func SetOptions[T runtimev1.DecryptRequest | runtimev1.EncryptRequest](req *T, opts proto.Message) {
46+
if req == nil {
47+
return
48+
}
49+
50+
switch r := any(req).(type) {
51+
case *runtimev1.EncryptRequest:
52+
r.Options = opts.(*runtimev1.EncryptRequestOptions)
53+
case *runtimev1.DecryptRequest:
54+
r.Options = opts.(*runtimev1.DecryptRequestOptions)
55+
}
56+
}
57+
58+
func Reset[T runtimev1.DecryptRequest | runtimev1.EncryptRequest | runtimev1.DecryptResponse | runtimev1.EncryptResponse](msg *T) {
59+
if msg == nil {
60+
return
61+
}
62+
63+
switch r := any(msg).(type) {
64+
case *runtimev1.EncryptRequest:
65+
r.Reset()
66+
case *runtimev1.DecryptRequest:
67+
r.Reset()
68+
case *runtimev1.EncryptResponse:
69+
r.Reset()
70+
case *runtimev1.DecryptResponse:
71+
r.Reset()
72+
}
73+
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package crypto_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/dapr/go-sdk/client/internal/crypto"
7+
commonv1 "github.com/dapr/go-sdk/internal/proto/dapr/proto/common/v1"
8+
runtimev1 "github.com/dapr/go-sdk/internal/proto/dapr/proto/runtime/v1"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"google.golang.org/protobuf/proto"
12+
)
13+
14+
func TestPayloadMethods(t *testing.T) {
15+
testCases := map[string]struct {
16+
protoMsg any
17+
inputData []byte
18+
}{
19+
"EncryptRequest": {
20+
protoMsg: &runtimev1.EncryptRequest{},
21+
inputData: []byte("test data"),
22+
},
23+
"EncryptResponse": {
24+
protoMsg: &runtimev1.EncryptResponse{},
25+
inputData: []byte("test data"),
26+
},
27+
"DecryptRequest": {
28+
protoMsg: &runtimev1.DecryptRequest{},
29+
inputData: []byte("test data"),
30+
},
31+
"DecryptResponse": {
32+
protoMsg: &runtimev1.DecryptResponse{},
33+
inputData: []byte("test data"),
34+
},
35+
}
36+
37+
for name, tc := range testCases {
38+
t.Run(name, func(t *testing.T) {
39+
inputPayload := &commonv1.StreamPayload{Data: tc.inputData}
40+
var outputPayload *commonv1.StreamPayload
41+
42+
switch r := tc.protoMsg.(type) {
43+
case *runtimev1.EncryptRequest:
44+
crypto.SetPayload(r, inputPayload)
45+
outputPayload = crypto.GetPayload(r)
46+
case *runtimev1.EncryptResponse:
47+
crypto.SetPayload(r, inputPayload)
48+
outputPayload = crypto.GetPayload(r)
49+
case *runtimev1.DecryptRequest:
50+
crypto.SetPayload(r, inputPayload)
51+
outputPayload = crypto.GetPayload(r)
52+
case *runtimev1.DecryptResponse:
53+
crypto.SetPayload(r, inputPayload)
54+
outputPayload = crypto.GetPayload(r)
55+
default:
56+
require.Failf(t, "unsupported proto message type", "the type was %T", r)
57+
}
58+
59+
assert.Equal(t, tc.inputData, outputPayload.Data, "payload should match the input")
60+
})
61+
}
62+
}
63+
64+
func TestSetOptions(t *testing.T) {
65+
testCases := map[string]struct {
66+
protoMsg any
67+
options proto.Message
68+
}{
69+
"EncryptRequest": {
70+
protoMsg: &runtimev1.EncryptRequest{},
71+
options: &runtimev1.EncryptRequestOptions{
72+
KeyName: "testing",
73+
},
74+
},
75+
"DecryptRequest": {
76+
protoMsg: &runtimev1.DecryptRequest{},
77+
options: &runtimev1.DecryptRequestOptions{
78+
KeyName: "testing",
79+
},
80+
},
81+
}
82+
83+
for name, tc := range testCases {
84+
t.Run(name, func(t *testing.T) {
85+
var outputOptions proto.Message
86+
87+
switch r := tc.protoMsg.(type) {
88+
case *runtimev1.EncryptRequest:
89+
crypto.SetOptions(r, tc.options)
90+
outputOptions = r.Options
91+
case *runtimev1.DecryptRequest:
92+
crypto.SetOptions(r, tc.options)
93+
outputOptions = r.Options
94+
default:
95+
require.Failf(t, "unsupported proto message type", "the type was %T", r)
96+
}
97+
98+
assert.Equal(t, tc.options, outputOptions, "options should be persisted")
99+
})
100+
}
101+
}
102+
103+
func TestReset(t *testing.T) {
104+
testCases := map[string]struct {
105+
protoMsg any
106+
}{
107+
"EncryptRequest": {
108+
protoMsg: &runtimev1.EncryptRequest{Payload: &commonv1.StreamPayload{Data: []byte("test data")}},
109+
},
110+
"EncryptResponse": {
111+
protoMsg: &runtimev1.EncryptResponse{Payload: &commonv1.StreamPayload{Data: []byte("test data")}},
112+
},
113+
"DecryptRequest": {
114+
protoMsg: &runtimev1.DecryptRequest{Payload: &commonv1.StreamPayload{Data: []byte("test data")}},
115+
},
116+
"DecryptResponse": {
117+
protoMsg: &runtimev1.DecryptResponse{Payload: &commonv1.StreamPayload{Data: []byte("test data")}},
118+
},
119+
}
120+
121+
for name, tc := range testCases {
122+
t.Run(name, func(t *testing.T) {
123+
var payload *commonv1.StreamPayload
124+
125+
switch r := tc.protoMsg.(type) {
126+
case *runtimev1.EncryptRequest:
127+
crypto.Reset(r)
128+
payload = crypto.GetPayload(r)
129+
case *runtimev1.EncryptResponse:
130+
crypto.Reset(r)
131+
payload = crypto.GetPayload(r)
132+
case *runtimev1.DecryptRequest:
133+
crypto.Reset(r)
134+
payload = crypto.GetPayload(r)
135+
case *runtimev1.DecryptResponse:
136+
crypto.Reset(r)
137+
payload = crypto.GetPayload(r)
138+
default:
139+
require.Failf(t, "unsupported proto message type", "the type was %T", r)
140+
}
141+
142+
assert.Nil(t, payload)
143+
})
144+
}
145+
}

0 commit comments

Comments
 (0)