@@ -404,6 +404,117 @@ def test_create_unix_connection_6(self):
404404 lambda : None , path = '/tmp/a' ,
405405 ssl_handshake_timeout = SSL_HANDSHAKE_TIMEOUT ))
406406
407+ def test_create_unix_connection_sock_cancel_detaches (self ):
408+ async def test ():
409+ srv_path = os .path .join (tempfile .mkdtemp (), 'test.sock' )
410+ srv = await asyncio .start_unix_server (
411+ lambda r , w : w .close (), path = srv_path )
412+
413+ sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
414+ sock .setblocking (False )
415+ try :
416+ sock .connect (srv_path )
417+ except BlockingIOError :
418+ pass
419+ await asyncio .sleep (0.01 )
420+
421+ task = asyncio .ensure_future (
422+ self .loop .create_unix_connection (
423+ asyncio .Protocol , sock = sock ))
424+ await asyncio .sleep (0 )
425+ task .cancel ()
426+ with self .assertRaises (asyncio .CancelledError ):
427+ await task
428+
429+ self .assertEqual (sock .fileno (), - 1 )
430+
431+ srv .close ()
432+ await srv .wait_closed ()
433+ if os .path .exists (srv_path ):
434+ os .unlink (srv_path )
435+
436+ self .loop .run_until_complete (test ())
437+
438+ def test_create_unix_connection_sock_cancel_fd_leak (self ):
439+ # Same as test_create_connection_sock_cancel_fd_leak but for
440+ # the create_unix_connection(sock=) path.
441+
442+ async def test ():
443+ srv_path = os .path .join (tempfile .mkdtemp (), 'test.sock' )
444+ srv = await asyncio .start_unix_server (
445+ lambda r , w : w .close (), path = srv_path )
446+
447+ sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
448+ sock .setblocking (False )
449+ await self .loop .sock_connect (sock , srv_path )
450+ stale_fd = sock .fileno ()
451+
452+ task = self .loop .create_task (
453+ self .loop .create_unix_connection (
454+ asyncio .Protocol , sock = sock ))
455+ await asyncio .sleep (0 )
456+ task .cancel ()
457+ with self .assertRaises (asyncio .CancelledError ):
458+ await task
459+
460+ # Create victim that reuses the fd.
461+ victim_sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
462+ victim_sock .setblocking (False )
463+ await self .loop .sock_connect (victim_sock , srv_path )
464+ victim_tr , _ = await self .loop .create_unix_connection (
465+ asyncio .Protocol , sock = victim_sock )
466+ victim_fd = victim_tr .get_extra_info ('socket' ).fileno ()
467+ if victim_fd != stale_fd :
468+ victim_tr .close ()
469+ sock .close ()
470+ srv .close ()
471+ await srv .wait_closed ()
472+ if os .path .exists (srv_path ):
473+ os .unlink (srv_path )
474+ raise unittest .SkipTest (
475+ f'fd not reused (got { victim_fd } , need { stale_fd } )' )
476+
477+ spy_a , spy_b = socket .socketpair ()
478+ spy_b .setblocking (False )
479+
480+ sock .close ()
481+
482+ victim_broken = False
483+ try :
484+ os .fstat (victim_fd )
485+ except OSError :
486+ victim_broken = True
487+
488+ if victim_broken :
489+ os .dup2 (spy_a .fileno (), stale_fd )
490+ spy_a .close ()
491+
492+ victim_tr .write (b'LEAKED' )
493+
494+ try :
495+ leaked = spy_b .recv (4096 )
496+ except BlockingIOError :
497+ leaked = b''
498+
499+ if victim_broken :
500+ os .close (stale_fd )
501+ spy_b .close ()
502+ victim_tr .close ()
503+ # Let pending callbacks (e.g. server-side connection_lost
504+ # from the cancelled connection) run before closing the
505+ # server, to avoid triggering call_exception_handler().
506+ await asyncio .sleep (0 )
507+ srv .close ()
508+ await srv .wait_closed ()
509+ if os .path .exists (srv_path ):
510+ os .unlink (srv_path )
511+
512+ self .assertEqual (leaked , b'' ,
513+ f"Data leaked to an unrelated socket: "
514+ f"got { leaked !r} " )
515+
516+ self .loop .run_until_complete (test ())
517+
407518
408519class Test_UV_Unix (_TestUnix , tb .UVTestCase ):
409520
0 commit comments