|
2 | 2 |
|
3 | 3 | import io |
4 | 4 | import logging |
5 | | -from functools import partial |
6 | 5 | from typing import Any, Dict, List, NamedTuple, Optional, Sequence |
7 | 6 |
|
8 | 7 | import tensorrt as trt |
9 | 8 | import torch |
10 | 9 | from torch_tensorrt._enums import dtype |
11 | 10 | from torch_tensorrt._features import ENABLED_FEATURES |
12 | 11 | from torch_tensorrt._Input import Input |
13 | | -from torch_tensorrt.distributed._distributed import ( |
14 | | - is_distributed_caching_enabled, |
15 | | - signal_distributed_engine_build_complete, |
16 | | - wait_for_distributed_engine_build, |
17 | | -) |
18 | | -from torch_tensorrt.distributed._lock import DistributedFileLock |
| 12 | +from torch_tensorrt.distributed._distributed import is_distributed_caching_enabled |
19 | 13 | from torch_tensorrt.dynamo._engine_cache import BaseEngineCache |
20 | 14 | from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible |
21 | 15 | from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import ( |
@@ -275,33 +269,33 @@ def interpret_module_to_result( |
275 | 269 | settings.cache_built_engines, |
276 | 270 | settings.reuse_cached_engines, |
277 | 271 | ) |
278 | | - _build_lock = None |
| 272 | + _lock: Optional[Any] = None |
279 | 273 |
|
280 | 274 | if _distributed_caching: |
| 275 | + import os as _os |
| 276 | + |
| 277 | + from filelock import FileLock |
| 278 | + |
281 | 279 | # is_distributed_caching_enabled guarantees engine_cache and hash_val are set. |
282 | 280 | assert engine_cache is not None |
283 | 281 | assert hash_val is not None |
284 | | - _build_lock = DistributedFileLock(engine_cache.engine_cache_dir, hash_val) |
285 | | - if _build_lock.acquire(): |
286 | | - logger.info("Acquired engine build lock — this rank builds") |
287 | | - else: |
288 | | - logger.info("Lock held by another rank — polling for cached engine") |
289 | | - _pull_fn = partial( |
290 | | - pull_cached_engine, |
291 | | - hash_val, |
292 | | - module, |
293 | | - engine_cache, |
294 | | - settings, |
295 | | - inputs, |
296 | | - symbolic_shape_expressions, |
297 | | - ) |
298 | | - cached: Optional[SerializedInterpreterResult] = ( |
299 | | - wait_for_distributed_engine_build( |
300 | | - _pull_fn, engine_cache.engine_cache_dir, hash_val |
301 | | - ) |
302 | | - ) |
303 | | - if cached is not None: |
304 | | - return cached |
| 282 | + |
| 283 | + _lock_path = _os.path.join(engine_cache.engine_cache_dir, f".{hash_val}.lock") |
| 284 | + _lock = FileLock(_lock_path, timeout=600) |
| 285 | + _lock.acquire() |
| 286 | + |
| 287 | + # Check cache again — another rank may have built while we waited |
| 288 | + cached = pull_cached_engine( |
| 289 | + hash_val, |
| 290 | + module, |
| 291 | + engine_cache, |
| 292 | + settings, |
| 293 | + inputs, |
| 294 | + symbolic_shape_expressions, |
| 295 | + ) |
| 296 | + if cached is not None: |
| 297 | + _lock.release() |
| 298 | + return cached |
305 | 299 |
|
306 | 300 | output_dtypes = infer_module_output_dtypes( |
307 | 301 | module, truncate_double=settings.truncate_double |
@@ -348,9 +342,9 @@ def interpret_module_to_result( |
348 | 342 | hash_val, interpreter_result, engine_cache, settings, inputs |
349 | 343 | ) |
350 | 344 |
|
351 | | - # Signal other ranks that the engine is cached and ready |
352 | | - if _build_lock is not None and _build_lock.acquired: |
353 | | - signal_distributed_engine_build_complete(_build_lock) |
| 345 | + # Release the filelock so other ranks can proceed |
| 346 | + if _distributed_caching and _lock is not None: |
| 347 | + _lock.release() |
354 | 348 |
|
355 | 349 | serialized_engine = interpreter_result.engine.serialize() |
356 | 350 | with io.BytesIO() as engine_bytes: |
|
0 commit comments