Skip to content

Commit 300d42a

Browse files
ArthurZuckerNouamaneTazidrbh
authored
Add ep (#39501)
* EP + updates Co-authored-by: Nouamane Tazi <[email protected]> Co-authored-by: drbh <[email protected]> * remove unrelated change * not working yet but let's see where it goes! * update the api a bit * udpate * where I am at for now * fix ep * refactor the API * yups * fix * fixup * clean modeling * just support llama4 for now! * properly avoid * fix * nits * Update src/transformers/models/llama4/modeling_llama4.py * Update src/transformers/integrations/tensor_parallel.py * style * ,,,, * update --------- Co-authored-by: Nouamane Tazi <[email protected]> Co-authored-by: drbh <[email protected]>
1 parent abaa043 commit 300d42a

File tree

9 files changed

+436
-186
lines changed

9 files changed

+436
-186
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
from ..utils import _LazyModule
18+
19+
20+
_import_structure = {
21+
"configuration_utils": ["DistributedConfig"],
22+
}
23+
24+
25+
if TYPE_CHECKING:
26+
from .configuration_utils import (
27+
DistributedConfig,
28+
)
29+
30+
else:
31+
import sys
32+
33+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import json
17+
import os
18+
from dataclasses import dataclass
19+
from typing import Any, Union
20+
21+
22+
@dataclass
23+
class DistributedConfig:
24+
"""
25+
Base class for distributed configs
26+
"""
27+
28+
enable_expert_parallel: bool = False
29+
# TODO: add tp_plan, pp_plan, device_mesh etc..
30+
31+
@classmethod
32+
def from_dict(cls, config_dict, **kwargs):
33+
"""
34+
Constructs a DistributedConfig instance from a dictionary of parameters.
35+
Args:
36+
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
37+
**kwargs: Additional keyword arguments to override dictionary values.
38+
Returns:
39+
DistributedConfig: Instance of DistributedConfig constructed from the dictionary.
40+
"""
41+
config = cls(**config_dict)
42+
to_remove = []
43+
for key, value in kwargs.items():
44+
if hasattr(config, key):
45+
setattr(config, key, value)
46+
to_remove.append(key)
47+
for key in to_remove:
48+
kwargs.pop(key, None)
49+
return config
50+
51+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
52+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
53+
"""
54+
Save this instance to a JSON file.
55+
Args:
56+
json_file_path (`str` or `os.PathLike`):
57+
Path to the JSON file in which this configuration instance's parameters will be saved.
58+
use_diff (`bool`, *optional*, defaults to `True`):
59+
If set to `True`, only the difference between the config instance and the default
60+
`QuantizationConfig()` is serialized to JSON file.
61+
"""
62+
with open(json_file_path, "w", encoding="utf-8") as writer:
63+
config_dict = self.to_dict()
64+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
65+
66+
writer.write(json_string)
67+
68+
def to_dict(self) -> dict[str, Any]:
69+
"""
70+
Serializes this instance to a Python dictionary. Returns:
71+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
72+
"""
73+
return copy.deepcopy(self.__dict__)
74+
75+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
76+
def __iter__(self):
77+
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
78+
for attr, value in copy.deepcopy(self.__dict__).items():
79+
yield attr, value
80+
81+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
82+
def __repr__(self):
83+
return f"{self.__class__.__name__} {self.to_json_string()}"
84+
85+
def to_json_string(self):
86+
"""
87+
Serializes this instance to a JSON formatted string.
88+
Returns:
89+
str: JSON formatted string representing the configuration instance.
90+
"""
91+
return json.dumps(self.__dict__, indent=2) + "\n"
92+
93+
def update(self, **kwargs):
94+
"""
95+
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
96+
returning all the unused kwargs.
97+
Args:
98+
kwargs (`Dict[str, Any]`):
99+
Dictionary of attributes to tentatively update this class.
100+
Returns:
101+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
102+
"""
103+
to_remove = []
104+
for key, value in kwargs.items():
105+
if hasattr(self, key):
106+
setattr(self, key, value)
107+
to_remove.append(key)
108+
109+
# Remove all the attributes that were updated, without modifying the input dict
110+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
111+
return unused_kwargs

src/transformers/integrations/hub_kernels.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@
5252
layer_name="TritonLlamaMLP",
5353
)
5454
},
55+
"MegaBlocksMoeMLP": {
56+
"cuda": LayerRepository(
57+
repo_id="kernels-community/megablocks",
58+
layer_name="MegaBlocksMoeMLP",
59+
)
60+
},
5561
}
5662

5763
register_kernel_mapping(_KERNEL_MAPPING)

0 commit comments

Comments
 (0)