Skip to content

Commit f533157

Browse files
authored
Python 4542 - Improved sessions API (#2712)
1 parent e028fe2 commit f533157

File tree

8 files changed

+364
-14
lines changed

8 files changed

+364
-14
lines changed

doc/changelog.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Changelog
22
=========
33

4+
Changes in Version 4.17.0 (2026/XX/XX)
5+
--------------------------------------
6+
7+
PyMongo 4.17 brings a number of changes including:
8+
9+
- Added the :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.bind` and :meth:`~pymongo.client_session.ClientSession.bind` methods
10+
that allow users to bind a session to all database operations within the scope of a context manager instead of having to explicitly pass the session to each individual operation.
11+
See <PLACEHOLDER> for examples and more information.
12+
413
Changes in Version 4.16.0 (2026/01/07)
514
--------------------------------------
615

pymongo/asynchronous/client_session.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142+
from contextvars import ContextVar, Token
142143
from typing import (
143144
TYPE_CHECKING,
144145
Any,
@@ -181,6 +182,28 @@
181182

182183
_IS_SYNC = False
183184

185+
_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
186+
187+
188+
class _AsyncBoundSessionContext:
189+
"""Context manager returned by AsyncClientSession.bind() that manages bound state."""
190+
191+
def __init__(self, session: AsyncClientSession, end_session: bool) -> None:
192+
self._session = session
193+
self._session_token: Optional[Token[AsyncClientSession]] = None
194+
self._end_session = end_session
195+
196+
async def __aenter__(self) -> AsyncClientSession:
197+
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
198+
return self._session
199+
200+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
201+
if self._session_token:
202+
_SESSION.reset(self._session_token) # type: ignore[arg-type]
203+
self._session_token = None
204+
if self._end_session:
205+
await self._session.end_session()
206+
184207

185208
class SessionOptions:
186209
"""Options for a new :class:`AsyncClientSession`.
@@ -547,6 +570,24 @@ def _check_ended(self) -> None:
547570
if self._server_session is None:
548571
raise InvalidOperation("Cannot use ended session")
549572

573+
def bind(self, end_session: bool = True) -> _AsyncBoundSessionContext:
574+
"""Bind this session so it is implicitly passed to all database operations within the returned context.
575+
576+
.. code-block:: python
577+
578+
async with client.start_session() as s:
579+
async with s.bind():
580+
# session=s is passed implicitly
581+
await client.db.collection.insert_one({"x": 1})
582+
583+
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
584+
If set to False, :meth:`~pymongo.asynchronous.client_session.AsyncClientSession.end_session()` must be called
585+
once the session is no longer used.
586+
587+
.. versionadded:: 4.17
588+
"""
589+
return _AsyncBoundSessionContext(self, end_session)
590+
550591
async def __aenter__(self) -> AsyncClientSession:
551592
return self
552593

pymongo/asynchronous/mongo_client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from pymongo.asynchronous import client_session, database, uri_parser
6666
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
6767
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
68-
from pymongo.asynchronous.client_session import _EmptyServerSession
68+
from pymongo.asynchronous.client_session import _SESSION, _EmptyServerSession
6969
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
7070
from pymongo.asynchronous.settings import TopologySettings
7171
from pymongo.asynchronous.topology import Topology, _ErrorContext
@@ -1408,7 +1408,8 @@ def start_session(
14081408
def _ensure_session(
14091409
self, session: Optional[AsyncClientSession] = None
14101410
) -> Optional[AsyncClientSession]:
1411-
"""If provided session is None, lend a temporary session."""
1411+
"""If provided session and bound session are None, lend a temporary session."""
1412+
session = session or self._get_bound_session()
14121413
if session:
14131414
return session
14141415

@@ -2267,11 +2268,14 @@ async def _tmp_session(
22672268
self, session: Optional[client_session.AsyncClientSession]
22682269
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
22692270
"""If provided session is None, lend a temporary session."""
2270-
if session is not None:
2271-
if not isinstance(session, client_session.AsyncClientSession):
2272-
raise ValueError(
2273-
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
2274-
)
2271+
if session is not None and not isinstance(session, client_session.AsyncClientSession):
2272+
raise ValueError(
2273+
f"'session' argument must be an AsyncClientSession or None, not {type(session)}"
2274+
)
2275+
2276+
# Check for a bound session. If one exists, treat it as an explicitly passed session.
2277+
session = session or self._get_bound_session()
2278+
if session:
22752279
# Don't call end_session.
22762280
yield session
22772281
return
@@ -2301,6 +2305,18 @@ async def _process_response(
23012305
if session is not None:
23022306
session._process_response(reply)
23032307

2308+
def _get_bound_session(self) -> Optional[AsyncClientSession]:
2309+
bound_session = _SESSION.get()
2310+
if bound_session:
2311+
if bound_session.client is self:
2312+
return bound_session
2313+
else:
2314+
raise InvalidOperation(
2315+
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
2316+
)
2317+
else:
2318+
return None
2319+
23042320
async def server_info(
23052321
self, session: Optional[client_session.AsyncClientSession] = None
23062322
) -> dict[str, Any]:

pymongo/synchronous/client_session.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142+
from contextvars import ContextVar, Token
142143
from typing import (
143144
TYPE_CHECKING,
144145
Any,
@@ -180,6 +181,28 @@
180181

181182
_IS_SYNC = True
182183

184+
_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)
185+
186+
187+
class _BoundSessionContext:
188+
"""Context manager returned by ClientSession.bind() that manages bound state."""
189+
190+
def __init__(self, session: ClientSession, end_session: bool) -> None:
191+
self._session = session
192+
self._session_token: Optional[Token[ClientSession]] = None
193+
self._end_session = end_session
194+
195+
def __enter__(self) -> ClientSession:
196+
self._session_token = _SESSION.set(self._session) # type: ignore[assignment]
197+
return self._session
198+
199+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
200+
if self._session_token:
201+
_SESSION.reset(self._session_token) # type: ignore[arg-type]
202+
self._session_token = None
203+
if self._end_session:
204+
self._session.end_session()
205+
183206

184207
class SessionOptions:
185208
"""Options for a new :class:`ClientSession`.
@@ -546,6 +569,24 @@ def _check_ended(self) -> None:
546569
if self._server_session is None:
547570
raise InvalidOperation("Cannot use ended session")
548571

572+
def bind(self, end_session: bool = True) -> _BoundSessionContext:
573+
"""Bind this session so it is implicitly passed to all database operations within the returned context.
574+
575+
.. code-block:: python
576+
577+
with client.start_session() as s:
578+
with s.bind():
579+
# session=s is passed implicitly
580+
client.db.collection.insert_one({"x": 1})
581+
582+
:param end_session: Whether to end the session on exiting the returned context. Defaults to True.
583+
If set to False, :meth:`~pymongo.client_session.ClientSession.end_session()` must be called
584+
once the session is no longer used.
585+
586+
.. versionadded:: 4.17
587+
"""
588+
return _BoundSessionContext(self, end_session)
589+
549590
def __enter__(self) -> ClientSession:
550591
return self
551592

pymongo/synchronous/mongo_client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
from pymongo.synchronous import client_session, database, uri_parser
109109
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
110110
from pymongo.synchronous.client_bulk import _ClientBulk
111-
from pymongo.synchronous.client_session import _EmptyServerSession
111+
from pymongo.synchronous.client_session import _SESSION, _EmptyServerSession
112112
from pymongo.synchronous.command_cursor import CommandCursor
113113
from pymongo.synchronous.settings import TopologySettings
114114
from pymongo.synchronous.topology import Topology, _ErrorContext
@@ -1406,7 +1406,8 @@ def start_session(
14061406
)
14071407

14081408
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
1409-
"""If provided session is None, lend a temporary session."""
1409+
"""If provided session and bound session are None, lend a temporary session."""
1410+
session = session or self._get_bound_session()
14101411
if session:
14111412
return session
14121413

@@ -2263,11 +2264,14 @@ def _tmp_session(
22632264
self, session: Optional[client_session.ClientSession]
22642265
) -> Generator[Optional[client_session.ClientSession], None]:
22652266
"""If provided session is None, lend a temporary session."""
2266-
if session is not None:
2267-
if not isinstance(session, client_session.ClientSession):
2268-
raise ValueError(
2269-
f"'session' argument must be a ClientSession or None, not {type(session)}"
2270-
)
2267+
if session is not None and not isinstance(session, client_session.ClientSession):
2268+
raise ValueError(
2269+
f"'session' argument must be a ClientSession or None, not {type(session)}"
2270+
)
2271+
2272+
# Check for a bound session. If one exists, treat it as an explicitly passed session.
2273+
session = session or self._get_bound_session()
2274+
if session:
22712275
# Don't call end_session.
22722276
yield session
22732277
return
@@ -2295,6 +2299,18 @@ def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSe
22952299
if session is not None:
22962300
session._process_response(reply)
22972301

2302+
def _get_bound_session(self) -> Optional[ClientSession]:
2303+
bound_session = _SESSION.get()
2304+
if bound_session:
2305+
if bound_session.client is self:
2306+
return bound_session
2307+
else:
2308+
raise InvalidOperation(
2309+
"Only the client that created the bound session can perform operations within its context block. See <PLACEHOLDER> for more information."
2310+
)
2311+
else:
2312+
return None
2313+
22982314
def server_info(self, session: Optional[client_session.ClientSession] = None) -> dict[str, Any]:
22992315
"""Get information about the MongoDB server we're connected to.
23002316

test/asynchronous/test_session.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,52 @@ async def _test_ops(self, client, *ops):
189189
f"{f.__name__} did not return implicit session to pool",
190190
)
191191

192+
# Explicit bound session
193+
for f, args, kw in ops:
194+
async with client.start_session() as s:
195+
async with s.bind():
196+
listener.reset()
197+
s._materialize()
198+
last_use = s._server_session.last_use
199+
start = time.monotonic()
200+
self.assertLessEqual(last_use, start)
201+
# In case "f" modifies its inputs.
202+
args = copy.copy(args)
203+
kw = copy.copy(kw)
204+
await f(*args, **kw)
205+
self.assertGreaterEqual(len(listener.started_events), 1)
206+
for event in listener.started_events:
207+
self.assertIn(
208+
"lsid",
209+
event.command,
210+
f"{f.__name__} sent no lsid with {event.command_name}",
211+
)
212+
213+
self.assertEqual(
214+
s.session_id,
215+
event.command["lsid"],
216+
f"{f.__name__} sent wrong lsid with {event.command_name}",
217+
)
218+
219+
self.assertFalse(s.has_ended)
220+
221+
self.assertTrue(s.has_ended)
222+
with self.assertRaisesRegex(InvalidOperation, "ended session"):
223+
async with s.bind():
224+
await f(*args, **kw)
225+
226+
# Test a session cannot be used on another client.
227+
async with self.client2.start_session() as s:
228+
async with s.bind():
229+
# In case "f" modifies its inputs.
230+
args = copy.copy(args)
231+
kw = copy.copy(kw)
232+
with self.assertRaisesRegex(
233+
InvalidOperation,
234+
"Only the client that created the bound session can perform operations within its context block",
235+
):
236+
await f(*args, **kw)
237+
192238
async def test_implicit_sessions_checkout(self):
193239
# "To confirm that implicit sessions only allocate their server session after a
194240
# successful connection checkout" test from Driver Sessions Spec.
@@ -825,6 +871,73 @@ async def test_session_not_copyable(self):
825871
async with client.start_session() as s:
826872
self.assertRaises(TypeError, lambda: copy.copy(s))
827873

874+
async def test_nested_session_binding(self):
875+
coll = self.client.pymongo_test.test
876+
await coll.insert_one({"x": 1})
877+
878+
session1 = self.client.start_session()
879+
session2 = self.client.start_session()
880+
session1._materialize()
881+
session2._materialize()
882+
try:
883+
self.listener.reset()
884+
# Uses implicit session
885+
await coll.find_one()
886+
implicit_lsid = self.listener.started_events[0].command.get("lsid")
887+
self.assertIsNotNone(implicit_lsid)
888+
self.assertNotEqual(implicit_lsid, session1.session_id)
889+
self.assertNotEqual(implicit_lsid, session2.session_id)
890+
891+
async with session1.bind(end_session=False):
892+
self.listener.reset()
893+
# Uses bound session1
894+
await coll.find_one()
895+
session1_lsid = self.listener.started_events[0].command.get("lsid")
896+
self.assertEqual(session1_lsid, session1.session_id)
897+
898+
async with session2.bind(end_session=False):
899+
self.listener.reset()
900+
# Uses bound session2
901+
await coll.find_one()
902+
session2_lsid = self.listener.started_events[0].command.get("lsid")
903+
self.assertEqual(session2_lsid, session2.session_id)
904+
self.assertNotEqual(session2_lsid, session1.session_id)
905+
906+
self.listener.reset()
907+
# Use bound session1 again
908+
await coll.find_one()
909+
session1_lsid = self.listener.started_events[0].command.get("lsid")
910+
self.assertEqual(session1_lsid, session1.session_id)
911+
self.assertNotEqual(session1_lsid, session2.session_id)
912+
913+
self.listener.reset()
914+
# Uses implicit session
915+
await coll.find_one()
916+
implicit_lsid = self.listener.started_events[0].command.get("lsid")
917+
self.assertIsNotNone(implicit_lsid)
918+
self.assertNotEqual(implicit_lsid, session1.session_id)
919+
self.assertNotEqual(implicit_lsid, session2.session_id)
920+
921+
finally:
922+
await session1.end_session()
923+
await session2.end_session()
924+
925+
async def test_session_binding_end_session(self):
926+
coll = self.client.pymongo_test.test
927+
await coll.insert_one({"x": 1})
928+
929+
async with self.client.start_session().bind() as s1:
930+
await coll.find_one()
931+
932+
self.assertTrue(s1.has_ended)
933+
934+
async with self.client.start_session().bind(end_session=False) as s2:
935+
await coll.find_one()
936+
937+
self.assertFalse(s2.has_ended)
938+
939+
await s2.end_session()
940+
828941

829942
class TestCausalConsistency(AsyncUnitTest):
830943
listener: SessionTestListener

0 commit comments

Comments
 (0)