Skip to content

Commit abbed70

Browse files
Fix dotted model names (#40745)
* Fix module loading for models with dots in names * quality check * added test * wrong import * Trigger CI rerun after making test model public * Update src/transformers/dynamic_module_utils.py * Update tests/utils/test_dynamic_module_utils.py * Update tests/utils/test_dynamic_module_utils.py * Move test * make fixup --------- Co-authored-by: Matt <[email protected]> Co-authored-by: Matt <[email protected]>
1 parent 75202b0 commit abbed70

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/transformers/dynamic_module_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@
4545

4646

4747
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48+
49+
50+
def _sanitize_module_name(name: str) -> str:
51+
"""
52+
Replace `.` in module names with `_dot_` so that it doesn't
53+
look like an import path separator.
54+
"""
55+
return name.replace(".", "_dot_")
56+
57+
4858
_HF_REMOTE_CODE_LOCK = threading.Lock()
4959

5060

@@ -358,9 +368,9 @@ def get_cached_module_file(
358368
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
359369
is_local = os.path.isdir(pretrained_model_name_or_path)
360370
if is_local:
361-
submodule = os.path.basename(pretrained_model_name_or_path)
371+
submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path))
362372
else:
363-
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
373+
submodule = _sanitize_module_name(pretrained_model_name_or_path.replace("/", os.path.sep))
364374
cached_module = try_to_load_from_cache(
365375
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
366376
)
@@ -395,7 +405,7 @@ def get_cached_module_file(
395405
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
396406
create_dynamic_module(full_submodule)
397407
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
398-
if submodule == os.path.basename(pretrained_model_name_or_path):
408+
if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)):
399409
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
400410
# has changed since last copy.
401411
if not (submodule_path / module_file).exists() or not filecmp.cmp(

tests/models/auto/test_modeling_auto.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,15 @@ def test_custom_model_patched_generation_inheritance(self):
594594
# More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance
595595
# patching was added in v4.45)
596596
self.assertTrue("GenerationMixin" in str(model.__class__.__bases__))
597+
598+
def test_model_with_dotted_name_and_relative_imports(self):
599+
"""
600+
Test for issue #40496: AutoModel.from_pretrained() doesn't work for models with '.' in their name
601+
when there's a relative import.
602+
603+
Without the fix, this raises: ModuleNotFoundError: No module named 'transformers_modules.test-model_v1'
604+
"""
605+
model_id = "hf-internal-testing/remote_code_model_with_dots"
606+
607+
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
608+
self.assertIsNotNone(model)

0 commit comments

Comments
 (0)