Skip to content

Commit 0f43195

Browse files
committed
Reorg
Signed-off-by: Alina Buzachis <[email protected]>
1 parent 40aa17c commit 0f43195

File tree

7 files changed

+184
-132
lines changed

7 files changed

+184
-132
lines changed

core/controller_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
from typing import Dict
3+
from typing import Iterator
4+
from typing import Optional
5+
6+
import requests
7+
from requests import Session
8+
from requests.auth import HTTPBasicAuth
9+
10+
from pattern_service.settings.aap import get_aap_settings
11+
12+
logger = logging.getLogger(__name__)
13+
14+
settings = get_aap_settings()
15+
16+
_aap_session: Optional[Session] = None
17+
18+
19+
def get_http_session(force_refresh: bool = False) -> Session:
20+
"""Returns a cached Session instance with AAP credentials."""
21+
global _aap_session
22+
if _aap_session is None or force_refresh:
23+
session = Session()
24+
session.auth = HTTPBasicAuth(settings.username, settings.password)
25+
session.verify = settings.verify_ssl
26+
session.headers.update({'Content-Type': 'application/json'})
27+
_aap_session = session
28+
29+
return _aap_session
30+
31+
32+
def get(path: str, *, params: Optional[Dict] = None) -> requests.Response:
33+
session = get_http_session()
34+
url = f"{settings.url}{path}"
35+
response = session.get(url, params=params, stream=True)
36+
response.raise_for_status()
37+
38+
return response

core/services.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import contextlib
2+
import io
3+
import json
4+
import logging
5+
import os
6+
import shutil
7+
import tarfile
8+
import tempfile
9+
from typing import Iterator
10+
11+
from .controller_client import get
12+
from .models import Pattern
13+
from .models import Task
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def update_task_status(task: Task, status_: str, details: dict):
19+
task.status = status_
20+
task.details = details
21+
task.save()
22+
23+
24+
@contextlib.contextmanager
25+
def download_collection(collection: str, version: str) -> Iterator[str]:
26+
"""
27+
Downloads and extracts a collection tarball to a temporary directory.
28+
29+
Args:
30+
collection: The name of the collection (e.g., 'my_namespace.my_collection').
31+
version: The version of the collection (e.g., '1.0.0').
32+
33+
Yields:
34+
The path to the extracted collection files.
35+
"""
36+
path = f"/api/galaxy/v3/plugin/ansible/content/published/collections/artifacts/{collection}-{version}.tar.gz"
37+
38+
temp_base_dir = tempfile.mkdtemp()
39+
collection_path = os.path.join(temp_base_dir, f"{collection}-{version}")
40+
os.makedirs(collection_path, exist_ok=True)
41+
42+
try:
43+
response = get(path)
44+
in_memory_tar = io.BytesIO(response.content)
45+
46+
with tarfile.open(fileobj=in_memory_tar, mode="r|*") as tar:
47+
tar.extractall(path=collection_path, filter="data")
48+
49+
logger.info(f"Collection extracted to {collection_path}")
50+
yield collection_path # Yield the path to the caller
51+
finally:
52+
shutil.rmtree(temp_base_dir)
53+
54+
55+
def pattern_task(pattern_id: int, task_id: int):
56+
"""
57+
Orchestrates downloading a collection and saving a pattern definition.
58+
"""
59+
task = Task.objects.get(id=task_id)
60+
try:
61+
pattern = Pattern.objects.get(id=pattern_id)
62+
update_task_status(task, "Running", {"info": "Processing pattern"})
63+
with download_collection(pattern.collection_name.replace(".", "-"), pattern.collection_version) as collection_path:
64+
path_to_definition = os.path.join(collection_path, "extensions", "patterns", pattern.pattern_name, "meta", "pattern.json")
65+
with open(path_to_definition, "r") as file:
66+
definition = json.load(file)
67+
68+
pattern.pattern_definition = definition
69+
pattern.collection_version_uri = pattern.collection_version_uri
70+
pattern.save(update_fields=["pattern_definition", "collection_version_uri"])
71+
update_task_status(task, "Completed", {"info": "Pattern processed successfully"})
72+
except FileNotFoundError:
73+
logger.error(f"Could not find pattern definition for task {task_id}")
74+
update_task_status(task, "Failed", {"error": "Pattern definition not found."})
75+
except Exception as e:
76+
logger.error(f"Task {task_id} failed: {e}")
77+
update_task_status(task, "Failed", {"error": str(e)})

