1+ import contextvars
12from typing import Optional
23
34import html5lib
45from asgiref .local import Local
56from django .http import HttpResponse
6- from django .test import Client , RequestFactory , TestCase , TransactionTestCase
7+ from django .test import (
8+ AsyncClient ,
9+ AsyncRequestFactory ,
10+ Client ,
11+ RequestFactory ,
12+ TestCase ,
13+ TransactionTestCase ,
14+ )
715
816from debug_toolbar .panels import Panel
917from debug_toolbar .toolbar import DebugToolbar
1018
19+ data_contextvar = contextvars .ContextVar ("djdt_toolbar_test_client" )
20+
1121
1222class ToolbarTestClient (Client ):
1323 def request (self , ** request ):
@@ -29,11 +39,35 @@ def handle_toolbar_created(sender, toolbar=None, **kwargs):
2939 return response
3040
3141
42+ class AsyncToolbarTestClient (AsyncClient ):
43+ async def request (self , ** request ):
44+ # Use a thread/async task context-local variable to guard against a
45+ # concurrent _created signal from a different thread/task.
46+ # In cases testsuite will have both regular and async tests or
47+ # multiple async tests running in an eventloop making async_client calls.
48+ data_contextvar .set (None )
49+
50+ def handle_toolbar_created (sender , toolbar = None , ** kwargs ):
51+ data_contextvar .set (toolbar )
52+
53+ DebugToolbar ._created .connect (handle_toolbar_created )
54+ try :
55+ response = await super ().request (** request )
56+ finally :
57+ DebugToolbar ._created .disconnect (handle_toolbar_created )
58+ response .toolbar = data_contextvar .get ()
59+
60+ return response
61+
62+
3263rf = RequestFactory ()
64+ arf = AsyncRequestFactory ()
3365
3466
3567class BaseMixin :
68+ _is_async = False
3669 client_class = ToolbarTestClient
70+ async_client_class = AsyncToolbarTestClient
3771
3872 panel : Optional [Panel ] = None
3973 panel_id = None
@@ -42,7 +76,11 @@ def setUp(self):
4276 super ().setUp ()
4377 self ._get_response = lambda request : HttpResponse ()
4478 self .request = rf .get ("/" )
45- self .toolbar = DebugToolbar (self .request , self .get_response )
79+ if self ._is_async :
80+ self .request = arf .get ("/" )
81+ self .toolbar = DebugToolbar (self .request , self .get_response_async )
82+ else :
83+ self .toolbar = DebugToolbar (self .request , self .get_response )
4684 self .toolbar .stats = {}
4785
4886 if self .panel_id :
@@ -59,6 +97,9 @@ def tearDown(self):
5997 def get_response (self , request ):
6098 return self ._get_response (request )
6199
100+ async def get_response_async (self , request ):
101+ return self ._get_response (request )
102+
62103 def assertValidHTML (self , content ):
63104 parser = html5lib .HTMLParser ()
64105 parser .parseFragment (content )
0 commit comments