|
| 1 | +import contextlib |
1 | 2 | import json
|
2 | 3 | from dataclasses import dataclass, field
|
3 | 4 | from pathlib import Path
|
4 | 5 | from typing import Dict, Generic, List, Optional, TypeVar
|
5 | 6 |
|
| 7 | +import filelock |
6 | 8 | import torch
|
7 | 9 | import transformers
|
| 10 | +from transformers.utils import HF_MODULES_CACHE |
8 | 11 |
|
9 | 12 | from tensorrt_llm import logger
|
10 | 13 | from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
|
@@ -57,6 +60,35 @@ def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]:
|
57 | 60 | return None
|
58 | 61 |
|
59 | 62 |
|
| 63 | +@contextlib.contextmanager |
| 64 | +def config_file_lock(timeout: int = 10): |
| 65 | + """ |
| 66 | + Context manager for file locking when loading pretrained configs. |
| 67 | +
|
| 68 | + This prevents race conditions when multiple processes try to download/load |
| 69 | + the same model configuration simultaneously. |
| 70 | +
|
| 71 | + Args: |
| 72 | + timeout: Maximum time to wait for lock acquisition in seconds |
| 73 | + """ |
| 74 | + # Use a single global lock file in HF cache directory |
| 75 | + # This serializes all model loading operations to prevent race conditions |
| 76 | + lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock" |
| 77 | + |
| 78 | + # Create and acquire the lock |
| 79 | + lock = filelock.FileLock(str(lock_path), timeout=timeout) |
| 80 | + |
| 81 | + try: |
| 82 | + with lock: |
| 83 | + yield |
| 84 | + except filelock.Timeout: |
| 85 | + logger.warning( |
| 86 | + f"Failed to acquire config lock within {timeout} seconds, proceeding without lock" |
| 87 | + ) |
| 88 | + # Fallback: proceed without locking to avoid blocking indefinitely |
| 89 | + yield |
| 90 | + |
| 91 | + |
60 | 92 | @dataclass(kw_only=True)
|
61 | 93 | class ModelConfig(Generic[TConfig]):
|
62 | 94 | pretrained_config: Optional[TConfig] = None
|
@@ -182,16 +214,20 @@ def from_pretrained(cls,
|
182 | 214 | checkpoint_dir: str,
|
183 | 215 | trust_remote_code=False,
|
184 | 216 | **kwargs):
|
185 |
| - pretrained_config = transformers.AutoConfig.from_pretrained( |
186 |
| - checkpoint_dir, |
187 |
| - trust_remote_code=trust_remote_code, |
188 |
| - ) |
| 217 | + # Use file lock to prevent race conditions when multiple processes |
| 218 | + # try to import/cache the same remote model config file |
| 219 | + with config_file_lock(): |
| 220 | + pretrained_config = transformers.AutoConfig.from_pretrained( |
| 221 | + checkpoint_dir, |
| 222 | + trust_remote_code=trust_remote_code, |
| 223 | + ) |
| 224 | + |
| 225 | + # Find the cache path by looking for the config.json file which should be in all |
| 226 | + # huggingface models |
| 227 | + model_dir = Path( |
| 228 | + transformers.utils.hub.cached_file(checkpoint_dir, |
| 229 | + 'config.json')).parent |
189 | 230 |
|
190 |
| - # Find the cache path by looking for the config.json file which should be in all |
191 |
| - # huggingface models |
192 |
| - model_dir = Path( |
193 |
| - transformers.utils.hub.cached_file(checkpoint_dir, |
194 |
| - 'config.json')).parent |
195 | 231 | quant_config = QuantConfig()
|
196 | 232 | layer_quant_config = None
|
197 | 233 | # quantized ckpt in modelopt format
|
|
0 commit comments