Skip to content

Commit 50b9704

Browse files
fix(bazel): adjust test setup and environment handling for Bazel compatibility
* Enable to see projects * bazel compliance: Rewrite collect_env_info and direct cache a tempfile
1 parent 978a185 commit 50b9704

File tree

16 files changed

+191
-73
lines changed

16 files changed

+191
-73
lines changed

MODULE.bazel

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,28 @@ bazel_dep(name = "aspect_rules_py", version = "1.3.2")
2323
# This allows sharding of tests and easier setup
2424
archive_override(
2525
module_name = "aspect_rules_py",
26-
integrity = "sha256-p7Fo6yMoh96SurOcGt1uLt95vRf/RUeuPvr/1oA6jzw=",
27-
strip_prefix = "rules_py-1296c6b88156594543b7f41ceb331bdc9a053055",
28-
urls = ["https://github.com/aspect-build/rules_py/archive/1296c6b88156594543b7f41ceb331bdc9a053055.zip"],
26+
integrity = "sha256-0EJJ2KgKvlISrxVk5Tmc74iMMJNZJZ/3dG55azcc8sA=",
27+
strip_prefix = "rules_py-1445ccaf3665cb5d8f78da4f5fc4d73fd36fa165",
28+
urls = ["https://github.com/aspect-build/rules_py/archive/1445ccaf3665cb5d8f78da4f5fc4d73fd36fa165.zip"],
2929
)
3030

3131
bazel_dep(name = "rules_uv", version = "0.65.0", dev_dependency = True)
3232

3333
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
3434
python.toolchain(
3535
is_default = True,
36-
python_version = "3.12",
36+
python_version = "3.11",
3737
)
3838

3939
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
4040
pip.parse(
4141
enable_implicit_namespace_pkgs = True,
42-
hub_name = "pip",
43-
python_version = "3.12",
42+
hub_name = "direct_pip",
43+
python_version = "3.11",
4444
requirements_darwin = ":requirements_darwin.txt",
4545
requirements_lock = ":requirements_linux.txt",
4646
)
47-
use_repo(pip, "pip")
47+
use_repo(pip, "direct_pip")
4848

