Skip to content

Commit 2ae9676

Browse files
xaionaro@dx.centerxaionaro@dx.center
authored andcommitted
Replace fmt.Errorf with custom error types in pkg/requester
1 parent 3c09bde commit 2ae9676

17 files changed

+938
-165
lines changed

pkg/requester/auth.go

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (r *Requester) GetDigests(ctx context.Context) (_ret [][]byte, _err error)
2929

3030
reqBytes, err := req.Marshal()
3131
if err != nil {
32-
return nil, fmt.Errorf("marshal request: %w", err)
32+
return nil, &ErrMarshalRequest{Err: err}
3333
}
3434

3535
resp, err := r.sendReceive(ctx, req)
@@ -40,7 +40,7 @@ func (r *Requester) GetDigests(ctx context.Context) (_ret [][]byte, _err error)
4040
var dr msgs.DigestResponse
4141
digestSize := r.conn.HashAlgo.Size()
4242
if err := dr.UnmarshalWithDigestSize(resp, digestSize); err != nil {
43-
return nil, fmt.Errorf("unmarshal: %w", err)
43+
return nil, &ErrUnmarshalResponse{Err: err}
4444
}
4545

4646
// Record B transcript: GET_DIGESTS request + DIGESTS response.
@@ -73,7 +73,7 @@ func (r *Requester) GetCertificate(ctx context.Context, slotID uint8) (_ret []by
7373

7474
reqBytes, err := req.Marshal()
7575
if err != nil {
76-
return nil, fmt.Errorf("marshal request: %w", err)
76+
return nil, &ErrMarshalRequest{Err: err}
7777
}
7878

7979
resp, err := r.sendReceive(ctx, req)
@@ -83,7 +83,7 @@ func (r *Requester) GetCertificate(ctx context.Context, slotID uint8) (_ret []by
8383

8484
var cr msgs.CertificateResponse
8585
if err := cr.Unmarshal(resp); err != nil {
86-
return nil, fmt.Errorf("unmarshal: %w", err)
86+
return nil, &ErrUnmarshalResponse{Err: err}
8787
}
8888

8989
// Record B transcript: GET_CERTIFICATE request + CERTIFICATE response.
@@ -103,7 +103,7 @@ func (r *Requester) GetCertificate(ctx context.Context, slotID uint8) (_ret []by
103103
// Validate the certificate chain if a trust anchor pool is configured.
104104
if r.cfg.Crypto.CertPool != nil {
105105
if err := r.validateCertChain(ctx, chain); err != nil {
106-
return nil, fmt.Errorf("certificate chain validation: %w", err)
106+
return nil, &ErrCertChainValidation{Err: err}
107107
}
108108
}
109109

@@ -121,17 +121,17 @@ func (r *Requester) validateCertChain(ctx context.Context, chain []byte) error {
121121
hashSize := r.conn.HashAlgo.Size()
122122
minSize := msgs.CertChainHeaderSize + hashSize
123123
if len(chain) < minSize {
124-
return fmt.Errorf("chain too short: %d bytes, need at least %d", len(chain), minSize)
124+
return &ErrCertChainTooShort{Size: len(chain), MinSize: minSize}
125125
}
126126

127127
certData := chain[msgs.CertChainHeaderSize+hashSize:]
128128
certs, err := parseDERCertificates(certData)
129129
if err != nil {
130-
return fmt.Errorf("parse certificates: %w", err)
130+
return &ErrParseCertificates{Err: err}
131131
}
132132

133133
if len(certs) == 0 {
134-
return fmt.Errorf("no certificates in chain")
134+
return &ErrNoCertificatesInChain{}
135135
}
136136

137137
// Build intermediate pool from all certs except the leaf (last one).
@@ -147,7 +147,7 @@ func (r *Requester) validateCertChain(ctx context.Context, chain []byte) error {
147147
}
148148

149149
if _, err := leaf.Verify(opts); err != nil {
150-
return fmt.Errorf("verify leaf certificate: %w", err)
150+
return &ErrVerifyLeafCertificate{Err: err}
151151
}
152152

153153
logger.Debugf(ctx, "certificate chain validated: %d certificates, leaf CN=%s", len(certs), leaf.Subject.CommonName)
@@ -170,12 +170,12 @@ func (r *Requester) Challenge(ctx context.Context, slotID uint8, hashType uint8)
170170
}},
171171
}
172172
if _, err := rand.Read(req.Nonce[:]); err != nil {
173-
return fmt.Errorf("generate nonce: %w", err)
173+
return &ErrGenerateNonce{Err: err}
174174
}
175175

176176
reqBytes, err := req.Marshal()
177177
if err != nil {
178-
return fmt.Errorf("marshal request: %w", err)
178+
return &ErrMarshalRequest{Err: err}
179179
}
180180

181181
resp, err := r.sendReceive(ctx, req)
@@ -194,13 +194,13 @@ func (r *Requester) Challenge(ctx context.Context, slotID uint8, hashType uint8)
194194

195195
var car msgs.ChallengeAuthResponse
196196
if err := car.UnmarshalWithSizes(resp, digestSize, measHashSize, sigSize); err != nil {
197-
return fmt.Errorf("unmarshal: %w", err)
197+
return &ErrUnmarshalResponse{Err: err}
198198
}
199199

200200
// Verify the responder's signature if a verifier is configured and a cert chain is available.
201201
if r.cfg.Crypto.Verifier != nil && len(r.peerCertChain) > 0 {
202202
if err := r.verifyChallengeSignature(ctx, reqBytes, resp, &car); err != nil {
203-
return fmt.Errorf("signature verification: %w", err)
203+
return &ErrSignatureVerification{Err: err}
204204
}
205205
}
206206

@@ -218,7 +218,7 @@ func (r *Requester) verifyChallengeSignature(
218218
// Extract the responder's public key from the stored certificate chain.
219219
pubKey, err := r.extractPeerPublicKey()
220220
if err != nil {
221-
return fmt.Errorf("extract peer public key: %w", err)
221+
return &ErrExtractPeerPublicKey{Err: err}
222222
}
223223

224224
// Marshal response without signature to build M1.
@@ -240,7 +240,7 @@ func (r *Requester) verifyChallengeSignature(
240240
digest.Write(signData)
241241

242242
if err := r.cfg.Crypto.Verifier.Verify(r.conn.AsymAlgo, pubKey, digest.Sum(nil), car.Signature); err != nil {
243-
return fmt.Errorf("verify: %w", err)
243+
return &ErrVerify{Err: err}
244244
}
245245

246246
logger.Debugf(ctx, "Challenge signature verified successfully")
@@ -284,25 +284,25 @@ func buildSigningData(h gocrypto.Hash, message []byte, contextStr string) []byte
284284
// [4+H:] concatenated DER-encoded X.509 certificates
285285
func (r *Requester) extractPeerPublicKey() (gocrypto.PublicKey, error) {
286286
if len(r.peerCertChain) == 0 {
287-
return nil, fmt.Errorf("no peer certificate chain available")
287+
return nil, &ErrNoPeerCertChain{}
288288
}
289289

290290
hashSize := r.conn.HashAlgo.Size()
291291
minSize := msgs.CertChainHeaderSize + hashSize
292292
if len(r.peerCertChain) < minSize {
293-
return nil, fmt.Errorf("peer cert chain too short: %d bytes, need at least %d", len(r.peerCertChain), minSize)
293+
return nil, &ErrPeerCertChainTooShort{Size: len(r.peerCertChain), MinSize: minSize}
294294
}
295295

296296
// Skip the SPDM cert chain header (4-byte length+reserved) and root hash.
297297
certData := r.peerCertChain[msgs.CertChainHeaderSize+hashSize:]
298298

299299
certs, err := parseDERCertificates(certData)
300300
if err != nil {
301-
return nil, fmt.Errorf("parse certificates: %w", err)
301+
return nil, &ErrParseCertificates{Err: err}
302302
}
303303

304304
if len(certs) == 0 {
305-
return nil, fmt.Errorf("no certificates found in chain")
305+
return nil, &ErrNoCertificatesFoundInChain{}
306306
}
307307

308308
// The leaf certificate is the last one in the chain.
@@ -319,12 +319,12 @@ func parseDERCertificates(data []byte) ([]*x509.Certificate, error) {
319319
// Determine this certificate's DER length first to extract exactly one cert.
320320
certLen, err := derObjectLength(remaining)
321321
if err != nil {
322-
return nil, fmt.Errorf("determine certificate length at offset %d: %w", len(data)-len(remaining), err)
322+
return nil, &ErrDetermineCertLengthAtOffset{Offset: len(data) - len(remaining), Err: err}
323323
}
324324

325325
cert, err := x509.ParseCertificate(remaining[:certLen])
326326
if err != nil {
327-
return nil, fmt.Errorf("parse certificate at offset %d: %w", len(data)-len(remaining), err)
327+
return nil, &ErrParseCertificateAtOffset{Offset: len(data) - len(remaining), Err: err}
328328
}
329329
certs = append(certs, cert)
330330
remaining = remaining[certLen:]
@@ -337,7 +337,7 @@ func parseDERCertificates(data []byte) ([]*x509.Certificate, error) {
337337
// ASN.1 object in data. This is needed to advance past concatenated DER certificates.
338338
func derObjectLength(data []byte) (int, error) {
339339
if len(data) < 2 {
340-
return 0, fmt.Errorf("DER data too short")
340+
return 0, &ErrInvalidDER{Reason: "DER data too short"}
341341
}
342342

343343
// Skip the tag byte.
@@ -350,10 +350,10 @@ func derObjectLength(data []byte) (int, error) {
350350
// Long form: lenByte & 0x7F = number of subsequent length bytes.
351351
numLenBytes := int(lenByte & 0x7F)
352352
if numLenBytes == 0 || numLenBytes > 4 {
353-
return 0, fmt.Errorf("invalid DER length encoding: %d length bytes", numLenBytes)
353+
return 0, &ErrInvalidDER{Reason: fmt.Sprintf("invalid DER length encoding: %d length bytes", numLenBytes)}
354354
}
355355
if 2+numLenBytes > len(data) {
356-
return 0, fmt.Errorf("DER data too short for length encoding")
356+
return 0, &ErrInvalidDER{Reason: "DER data too short for length encoding"}
357357
}
358358

359359
var length uint32
@@ -363,7 +363,7 @@ func derObjectLength(data []byte) (int, error) {
363363

364364
totalLen := 2 + numLenBytes + int(length)
365365
if totalLen > len(data) {
366-
return 0, fmt.Errorf("DER object length %d exceeds data length %d", totalLen, len(data))
366+
return 0, &ErrInvalidDER{Reason: fmt.Sprintf("DER object length %d exceeds data length %d", totalLen, len(data))}
367367
}
368368
return totalLen, nil
369369
}

pkg/requester/chunk.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package requester
22

33
import (
44
"context"
5-
"fmt"
65

76
"github.com/facebookincubator/go-belt/tool/logger"
87

@@ -25,7 +24,7 @@ func (r *Requester) ChunkSend(ctx context.Context, handle uint8, largeMsg []byte
2524
maxChunkRest := int(r.cfg.DataTransferSize) - 12
2625

2726
if maxChunkFirst <= 0 || maxChunkRest <= 0 {
28-
return fmt.Errorf("chunk_send: DataTransferSize too small for chunking")
27+
return &ErrChunkDataTransferSizeTooSmall{}
2928
}
3029

3130
offset := 0
@@ -67,21 +66,21 @@ func (r *Requester) ChunkSend(ctx context.Context, handle uint8, largeMsg []byte
6766

6867
resp, err := r.sendReceive(ctx, req)
6968
if err != nil {
70-
return fmt.Errorf("chunk_send: seq=%d: %w", seqNo, err)
69+
return &ErrChunkSend{SeqNo: seqNo, Err: err}
7170
}
7271

7372
if resp[1] != uint8(codes.ResponseChunkSendAck) {
74-
return fmt.Errorf("chunk_send: unexpected response code 0x%02X at seq=%d", resp[1], seqNo)
73+
return &ErrChunkSendUnexpectedResponseCode{Code: resp[1], SeqNo: seqNo}
7574
}
7675

7776
var ack msgs.ChunkSendAck
7877
if err := ack.Unmarshal(resp); err != nil {
79-
return fmt.Errorf("chunk_send: unmarshal ack: %w", err)
78+
return &ErrChunkUnmarshalAck{Err: err}
8079
}
8180

8281
// Check for early error response embedded in ACK.
8382
if ack.Header.Param1&msgs.ChunkSendAckAttrEarlyError != 0 {
84-
return fmt.Errorf("chunk_send: early error at seq=%d", seqNo)
83+
return &ErrChunkSendEarlyError{SeqNo: seqNo}
8584
}
8685

8786
offset = end
@@ -112,16 +111,16 @@ func (r *Requester) ChunkGet(ctx context.Context, handle uint8) ([]byte, error)
112111

113112
resp, err := r.sendReceive(ctx, req)
114113
if err != nil {
115-
return nil, fmt.Errorf("chunk_get: seq=%d: %w", seqNo, err)
114+
return nil, &ErrChunkGet{SeqNo: seqNo, Err: err}
116115
}
117116

118117
if resp[1] != uint8(codes.ResponseChunkResponse) {
119-
return nil, fmt.Errorf("chunk_get: unexpected response code 0x%02X at seq=%d", resp[1], seqNo)
118+
return nil, &ErrChunkGetUnexpectedResponseCode{Code: resp[1], SeqNo: seqNo}
120119
}
121120

122121
var cr msgs.ChunkResponse
123122
if err := cr.Unmarshal(resp); err != nil {
124-
return nil, fmt.Errorf("chunk_get: unmarshal: %w", err)
123+
return nil, &ErrChunkUnmarshalResponse{Err: err}
125124
}
126125

127126
result = append(result, cr.Chunk...)

pkg/requester/connection.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ func (r *Requester) InitConnection(ctx context.Context) (_ret *ConnectionInfo, _
2020
logger.Tracef(ctx, "InitConnection")
2121
defer func() { logger.Tracef(ctx, "/InitConnection: result:%v; err:%v", _ret, _err) }()
2222
if err := r.getVersion(ctx); err != nil {
23-
return nil, fmt.Errorf("get_version: %w", err)
23+
return nil, &ErrGetVersion{Err: err}
2424
}
2525
if err := r.getCapabilities(ctx); err != nil {
26-
return nil, fmt.Errorf("get_capabilities: %w", err)
26+
return nil, &ErrGetCapabilities{Err: err}
2727
}
2828
if err := r.negotiateAlgorithms(ctx); err != nil {
29-
return nil, fmt.Errorf("negotiate_algorithms: %w", err)
29+
return nil, &ErrNegotiateAlgorithms{Err: err}
3030
}
3131
return &r.conn, nil
3232
}
@@ -48,18 +48,17 @@ func (r *Requester) getVersion(ctx context.Context) error {
4848

4949
var vr msgs.VersionResponse
5050
if err := vr.Unmarshal(resp); err != nil {
51-
return fmt.Errorf("unmarshal: %w", err)
51+
return &ErrUnmarshalResponse{Err: err}
5252
}
5353

5454
// Per DSP0274 Section 10.3, VERSION response SPDMVersion must be 0x10.
5555
if vr.Header.SPDMVersion != 0x10 {
56-
return fmt.Errorf("VERSION response SPDMVersion=0x%02X, expected 0x10: %w",
57-
vr.Header.SPDMVersion, status.ErrInvalidMsgField)
56+
return &ErrVersionResponseInvalid{SPDMVersion: vr.Header.SPDMVersion}
5857
}
5958

6059
// Per DSP0274 Section 10.3, VERSION response must contain at least one entry.
6160
if len(vr.VersionEntries) == 0 {
62-
return fmt.Errorf("VERSION response has 0 entries: %w", status.ErrInvalidMsgField)
61+
return &ErrVersionResponseEmpty{}
6362
}
6463

6564
// Build set of our supported versions for lookup.
@@ -118,14 +117,14 @@ func (r *Requester) getCapabilities(ctx context.Context) error {
118117

119118
var cr msgs.CapabilitiesResponse
120119
if err := cr.Unmarshal(resp); err != nil {
121-
return fmt.Errorf("unmarshal: %w", err)
120+
return &ErrUnmarshalResponse{Err: err}
122121
}
123122

124123
r.conn.PeerCaps = caps.ResponderCaps(cr.Flags)
125124

126125
// Per DSP0274 Section 10.4, validate peer capability flag dependencies.
127126
if err := caps.ValidateResponderCaps(r.conn.PeerCaps); err != nil {
128-
return fmt.Errorf("invalid peer capabilities: %w", err)
127+
return &ErrInvalidPeerCapabilities{Err: err}
129128
}
130129

131130
r.state = StateAfterCapabilities
@@ -158,7 +157,7 @@ func (r *Requester) negotiateAlgorithms(ctx context.Context) error {
158157

159158
var ar msgs.AlgorithmsResponse
160159
if err := ar.Unmarshal(resp); err != nil {
161-
return fmt.Errorf("unmarshal: %w", err)
160+
return &ErrUnmarshalResponse{Err: err}
162161
}
163162

164163
if err := r.validateAlgorithmsResponse(&ar); err != nil {
@@ -231,27 +230,38 @@ func (r *Requester) buildAlgStructs() []msgs.AlgStructTable {
231230
func (r *Requester) validateAlgorithmsResponse(ar *msgs.AlgorithmsResponse) error {
232231
sel := algo.BaseHashAlgo(ar.BaseHashSel)
233232
if sel == 0 {
234-
return fmt.Errorf("ALGORITHMS response has zero BaseHashSel: %w", status.ErrNegotiationFail)
233+
return &ErrAlgorithmsNegotiationFail{
234+
Reason: "ALGORITHMS response has zero BaseHashSel",
235+
Err: status.ErrNegotiationFail,
236+
}
235237
}
236238

237239
// Per DSP0274 Section 10.5, each selection field must have exactly one bit set.
238240
if !isSingleBit(ar.BaseHashSel) {
239-
return fmt.Errorf("ALGORITHMS response BaseHashSel has multiple bits: 0x%08X: %w",
240-
ar.BaseHashSel, status.ErrInvalidMsgField)
241+
return &ErrAlgorithmsNegotiationFail{
242+
Reason: fmt.Sprintf("ALGORITHMS response BaseHashSel has multiple bits: 0x%08X", ar.BaseHashSel),
243+
Err: status.ErrInvalidMsgField,
244+
}
241245
}
242246
if !isSingleBit(ar.BaseAsymSel) && ar.BaseAsymSel != 0 {
243-
return fmt.Errorf("ALGORITHMS response BaseAsymSel has multiple bits: 0x%08X: %w",
244-
ar.BaseAsymSel, status.ErrInvalidMsgField)
247+
return &ErrAlgorithmsNegotiationFail{
248+
Reason: fmt.Sprintf("ALGORITHMS response BaseAsymSel has multiple bits: 0x%08X", ar.BaseAsymSel),
249+
Err: status.ErrInvalidMsgField,
250+
}
245251
}
246252

247253
// Per DSP0274 Section 10.5, selected algorithms must be a subset of what was requested.
248254
if ar.BaseHashSel&uint32(r.cfg.BaseHashAlgo) != ar.BaseHashSel {
249-
return fmt.Errorf("ALGORITHMS BaseHashSel 0x%08X not subset of requested 0x%08X: %w",
250-
ar.BaseHashSel, uint32(r.cfg.BaseHashAlgo), status.ErrInvalidMsgField)
255+
return &ErrAlgorithmsNegotiationFail{
256+
Reason: fmt.Sprintf("ALGORITHMS BaseHashSel 0x%08X not subset of requested 0x%08X", ar.BaseHashSel, uint32(r.cfg.BaseHashAlgo)),
257+
Err: status.ErrInvalidMsgField,
258+
}
251259
}
252260
if ar.BaseAsymSel != 0 && ar.BaseAsymSel&uint32(r.cfg.BaseAsymAlgo) != ar.BaseAsymSel {
253-
return fmt.Errorf("ALGORITHMS BaseAsymSel 0x%08X not subset of requested 0x%08X: %w",
254-
ar.BaseAsymSel, uint32(r.cfg.BaseAsymAlgo), status.ErrInvalidMsgField)
261+
return &ErrAlgorithmsNegotiationFail{
262+
Reason: fmt.Sprintf("ALGORITHMS BaseAsymSel 0x%08X not subset of requested 0x%08X", ar.BaseAsymSel, uint32(r.cfg.BaseAsymAlgo)),
263+
Err: status.ErrInvalidMsgField,
264+
}
255265
}
256266

257267
r.conn.HashAlgo = sel

0 commit comments

Comments
 (0)