diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 89cbc9fe88..9fa5123119 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -146,27 +146,12 @@ def _restart_on_unavailable( except ServiceUnavailable: del item_buffer[:] - with trace_call( - trace_name, - session, - attributes, - observability_options=observability_options, - metadata=metadata, - ) as span, MetricsCapture(): - request.resume_token = resume_token - if transaction is not None: - transaction_selector = transaction._build_transaction_selector_pb() - request.transaction = transaction_selector - attempt += 1 - iterator = method( - request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), - ) + request.resume_token = resume_token + if transaction is not None: + transaction_selector = transaction._build_transaction_selector_pb() + request.transaction = transaction_selector + attempt += 1 + iterator = None continue except InternalServerError as exc: @@ -177,27 +162,12 @@ def _restart_on_unavailable( if not resumable_error: raise del item_buffer[:] - with trace_call( - trace_name, - session, - attributes, - observability_options=observability_options, - metadata=metadata, - ) as span, MetricsCapture(): - request.resume_token = resume_token - if transaction is not None: - transaction_selector = transaction._build_transaction_selector_pb() - attempt += 1 - request.transaction = transaction_selector - iterator = method( - request=request, - metadata=request_id_manager.metadata_with_request_id( - nth_request, - attempt, - metadata, - span, - ), - ) + request.resume_token = resume_token + if transaction is not None: + transaction_selector = transaction._build_transaction_selector_pb() + attempt += 1 + request.transaction = transaction_selector + iterator = None continue if len(item_buffer) == 0: diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 974cc8e75e..f09bd06d1f 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -405,6 +405,56 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): self.assertEqual(request.resume_token, RESUME_TOKEN) self.assertNoSpans() + def test_iteration_w_raw_raising_unavailable_during_restart(self): + from google.api_core.exceptions import ServiceUnavailable + + FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN)) + LAST = (self._make_item(2),) + before = _MockIterator( + *FIRST, fail_after=True, error=ServiceUnavailable("testing") + ) + after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=["test", "resume_token"]) + # The second call (the first retry) raises ServiceUnavailable immediately. + # The third call (the second retry) succeeds. + restart = mock.Mock( + spec=[], + side_effect=[before, ServiceUnavailable("retry failed"), after], + ) + database = _Database() + database.spanner_api = build_spanner_api() + session = _Session(database) + derived = _build_snapshot_derived(session) + resumable = self._call_fut(derived, restart, request, session=session) + self.assertEqual(list(resumable), list(FIRST + LAST)) + self.assertEqual(len(restart.mock_calls), 3) + self.assertEqual(request.resume_token, RESUME_TOKEN) + self.assertNoSpans() + + def test_iteration_w_raw_raising_resumable_internal_error_during_restart(self): + FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN)) + LAST = (self._make_item(2),) + before = _MockIterator( + *FIRST, + fail_after=True, + error=INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, + ) + after = _MockIterator(*LAST) + request = mock.Mock(test="test", spec=["test", "resume_token"]) + restart = mock.Mock( + spec=[], + side_effect=[before, INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, after], + ) + database = _Database() + database.spanner_api = build_spanner_api() + session = _Session(database) + derived = _build_snapshot_derived(session) + resumable = self._call_fut(derived, restart, request, session=session) + self.assertEqual(list(resumable), list(FIRST + LAST)) + self.assertEqual(len(restart.mock_calls), 3) + self.assertEqual(request.resume_token, RESUME_TOKEN) + self.assertNoSpans() + def test_iteration_w_raw_w_multiuse(self): from google.cloud.spanner_v1 import ( ReadRequest,