Skip to content
7 changes: 5 additions & 2 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ def set_session_token_header(
path: str,
request_object: "RequestObject",
options: Mapping[str, Any],
partition_key_range_id: Optional[str] = None) -> None:
partition_key_range_id: Optional[str] = None,
**kwargs
) -> None:
# set session token if required
if _is_session_token_request(cosmos_client_connection, headers, request_object):
# if there is a token set via option, then use it to override default
Expand All @@ -367,7 +369,8 @@ def set_session_token_header(
cosmos_client_connection._container_properties_cache,
cosmos_client_connection._routing_map_provider,
partition_key_range_id,
options))
options,
**kwargs))
if session_token != "":
headers[http_constants.HttpHeaders.SessionToken] = session_token

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,17 @@ def populate_request_headers(
self,
routing_provider: SmartRoutingMapProvider,
request_headers: Dict[str, Any],
feed_options: Optional[Dict[str, Any]] = None) -> None:
feed_options: Optional[Dict[str, Any]] = None,
**kwargs: Any) -> None:
pass

@abstractmethod
async def populate_request_headers_async(
self,
async_routing_provider: AsyncSmartRoutingMapProvider,
request_headers: Dict[str, Any],
feed_options: Optional[Dict[str, Any]] = None) -> None:
feed_options: Optional[Dict[str, Any]] = None,
**kwargs: Any) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -152,7 +154,7 @@ def populate_request_headers(
self,
routing_provider: SmartRoutingMapProvider,
request_headers: Dict[str, Any],
feed_options: Optional[Dict[str, Any]] = None) -> None:
feed_options: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
request_headers[http_constants.HttpHeaders.AIM] = http_constants.HttpHeaders.IncrementalFeedHeaderValue

self._change_feed_start_from.populate_request_headers(request_headers)
Expand All @@ -163,7 +165,7 @@ async def populate_request_headers_async(
self,
async_routing_provider: AsyncSmartRoutingMapProvider,
request_headers: Dict[str, Any],
feed_options: Optional[Dict[str, Any]] = None) -> None: # pylint: disable=unused-argument
feed_options: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None: # pylint: disable=unused-argument

request_headers[http_constants.HttpHeaders.AIM] = http_constants.HttpHeaders.IncrementalFeedHeaderValue

Expand Down Expand Up @@ -282,15 +284,15 @@ def populate_request_headers(
self,
routing_provider: SmartRoutingMapProvider,
request_headers: Dict[str, Any],
feed_options = None) -> None:
feed_options = None, **kwargs) -> None:
self.set_start_from_request_headers(request_headers)

# based on the feed range to find the overlapping partition key range id
over_lapping_ranges = \
routing_provider.get_overlapping_ranges(
self._container_link,
[self._continuation.current_token.feed_range],
feed_options)
feed_options, **kwargs)

self.set_pk_range_id_request_headers(over_lapping_ranges, request_headers)

Expand All @@ -301,15 +303,15 @@ async def populate_request_headers_async(
self,
async_routing_provider: AsyncSmartRoutingMapProvider,
request_headers: Dict[str, Any],
feed_options: Optional[Dict[str, Any]] = None) -> None:
feed_options: Optional[Dict[str, Any]] = None, **kwargs) -> None:
self.set_start_from_request_headers(request_headers)

# based on the feed range to find the overlapping partition key range id
over_lapping_ranges = \
await async_routing_provider.get_overlapping_ranges(
self._container_link,
[self._continuation.current_token.feed_range],
feed_options)
feed_options, **kwargs)

self.set_pk_range_id_request_headers(over_lapping_ranges, request_headers)

