@@ -70,7 +70,10 @@ type validation struct {
70
70
// mx protects the validator map
71
71
mx sync.Mutex
72
72
// topicVals tracks per topic validators
73
- topicVals map [string ]* topicVal
73
+ topicVals map [string ]* validatorImpl
74
+
75
+ // defaultVals tracks default validators applicable to all topics
76
+ defaultVals []* validatorImpl
74
77
75
78
// validateQ is the front-end to the validation pipeline
76
79
validateQ chan * validateReq
@@ -84,13 +87,13 @@ type validation struct {
84
87
85
88
// validation requests
86
89
type validateReq struct {
87
- vals []* topicVal
90
+ vals []* validatorImpl
88
91
src peer.ID
89
92
msg * Message
90
93
}
91
94
92
95
// representation of topic validators
93
- type topicVal struct {
96
+ type validatorImpl struct {
94
97
topic string
95
98
validate ValidatorEx
96
99
validateTimeout time.Duration
@@ -117,7 +120,7 @@ type rmValReq struct {
117
120
// newValidation creates a new validation pipeline
118
121
func newValidation () * validation {
119
122
return & validation {
120
- topicVals : make (map [string ]* topicVal ),
123
+ topicVals : make (map [string ]* validatorImpl ),
121
124
validateQ : make (chan * validateReq , defaultValidateQueueSize ),
122
125
validateThrottle : make (chan struct {}, defaultValidateThrottle ),
123
126
validateWorkers : runtime .NumCPU (),
@@ -136,17 +139,28 @@ func (v *validation) Start(p *PubSub) {
136
139
137
140
// AddValidator adds a new validator
138
141
func (v * validation ) AddValidator (req * addValReq ) {
142
+ val , err := v .makeValidator (req )
143
+ if err != nil {
144
+ req .resp <- err
145
+ return
146
+ }
147
+
139
148
v .mx .Lock ()
140
149
defer v .mx .Unlock ()
141
150
142
- topic := req .topic
151
+ topic := val .topic
143
152
144
153
_ , ok := v .topicVals [topic ]
145
154
if ok {
146
155
req .resp <- fmt .Errorf ("duplicate validator for topic %s" , topic )
147
156
return
148
157
}
149
158
159
+ v .topicVals [topic ] = val
160
+ req .resp <- nil
161
+ }
162
+
163
+ func (v * validation ) makeValidator (req * addValReq ) (* validatorImpl , error ) {
150
164
makeValidatorEx := func (v Validator ) ValidatorEx {
151
165
return func (ctx context.Context , p peer.ID , msg * Message ) ValidationResult {
152
166
if v (ctx , p , msg ) {
@@ -170,12 +184,15 @@ func (v *validation) AddValidator(req *addValReq) {
170
184
validator = v
171
185
172
186
default :
173
- req .resp <- fmt .Errorf ("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx" , topic )
174
- return
187
+ topic := req .topic
188
+ if req .topic == "" {
189
+ topic = "(default)"
190
+ }
191
+ return nil , fmt .Errorf ("unknown validator type for topic %s; must be an instance of Validator or ValidatorEx" , topic )
175
192
}
176
193
177
- val := & topicVal {
178
- topic : topic ,
194
+ val := & validatorImpl {
195
+ topic : req . topic ,
179
196
validate : validator ,
180
197
validateTimeout : 0 ,
181
198
validateThrottle : make (chan struct {}, defaultValidateConcurrency ),
@@ -190,8 +207,7 @@ func (v *validation) AddValidator(req *addValReq) {
190
207
val .validateThrottle = make (chan struct {}, req .throttle )
191
208
}
192
209
193
- v .topicVals [topic ] = val
194
- req .resp <- nil
210
+ return val , nil
195
211
}
196
212
197
213
// RemoveValidator removes an existing validator
@@ -244,18 +260,21 @@ func (v *validation) Push(src peer.ID, msg *Message) bool {
244
260
}
245
261
246
262
// getValidators returns all validators that apply to a given message
247
- func (v * validation ) getValidators (msg * Message ) []* topicVal {
263
+ func (v * validation ) getValidators (msg * Message ) []* validatorImpl {
248
264
v .mx .Lock ()
249
265
defer v .mx .Unlock ()
250
266
267
+ var vals []* validatorImpl
268
+ vals = append (vals , v .defaultVals ... )
269
+
251
270
topic := msg .GetTopic ()
252
271
253
272
val , ok := v .topicVals [topic ]
254
273
if ! ok {
255
- return nil
274
+ return vals
256
275
}
257
276
258
- return [] * topicVal { val }
277
+ return append ( vals , val )
259
278
}
260
279
261
280
// validateWorker is an active goroutine performing inline validation
@@ -271,7 +290,7 @@ func (v *validation) validateWorker() {
271
290
}
272
291
273
292
// validate performs validation and only sends the message if all validators succeed
274
- func (v * validation ) validate (vals []* topicVal , src peer.ID , msg * Message , synchronous bool ) error {
293
+ func (v * validation ) validate (vals []* validatorImpl , src peer.ID , msg * Message , synchronous bool ) error {
275
294
// If signature verification is enabled, but signing is disabled,
276
295
// the Signature is required to be nil upon receiving the message in PubSub.pushMsg.
277
296
if msg .Signature != nil {
@@ -292,7 +311,7 @@ func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message, synch
292
311
v .tracer .ValidateMessage (msg )
293
312
}
294
313
295
- var inline , async []* topicVal
314
+ var inline , async []* validatorImpl
296
315
for _ , val := range vals {
297
316
if val .validateInline || synchronous {
298
317
inline = append (inline , val )
@@ -360,7 +379,7 @@ func (v *validation) validateSignature(msg *Message) bool {
360
379
return true
361
380
}
362
381
363
- func (v * validation ) doValidateTopic (vals []* topicVal , src peer.ID , msg * Message , r ValidationResult ) {
382
+ func (v * validation ) doValidateTopic (vals []* validatorImpl , src peer.ID , msg * Message , r ValidationResult ) {
364
383
result := v .validateTopic (vals , src , msg )
365
384
366
385
if result == ValidationAccept && r != ValidationAccept {
@@ -388,7 +407,7 @@ func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message
388
407
}
389
408
}
390
409
391
- func (v * validation ) validateTopic (vals []* topicVal , src peer.ID , msg * Message ) ValidationResult {
410
+ func (v * validation ) validateTopic (vals []* validatorImpl , src peer.ID , msg * Message ) ValidationResult {
392
411
if len (vals ) == 1 {
393
412
return v .validateSingleTopic (vals [0 ], src , msg )
394
413
}
@@ -404,7 +423,7 @@ func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message)
404
423
405
424
select {
406
425
case val .validateThrottle <- struct {}{}:
407
- go func (val * topicVal ) {
426
+ go func (val * validatorImpl ) {
408
427
rch <- val .validateMsg (ctx , src , msg )
409
428
<- val .validateThrottle
410
429
}(val )
@@ -438,7 +457,7 @@ loop:
438
457
}
439
458
440
459
// fast path for single topic validation that avoids the extra goroutine
441
- func (v * validation ) validateSingleTopic (val * topicVal , src peer.ID , msg * Message ) ValidationResult {
460
+ func (v * validation ) validateSingleTopic (val * validatorImpl , src peer.ID , msg * Message ) ValidationResult {
442
461
select {
443
462
case val .validateThrottle <- struct {}{}:
444
463
res := val .validateMsg (v .p .ctx , src , msg )
@@ -451,7 +470,7 @@ func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Messag
451
470
}
452
471
}
453
472
454
- func (val * topicVal ) validateMsg (ctx context.Context , src peer.ID , msg * Message ) ValidationResult {
473
+ func (val * validatorImpl ) validateMsg (ctx context.Context , src peer.ID , msg * Message ) ValidationResult {
455
474
start := time .Now ()
456
475
defer func () {
457
476
log .Debugf ("validation done; took %s" , time .Since (start ))
@@ -479,6 +498,31 @@ func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message)
479
498
}
480
499
481
500
// / Options
501
+ // WithDefaultValidator adds a validator that applies to all topics by default; it can be used
502
+ // more than once and add multiple validators. Having a defult validator does not inhibit registering
503
+ // a per topic validator.
504
+ func WithDefaultValidator (val interface {}, opts ... ValidatorOpt ) Option {
505
+ return func (ps * PubSub ) error {
506
+ addVal := & addValReq {
507
+ validate : val ,
508
+ }
509
+
510
+ for _ , opt := range opts {
511
+ err := opt (addVal )
512
+ if err != nil {
513
+ return err
514
+ }
515
+ }
516
+
517
+ val , err := ps .val .makeValidator (addVal )
518
+ if err != nil {
519
+ return err
520
+ }
521
+
522
+ ps .val .defaultVals = append (ps .val .defaultVals , val )
523
+ return nil
524
+ }
525
+ }
482
526
483
527
// WithValidateQueueSize sets the buffer of validate queue. Defaults to 32.
484
528
// When queue is full, validation is throttled and new messages are dropped.
0 commit comments