Skip to content

Commit 5bdbf64

Browse files
committed
add pytorch and python version to NN model; remove model metadata functionality (for now)
1 parent 34e3e01 commit 5bdbf64

File tree

2 files changed

+9
-62
lines changed

2 files changed

+9
-62
lines changed

annif/backend/nn_ensemble.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33

44
from __future__ import annotations
55

6-
import importlib
7-
import json
86
import os.path
97
import shutil
10-
import zipfile
8+
import sys
119
from io import BytesIO
1210
from typing import TYPE_CHECKING, Any
1311

@@ -117,8 +115,10 @@ def save(self, filepath):
117115
torch.save(
118116
{
119117
"model_state_dict": self.state_dict(),
120-
"model_config": self.model_config,
121118
"model_class": self.__class__.__name__,
119+
"model_config": self.model_config,
120+
"pytorch_version": str(torch.__version__),
121+
"python_version": sys.version,
122122
},
123123
filepath,
124124
)
@@ -173,11 +173,9 @@ def initialize(self, parallel: bool = False) -> None:
173173
try:
174174
self._model = NNEnsembleModel.load(model_filename)
175175
except Exception as err:
176-
metadata = self.get_model_metadata(model_filename)
177176
message = (
178177
f"loading model from {model_filename}; "
179-
f"model metadata: {metadata}; "
180-
f'Original error message: "{err}"'
178+
f'original error message: "{err}"'
181179
)
182180
raise OperationFailedException(message, backend_id=self.backend_id)
183181

@@ -343,16 +341,3 @@ def _learn(
343341
self._fit_model(
344342
corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
345343
)
346-
347-
def get_model_metadata(self, model_filename: str) -> dict | None:
348-
"""Read metadata from Keras model files."""
349-
350-
try:
351-
with zipfile.ZipFile(model_filename, "r") as zip:
352-
with zip.open("metadata.json") as metadata_file:
353-
metadata_str = metadata_file.read().decode("utf-8")
354-
metadata = json.loads(metadata_str)
355-
return metadata
356-
except Exception:
357-
self.warning(f"Failed to read metadata from {model_filename}")
358-
return None

tests/test_backend_nn_ensemble.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Unit tests for the nn_ensemble backend in Annif"""
22

3-
import importlib
4-
import os.path
53
import time
64
from datetime import datetime, timedelta, timezone
75
from unittest import mock
@@ -193,44 +191,8 @@ def test_nn_ensemble_modification_time(app_project):
193191
assert datetime.now(timezone.utc) - nn_ensemble.modification_time < timedelta(1)
194192

195193

196-
@pytest.mark.skip
197-
def test_nn_ensemble_get_model_metadata(app_project):
198-
nn_ensemble_type = annif.backend.get_backend("nn_ensemble")
199-
nn_ensemble = nn_ensemble_type(
200-
backend_id="nn_ensemble",
201-
config_params={"sources": "dummy-en"},
202-
project=app_project,
203-
)
204-
model_filename = os.path.join(nn_ensemble.datadir, nn_ensemble.MODEL_FILE)
205-
206-
expected_version = importlib.metadata.version("torch")
207-
expected_date_saved = datetime.now(timezone.utc)
208-
actual_metadata = nn_ensemble.get_model_metadata(model_filename)
209-
210-
assert actual_metadata["torch_version"] == expected_version
211-
datetime_format = "%Y-%m-%d@%H:%M:%S"
212-
actual_datetime = datetime.strptime(actual_metadata["date_saved"], datetime_format)
213-
assert expected_date_saved - actual_datetime.astimezone(
214-
tz=timezone.utc
215-
) < timedelta(1)
216-
217-
218-
def test_nn_ensemble_get_model_metadata_nonexistent_file(app_project):
219-
nn_ensemble_type = annif.backend.get_backend("nn_ensemble")
220-
nn_ensemble = nn_ensemble_type(
221-
backend_id="nn_ensemble",
222-
config_params={"sources": "dummy-en"},
223-
project=app_project,
224-
)
225-
nonexistent_model_file = "nonexistent.zip"
226-
model_filename = os.path.join(nn_ensemble.datadir, nonexistent_model_file)
227-
228-
actual_metadata = nn_ensemble.get_model_metadata(model_filename)
229-
assert actual_metadata is None
230-
231-
232-
@mock.patch("annif.backend.nn_ensemble.load_model", side_effect=Exception)
233-
def test_nn_ensemble_initialize_error(load_model, app_project):
194+
@mock.patch("annif.backend.nn_ensemble.NNEnsembleModel.load", side_effect=Exception)
195+
def test_nn_ensemble_initialize_error(load, app_project):
234196
nn_ensemble_type = annif.backend.get_backend("nn_ensemble")
235197
nn_ensemble = nn_ensemble_type(
236198
backend_id="nn_ensemble",
@@ -240,10 +202,10 @@ def test_nn_ensemble_initialize_error(load_model, app_project):
240202
assert nn_ensemble._model is None
241203
with pytest.raises(
242204
OperationFailedException,
243-
match=r"loading Keras model from .*; model metadata: .*",
205+
match=r"loading model from .*; original error message: .*",
244206
):
245207
nn_ensemble.initialize()
246-
assert load_model.called
208+
assert load.called
247209

248210

249211
def test_nn_ensemble_initialize(app_project):

0 commit comments

Comments
 (0)