Skip to content

Commit 2ea1d3c

Browse files
committed
using filelock
1 parent 0af1296 commit 2ea1d3c

6 files changed

Lines changed: 44 additions & 474 deletions

File tree

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,11 @@ jobs:
611611
python -m pytest -ra -v --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \
612612
distributed/test_nccl_ops.py \
613613
distributed/test_native_nccl.py \
614-
distributed/test_export_save_load.py
614+
distributed/test_export_save_load.py \
615+
distributed/test_distributed_engine_cache.py
615616
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_native_nccl.py --multirank
616617
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_export_save_load.py --multirank
618+
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_distributed_engine_cache.py --multirank
617619
popd
618620
619621
concurrency:

py/torch_tensorrt/distributed/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
distributed_context,
33
is_distributed_caching_enabled,
44
set_distributed_mode,
5-
signal_distributed_engine_build_complete,
6-
wait_for_distributed_engine_build,
75
)
8-
from torch_tensorrt.distributed._lock import DistributedFileLock # noqa: F401
96
from torch_tensorrt.distributed._nccl_utils import ( # noqa: F401
107
setup_nccl_for_torch_tensorrt,
118
)

py/torch_tensorrt/distributed/_distributed.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -230,71 +230,3 @@ def is_distributed_caching_enabled(
230230
and dist.is_initialized()
231231
and dist.get_world_size() > 1
232232
)
233-
234-
235-
def wait_for_distributed_engine_build(
236-
pull_fn: Any,
237-
cache_dir: str,
238-
hash_val: str,
239-
poll_interval: float = 0.5,
240-
timeout: float = 600.0,
241-
) -> Any:
242-
"""Non-building rank: poll for cached engine file, then load from cache.
243-
244-
Called when this rank failed to acquire the build lock, meaning another
245-
rank is building the engine. Polls the filesystem for the cached engine
246-
file instead of using NCCL collectives (which are unreliable inside
247-
the TRT compilation path due to aot_autograd/CUDA stream conflicts).
248-
249-
Args:
250-
pull_fn: Zero-arg callable (e.g. functools.partial) that loads the
251-
engine from cache. Returns SerializedInterpreterResult on
252-
hit, None on miss.
253-
cache_dir: Shared engine cache directory path.
254-
hash_val: Engine hash for this compilation.
255-
poll_interval: Seconds between filesystem checks (default 0.5s).
256-
timeout: Maximum seconds to wait before giving up (default 600s).
257-
258-
Returns:
259-
SerializedInterpreterResult on cache hit, None on timeout.
260-
"""
261-
import logging
262-
import os
263-
import time
264-
265-
logger = logging.getLogger(__name__)
266-
267-
blob_path = os.path.join(cache_dir, hash_val, "blob.bin")
268-
logger.info(f"Polling for cached engine: {blob_path}")
269-
270-
elapsed = 0.0
271-
while not os.path.exists(blob_path):
272-
time.sleep(poll_interval)
273-
elapsed += poll_interval
274-
if elapsed >= timeout:
275-
logger.warning(
276-
f"Polling timed out after {timeout:.0f}s — building engine locally"
277-
)
278-
return None
279-
280-
logger.info(f"Cached engine found after {elapsed:.1f}s — loading from cache")
281-
cached = pull_fn()
282-
if cached is not None:
283-
return cached
284-
285-
logger.warning("Cache file exists but pull_cached_engine failed — building locally")
286-
return None
287-
288-
289-
def signal_distributed_engine_build_complete(lock: Any) -> None:
290-
"""Building rank: release the file lock after caching the engine.
291-
292-
Called after the building rank has inserted the engine into the shared
293-
cache. Releases the file lock so other ranks' stale lock detection
294-
works correctly. No NCCL collective needed — waiter ranks poll the
295-
filesystem directly.
296-
297-
Args:
298-
lock: DistributedFileLock instance that was acquired by this rank.
299-
"""
300-
lock.release()

py/torch_tensorrt/distributed/_lock.py

Lines changed: 0 additions & 189 deletions
This file was deleted.

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,14 @@
22

33
import io
44
import logging
5-
from functools import partial
65
from typing import Any, Dict, List, NamedTuple, Optional, Sequence
76

87
import tensorrt as trt
98
import torch
109
from torch_tensorrt._enums import dtype
1110
from torch_tensorrt._features import ENABLED_FEATURES
1211
from torch_tensorrt._Input import Input
13-
from torch_tensorrt.distributed._distributed import (
14-
is_distributed_caching_enabled,
15-
signal_distributed_engine_build_complete,
16-
wait_for_distributed_engine_build,
17-
)
18-
from torch_tensorrt.distributed._lock import DistributedFileLock
12+
from torch_tensorrt.distributed._distributed import is_distributed_caching_enabled
1913
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
2014
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
2115
from torch_tensorrt.dynamo.conversion._symbolic_shape_capture import (
@@ -275,33 +269,33 @@ def interpret_module_to_result(
275269
settings.cache_built_engines,
276270
settings.reuse_cached_engines,
277271
)
278-
_build_lock = None
272+
_lock: Optional[Any] = None
279273

280274
if _distributed_caching:
275+
import os as _os
276+
277+
from filelock import FileLock
278+
281279
# is_distributed_caching_enabled guarantees engine_cache and hash_val are set.
282280
assert engine_cache is not None
283281
assert hash_val is not None
284-
_build_lock = DistributedFileLock(engine_cache.engine_cache_dir, hash_val)
285-
if _build_lock.acquire():
286-
logger.info("Acquired engine build lock — this rank builds")
287-
else:
288-
logger.info("Lock held by another rank — polling for cached engine")
289-
_pull_fn = partial(
290-
pull_cached_engine,
291-
hash_val,
292-
module,
293-
engine_cache,
294-
settings,
295-
inputs,
296-
symbolic_shape_expressions,
297-
)
298-
cached: Optional[SerializedInterpreterResult] = (
299-
wait_for_distributed_engine_build(
300-
_pull_fn, engine_cache.engine_cache_dir, hash_val
301-
)
302-
)
303-
if cached is not None:
304-
return cached
282+
283+
_lock_path = _os.path.join(engine_cache.engine_cache_dir, f".{hash_val}.lock")
284+
_lock = FileLock(_lock_path, timeout=600)
285+
_lock.acquire()
286+
287+
# Check cache again — another rank may have built while we waited
288+
cached = pull_cached_engine(
289+
hash_val,
290+
module,
291+
engine_cache,
292+
settings,
293+
inputs,
294+
symbolic_shape_expressions,
295+
)
296+
if cached is not None:
297+
_lock.release()
298+
return cached
305299

306300
output_dtypes = infer_module_output_dtypes(
307301
module, truncate_double=settings.truncate_double
@@ -348,9 +342,9 @@ def interpret_module_to_result(
348342
hash_val, interpreter_result, engine_cache, settings, inputs
349343
)
350344

351-
# Signal other ranks that the engine is cached and ready
352-
if _build_lock is not None and _build_lock.acquired:
353-
signal_distributed_engine_build_complete(_build_lock)
345+
# Release the filelock so other ranks can proceed
346+
if _distributed_caching and _lock is not None:
347+
_lock.release()
354348

355349
serialized_engine = interpreter_result.engine.serialize()
356350
with io.BytesIO() as engine_bytes:

0 commit comments

Comments
 (0)