diff --git a/CHANGELOG.md b/CHANGELOG.md index af8711b0..313bf2b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.35-dev1 + +* feat: add INFERENCE_GLOBAL_WORKING_DIR and INFERENCE_GLOBAL_WOKRING_PROCESS_DIR configuration parameters to control temporary storage + ## 0.7.34 * Reduce excessive logging diff --git a/test_unstructured_inference/test_config.py b/test_unstructured_inference/test_config.py index 14d93a50..fc0ccf63 100644 --- a/test_unstructured_inference/test_config.py +++ b/test_unstructured_inference/test_config.py @@ -1,3 +1,9 @@ +from pathlib import Path +import shutil +import tempfile +import pytest + + def test_default_config(): from unstructured_inference.config import inference_config @@ -9,3 +15,43 @@ def test_env_override(monkeypatch): from unstructured_inference.config import inference_config assert inference_config.TT_TABLE_CONF == 1 + + +@pytest.fixture() +def _setup_tmpdir(): + from unstructured_inference.config import inference_config + + _tmpdir = tempfile.tempdir + _storage_tmpdir = inference_config.INFERENCE_GLOBAL_WORKING_PROCESS_DIR + _storage_tmpdir_bak = f"{inference_config.INFERENCE_GLOBAL_WORKING_PROCESS_DIR}_bak" + if Path(_storage_tmpdir).is_dir(): + shutil.move(_storage_tmpdir, _storage_tmpdir_bak) + tempfile.tempdir = None + yield + if Path(_storage_tmpdir_bak).is_dir(): + if Path(_storage_tmpdir).is_dir(): + shutil.rmtree(_storage_tmpdir) + shutil.move(_storage_tmpdir_bak, _storage_tmpdir) + tempfile.tempdir = _tmpdir + + +@pytest.mark.usefixtures("_setup_tmpdir") +def test_env_storage_disabled(monkeypatch): + monkeypatch.setenv("INFERENCE_GLOBAL_WORKING_DIR_ENABLED", "false") + from unstructured_inference.config import inference_config + + assert not inference_config.INFERENCE_GLOBAL_WORKING_DIR_ENABLED + assert str(Path.home() / ".cache/unstructured") == inference_config.INFERENCE_GLOBAL_WORKING_DIR + assert not Path(inference_config.INFERENCE_GLOBAL_WORKING_PROCESS_DIR).is_dir() + assert tempfile.gettempdir() != inference_config.INFERENCE_GLOBAL_WORKING_PROCESS_DIR + + +@pytest.mark.usefixtures("_setup_tmpdir") +def test_env_storage_enabled(monkeypatch): + monkeypatch.setenv("INFERENCE_GLOBAL_WORKING_DIR_ENABLED", "true") + from unstructured_inference.config import inference_config + + assert inference_config.INFERENCE_GLOBAL_WORKING_DIR_ENABLED + assert str(Path.home() / ".cache/unstructured") == inference_config.INFERENCE_GLOBAL_WORKING_DIR + assert Path(inference_config.INFERENCE_GLOBAL_WORKING_PROCESS_DIR).is_dir() + assert tempfile.gettempdir() == inference_config.INFERENCE_GLOBAL_WORKING_PROCESS_DIR diff --git a/unstructured_inference/__init__.py b/unstructured_inference/__init__.py index e69de29b..2404c038 100644 --- a/unstructured_inference/__init__.py +++ b/unstructured_inference/__init__.py @@ -0,0 +1,4 @@ +from .config import inference_config + +# init inference_config +inference_config diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index e6fd9f15..7f3624d3 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.34" # pragma: no cover +__version__ = "0.7.35-dev1" # pragma: no cover diff --git a/unstructured_inference/config.py b/unstructured_inference/config.py index d5765bbf..ce740a62 100644 --- a/unstructured_inference/config.py +++ b/unstructured_inference/config.py @@ -7,13 +7,26 @@ """ import os +import tempfile from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + + +@lru_cache(maxsize=1) +def _get_tempdir(dir: str) -> str: + tempdir = Path(dir) / f"tmp/{os.getpgid(0)}" + return str(tempdir) @dataclass class InferenceConfig: """class for configuring inference parameters""" + def ___post_init__(self): + if self.INFERENCE_GLOBAL_WORKING_DIR_ENABLED: + self._setup_tmpdir(self.INFERENCE_GLOBAL_WORKING_PROCESS_DIR) + def _get_string(self, var: str, default_value: str = "") -> str: """attempt to get the value of var from the os environment; if not present return the default_value""" @@ -29,6 +42,15 @@ def _get_float(self, var: str, default_value: float) -> float: return float(value) return default_value + def _get_bool(self, var: str, default_value: bool) -> bool: + if value := self._get_string(var): + return value.lower() in ("true", "1", "t") + return default_value + + def _setup_tmpdir(self, tmpdir: str) -> None: + Path(tmpdir).mkdir(parents=True, exist_ok=True) + tempfile.tempdir = tmpdir + @property def TABLE_IMAGE_BACKGROUND_PAD(self) -> int: """number of pixels to pad around an table image with a white background color @@ -106,5 +128,30 @@ def ELEMENTS_V_PADDING_COEF(self) -> float: """Same as ELEMENTS_H_PADDING_COEF but the vertical extension.""" return self._get_float("ELEMENTS_V_PADDING_COEF", 0.3) + @property + def INFERENCE_GLOBAL_WORKING_DIR_ENABLED(self) -> bool: + """Enable usage of INFERENCE_GLOBAL_WORKING_DIR and INFERENCE_GLOBAL_WORKING_PROCESS_DIR.""" + return self._get_bool("INFERENCE_GLOBAL_WORKING_DIR_ENABLED", False) + + @property + def INFERENCE_GLOBAL_WORKING_DIR(self) -> str: + """Path to Unstructured cache directory.""" + return self._get_string( + "INFERENCE_GLOBAL_WORKING_DIR", str(Path.home() / ".cache/unstructured") + ) + + @property + def INFERENCE_GLOBAL_WORKING_PROCESS_DIR(self) -> str: + """Path to Unstructured cache tempdir. Overrides TMPDIR, TEMP and TMP. + Defaults to '{INFERENCE_GLOBAL_WORKING_DIR}/tmp/{os.getpgid(0)}'. + """ + default_tmpdir = _get_tempdir(dir=self.INFERENCE_GLOBAL_WORKING_DIR) + tmpdir = self._get_string("INFERENCE_GLOBAL_WORKING_PROCESS_DIR", default_tmpdir) + if tmpdir == "": + tmpdir = default_tmpdir + if self.INFERENCE_GLOBAL_WORKING_DIR_ENABLED: + self._setup_tmpdir(tmpdir) + return tmpdir + inference_config = InferenceConfig()