core/tasks.py

Lines changed: 3 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,5 @@
1-
import contextlib
2-
import io
3-
import json
4-
import logging
5-
import os
6-
import shutil
7-
import tarfile
8-
import tempfile
9-
import urllib.parse
10-
from typing import Dict
11-
from typing import Iterator
12-
from typing import Optional
1+
from core.services import pattern_task
132

14-
import requests
15-
from requests import Session
16-
from requests.auth import HTTPBasicAuth
173

18-
from pattern_service.settings.aap import get_aap_settings
19-
20-
from .models import Pattern
21-
from .models import Task
22-
23-
logger = logging.getLogger(__name__)
24-
25-
settings = get_aap_settings()
26-
_aap_session: Optional[Session] = None
27-
28-
29-
def update_task_status(task: Task, status_: str, details: dict):
30-
task.status = status_
31-
task.details = details
32-
task.save()
33-
34-
35-
@contextlib.contextmanager
36-
def download_collection(collection: str, version: str) -> Iterator[str]:
37-
"""
38-
Downloads and extracts a collection tarball to a temporary directory.
39-
40-
Args:
41-
collection: The name of the collection (e.g., 'my_namespace.my_collection').
42-
version: The version of the collection (e.g., '1.0.0').
43-
44-
Yields:
45-
The path to the extracted collection files.
46-
"""
47-
path = f"/api/galaxy/v3/plugin/ansible/content/published/collections/artifacts/{collection}-{version}.tar.gz"
48-
49-
temp_base_dir = tempfile.mkdtemp()
50-
collection_path = os.path.join(temp_base_dir, f"{collection}-{version}")
51-
os.makedirs(collection_path, exist_ok=True)
52-
53-
try:
54-
response = get(path)
55-
in_memory_tar = io.BytesIO(response.content)
56-
57-
with tarfile.open(fileobj=in_memory_tar, mode="r|*") as tar:
58-
tar.extractall(path=collection_path, filter="data")
59-
60-
logger.info(f"Collection extracted to {collection_path}")
61-
yield collection_path # Yield the path to the caller
62-
finally:
63-
shutil.rmtree(temp_base_dir)
64-
65-
66-
def get_http_session(force_refresh: bool = False) -> Session:
67-
"""Returns a cached Session instance with AAP credentials."""
68-
global _aap_session
69-
if _aap_session is None or force_refresh:
70-
session = Session()
71-
session.auth = HTTPBasicAuth(settings.username, settings.password)
72-
session.verify = settings.verify_ssl
73-
session.headers.update({'Content-Type': 'application/json'})
74-
_aap_session = session
75-
76-
return _aap_session
77-
78-
79-
def get(path: str, *, params: Optional[Dict] = None) -> requests.Response:
80-
session = get_http_session()
81-
url = f"{settings.url}{path}"
82-
response = session.get(url, params=params, stream=True)
83-
response.raise_for_status()
84-
85-
return response
86-
87-
88-
def run_pattern_task(pattern_id: int, task_id: int):
89-
"""
90-
Orchestrates downloading a collection and saving a pattern definition.
91-
"""
92-
task = Task.objects.get(id=task_id)
93-
94-
try:
95-
pattern = Pattern.objects.get(id=pattern_id)
96-
update_task_status(task, "Running", {"info": "Processing pattern"})
97-
98-
# Get all necessary names from the pattern object
99-
collection_name = pattern.collection_name.replace(".", "-")
100-
collection_version = pattern.collection_version
101-
pattern_name = pattern.pattern_name
102-
103-
update_task_status(task, "Running", {"info": "Downloading collection tarball"})
104-
105-
with download_collection(collection_name, collection_version) as collection_path:
106-
path_to_definition = os.path.join(collection_path, "extensions", "patterns", pattern_name, "meta", "pattern.json")
107-
108-
with open(path_to_definition, "r") as file:
109-
definition = json.load(file)
110-
111-
pattern.pattern_definition = definition
112-
pattern.collection_version_uri = pattern.collection_version_uri
113-
pattern.save(update_fields=["pattern_definition", "collection_version_uri"])
114-
115-
update_task_status(task, "Completed", {"info": "Pattern processed successfully"})
116-
117-
except FileNotFoundError:
118-
logger.error(f"Could not find pattern definition for task {task_id}")
119-
update_task_status(task, "Failed", {"error": "Pattern definition file not found in collection."})
120-
121-
except Exception as e:
122-
logger.error(f"Task {task_id} failed: {e}")
123-
update_task_status(task, "Failed", {"error": str(e)})
4+
def execute_pattern_task(pattern_id, task_id):
5+
pattern_task(pattern_id, task_id)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from core.controller_client import _aap_session
2+
from core.controller_client import get_http_session
3+
4+
5+
def test_get_http_session_caches():
6+
"""Subsequent calls without force_refresh must return the *same* object."""
7+
s1 = get_http_session()
8+
s2 = get_http_session()
9+
assert s1 is s2
10+
11+
s3 = get_http_session(force_refresh=True)
12+
assert s3 is not s1 and s3 is _aap_session

