|
16 | 16 | import os |
17 | 17 | import pathlib |
18 | 18 | import sys |
| 19 | +import tempfile |
19 | 20 | from collections import namedtuple |
20 | 21 | from typing import Callable, Dict, Optional, Tuple, Union |
21 | 22 |
|
|
29 | 30 | from direct.utils.io import check_is_valid_url, read_text_from_url |
30 | 31 | from direct.utils.logging import setup |
31 | 32 |
|
| 33 | +import platform |
| 34 | +import importlib.metadata |
| 35 | +from collections import namedtuple |
| 36 | + |
32 | 37 | logger = logging.getLogger(__name__) |
33 | 38 |
|
34 | 39 | # Environmental variables |
35 | 40 | DIRECT_ROOT_DIR = pathlib.Path(pathlib.Path(__file__).resolve().parent.parent) |
36 | | -DIRECT_CACHE_DIR = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR))) |
37 | | -DIRECT_MODEL_DOWNLOAD_DIR = ( |
38 | | - pathlib.Path(os.environ.get("DIRECT_MODEL_DOWNLOAD_DIR", str(DIRECT_ROOT_DIR))) / "downloaded_models" |
39 | | -) |
| 41 | + |
| 42 | + |
| 43 | +def resolve_cache_dir() -> pathlib.Path: |
| 44 | + cache_dir_path = pathlib.Path(os.environ.get("DIRECT_CACHE_DIR", str(DIRECT_ROOT_DIR))) |
| 45 | + # Check if the directory is writable |
| 46 | + if os.access(str(cache_dir_path), os.W_OK): |
| 47 | + logger.info(f"Using cache directory: {cache_dir_path}") |
| 48 | + return cache_dir_path |
| 49 | + if "DIRECT_CACHE_DIR" in os.environ: |
| 50 | + env_path = pathlib.Path(os.environ["DIRECT_CACHE_DIR"]) |
| 51 | + if os.access(str(env_path), os.W_OK): |
| 52 | + logger.info(f"Using cache directory: {env_path}") |
| 53 | + return env_path |
| 54 | + try: |
| 55 | + tmpdir = os.environ.get("TMPDIR", tempfile.gettempdir()) |
| 56 | + cache_dir = pathlib.Path(tmpdir) / "direct_cache" |
| 57 | + cache_dir.mkdir(parents=True, exist_ok=True) |
| 58 | + if os.access(str(cache_dir), os.W_OK): |
| 59 | + logger.info(f"Using cache directory: {cache_dir}") |
| 60 | + return cache_dir |
| 61 | + except Exception: |
| 62 | + pass |
| 63 | + |
| 64 | + # Fallback to a default tmp directory |
| 65 | + fallback = pathlib.Path("/tmp/direct_cache") |
| 66 | + fallback.mkdir(parents=True, exist_ok=True) |
| 67 | + logger.warning(f"Falling back to cache directory: {fallback}") |
| 68 | + return fallback |
| 69 | + |
| 70 | + |
| 71 | +DIRECT_CACHE_DIR = resolve_cache_dir() |
| 72 | +DIRECT_MODEL_DOWNLOAD_DIR = DIRECT_CACHE_DIR / "downloaded_models" |
| 73 | + |
| 74 | + |
| 75 | +def collect_env_info() -> str: |
| 76 | + """Collects environment information. |
| 77 | +
|
| 78 | + Returns |
| 79 | + ------- |
| 80 | + env_info: str |
| 81 | + Environment information as a formatted string. |
| 82 | + """ |
| 83 | + SystemEnv = namedtuple( |
| 84 | + "SystemEnv", |
| 85 | + [ |
| 86 | + "torch_version", |
| 87 | + "is_debug_build", |
| 88 | + "cuda_compiled_version", |
| 89 | + "python_version", |
| 90 | + "python_platform", |
| 91 | + "os", |
| 92 | + "libc_version", |
| 93 | + "is_cuda_available", |
| 94 | + "cuda_runtime_version", |
| 95 | + "cudnn_version", |
| 96 | + "pip_packages", |
| 97 | + "cpu_info", |
| 98 | + ], |
| 99 | + ) |
| 100 | + |
| 101 | + def safe_version(pkg): |
| 102 | + try: |
| 103 | + return importlib.metadata.version(pkg) |
| 104 | + except importlib.metadata.PackageNotFoundError: |
| 105 | + return "Not installed" |
| 106 | + |
| 107 | + def get_cudnn_version(): |
| 108 | + try: |
| 109 | + return str(torch.backends.cudnn.version()) if torch.backends.cudnn.is_available() else "Unavailable" |
| 110 | + except Exception: |
| 111 | + return "Unknown" |
| 112 | + |
| 113 | + def get_cpu_info(): |
| 114 | + try: |
| 115 | + return platform.processor() or platform.machine() |
| 116 | + except Exception: |
| 117 | + return "Unknown" |
| 118 | + |
| 119 | + pip_packages = {pkg: safe_version(pkg) for pkg in ["torch", "numpy", "triton", "optree", "mypy", "flake8", "onnx"]} |
| 120 | + pip_str = "\n " + "\n ".join(f"{pkg}=={ver}" for pkg, ver in pip_packages.items()) |
| 121 | + |
| 122 | + def pretty_print(env): |
| 123 | + lines = [ |
| 124 | + f"PyTorch version: {env.torch_version}", |
| 125 | + f"Is debug build: {env.is_debug_build}", |
| 126 | + f"CUDA used to build PyTorch: {env.cuda_compiled_version}", |
| 127 | + f"Python version: {env.python_version}", |
| 128 | + f"Python platform: {env.python_platform}", |
| 129 | + f"OS: {env.os}", |
| 130 | + f"Libc version: {env.libc_version}", |
| 131 | + f"Is CUDA available: {env.is_cuda_available}", |
| 132 | + f"CUDA runtime version: {env.cuda_runtime_version}", |
| 133 | + f"cuDNN version: {env.cudnn_version}", |
| 134 | + f"CPU info: {env.cpu_info}", |
| 135 | + f"Relevant pip packages: {env.pip_packages}", |
| 136 | + ] |
| 137 | + return "\n" + "\n".join(lines) |
| 138 | + |
| 139 | + return pretty_print( |
| 140 | + SystemEnv( |
| 141 | + torch_version=torch.__version__, |
| 142 | + is_debug_build=str(getattr(torch.version, "debug", "Unknown")), |
| 143 | + cuda_compiled_version=getattr(torch.version, "cuda", "None"), |
| 144 | + python_version=sys.version.replace("\n", " "), |
| 145 | + python_platform=platform.platform(), |
| 146 | + os=platform.platform(), |
| 147 | + libc_version="-".join(platform.libc_ver()) if sys.platform.startswith("linux") else "N/A", |
| 148 | + is_cuda_available=str(torch.cuda.is_available()), |
| 149 | + cuda_runtime_version=getattr(torch.version, "cuda", "No CUDA"), |
| 150 | + cudnn_version=get_cudnn_version(), |
| 151 | + pip_packages=pip_str, |
| 152 | + cpu_info=get_cpu_info(), |
| 153 | + ) |
| 154 | + ) |
40 | 155 |
|
41 | 156 |
|
42 | 157 | def load_model_config_from_name(model_name: str) -> Callable: |
@@ -152,7 +267,7 @@ def setup_logging( |
152 | 267 | logger.info("Run name: %s", run_name) |
153 | 268 | logger.info("Config file: %s", cfg_filename) |
154 | 269 | logger.info("CUDA %s - cuDNN %s", torch.version.cuda, torch.backends.cudnn.version()) |
155 | | - logger.info("Environment information: %s", collect_env.get_pretty_env_info()) |
| 270 | + logger.info("Environment information: %s", collect_env_info()) |
156 | 271 | logger.info("DIRECT version: %s", direct.__version__) |
157 | 272 | git_hash = direct.utils.git_hash() |
158 | 273 | logger.info("Git hash: %s", git_hash if git_hash else "N/A") |
|
0 commit comments