4949
npm = use_extension("@aspect_rules_js//npm:extensions.bzl", "npm")
5050
npm.npm_translate_lock(

direct/BUILD.bazel

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library")
2-
load("@pip//:requirements.bzl", "requirement")
2+
load("@direct_pip//:requirements.bzl", "requirement")
33
load("@rules_cc//cc:defs.bzl", "cc_library")
44
load("//tools:cython_rules.bzl", "pyx_library")
55

66
cc_library(
77
name = "numpy_headers",
88
hdrs = [
9-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/__multiarray_api.h",
10-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/__ufunc_api.h",
11-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/_dtype_api.h",
12-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h",
13-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/_numpyconfig.h",
14-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/arrayobject.h",
15-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/arrayscalars.h",
16-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/experimental_dtype_api.h",
17-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/halffloat.h",
18-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/ndarrayobject.h",
19-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/ndarraytypes.h",
20-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/noprefix.h",
21-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h",
22-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_3kcompat.h",
23-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_common.h",
24-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_cpu.h",
25-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_endian.h",
26-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_interrupt.h",
27-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_math.h",
28-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_no_deprecated_api.h",
29-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/npy_os.h",
30-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/numpyconfig.h",
31-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/old_defines.h",
32-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/random/bitgen.h",
33-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/random/distributions.h",
34-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/random/libdivide.h",
35-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/ufuncobject.h",
36-
"@@rules_python~~pip~pip_312_numpy//:site-packages/numpy/core/include/numpy/utils.h",
9+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/__multiarray_api.h",
10+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/__ufunc_api.h",
11+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/_dtype_api.h",
12+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h",
13+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/_numpyconfig.h",
14+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/arrayobject.h",
15+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/arrayscalars.h",
16+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/experimental_dtype_api.h",
17+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/halffloat.h",
18+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/ndarrayobject.h",
19+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/ndarraytypes.h",
20+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/noprefix.h",
21+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h",
22+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_3kcompat.h",
23+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_common.h",
24+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_cpu.h",
25+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_endian.h",
26+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_interrupt.h",
27+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_math.h",
28+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_no_deprecated_api.h",
29+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/npy_os.h",
30+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/numpyconfig.h",
31+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/old_defines.h",
32+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/random/bitgen.h",
33+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/random/distributions.h",
34+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/random/libdivide.h",
35+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/ufuncobject.h",
36+
"@@rules_python~~pip~direct_pip_311_numpy//:site-packages/numpy/core/include/numpy/utils.h",
3737
],
3838
strip_include_prefix = "/site-packages/numpy/core/include",
3939
visibility = ["//visibility:private"],
@@ -63,7 +63,10 @@ pyx_library(
6363
py_library(
6464
name = "algorithms",
6565
srcs = glob(["algorithms/**/*.py"]),
66-
deps = [requirement("torch"), requirement("numpy")],
66+
deps = [
67+
requirement("torch"),
68+
requirement("numpy"),
69+
],
6770
)
6871

6972
py_library(
@@ -130,7 +133,9 @@ py_library(
130133

131134
py_library(
132135
name = "direct_lib",
133-
srcs = glob(["*.py"]),
136+
srcs = glob(["**/*.py"]),
137+
imports = ["."],
138+
visibility = ["//visibility:public"],
134139
deps = [
135140
":algorithms",
136141
":common",
@@ -145,14 +150,16 @@ py_library(
145150
requirement("torchvision"),
146151
requirement("omegaconf"),
147152
],
148-
visibility = ["//visibility:public"],
149153
)
150154

151155
py_binary(
152156
name = "direct",
153157
srcs = glob([
154158
"cli/**/*.py",
155159
]),
160+
data = [
161+
"//projects:all_configs",
162+
],
156163
main = "cli/cli.py",
157164
deps = [
158165
":direct_lib",

direct/environment.py

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import argparse
15+
import importlib.metadata
1516
import logging
1617
import os
1718
import pathlib
19+
import platform
1820
import sys
21+
import tempfile
1922
from collections import namedtuple
2023
from typing import Callable, Dict, Optional, Tuple, Union
2124

@@ -33,10 +36,120 @@
3336

3437
# Environmental variables
3538
DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent)
36-
DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
37-
DIRECT_MODEL_DOWNLOAD_DIR = (
38-
pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models"
39-
)
39+
40+
41+
def resolve_cache_dir() -> pathlib.Path:
42+
cache_dir_path = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR)))
43+
# Check if the directory is writable
44+
if os.access(str(cache_dir_path), os.W_OK):
45+
logger.info("Using cache directory: %s", cache_dir_path)
46+
return cache_dir_path
47+
if "DIRECT_CACHE_DIR" in os.environ:
48+
env_path = pathlib.Path(os.environ["DIRECT_CACHE_DIR"])
49+
if os.access(str(env_path), os.W_OK):
50+
logger.info("Using cache directory: %s", env_path)
51+
return env_path
52+
try:
53+
tmpdir = os.environ.get("TMPDIR", tempfile.gettempdir())
54+
cache_dir = pathlib.Path(tmpdir) / "direct_cache"
55+
cache_dir.mkdir(parents=True, exist_ok=True)
56+
if os.access(str(cache_dir), os.W_OK):
57+
logger.info("Using cache directory: %s", cache_dir)
58+
return cache_dir
59+
except Exception as e:
60+
logger.warning("Failed to create or access cache directory in TMPDIR: %s", e)
61+
62+
# Fallback to a default tmp directory
63+
fallback = pathlib.Path(tempfile.gettempdir()) / "direct_cache"
64+
fallback.mkdir(parents=True, exist_ok=True)
65+
logger.warning("Falling back to cache directory: %s", fallback)
66+
return fallback
67+
68+
69+
DIRECT_CACHE_DIR = resolve_cache_dir()
70+
DIRECT_MODEL_DOWNLOAD_DIR = DIRECT_CACHE_DIR / "downloaded_models"
71+
72+
73+
def collect_env_info() -> str:
74+
"""Collects environment information.
75+
76+
Returns
77+
-------
78+
env_info: str
79+
Environment information as a formatted string.
80+
"""
81+
SystemEnv = namedtuple(
82+
"SystemEnv",
83+
[
84+
"torch_version",
85+
"is_debug_build",
86+
"cuda_compiled_version",
87+
"python_version",
88+
"python_platform",
89+
"os",
90+
"libc_version",
91+
"is_cuda_available",
92+
"cuda_runtime_version",
93+
"cudnn_version",
94+
"pip_packages",
95+
"cpu_info",
96+
],
97+
)
98+
99+
def safe_version(pkg):
100+
try:
101+
return importlib.metadata.version(pkg)
102+
except importlib.metadata.PackageNotFoundError:
103+
return "Not installed"
104+
105+
def get_cudnn_version():
106+
try:
107+
return str(torch.backends.cudnn.version()) if torch.backends.cudnn.is_available() else "Unavailable"
108+
except Exception:
109+
return "Unknown"
110+
111+
def get_cpu_info():
112+
try:
113+
return platform.processor() or platform.machine()
114+
except Exception:
115+
return "Unknown"
116+
117+
pip_packages = {pkg: safe_version(pkg) for pkg in ["torch", "numpy", "triton", "optree", "mypy", "flake8", "onnx"]}
118+
pip_str = "\n " + "\n ".join(f"{pkg}=={ver}" for pkg, ver in pip_packages.items())
119+
120+
def pretty_print(env):
121+
lines = [
122+
f"PyTorch version: {env.torch_version}",
123+
f"Is debug build: {env.is_debug_build}",
124+
f"CUDA used to build PyTorch: {env.cuda_compiled_version}",
125+
f"Python version: {env.python_version}",
126+
f"Python platform: {env.python_platform}",
127+
f"OS: {env.os}",
128+
f"Libc version: {env.libc_version}",
129+
f"Is CUDA available: {env.is_cuda_available}",
130+
f"CUDA runtime version: {env.cuda_runtime_version}",
131+
f"cuDNN version: {env.cudnn_version}",
132+
f"CPU info: {env.cpu_info}",
133+
f"Relevant pip packages: {env.pip_packages}",
134+
]
135+
return "\n" + "\n".join(lines)
136+
137+
return pretty_print(
138+
SystemEnv(
139+
torch_version=torch.__version__,
140+
is_debug_build=str(getattr(torch.version, "debug", "Unknown")),
141+
cuda_compiled_version=getattr(torch.version, "cuda", "None"),
142+
python_version=sys.version.replace("\n", " "),
143+
python_platform=platform.platform(),
144+
os=platform.platform(),
145+
libc_version="-".join(platform.libc_ver()) if sys.platform.startswith("linux") else "N/A",
146+
is_cuda_available=str(torch.cuda.is_available()),
147+
cuda_runtime_version=getattr(torch.version, "cuda", "No CUDA"),
148+
cudnn_version=get_cudnn_version(),
149+
pip_packages=pip_str,
150+
cpu_info=get_cpu_info(),
151+
)
152+
)
40153

41154

42155
def load_model_config_from_name(model_name: str) -> Callable:
@@ -152,7 +265,7 @@ def setup_logging(
152265
logger.info("Run name: %s", run_name)
153266
logger.info("Config file: %s", cfg_filename)
154267
logger.info("CUDA %s - cuDNN %s", torch.version.cuda, torch.backends.cudnn.version())
155-
logger.info("Environment information: %s", collect_env.get_pretty_env_info())
268+
logger.info("Environment information: %s", collect_env_info())
156269
logger.info("DIRECT version: %s", direct.__version__)
157270
git_hash = direct.utils.git_hash()
158271
logger.info("Git hash: %s", git_hash if git_hash else "N/A")

projects/BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
filegroup(
2+
name = "all_configs",
3+
srcs = glob(["**/*.yaml"]),
4+
visibility = ["//visibility:public"],
5+
)

requirements_darwin.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,9 +778,7 @@ scipy==1.15.2 \
778778
setuptools==79.0.0 \
779779
--hash=sha256:9828422e7541213b0aacb6e10bbf9dd8febeaa45a48570e09b6d100e063fc9f9 \
780780
--hash=sha256:b9ab3a104bedb292323f53797b00864e10e434a3ab3906813a7169e4745b912a
781-
# via
782-
# tensorboard
783-
# torch
781+
# via tensorboard
784782
six==1.17.0 \
785783
--hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \
786784
--hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81

requirements_linux.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,7 @@ scipy==1.15.2 \
848848
setuptools==79.0.0 \
849849
--hash=sha256:9828422e7541213b0aacb6e10bbf9dd8febeaa45a48570e09b6d100e063fc9f9 \
850850
--hash=sha256:b9ab3a104bedb292323f53797b00864e10e434a3ab3906813a7169e4745b912a
851-
# via
852-
# tensorboard
853-
# torch
851+
# via tensorboard
854852
six==1.17.0 \
855853
--hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \
856854
--hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81

tests/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@aspect_rules_py//py:defs.bzl", "py_test")
2-
load("@pip//:requirements.bzl", "requirement")
2+
load("@direct_pip//:requirements.bzl", "requirement")
33

44
REQUIREMENTS = [
55
"//direct:direct_lib",

tests/test_cli/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@aspect_rules_py//py:defs.bzl", "py_test")
2-
load("@pip//:requirements.bzl", "requirement")
2+
load("@direct_pip//:requirements.bzl", "requirement")
33

44
REQUIREMENTS = [
55
"//direct:direct_lib",

tests/tests_common/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@aspect_rules_py//py:defs.bzl", "py_test")
2-
load("@pip//:requirements.bzl", "requirement")
2+
load("@direct_pip//:requirements.bzl", "requirement")
33

44
REQUIREMENTS = [
55
"//direct:direct_lib",

tests/tests_data/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@aspect_rules_py//py:defs.bzl", "py_test")
2-
load("@pip//:requirements.bzl", "requirement")
2+
load("@direct_pip//:requirements.bzl", "requirement")
33

44
REQUIREMENTS = [
55
"//direct:direct_lib",

0 commit comments

Comments
 (0)