core/tests/test_tasks.py renamed to core/tests/test_services.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from core.models import Pattern
1313
from core.models import PatternInstance
1414
from core.models import Task
15-
from core.tasks import run_pattern_task
15+
from core.services import pattern_task
1616

1717

1818
class SharedDataMixin:
@@ -70,28 +70,62 @@ def tearDown(self):
7070
self.temp_dirs.clear()
7171

7272

73-
class TaskTests(SharedDataMixin, TestCase):
73+
class PatternTaskTest(SharedDataMixin, TestCase):
74+
@patch("core.services.update_task_status", wraps=pattern_task.__globals__["update_task_status"])
75+
@patch("core.services.open", new_callable=mock_open, read_data='{"name": "test"}')
76+
@patch("core.services.download_collection")
77+
def test_run_pattern_task_success(self, mock_download, mock_open_fn, mock_update_status):
78+
pattern = Pattern.objects.create(
79+
collection_name="demo.collection",
80+
collection_version="1.0.0",
81+
pattern_name="test_pattern",
82+
)
83+
task = Task.objects.create(status="Initiated", details={})
84+
temp_dir = tempfile.mkdtemp()
85+
mock_download.return_value.__enter__.return_value = temp_dir
86+
87+
os.makedirs(os.path.join(temp_dir, "extensions", "patterns", "test_pattern", "meta"))
88+
with open(os.path.join(temp_dir, "extensions", "patterns", "test_pattern", "meta", "pattern.json"), "w") as f:
89+
f.write(json.dumps({"name": "test"}))
90+
91+
pattern_task(pattern.id, task.id)
92+
93+
mock_update_status.assert_any_call(task, "Running", {"info": "Processing pattern"})
94+
mock_update_status.assert_any_call(task, "Completed", {"info": "Pattern processed successfully"})
95+
96+
@patch("core.services.update_task_status", wraps=pattern_task.__globals__["update_task_status"])
97+
@patch("core.services.download_collection", side_effect=FileNotFoundError)
98+
def test_run_pattern_task_file_not_found(self, mock_download, mock_update_status):
99+
pattern = Pattern.objects.create(
100+
collection_name="demo.collection",
101+
collection_version="1.0.0",
102+
pattern_name="missing_pattern",
103+
)
104+
task = Task.objects.create(status="Initiated", details={})
105+
106+
pattern_task(pattern.id, task.id)
107+
108+
mock_update_status.assert_called_with(task, "Failed", {"error": "Pattern definition not found."})
74109