Expand Down
24 changes: 13 additions & 11 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def read_items(
options: Optional[Mapping[str, Any]] = None,
*,
executor: Optional[ThreadPoolExecutor] = None,
response_hook: Optional[Callable[[Mapping[str, Any], List[Dict[str, Any]]], None]] = None,
**kwargs: Any
) -> CosmosList:
"""Reads many items.
Expand Down Expand Up @@ -1091,6 +1092,7 @@ def read_items(
partition_key_definition=partition_key_definition,
executor=executor,
max_concurrency=max_concurrency,
response_hook=response_hook,
**kwargs)
return helper.read_items()

Expand Down Expand Up @@ -2134,7 +2136,7 @@ def PatchItem(
headers,
options.get("partitionKey", None))
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites)
request_data = {}
if options.get("filterPredicate"):
Expand Down Expand Up @@ -2229,7 +2231,7 @@ def _Batch(
headers,
options.get("partitionKey", None))
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
return cast(
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
self.__Post(path, request_params, batch_operations, headers, **kwargs)
Expand Down Expand Up @@ -2761,7 +2763,7 @@ def Create(
headers,
options.get("partitionKey", None))
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites)
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
self.last_response_headers = last_response_headers
Expand Down Expand Up @@ -2812,7 +2814,7 @@ def Upsert(
headers,
options.get("partitionKey", None))
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites)
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
self.last_response_headers = last_response_headers
Expand Down Expand Up @@ -2861,7 +2863,7 @@ def Replace(
headers,
options.get("partitionKey", None))
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites)
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
self.last_response_headers = last_response_headers
Expand Down Expand Up @@ -2909,7 +2911,7 @@ def Read(
headers,
options.get("partitionKey", None))
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
# update session for request mutates data on server side
self._UpdateSessionIfRequired(headers, result, last_response_headers)
Expand Down Expand Up @@ -2955,7 +2957,7 @@ def DeleteResource(
documents._OperationType.Delete,
headers,
options.get("partitionKey", None))
base.set_session_token_header(self, headers, path, request_params, options)
base.set_session_token_header(self, headers, path, request_params, options, **kwargs)
request_params.set_retry_write(options, self.connection_policy.RetryNonIdempotentWrites)
request_params.set_excluded_location_from_options(options)
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
Expand Down Expand Up @@ -3209,14 +3211,14 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
options.get("partitionKey", None)
)
request_params.set_excluded_location_from_options(options)
base.set_session_token_header(self, headers, path, request_params, options, partition_key_range_id)
base.set_session_token_header(self, headers, path, request_params, options, partition_key_range_id, **kwargs)

change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState")
if change_feed_state is not None:
feed_options = {}
if 'excludedLocations' in options:
feed_options['excludedLocations'] = options['excludedLocations']
change_feed_state.populate_request_headers(self._routing_map_provider, headers, feed_options)
change_feed_state.populate_request_headers(self._routing_map_provider, headers, feed_options, **kwargs)
request_params.headers = headers

result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
Expand Down Expand Up @@ -3254,7 +3256,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
request_params.set_excluded_location_from_options(options)
if not is_query_plan:
req_headers[http_constants.HttpHeaders.IsQuery] = "true"
base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id)
base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id, **kwargs)

# Check if the over lapping ranges can be populated
feed_range_epk = None
Expand All @@ -3271,7 +3273,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
if feed_range_epk is not None:
last_response_headers = CaseInsensitiveDict()
over_lapping_ranges = self._routing_map_provider.get_overlapping_ranges(resource_id, [feed_range_epk],
options)
options, **kwargs)
# It is possible to get more than one over lapping range. We need to get the query results for each one
results: Dict[str, Any] = {}
# For each over lapping range we will take a sub range of the feed range EPK that overlaps with the over
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ def refresh_endpoint_list(self, database_account, **kwargs):
# if refresh is not needed or refresh is already taking place, return
if not self.refresh_needed:
return
try:
self._refresh_endpoint_list_private(database_account, **kwargs)
except Exception as e:
raise e
self._refresh_endpoint_list_private(database_account, **kwargs)

def _refresh_endpoint_list_private(self, database_account=None, **kwargs):
if database_account:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool:
return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request)


def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionKeyRangeWrapper]:
def create_pk_range_wrapper(self, request: RequestObject, **kwargs) -> Optional[PartitionKeyRangeWrapper]:
if HttpHeaders.IntendedCollectionRID in request.headers:
container_rid = request.headers[HttpHeaders.IntendedCollectionRID]
else:
Expand All @@ -73,12 +73,12 @@ def create_pk_range_wrapper(self, request: RequestObject) -> Optional[PartitionK
# get the partition key range for the given partition key
epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] # pylint: disable=protected-access
partition_ranges = (self.client._routing_map_provider # pylint: disable=protected-access
.get_overlapping_ranges(container_link, epk_range, options))
.get_overlapping_ranges(container_link, epk_range, options, **kwargs))
partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0])
elif HttpHeaders.PartitionKeyRangeID in request.headers:
pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID]
epk_range =(self.client._routing_map_provider # pylint: disable=protected-access
.get_range_by_partition_key_range_id(container_link, pk_range_id, options))
.get_range_by_partition_key_range_id(container_link, pk_range_id, options, **kwargs))
if not epk_range:
self.global_partition_endpoint_manager_core.log_warn_or_debug(
"Illegal state: partition key range cache not initialized correctly. "
Expand Down
11 changes: 4 additions & 7 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,10 @@ def get_regional_routing_contexts_by_loc(new_locations: List[Dict[str, str]]):
if not new_location["name"]:
# during fail-over the location name is empty
continue
try:
region_uri = new_location["databaseAccountEndpoint"]
parsed_locations.append(new_location["name"])
regional_object = RegionalRoutingContext(region_uri)
regional_routing_contexts_by_location.update({new_location["name"]: regional_object})
except Exception as e:
raise e
region_uri = new_location["databaseAccountEndpoint"]
parsed_locations.append(new_location["name"])
regional_object = RegionalRoutingContext(region_uri)
regional_routing_contexts_by_location.update({new_location["name"]: regional_object})

# Also store a hash map of endpoints for each location
locations_by_endpoints = {value.get_primary(): key for key, value in regional_routing_contexts_by_location.items()}
Expand Down
10 changes: 6 additions & 4 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_read_items_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple, Any, Optional, TYPE_CHECKING,Mapping
from typing import Dict, List, Tuple, Any, Optional, TYPE_CHECKING, Mapping, Callable

from azure.core.utils import CaseInsensitiveDict

Expand All @@ -49,6 +49,7 @@ def __init__(
*,
executor: Optional[ThreadPoolExecutor] = None,
max_concurrency: int = 10,
response_hook: Optional[Callable] = None,
**kwargs: Any
):
self.client = client
Expand All @@ -60,6 +61,7 @@ def __init__(
self.executor = executor
self.max_concurrency = max_concurrency
self.max_items_per_query = 1000
self.response_hook = response_hook

def read_items(self) -> CosmosList:
"""Reads many items synchronously using a query-based approach with a thread pool.
Expand Down Expand Up @@ -136,8 +138,8 @@ def _execute_with_executor(
cosmos_list = CosmosList(results, response_headers=final_headers)

# Call the original response hook with the final results if provided
if 'response_hook' in self.kwargs:
self.kwargs['response_hook'](final_headers, cosmos_list)
if self.response_hook:
self.response_hook(final_headers, cosmos_list)

return cosmos_list

Expand Down Expand Up @@ -167,7 +169,7 @@ def _partition_items_by_range(self) -> Dict[str, List[Tuple[int, str, "_Partitio
pk_value = pk_items[0][2]
epk_range = partition_key._get_epk_range_for_partition_key(pk_value)
overlapping_ranges = self.client._routing_map_provider.get_overlapping_ranges(
collection_rid, [epk_range], self.options
collection_rid, [epk_range], self.options, **self.kwargs
)
if overlapping_ranges:
range_id = overlapping_ranges[0]["id"]
Expand Down
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin
"""
pk_range_wrapper = None
if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]):
pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0])
pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0], **kwargs)
# instantiate all retry policies here to be applied for each request execution
endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy(
client.connection_policy, global_endpoint_manager, *args
Expand Down
Loading
Loading