Skip to content

Commit 1410dae

Browse files
committed
Update client/crypto to use new helpers
1 parent 079730c commit 1410dae

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

client/crypto.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"google.golang.org/grpc"
2323
"google.golang.org/protobuf/proto"
2424

25+
"github.com/dapr/go-sdk/client/internal/crypto"
2526
commonv1pb "github.com/dapr/go-sdk/internal/proto/dapr/proto/common/v1"
2627
runtimev1pb "github.com/dapr/go-sdk/internal/proto/dapr/proto/runtime/v1"
2728
)
@@ -48,9 +49,11 @@ func (c *GRPCClient) Encrypt(ctx context.Context, in io.Reader, opts EncryptOpti
4849
}
4950

5051
// Use the context of the stream here.
51-
return c.performCryptoOperation(
52-
stream.Context(), stream,
53-
in, opts,
52+
return performCryptoOperation(
53+
stream.Context(),
54+
stream,
55+
in,
56+
opts,
5457
&runtimev1pb.EncryptRequest{},
5558
&runtimev1pb.EncryptResponse{},
5659
)
@@ -72,15 +75,25 @@ func (c *GRPCClient) Decrypt(ctx context.Context, in io.Reader, opts DecryptOpti
7275
}
7376

7477
// Use the context of the stream here.
75-
return c.performCryptoOperation(
76-
stream.Context(), stream,
77-
in, opts,
78+
return performCryptoOperation(
79+
stream.Context(),
80+
stream,
81+
in,
82+
opts,
7883
&runtimev1pb.DecryptRequest{},
7984
&runtimev1pb.DecryptResponse{},
8085
)
8186
}
8287

83-
func (c *GRPCClient) performCryptoOperation(ctx context.Context, stream grpc.ClientStream, in io.Reader, opts cryptoOperationOpts, reqProto runtimev1pb.CryptoRequests, resProto runtimev1pb.CryptoResponses) (io.Reader, error) {
88+
func performCryptoOperation[T runtimev1pb.DecryptRequest | runtimev1pb.EncryptRequest, Y runtimev1pb.DecryptResponse | runtimev1pb.EncryptResponse](
89+
ctx context.Context,
90+
stream grpc.ClientStream,
91+
in io.Reader,
92+
opts cryptoOperationOpts,
93+
reqProto *T,
94+
resProto *Y,
95+
) (io.Reader, error) {
96+
8497
var err error
8598
// Pipe for writing the response
8699
pr, pw := io.Pipe()
@@ -110,11 +123,11 @@ func (c *GRPCClient) performCryptoOperation(ctx context.Context, stream grpc.Cli
110123

111124
// First message only - add the options
112125
if optsProto != nil {
113-
reqProto.SetOptions(optsProto)
126+
crypto.SetOptions(reqProto, optsProto)
114127
optsProto = nil
115128
} else {
116129
// Reset the object so we can re-use it
117-
reqProto.Reset()
130+
crypto.Reset(reqProto)
118131
}
119132

120133
n, err = in.Read(*reqBuf)
@@ -127,7 +140,7 @@ func (c *GRPCClient) performCryptoOperation(ctx context.Context, stream grpc.Cli
127140

128141
// Send the chunk if there's anything to send
129142
if n > 0 {
130-
reqProto.SetPayload(&commonv1pb.StreamPayload{
143+
crypto.SetPayload(reqProto, &commonv1pb.StreamPayload{
131144
Data: (*reqBuf)[:n],
132145
Seq: seq,
133146
})
@@ -184,7 +197,7 @@ func (c *GRPCClient) performCryptoOperation(ctx context.Context, stream grpc.Cli
184197
}
185198

186199
// Write the data, if any, into the pipe
187-
payload = resProto.GetPayload()
200+
payload = crypto.GetPayload(resProto)
188201
if payload != nil {
189202
if payload.GetSeq() != expectSeq {
190203
pw.CloseWithError(fmt.Errorf("invalid sequence number in chunk: %d (expected: %d)", payload.GetSeq(), expectSeq))
@@ -205,7 +218,7 @@ func (c *GRPCClient) performCryptoOperation(ctx context.Context, stream grpc.Cli
205218
}
206219

207220
// Reset the proto
208-
resProto.Reset()
221+
crypto.Reset(resProto)
209222
}
210223

211224
// Close the writer of the pipe when done

client/crypto_test.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/stretchr/testify/require"
2626
"google.golang.org/grpc"
2727

28+
"github.com/dapr/go-sdk/client/internal/crypto"
2829
commonv1 "github.com/dapr/go-sdk/internal/proto/dapr/proto/common/v1"
2930
runtimev1pb "github.com/dapr/go-sdk/internal/proto/dapr/proto/runtime/v1"
3031
)
@@ -189,22 +190,24 @@ func TestDecrypt(t *testing.T) {
189190
/* --- Server methods --- */
190191

191192
func (s *testDaprServer) EncryptAlpha1(stream runtimev1pb.Dapr_EncryptAlpha1Server) error {
192-
return s.performCryptoOperation(
193+
return testPerformCryptoOperation(
193194
stream,
194195
&runtimev1pb.EncryptRequest{},
195196
&runtimev1pb.EncryptResponse{},
196197
)
197198
}
198199

199200
func (s *testDaprServer) DecryptAlpha1(stream runtimev1pb.Dapr_DecryptAlpha1Server) error {
200-
return s.performCryptoOperation(
201+
return testPerformCryptoOperation(
201202
stream,
202203
&runtimev1pb.DecryptRequest{},
203204
&runtimev1pb.DecryptResponse{},
204205
)
205206
}
206207

207-
func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqProto runtimev1pb.CryptoRequests, resProto runtimev1pb.CryptoResponses) error {
208+
func testPerformCryptoOperation[T runtimev1pb.DecryptRequest | runtimev1pb.EncryptRequest, Y runtimev1pb.DecryptResponse | runtimev1pb.EncryptResponse](
209+
stream grpc.ServerStream, reqProto *T, resProto *Y,
210+
) error {
208211
// This doesn't really encrypt or decrypt the data and just sends back whatever it receives
209212
pr, pw := io.Pipe()
210213

@@ -216,7 +219,7 @@ func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqPro
216219
)
217220
first := true
218221
for !done && stream.Context().Err() == nil {
219-
reqProto.Reset()
222+
crypto.Reset(reqProto)
220223
err = stream.RecvMsg(reqProto)
221224
if errors.Is(err, io.EOF) {
222225
done = true
@@ -225,16 +228,16 @@ func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqPro
225228
return
226229
}
227230

228-
if first && !reqProto.HasOptions() {
231+
if first && !hasOptions(reqProto) {
229232
pw.CloseWithError(errors.New("first message must have options"))
230233
return
231-
} else if !first && reqProto.HasOptions() {
234+
} else if !first && hasOptions(reqProto) {
232235
pw.CloseWithError(errors.New("messages after first must not have options"))
233236
return
234237
}
235238
first = false
236239

237-
payload := reqProto.GetPayload()
240+
payload := crypto.GetPayload(reqProto)
238241
if payload != nil {
239242
if payload.GetSeq() != expectSeq {
240243
pw.CloseWithError(fmt.Errorf("invalid sequence number: %d (expected: %d)", payload.GetSeq(), expectSeq))
@@ -261,7 +264,7 @@ func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqPro
261264
)
262265
buf := make([]byte, 2<<10)
263266
for !done && stream.Context().Err() == nil {
264-
resProto.Reset()
267+
crypto.Reset(resProto)
265268

266269
n, err = pr.Read(buf)
267270
if errors.Is(err, io.EOF) {
@@ -271,7 +274,7 @@ func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqPro
271274
}
272275

273276
if n > 0 {
274-
resProto.SetPayload(&commonv1.StreamPayload{
277+
crypto.SetPayload(resProto, &commonv1.StreamPayload{
275278
Seq: seq,
276279
Data: buf[:n],
277280
})
@@ -286,3 +289,18 @@ func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqPro
286289

287290
return nil
288291
}
292+
293+
func hasOptions[T runtimev1pb.DecryptRequest | runtimev1pb.EncryptRequest](msg *T) bool {
294+
if msg == nil {
295+
return false
296+
}
297+
298+
switch r := any(msg).(type) {
299+
case *runtimev1pb.EncryptRequest:
300+
return r.Options != nil
301+
case *runtimev1pb.DecryptRequest:
302+
return r.Options != nil
303+
}
304+
305+
return false
306+
}

0 commit comments

Comments
 (0)