75-
@patch("core.tasks.download_collection", side_effect=Exception("Download failed"))
110+
@patch("core.services.download_collection", side_effect=Exception("Download failed"))
76111
def test_run_pattern_task_handles_download_failure(self, mock_download):
77-
run_pattern_task(self.pattern.id, self.task.id)
112+
pattern_task(self.pattern.id, self.task.id)
78113
self.task.refresh_from_db()
79114
self.assertEqual(self.task.status, "Failed")
80115
self.assertIn("Download failed", self.task.details.get("error", ""))
81116

82-
@patch("core.tasks.update_task_status", wraps=run_pattern_task.__globals__["update_task_status"])
83-
@patch("core.tasks.download_collection")
117+
@patch("core.services.update_task_status", wraps=pattern_task.__globals__["update_task_status"])
118+
@patch("core.services.download_collection")
84119
def test_full_status_update_flow(self, mock_download, mock_update_status):
85120
temp_dir_path = self.create_temp_collection_dir()
86121
mock_download.return_value.__enter__.return_value = temp_dir_path
87122

88123
# Run the task
89-
run_pattern_task(self.pattern.id, self.task.id)
124+
pattern_task(self.pattern.id, self.task.id)
90125

91126
# Verify calls to update_task_status
92127
expected_calls = [
93128
(self.task, "Running", {"info": "Processing pattern"}),
94-
(self.task, "Running", {"info": "Downloading collection tarball"}),
95129
(self.task, "Completed", {"info": "Pattern processed successfully"}),
96130
]
97131
actual_calls = [tuple(call.args) for call in mock_update_status.call_args_list]

core/tests/test_views.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ def setUpTestData(cls):
5151

5252

5353
class PatternViewSetTest(SharedDataMixin, APITestCase):
54+
def setUp(self):
55+
patcher = patch("core.services.update_task_status")
56+
self.mock_update_status = patcher.start()
57+
self.addCleanup(patcher.stop)
58+
59+
# No-op to avoid context['request'] access
60+
self.mock_update_status.side_effect = lambda task, status, details=None: (Task.objects.filter(pk=task.pk).update(status=status, details=details or {}))
61+
5462
def create_temp_collection_dir(self):
5563
temp_dir = tempfile.mkdtemp()
5664
os.makedirs(os.path.join(temp_dir, "extensions", "patterns", "new_pattern", "meta"), exist_ok=True)
@@ -73,7 +81,7 @@ def test_pattern_detail_view(self):
7381
self.assertEqual(response.status_code, status.HTTP_200_OK)
7482
self.assertEqual(response.data["collection_name"], "mynamespace.mycollection")
7583

76-
@patch("core.tasks.download_collection")
84+
@patch("core.services.download_collection")
7785
def test_pattern_create_view(self, mock_download_collection):
7886
temp_dir = self.create_temp_collection_dir() # Simulate a valid pattern.json
7987
mock_download_collection.return_value.__enter__.return_value = temp_dir

core/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .serializers import PatternInstanceSerializer
1616
from .serializers import PatternSerializer
1717
from .serializers import TaskSerializer
18-
from .tasks import run_pattern_task
18+
from .tasks import pattern_task
1919

2020

2121
class CoreViewSet(AnsibleBaseView):
@@ -33,7 +33,7 @@ def create(self, request, *args, **kwargs):
3333

3434
task = Task.objects.create(status="Initiated", details={"model": "Pattern", "id": pattern.id})
3535

36-
run_pattern_task(pattern.id, task.id)
36+
pattern_task(pattern.id, task.id)
3737

3838
headers = self.get_success_headers(serializer.data)
3939

@@ -43,6 +43,7 @@ def create(self, request, *args, **kwargs):
4343
"message": "Pattern creation initiated. Check task status for progress.",
4444
},
4545
status=status.HTTP_202_ACCEPTED,
46+
headers=headers,
4647
)
4748

4849

0 commit comments

Comments
 (0)