@@ -189,6 +189,52 @@ async def _test_ops(self, client, *ops):
189189 f"{ f .__name__ } did not return implicit session to pool" ,
190190 )
191191
192+ # Explicit bound session
193+ for f , args , kw in ops :
194+ async with client .start_session () as s :
195+ async with s .bind ():
196+ listener .reset ()
197+ s ._materialize ()
198+ last_use = s ._server_session .last_use
199+ start = time .monotonic ()
200+ self .assertLessEqual (last_use , start )
201+ # In case "f" modifies its inputs.
202+ args = copy .copy (args )
203+ kw = copy .copy (kw )
204+ await f (* args , ** kw )
205+ self .assertGreaterEqual (len (listener .started_events ), 1 )
206+ for event in listener .started_events :
207+ self .assertIn (
208+ "lsid" ,
209+ event .command ,
210+ f"{ f .__name__ } sent no lsid with { event .command_name } " ,
211+ )
212+
213+ self .assertEqual (
214+ s .session_id ,
215+ event .command ["lsid" ],
216+ f"{ f .__name__ } sent wrong lsid with { event .command_name } " ,
217+ )
218+
219+ self .assertFalse (s .has_ended )
220+
221+ self .assertTrue (s .has_ended )
222+ with self .assertRaisesRegex (InvalidOperation , "ended session" ):
223+ async with s .bind ():
224+ await f (* args , ** kw )
225+
226+ # Test a session cannot be used on another client.
227+ async with self .client2 .start_session () as s :
228+ async with s .bind ():
229+ # In case "f" modifies its inputs.
230+ args = copy .copy (args )
231+ kw = copy .copy (kw )
232+ with self .assertRaisesRegex (
233+ InvalidOperation ,
234+ "Only the client that created the bound session can perform operations within its context block" ,
235+ ):
236+ await f (* args , ** kw )
237+
192238 async def test_implicit_sessions_checkout (self ):
193239 # "To confirm that implicit sessions only allocate their server session after a
194240 # successful connection checkout" test from Driver Sessions Spec.
@@ -825,6 +871,73 @@ async def test_session_not_copyable(self):
825871 async with client .start_session () as s :
826872 self .assertRaises (TypeError , lambda : copy .copy (s ))
827873
874+ async def test_nested_session_binding (self ):
875+ coll = self .client .pymongo_test .test
876+ await coll .insert_one ({"x" : 1 })
877+
878+ session1 = self .client .start_session ()
879+ session2 = self .client .start_session ()
880+ session1 ._materialize ()
881+ session2 ._materialize ()
882+ try :
883+ self .listener .reset ()
884+ # Uses implicit session
885+ await coll .find_one ()
886+ implicit_lsid = self .listener .started_events [0 ].command .get ("lsid" )
887+ self .assertIsNotNone (implicit_lsid )
888+ self .assertNotEqual (implicit_lsid , session1 .session_id )
889+ self .assertNotEqual (implicit_lsid , session2 .session_id )
890+
891+ async with session1 .bind (end_session = False ):
892+ self .listener .reset ()
893+ # Uses bound session1
894+ await coll .find_one ()
895+ session1_lsid = self .listener .started_events [0 ].command .get ("lsid" )
896+ self .assertEqual (session1_lsid , session1 .session_id )
897+
898+ async with session2 .bind (end_session = False ):
899+ self .listener .reset ()
900+ # Uses bound session2
901+ await coll .find_one ()
902+ session2_lsid = self .listener .started_events [0 ].command .get ("lsid" )
903+ self .assertEqual (session2_lsid , session2 .session_id )
904+ self .assertNotEqual (session2_lsid , session1 .session_id )
905+
906+ self .listener .reset ()
907+ # Use bound session1 again
908+ await coll .find_one ()
909+ session1_lsid = self .listener .started_events [0 ].command .get ("lsid" )
910+ self .assertEqual (session1_lsid , session1 .session_id )
911+ self .assertNotEqual (session1_lsid , session2 .session_id )
912+
913+ self .listener .reset ()
914+ # Uses implicit session
915+ await coll .find_one ()
916+ implicit_lsid = self .listener .started_events [0 ].command .get ("lsid" )
917+ self .assertIsNotNone (implicit_lsid )
918+ self .assertNotEqual (implicit_lsid , session1 .session_id )
919+ self .assertNotEqual (implicit_lsid , session2 .session_id )
920+
921+ finally :
922+ await session1 .end_session ()
923+ await session2 .end_session ()
924+
925+ async def test_session_binding_end_session (self ):
926+ coll = self .client .pymongo_test .test
927+ await coll .insert_one ({"x" : 1 })
928+
929+ async with self .client .start_session ().bind () as s1 :
930+ await coll .find_one ()
931+
932+ self .assertTrue (s1 .has_ended )
933+
934+ async with self .client .start_session ().bind (end_session = False ) as s2 :
935+ await coll .find_one ()
936+
937+ self .assertFalse (s2 .has_ended )
938+
939+ await s2 .end_session ()
940+
828941
829942class TestCausalConsistency (AsyncUnitTest ):
830943 listener : SessionTestListener
0 commit comments