@@ -30,7 +30,7 @@ use tokio::{
30
30
#[ cfg( target_os = "linux" ) ]
31
31
use tokio_vsock:: VsockListener ;
32
32
33
- use crate :: asynchronous:: unix_incoming:: UnixIncoming ;
33
+ use crate :: asynchronous:: { stream :: SendingMessage , unix_incoming:: UnixIncoming } ;
34
34
use crate :: common:: { self , Domain } ;
35
35
use crate :: context;
36
36
use crate :: error:: { get_status, Error , Result } ;
@@ -329,7 +329,7 @@ struct ServerWriter {
329
329
330
330
#[ async_trait]
331
331
impl WriterDelegate for ServerWriter {
332
- async fn recv ( & mut self ) -> Option < GenMessage > {
332
+ async fn recv ( & mut self ) -> Option < SendingMessage > {
333
333
self . rx . recv ( ) . await
334
334
}
335
335
async fn disconnect ( & self , _msg : & GenMessage , _: Error ) { }
@@ -371,12 +371,14 @@ impl ReaderDelegate for ServerReader {
371
371
async fn handle_msg ( & self , msg : GenMessage ) {
372
372
let handler_shutdown_waiter = self . handler_shutdown . subscribe ( ) ;
373
373
let context = self . context ( ) ;
374
+ let ( wait_tx, wait_rx) = tokio:: sync:: oneshot:: channel :: < ( ) > ( ) ;
374
375
spawn ( async move {
375
376
select ! {
376
- _ = context. handle_msg( msg) => { }
377
+ _ = context. handle_msg( msg, wait_tx ) => { }
377
378
_ = handler_shutdown_waiter. wait_shutdown( ) => { }
378
379
}
379
380
} ) ;
381
+ wait_rx. await . unwrap_or_default ( ) ;
380
382
}
381
383
}
382
384
@@ -402,7 +404,7 @@ struct HandlerContext {
402
404
}
403
405
404
406
impl HandlerContext {
405
- async fn handle_msg ( & self , msg : GenMessage ) {
407
+ async fn handle_msg ( & self , msg : GenMessage , wait_tx : tokio :: sync :: oneshot :: Sender < ( ) > ) {
406
408
let stream_id = msg. header . stream_id ;
407
409
408
410
if ( stream_id % 2 ) != 1 {
@@ -416,7 +418,7 @@ impl HandlerContext {
416
418
}
417
419
418
420
match msg. header . type_ {
419
- MESSAGE_TYPE_REQUEST => match self . handle_request ( msg) . await {
421
+ MESSAGE_TYPE_REQUEST => match self . handle_request ( msg, wait_tx ) . await {
420
422
Ok ( opt_msg) => match opt_msg {
421
423
Some ( msg) => {
422
424
Self :: respond ( self . tx . clone ( ) , stream_id, msg)
@@ -435,7 +437,7 @@ impl HandlerContext {
435
437
} ;
436
438
437
439
self . tx
438
- . send ( msg)
440
+ . send ( SendingMessage :: new ( msg) )
439
441
. await
440
442
. map_err ( err_to_others_err ! ( e, "Send packet to sender error " ) )
441
443
. ok ( ) ;
@@ -444,6 +446,8 @@ impl HandlerContext {
444
446
Err ( status) => Self :: respond_with_status ( self . tx . clone ( ) , stream_id, status) . await ,
445
447
} ,
446
448
MESSAGE_TYPE_DATA => {
449
+ // no need to wait data message handling
450
+ drop ( wait_tx) ;
447
451
// TODO(wllenyj): Compatible with golang behavior.
448
452
if ( msg. header . flags & FLAG_REMOTE_CLOSED ) == FLAG_REMOTE_CLOSED
449
453
&& !msg. payload . is_empty ( )
@@ -492,7 +496,11 @@ impl HandlerContext {
492
496
}
493
497
}
494
498
495
- async fn handle_request ( & self , msg : GenMessage ) -> StdResult < Option < Response > , Status > {
499
+ async fn handle_request (
500
+ & self ,
501
+ msg : GenMessage ,
502
+ wait_tx : tokio:: sync:: oneshot:: Sender < ( ) > ,
503
+ ) -> StdResult < Option < Response > , Status > {
496
504
//TODO:
497
505
//if header.stream_id <= self.last_stream_id {
498
506
// return Err;
@@ -513,10 +521,11 @@ impl HandlerContext {
513
521
} ) ?;
514
522
515
523
if let Some ( method) = srv. get_method ( & req. method ) {
524
+ drop ( wait_tx) ;
516
525
return self . handle_method ( method, req_msg) . await ;
517
526
}
518
527
if let Some ( stream) = srv. get_stream ( & req. method ) {
519
- return self . handle_stream ( stream, req_msg) . await ;
528
+ return self . handle_stream ( stream, req_msg, wait_tx ) . await ;
520
529
}
521
530
Err ( get_status (
522
531
Code :: UNIMPLEMENTED ,
@@ -572,6 +581,7 @@ impl HandlerContext {
572
581
& self ,
573
582
stream : Arc < dyn StreamHandler + Send + Sync > ,
574
583
req_msg : Message < Request > ,
584
+ wait_tx : tokio:: sync:: oneshot:: Sender < ( ) > ,
575
585
) -> StdResult < Option < Response > , Status > {
576
586
let stream_id = req_msg. header . stream_id ;
577
587
let req = req_msg. payload ;
@@ -583,6 +593,9 @@ impl HandlerContext {
583
593
584
594
let _remote_close = ( req_msg. header . flags & FLAG_REMOTE_CLOSED ) == FLAG_REMOTE_CLOSED ;
585
595
let _remote_open = ( req_msg. header . flags & FLAG_REMOTE_OPEN ) == FLAG_REMOTE_OPEN ;
596
+
597
+ drop ( wait_tx) ;
598
+
586
599
let si = StreamInner :: new (
587
600
stream_id,
588
601
self . tx . clone ( ) ,
@@ -631,7 +644,7 @@ impl HandlerContext {
631
644
header : MessageHeader :: new_response ( stream_id, payload. len ( ) as u32 ) ,
632
645
payload,
633
646
} ;
634
- tx. send ( msg)
647
+ tx. send ( SendingMessage :: new ( msg) )
635
648
. await
636
649
. map_err ( err_to_others_err ! ( e, "Send packet to sender error " ) )
637
650
}
0 commit comments