From 2b45ed9eefc2274ccc0167473d1e6ec356393aea Mon Sep 17 00:00:00 2001 From: Nadav Tzaysler Date: Mon, 21 Jul 2025 00:20:07 +0300 Subject: [PATCH 1/3] added healthcheck to the trino python client making it able to proccess queries that will take more than 15 minutes --- tests/unit/test_client.py | 146 ++++++++++++++++++++++++++++++++++++++ trino/client.py | 51 ++++++++++++- 2 files changed, 196 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 909823ee..f8bf319b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1447,3 +1447,149 @@ def delete_password(self, servicename, username): return None os.remove(file_path) + + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat is sent periodically and does not stop on success.""" + head_call_count = 0 + def fake_head(url, timeout=10): + nonlocal head_call_count + head_call_count += 1 + class Resp: + status_code = 200 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.1) + def finish_query(*args, **kwargs): + query._finished = True + return [] + query.fetch = finish_query + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.3) + query._stop_heartbeat() + assert head_call_count >= 2 + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops after 3 consecutive failures.""" + def fake_head(url, timeout=10): + class Resp: + status_code = 500 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.3) + assert not query._heartbeat_enabled + query._stop_heartbeat() + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops if server returns 404 or 405.""" + for code in (404, 405): + def fake_head(url, timeout=10, code=code): + class Resp: + status_code = code + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.2) + assert not query._heartbeat_enabled + query._stop_heartbeat() + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops when the query is finished.""" + head_call_count = 0 + def fake_head(url, timeout=10): + nonlocal head_call_count + head_call_count += 1 + class Resp: + status_code = 200 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.1) + query._finished = True + time.sleep(0.1) + query._stop_heartbeat() + # Heartbeat should have stopped after query finished + assert head_call_count >= 1 + +@mock.patch("trino.client.TrinoRequest.http") +def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data): + """Test that heartbeat stops when the query is cancelled.""" + head_call_count = 0 + def fake_head(url, timeout=10): + nonlocal head_call_count + head_call_count += 1 + class Resp: + status_code = 200 + return Resp() + mock_requests.head.side_effect = fake_head + mock_requests.Response.return_value.json.return_value = sample_post_response_data + mock_requests.get.return_value.json.return_value = sample_get_response_data + mock_requests.post.return_value.json.return_value = sample_post_response_data + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession(user="test"), + http_scheme="http", + ) + query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05) + query._next_uri = "http://coordinator/v1/statement/next" + query._row_mapper = mock.Mock(map=lambda x: []) + query._start_heartbeat() + time.sleep(0.1) + query._cancelled = True + time.sleep(0.1) + query._stop_heartbeat() + # Heartbeat should have stopped after query cancelled + assert head_call_count >= 1 diff --git a/trino/client.py b/trino/client.py index 7cc1f0f2..73200a72 100644 --- a/trino/client.py +++ b/trino/client.py @@ -808,7 +808,8 @@ def __init__( request: TrinoRequest, query: str, legacy_primitive_types: bool = False, - fetch_mode: Literal["mapped", "segments"] = "mapped" + fetch_mode: Literal["mapped", "segments"] = "mapped", + heartbeat_interval: float = 60.0, # seconds ) -> None: self._query_id: Optional[str] = None self._stats: Dict[Any, Any] = {} @@ -826,6 +827,11 @@ def __init__( self._legacy_primitive_types = legacy_primitive_types self._row_mapper: Optional[RowMapper] = None self._fetch_mode = fetch_mode + self._heartbeat_interval = heartbeat_interval + self._heartbeat_thread = None + self._heartbeat_stop_event = threading.Event() + self._heartbeat_failures = 0 + self._heartbeat_enabled = True @property def query_id(self) -> Optional[str]: @@ -868,6 +874,39 @@ def result(self): def info_uri(self): return self._info_uri + def _start_heartbeat(self): + if self._heartbeat_thread is not None: + return + self._heartbeat_stop_event.clear() + self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True) + self._heartbeat_thread.start() + + def _stop_heartbeat(self): + self._heartbeat_stop_event.set() + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=2) + self._heartbeat_thread = None + + def _heartbeat_loop(self): + while not self._heartbeat_stop_event.is_set() and not self.finished and not self.cancelled and self._heartbeat_enabled: + if self._next_uri is None: + break + try: + response = self._request.http.head(self._next_uri, timeout=10) + if response.status_code == 404 or response.status_code == 405: + self._heartbeat_enabled = False + break + if response.status_code == 200: + self._heartbeat_failures = 0 + else: + self._heartbeat_failures += 1 + except Exception: + self._heartbeat_failures += 1 + if self._heartbeat_failures >= 3: + self._heartbeat_enabled = False + break + self._heartbeat_stop_event.wait(self._heartbeat_interval) + def execute(self, additional_http_headers=None) -> TrinoResult: """Initiate a Trino query by sending the SQL statement @@ -895,6 +934,9 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) + # Start heartbeat thread + self._start_heartbeat() + # Execute should block until at least one row is received or query is finished or cancelled while not self.finished and not self.cancelled and len(self._result.rows) == 0: self._result.rows += self.fetch() @@ -921,6 +963,7 @@ def fetch(self) -> List[Union[List[Any]], Any]: self._update_state(status) if status.next_uri is None: self._finished = True + self._stop_heartbeat() if not self._row_mapper: return [] @@ -968,6 +1011,7 @@ def cancel(self) -> None: if response.status_code == requests.codes.no_content: self._cancelled = True logger.debug("query cancelled: %s", self.query_id) + self._stop_heartbeat() return self._request.raise_response_error(response) @@ -985,6 +1029,11 @@ def finished(self) -> bool: def cancelled(self) -> bool: return self._cancelled + @property + def is_running(self) -> bool: + """Return True if the query is still running (not finished or cancelled).""" + return not self.finished and not self.cancelled + def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): def wrapper(func): From 6f2109dc461eab95770e4df6449dcba372e90e77 Mon Sep 17 00:00:00 2001 From: EdenKik Date: Sun, 27 Jul 2025 15:13:05 +0300 Subject: [PATCH 2/3] fixed lint issues --- tests/unit/test_client.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f8bf319b..a03e4d35 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1453,9 +1453,11 @@ def delete_password(self, servicename, username): def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data): """Test that heartbeat is sent periodically and does not stop on success.""" head_call_count = 0 + def fake_head(url, timeout=10): nonlocal head_call_count head_call_count += 1 + class Resp: status_code = 200 return Resp() @@ -1470,6 +1472,7 @@ class Resp: http_scheme="http", ) query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.1) + def finish_query(*args, **kwargs): query._finished = True return [] @@ -1481,6 +1484,7 @@ def finish_query(*args, **kwargs): query._stop_heartbeat() assert head_call_count >= 2 + @mock.patch("trino.client.TrinoRequest.http") def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data): """Test that heartbeat stops after 3 consecutive failures.""" @@ -1506,6 +1510,7 @@ class Resp: assert not query._heartbeat_enabled query._stop_heartbeat() + @mock.patch("trino.client.TrinoRequest.http") def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data): """Test that heartbeat stops if server returns 404 or 405.""" @@ -1532,13 +1537,16 @@ class Resp: assert not query._heartbeat_enabled query._stop_heartbeat() + @mock.patch("trino.client.TrinoRequest.http") def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data): """Test that heartbeat stops when the query is finished.""" head_call_count = 0 + def fake_head(url, timeout=10): nonlocal head_call_count head_call_count += 1 + class Resp: status_code = 200 return Resp() @@ -1563,13 +1571,16 @@ class Resp: # Heartbeat should have stopped after query finished assert head_call_count >= 1 + @mock.patch("trino.client.TrinoRequest.http") def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data): """Test that heartbeat stops when the query is cancelled.""" head_call_count = 0 + def fake_head(url, timeout=10): nonlocal head_call_count head_call_count += 1 + class Resp: status_code = 200 return Resp() From 885c4ae6d368034f52a42bcb3a67cf85816a3553 Mon Sep 17 00:00:00 2001 From: nadav tzaysler Date: Sun, 27 Jul 2025 15:38:24 +0300 Subject: [PATCH 3/3] fixed lint issues --- trino/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trino/client.py b/trino/client.py index 73200a72..d1d0ab61 100644 --- a/trino/client.py +++ b/trino/client.py @@ -888,7 +888,8 @@ def _stop_heartbeat(self): self._heartbeat_thread = None def _heartbeat_loop(self): - while not self._heartbeat_stop_event.is_set() and not self.finished and not self.cancelled and self._heartbeat_enabled: + while all([not self._heartbeat_stop_event.is_set(), not self.finished, not self.cancelled, + self._heartbeat_enabled]): if self._next_uri is None: break try: