Skip to content

Commit 87efa45

Browse files
bazel compliance: Rewrite collect_env_info and direct cache a tempfile
1 parent 07d65fc commit 87efa45

File tree

5 files changed

+134
-26
lines changed

5 files changed

+134
-26
lines changed

direct/environment.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pathlib
1818
import sys
19+
import tempfile
1920
from collections import namedtuple
2021
from typing import Callable, Dict, Optional, Tuple, Union
2122

@@ -29,14 +30,128 @@
2930
from direct.utils.io import check_is_valid_url, read_text_from_url
3031
from direct.utils.logging import setup
3132

33+
import platform
34+
import importlib.metadata
35+
from collections import namedtuple
36+
3237
logger = logging.getLogger(__name__)
3338

3439
# Environmental variables
3540
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent)
36-
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
37-
DIRECT_MODEL_DOWNLOAD_DIR = (
38-
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models"
39-
)
41+
42+
43+
def resolve_cache_dir() -> pathlib.Path:
44+
cache_dir_path = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
45+
# Check if the directory is writable
46+
if os.access(str(cache_dir_path), os.W_OK):
47+
logger.info(f"Using cache directory: {cache_dir_path}")
48+
return cache_dir_path
49+
if "DIRECT_CACHE_DIR" in os.environ:
50+
env_path = pathlib.Path(os.environ["DIRECT_CACHE_DIR"])
51+
if os.access(str(env_path), os.W_OK):
52+
logger.info(f"Using cache directory: {env_path}")
53+
return env_path
54+
try:
55+
tmpdir = os.environ.get("TMPDIR", tempfile.gettempdir())
56+
cache_dir = pathlib.Path(tmpdir) / "direct_cache"
57+
cache_dir.mkdir(parents=True, exist_ok=True)
58+
if os.access(str(cache_dir), os.W_OK):
59+
logger.info(f"Using cache directory: {cache_dir}")
60+
return cache_dir
61+
except Exception:
62+
pass
63+
64+
# Fallback to a default tmp directory
65+
fallback = pathlib.Path("/tmp/direct_cache")
66+
fallback.mkdir(parents=True, exist_ok=True)
67+
logger.warning(f"Falling back to cache directory: {fallback}")
68+
return fallback
69+
70+
71+
DIRECT_CACHE_DIR = resolve_cache_dir()
72+
DIRECT_MODEL_DOWNLOAD_DIR = DIRECT_CACHE_DIR / "downloaded_models"
73+
74+
75+
def collect_env_info() -> str:
76+
"""Collects environment information.
77+
78+
Returns
79+
-------
80+
env_info: str
81+
Environment information as a formatted string.
82+
"""
83+
SystemEnv = namedtuple(
84+
"SystemEnv",
85+
[
86+
"torch_version",
87+
"is_debug_build",
88+
"cuda_compiled_version",
89+
"python_version",
90+
"python_platform",
91+
"os",
92+
"libc_version",
93+
"is_cuda_available",
94+
"cuda_runtime_version",
95+
"cudnn_version",
96+
"pip_packages",
97+
"cpu_info",
98+
],
99+
)
100+
101+
def safe_version(pkg):
102+
try:
103+
return importlib.metadata.version(pkg)
104+
except importlib.metadata.PackageNotFoundError:
105+
return "Not installed"
106+
107+
def get_cudnn_version():
108+
try:
109+
return str(torch.backends.cudnn.version()) if torch.backends.cudnn.is_available() else "Unavailable"
110+
except Exception:
111+
return "Unknown"
112+
113+
def get_cpu_info():
114+
try:
115+
return platform.processor() or platform.machine()
116+
except Exception:
117+
return "Unknown"
118+
119+
pip_packages = {pkg: safe_version(pkg) for pkg in ["torch", "numpy", "triton", "optree", "mypy", "flake8", "onnx"]}
120+
pip_str = "\n " + "\n ".join(f"{pkg}=={ver}" for pkg, ver in pip_packages.items())
121+
122+
def pretty_print(env):
123+
lines = [
124+
f"PyTorch version: {env.torch_version}",
125+
f"Is debug build: {env.is_debug_build}",
126+
f"CUDA used to build PyTorch: {env.cuda_compiled_version}",
127+
f"Python version: {env.python_version}",
128+
f"Python platform: {env.python_platform}",
129+
f"OS: {env.os}",
130+
f"Libc version: {env.libc_version}",
131+
f"Is CUDA available: {env.is_cuda_available}",
132+
f"CUDA runtime version: {env.cuda_runtime_version}",
133+
f"cuDNN version: {env.cudnn_version}",
134+
f"CPU info: {env.cpu_info}",
135+
f"Relevant pip packages: {env.pip_packages}",
136+
]
137+
return "\n" + "\n".join(lines)
138+
139+
return pretty_print(
140+
SystemEnv(
141+
torch_version=torch.__version__,
142+
is_debug_build=str(getattr(torch.version, "debug", "Unknown")),
143+
cuda_compiled_version=getattr(torch.version, "cuda", "None"),
144+
python_version=sys.version.replace("\n", " "),
145+
python_platform=platform.platform(),
146+
os=platform.platform(),
147+
libc_version="-".join(platform.libc_ver()) if sys.platform.startswith("linux") else "N/A",
148+
is_cuda_available=str(torch.cuda.is_available()),
149+
cuda_runtime_version=getattr(torch.version, "cuda", "No CUDA"),
150+
cudnn_version=get_cudnn_version(),
151+
pip_packages=pip_str,
152+
cpu_info=get_cpu_info(),
153+
)
154+
)
40155

41156

42157
def load_model_config_from_name(model_name: str) -> Callable:
@@ -152,7 +267,7 @@ def setup_logging(
152267
logger.info("Run name: %s", run_name)
153268
logger.info("Config file: %s", cfg_filename)
154269
logger.info("CUDA %s - cuDNN %s", torch.version.cuda, torch.backends.cudnn.version())
155-
logger.info("Environment information: %s", collect_env.get_pretty_env_info())
270+
logger.info("Environment information: %s", collect_env_info())
156271
logger.info("DIRECT version: %s", direct.__version__)
157272
git_hash = direct.utils.git_hash()
158273
logger.info("Git hash: %s", git_hash if git_hash else "N/A")

requirements_darwin.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,9 +778,7 @@ scipy==1.15.2 \
778778
setuptools==79.0.0 \
779779
--hash=sha256:9828422e7541213b0aacb6e10bbf9dd8febeaa45a48570e09b6d100e063fc9f9 \
780780
--hash=sha256:b9ab3a104bedb292323f53797b00864e10e434a3ab3906813a7169e4745b912a
781-
# via
782-
# tensorboard
783-
# torch
781+
# via tensorboard
784782
six==1.17.0 \
785783
--hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \
786784
--hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81

requirements_linux.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,7 @@ scipy==1.15.2 \
848848
setuptools==79.0.0 \
849849
--hash=sha256:9828422e7541213b0aacb6e10bbf9dd8febeaa45a48570e09b6d100e063fc9f9 \
850850
--hash=sha256:b9ab3a104bedb292323f53797b00864e10e434a3ab3906813a7169e4745b912a
851-
# via
852-
# tensorboard
853-
# torch
851+
# via tensorboard
854852
six==1.17.0 \
855853
--hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \
856854
--hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81

tests/tests_nn/cirim_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@ def create_input(shape):
3131
)
3232
@pytest.mark.parametrize(
3333
"depth",
34-
[2, 4],
34+
[2, 3],
3535
)
3636
@pytest.mark.parametrize(
3737
"time_steps",
38-
[8, 16],
38+
[4, 6],
3939
)
4040
@pytest.mark.parametrize(
4141
"recurrent_hidden_channels",
42-
[64, 128],
42+
[64],
4343
)
4444
@pytest.mark.parametrize(
4545
"num_cascades",
46-
[1, 2, 8],
46+
[1, 4],
4747
)
4848
@pytest.mark.parametrize(
4949
"no_parameter_sharing",

tests/tests_nn/transformers_test.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ def create_input(shape):
2929

3030
@pytest.mark.parametrize(
3131
"shape",
32-
[
33-
[3, 2, 32, 32],
34-
[3, 2, 16, 16],
35-
],
32+
[[1, 2, 32, 32]],
3633
)
3734
@pytest.mark.parametrize(
3835
"embedding_dim",
@@ -123,23 +120,23 @@ def test_uformer(
123120
)
124121
@pytest.mark.parametrize(
125122
"patch_size",
126-
[16, 8, (16, 10)],
123+
[16, (16, 10)],
127124
)
128125
@pytest.mark.parametrize(
129126
"embedding_dim",
130127
[6, 12],
131128
)
132129
@pytest.mark.parametrize(
133130
"depth",
134-
[2, 4],
131+
[2],
135132
)
136133
@pytest.mark.parametrize(
137134
"num_heads",
138-
[3, 4],
135+
[3],
139136
)
140137
@pytest.mark.parametrize(
141138
"mlp_ratio",
142-
[4.0, 2.0],
139+
[2.0],
143140
)
144141
@pytest.mark.parametrize(
145142
"qkv_bias",
@@ -163,7 +160,7 @@ def test_uformer(
163160
)
164161
@pytest.mark.parametrize(
165162
"normalized",
166-
[True, False],
163+
[True],
167164
)
168165
def test_vision_transformer_2d(
169166
shape,
@@ -213,11 +210,11 @@ def test_vision_transformer_2d(
213210
)
214211
@pytest.mark.parametrize(
215212
"embedding_dim",
216-
[8, 16],
213+
[8],
217214
)
218215
@pytest.mark.parametrize(
219216
"depth",
220-
[4, 8],
217+
[4],
221218
)
222219
@pytest.mark.parametrize(
223220
"num_heads",

0 commit comments

Comments
 (0)