Skip to content

Commit d0d1e1a

Browse files
add retry on database connection (#59)
* add retry on database connection * qa * qa
1 parent c8e32f5 commit d0d1e1a

File tree

3 files changed

+138
-10
lines changed

3 files changed

+138
-10
lines changed

.github/workflows/main.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ on:
44
workflow_dispatch:
55
push:
66
branches:
7-
- main
7+
- main
88
tags:
9-
- "v*"
9+
- "v*"
1010
pull_request:
1111
branches:
12-
- main
12+
- main
1313

1414

1515
jobs:

cads_worker/worker.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import random
66
import socket
7+
import time
78
from typing import Any
89

910
import cacholote
@@ -23,6 +24,8 @@
2324

2425
LEVELS_MAPPING = logging.getLevelNamesMapping()
2526

27+
DB_CONNECTION_RETRIES = int(os.getenv("WORKER_DB_CONNECTION_RETRIES", 3))
28+
2629

2730
@functools.lru_cache
2831
def create_session_maker() -> cads_broker.database.sa.orm.sessionmaker:
@@ -32,13 +35,34 @@ def create_session_maker() -> cads_broker.database.sa.orm.sessionmaker:
3235
def ensure_session(func):
3336
@functools.wraps(func)
3437
def wrapper(self, *args, session=None, **kwargs):
35-
close_session = False
36-
if session is None:
37-
session = create_session_maker()()
38-
close_session = True
39-
func(self, *args, session=session, **kwargs)
40-
if close_session:
41-
session.close()
38+
retries = 1
39+
while retries <= DB_CONNECTION_RETRIES:
40+
try:
41+
close_session = False
42+
# create a new session if not provided
43+
if session is None:
44+
session = create_session_maker()()
45+
close_session = True
46+
# run the function
47+
result = func(self, *args, session=session, **kwargs)
48+
# close the session if we created it
49+
if close_session:
50+
session.close()
51+
return result
52+
except cads_broker.database.sa.exc.OperationalError as e:
53+
exception = e
54+
retries += 1
55+
self.logger.warning(
56+
f"Database operation failed. Retrying {retries}/{DB_CONNECTION_RETRIES}...",
57+
error=str(e),
58+
)
59+
# close the session anyway because it could be broken
60+
session.close()
61+
session = None
62+
time.sleep(os.getenv("WORKER_DB_CONNECTION_RETRY_SLEEP", 2))
63+
64+
self.logger.error("Max retries reached. Aborting operation.")
65+
raise exception
4266

4367
return wrapper
4468

tests/test_40_worker.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
from sqlalchemy.orm import Session
5+
6+
from cads_worker import worker
7+
8+
9+
def mock_session_maker():
10+
return Session()
11+
12+
13+
def test_ensure_session_decorator():
14+
# Test case 1: when session is None
15+
@worker.ensure_session
16+
def sample_function(self, session=None):
17+
assert isinstance(session, Session)
18+
return session
19+
20+
with patch(
21+
"cads_worker.worker.create_session_maker", return_value=mock_session_maker
22+
):
23+
result = sample_function(self=None)
24+
assert isinstance(result, Session)
25+
26+
# Test case 2: when session is provided
27+
mock_session = Session()
28+
result = sample_function(self=None, session=mock_session)
29+
assert result is mock_session
30+
31+
# Clean up
32+
mock_session.close()
33+
34+
35+
def test_ensure_session_decorator_nested():
36+
# Test nested function calls with session
37+
@worker.ensure_session
38+
def outer_function(self, session=None):
39+
@worker.ensure_session
40+
def inner_function(self, session=None):
41+
return session
42+
43+
return inner_function(self=None, session=session)
44+
45+
with patch(
46+
"cads_worker.worker.create_session_maker", return_value=mock_session_maker
47+
):
48+
result = outer_function(self=None)
49+
assert isinstance(result, Session)
50+
result.close()
51+
52+
53+
def test_ensure_session_decorator_error():
54+
# Test error handling
55+
@worker.ensure_session
56+
def failing_function(self, session=None):
57+
raise ValueError("Test error")
58+
59+
with pytest.raises(ValueError):
60+
failing_function(self=None)
61+
62+
63+
call_count = 1
64+
65+
66+
def test_ensure_session_retry():
67+
# Test retries
68+
69+
context = worker.Context()
70+
71+
@worker.ensure_session
72+
def failing_function(self, session=None):
73+
print("failing function called")
74+
raise worker.cads_broker.database.sa.exc.OperationalError(
75+
"Simulated DB error", None, None
76+
)
77+
78+
with patch(
79+
"cads_worker.worker.create_session_maker", return_value=mock_session_maker
80+
):
81+
with pytest.raises(worker.cads_broker.database.sa.exc.OperationalError):
82+
# This should raise after max retries
83+
failing_function(self=context)
84+
85+
global call_count
86+
87+
@worker.ensure_session
88+
def successful_function(self, session=None):
89+
global call_count
90+
if call_count < 3:
91+
print("failing function called - retries:", call_count)
92+
call_count += 1
93+
raise worker.cads_broker.database.sa.exc.OperationalError(
94+
"Simulated DB error", None, None
95+
)
96+
else:
97+
print("successful function called - retries:", call_count)
98+
return session
99+
100+
with patch(
101+
"cads_worker.worker.create_session_maker", return_value=mock_session_maker
102+
):
103+
result = successful_function(self=context)
104+
assert isinstance(result, Session)

0 commit comments

Comments
 (0)