diff --git a/src/transformers/distributed/__init__.py b/src/transformers/distributed/__init__.py new file mode 100644 index 000000000000..ba6db8358d2b --- /dev/null +++ b/src/transformers/distributed/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..utils import _LazyModule + + +_import_structure = { + "configuration_utils": ["DistributedConfig"], +} + + +if TYPE_CHECKING: + from .configuration_utils import ( + DistributedConfig, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/distributed/configuration_utils.py b/src/transformers/distributed/configuration_utils.py new file mode 100644 index 000000000000..4b98c175e1b1 --- /dev/null +++ b/src/transformers/distributed/configuration_utils.py @@ -0,0 +1,111 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os +from dataclasses import dataclass +from typing import Any, Union + + +@dataclass +class DistributedConfig: + """ + Base class for distributed configs + """ + + enable_expert_parallel: bool = False + # TODO: add tp_plan, pp_plan, device_mesh etc.. + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a DistributedConfig instance from a dictionary of parameters. + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + Returns: + DistributedConfig: Instance of DistributedConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 7aa6c48f4c50..e824a5ab1f0e 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -52,6 +52,12 @@ layer_name="TritonLlamaMLP", ) }, + "MegaBlocksMoeMLP": { + "cuda": LayerRepository( + repo_id="kernels-community/megablocks", + layer_name="MegaBlocksMoeMLP", + ) + }, } register_kernel_mapping(_KERNEL_MAPPING) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index d765e0b68494..33fe4bbf6841 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -23,6 +23,7 @@ import torch.distributed as dist from torch import nn +from ..distributed import DistributedConfig from ..utils import is_torch_greater_or_equal, logging from ..utils.generic import GeneralInterface @@ -90,7 +91,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None): device_map = tp_device tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) - return tp_device, device_map, device_mesh + return tp_device, device_map, device_mesh, tp_size def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int]: @@ -119,20 +120,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int return [single_size] * blocks -def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> str | None: +def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None: """ Get the TP style for a parameter from the TP plan. The TP plan is a dictionary that maps parameter names to TP styles. The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight"). + + The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but + not parrent classes for `post_init` calls """ generic_param_name = re.sub(r"\d+", "*", parameter_name) if generic_param_name in tp_plan: return tp_plan[generic_param_name] - elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan: + elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight: return tp_plan[generic_param_name.rsplit(".", 1)[0]] - else: - return None + return None str_to_torch_dtype = { @@ -198,8 +201,10 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): slice_dtype = slice_.get_dtype() # Handle F8_E4M3 dtype by converting to float16 before slicing # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn' - if slice_dtype == "F8_E4M3": + casted = False + if slice_dtype == "F8_E4M3" or slice_dtype == "F8_E5M2": slice_ = slice_[...].to(torch.float16) + casted = True if dim == 0: tensor = slice_[tensors_slices, ...] @@ -209,7 +214,11 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): tensor = slice_[..., tensors_slices] else: raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") - return tensor.to(str_to_torch_dtype[slice_dtype]) + + if casted: + return tensor + else: + return tensor.to(str_to_torch_dtype[slice_dtype]) def repack_weights( @@ -423,16 +432,27 @@ def __init__( @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + mod.expert_parallel_group = device_mesh.get_group() if inputs and isinstance(inputs[0], DTensor): inputs = inputs[0].to_local() return inputs @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): - # this op cannot be async, otherwise it completely breaks the outputs of models - torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) + if isinstance(outputs, torch.Tensor): + dist.all_reduce(outputs, op=dist.ReduceOp.SUM, async_op=False) + else: + dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) return outputs + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn, None, None), + partial(self._prepare_output_fn, None, None), + ) + class IsolatedParallel(TensorParallelLayer): """ @@ -453,6 +473,14 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # TODO: figure out dynamo support for instance method and switch this to instance method return outputs + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + param = param[...].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + param = param / device_mesh.size() # TODO should be optionable + # TODO: assumes parent module will allreduce the output afterwards (e.g rowlinear bias is IsolatedParallel and parent module is GatherParallel) + return param + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: distribute_module( module, @@ -773,6 +801,108 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) +class GroupedGemmParallel(TensorParallelLayer): + """ + Applies Expert Parallelism to MoE experts by loading the correct experts on each device. + """ + + def __init__(self): + super().__init__() + self.use_dtensor = False + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + ep_rank = rank + global_num_experts = empty_param.shape[0] + if global_num_experts % device_mesh.size() != 0: + raise ValueError( + f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0" + ) + local_num_experts = global_num_experts // device_mesh.size() + param = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + if "gate_up" in param_type and False: + param = torch.cat([param[..., ::2], param[..., 1::2]], dim=-1) + return param + + +class RouterParallel(TensorParallelLayer): + """ + Allows to reshape the router scores to support running expert parallel. + """ + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.use_dtensor = False + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + input_tensor = inputs[0] + if isinstance(input_tensor, DTensor): + raise NotImplementedError("RouterParallel does not support DTensor input for now") + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + """ + Imagine if you had 4 tokens, top_k = 4, and 128experts. + With EP = 8. + Imagine router_indices being: + [ 52, 42, 119, 67], + [102, 89, 61, 40], + [ 82, 103, 4, 34], + [ 93, 23, 109, 11], + + then you can map which rank should be getting which values + + [3, 2, 7, 4], + [6, 5, 3, 2], + [5, 6, 0, 2], + [5, 1, 6, 0], + + Thus for say rank 0, you fill with 0 the index tensor + + [ 0, 0, 0, 0], + [ 0, 0, 0, 0], + [ 0, 0, 4, 0], + [ 0, 0, 0, 11], + + This works well. For another rank you need to make sure you round to num_local_expert + because the next operation will one hot encode the router index vector. + + This allows us to know directly which local expert is hit. + Similarly the scores are indexed with something created form + router_indices. + + The kinda naive training loop that we use for device_map "auto" uses a similar logic. + Here we are just making each rank believe that he is alone, and he computes his part of the hiddenstates. + """ + ep_rank, ep_size = device_mesh.get_local_rank(), device_mesh.size() + num_local_experts = mod.num_experts // ep_size + router_scores, router_indices = outputs + router_scores = router_scores[:, ep_rank * num_local_experts : (ep_rank + 1) * num_local_experts] + router_indices = router_indices.masked_fill((router_indices // num_local_experts) != ep_rank, 0) + router_indices = router_indices % num_local_experts + return router_scores, router_indices + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # TODO: i'd like for this to be the default + param = param[...].to(param_casting_dtype) + if to_contiguous: + param = param.contiguous() + return param + + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + # TODO: need an abstract Parallel class that is different from TensorParallelLayer + distribute_module( + module, + device_mesh, + partial(self._prepare_input_fn, None, None), + partial(self._prepare_output_fn, None, None), + ) + + class ParallelInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given entry) @@ -789,6 +919,8 @@ class ParallelInterface(GeneralInterface): "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False), "sequence_parallel": SequenceParallel(), "replicate": ReplicateParallel(), + "grouped_gemm": GroupedGemmParallel(), + "ep_router": RouterParallel(), } if is_torch_greater_or_equal("2.5") and _torch_distributed_available else {} @@ -841,25 +973,17 @@ def replace_state_dict_local_with_dtensor( return state_dict -def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh): - """ - Add hooks to the module holding the layer. Meaning: - ``` - class MyModel(nn.Module): - def __init__(self): - self.layer = nn.Linear(10, 10) - ``` - has state_dict like: - ``` - { - "layer.weight": torch.Tensor, - "layer.bias": torch.Tensor - } - ``` - we add hooks to `MyModel` as well as `layer` to make sure that the tensors are correctly sharded and gathered. - """ +def add_tensor_parallel_hooks_to_module( + model, module, tp_plan, layer_name, current_module_plan, device_mesh, parameter_name=None +): + r""" + This function is called in `PretrainedModel.post_init()`. It is responsible of adding hooks + to the modules of the `model`, based on the `PretrainedModel._tp_plan`. + + This is the place where we add the `pre_forward` and `post_forwards` hooks. These are defined + for each `TensorParallelLayer` as `_prepare_input_fn` and `_prepare_output_fn`. - # 1. We add hooks to the layer being loaded: + """ if current_module_plan is not None: tp_layer = ALL_PARALLEL_STYLES[current_module_plan] try: @@ -868,26 +992,19 @@ def __init__(self): print( f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}" ) + module._hf_tp_plan = current_module_plan module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" - # 2. We add hooks to the parent module if needed - if "." in layer_name: - parent_layer_name = layer_name.rsplit(".", 1)[0] - generic_name = re.sub(r"\d+", "*", parent_layer_name) - # The module itself needs hooks - if module_plan := tp_plan.get(generic_name, False): - tp_layer = ALL_PARALLEL_STYLES[module_plan] - module_to_tp_ = model.get_submodule(parent_layer_name) - tp_layer.prepare_module_tp(module_to_tp_, device_mesh) - module_to_tp_._hf_tp_plan = current_module_plan - module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}" - def shard_and_distribute_module( model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh -): +): # TODO: rename to shard_and_distribute_param r""" + This function is called in `from_pretrained` when loading a model's checkpoints. + It receives the pointer to the parameter (or the parameter itself) and takes care of "sharding". + All process run this function, so they just load the partition of the tensor that they require. + Main uses cases: - column / rowise parallelism, you just shard all the weights of the layer (weight and bias) - packed layers: you slice the weights, then shard like above @@ -898,39 +1015,33 @@ def shard_and_distribute_module( """ param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name tp_plan = model._tp_plan - module_to_tp = model.get_submodule(param_name) + module_to_tp = model.get_submodule(param_name) # TODO: can i loop over modules? rank = int(rank) + current_shard_plan = _get_parameter_tp_plan(parameter_name, tp_plan) - current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan) + if dist.get_rank() == 0: + if current_shard_plan is None: + logger.info(f"Tensor sharding plan for {param_name} not found, using default 'replicate' plan.") + else: + logger.info(f"Tensor sharding plan for {param_name}: {current_shard_plan}") - if current_module_plan is None: - current_module_plan = "replicate" - if dist.get_rank() == 0: - logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.") + if current_shard_plan is not None: + try: + tp_layer = ALL_PARALLEL_STYLES[current_shard_plan] + param = tp_layer.partition_tensor( + param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh + ) + except NotImplementedError as e: + print( + f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" + ) else: - if dist.get_rank() == 0: - logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}") - - # Add hooks to the module if not done yet - # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) - if not getattr(module_to_tp, "_is_hooked", False): - add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) - module_to_tp._is_hooked = True - - try: - tp_layer = ALL_PARALLEL_STYLES[current_module_plan] - param = tp_layer.partition_tensor( - param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh - ) - except NotImplementedError as e: - print( - f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" - ) + param = param[:].to(param_casting_dtype) # SUPER IMPORTANT we have to use setattr # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): - param = torch.nn.Parameter(param, requires_grad=param.is_floating_point()) + param = torch.nn.Parameter(param, requires_grad=empty_param.is_floating_point()) setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param @@ -965,3 +1076,43 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None): logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}") if len(unsharded_layers) > 0: logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}") + + +def distribute_model(model, distributed_config, device_mesh, tp_size): + _plan = "_tp_plan" + model._tp_plan = getattr(model.config, "base_model_tp_plan").copy() + if distributed_config is not None: + distributed_config = DistributedConfig.from_config(distributed_config) + if distributed_config.enable_expert_parallel: + _plan = "_ep_plan" + model._tp_plan = getattr(model.config, "base_model_ep_plan", model._tp_plan).copy() + + # now fetch my childrens + for name, module in model.named_children(): + if plan := getattr(module, _plan, getattr(module, "tp_plan", None)): + model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) + if hasattr(module, "config"): + plan = getattr(module.config, f"base_model{_plan}", {}) + if plan == {}: + plan = getattr(module.config, "base_model_tp_plan", {}) + model._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) + + if model._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available: + for v in model._tp_plan.values(): + if v not in ALL_PARALLEL_STYLES: + raise ValueError(f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}") + for name, module in model.named_modules(): + if not getattr(module, "_is_hooked", False): + from transformers.integrations.tensor_parallel import add_tensor_parallel_hooks_to_module + + plan = _get_parameter_tp_plan(parameter_name=name, tp_plan=model._tp_plan, is_weight=False) + add_tensor_parallel_hooks_to_module( + model=model, + module=module, + tp_plan=model._tp_plan, + layer_name="", + current_module_plan=plan, + device_mesh=device_mesh, + ) + module._is_hooked = True + return model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5c4226ad2cae..4fb687103657 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -63,8 +63,8 @@ from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_paged import sdpa_attention_paged_forward from .integrations.tensor_parallel import ( - ALL_PARALLEL_STYLES, _get_parameter_tp_plan, + distribute_model, initialize_tensor_parallelism, repack_weights, replace_state_dict_local_with_dtensor, @@ -2220,6 +2220,9 @@ def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's modules properly initialized (such as weight initialization). + + This is also used when the user is running distributed code. We add hooks to the modules here, according to + the model's tp_plan! """ self.init_weights() self._backward_compatibility_gradient_checkpointing() @@ -2252,17 +2255,6 @@ def post_init(self): # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None - self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {} - for name, module in self.named_children(): - if plan := getattr(module, "_tp_plan", None): - self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) - - if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available: - for v in self._tp_plan.values(): - if v not in ALL_PARALLEL_STYLES: - raise ValueError( - f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}" - ) def dequantize(self): """ @@ -4669,6 +4661,7 @@ def from_pretrained( load_in_8bit = kwargs.pop("load_in_8bit", False) load_in_4bit = kwargs.pop("load_in_4bit", False) quantization_config = kwargs.pop("quantization_config", None) + distributed_config = kwargs.pop("distributed_config", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) @@ -4689,6 +4682,9 @@ def from_pretrained( ): key_mapping = cls._checkpoint_conversion_mapping + if distributed_config is not None: + tp_plan = "auto" + # Not used anymore -- remove them from the kwargs _ = kwargs.pop("resume_download", None) _ = kwargs.pop("mirror", None) @@ -4720,16 +4716,12 @@ def from_pretrained( # `device_map` pointing to the correct device if tp_plan is not None: if device_mesh is None: - tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) + tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size) else: - if "tp" not in device_mesh.mesh_dim_names: - raise ValueError( - "When using `tp_plan`, the `device_mesh` must contain a 'tp' dimension. " - "Please provide a valid `device_mesh`." - ) - device_mesh = device_mesh["tp"] - tp_size = device_mesh["tp"].size() - device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}") + # TODO: make device_mesh support multiple dimensions + if device_mesh.ndim > 1: + raise ValueError("device_mesh must be 1 dimensional and will be used for TP") + device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"])) if tp_size is None: tp_size = torch.distributed.get_world_size() @@ -5029,23 +5021,18 @@ def from_pretrained( ) config.name_or_path = pretrained_model_name_or_path - - # Instantiate model. model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called) - config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) + if _torch_distributed_available and device_mesh is not None: + model = distribute_model(model, distributed_config, device_mesh, tp_size) + # Make sure to tie the weights correctly model.tie_weights() - # Last check for tp - if device_mesh is not None and not model.supports_tp_plan: - if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None: - raise NotImplementedError("This model does not have a tensor parallel plan.") - # make sure we use the model's config since the __init__ call might have copied it config = model.config @@ -5126,11 +5113,6 @@ def _assign_original_dtype(module): key_mapping=key_mapping, weights_only=weights_only, ) - - # record tp degree the model sharded to - model._tp_size = tp_size - model._device_mesh = device_mesh - # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py index cff2ecb6ed18..76162ee25964 100644 --- a/src/transformers/models/llama4/configuration_llama4.py +++ b/src/transformers/models/llama4/configuration_llama4.py @@ -265,6 +265,19 @@ class Llama4TextConfig(PretrainedConfig): "layers.*.feed_forward.down_proj": "local_rowwise", "layers.*.feed_forward": "gather", } + base_model_ep_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", # row because not linear + "layers.*.feed_forward.experts.down_proj": "grouped_gemm", # col because not linear + "layers.*.feed_forward.experts": "gather", # all reduce + "layers.*.feed_forward.gate_proj": "local_colwise", + "layers.*.feed_forward.up_proj": "local_colwise", + "layers.*.feed_forward.down_proj": "local_rowwise", + "layers.*.feed_forward.router": "ep_router", + } def __init__( self, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 53d9367b7c18..d04d443ec851 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations.hub_kernels import use_kernel_forward_from_hub +from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_chunked_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer @@ -35,6 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs from .configuration_llama4 import Llama4Config, Llama4TextConfig @@ -65,7 +66,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor """ - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + hidden_states = hidden_states.view(self.gate_up_proj.shape[0], -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) @@ -127,6 +128,20 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class Llama4Router(nn.Linear): + def __init__(self, config): + super().__init__(config.hidden_size, config.num_local_experts, bias=False) + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + def forward(self, hidden_states): + router_logits = super().forward(hidden_states) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + router_scores = torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value) + router_scores = torch.nn.functional.sigmoid(router_scores.float()).to(router_scores.dtype) + return router_scores, router_logits + + @use_kernel_forward_from_hub("Llama4TextMoe") class Llama4TextMoe(nn.Module): def __init__(self, config): @@ -135,28 +150,18 @@ def __init__(self, config): self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts self.experts = Llama4TextExperts(config) - self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + self.router = Llama4Router(config) self.shared_expert = Llama4TextMLP(config) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) - - router_scores = ( - torch.full_like(router_logits, float("-inf")).scatter_(1, router_indices, router_top_value).transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - - routed_in = hidden_states.repeat(self.num_experts, 1) + router_scores, router_logits = self.router(hidden_states) + routed_in = hidden_states.repeat(router_scores.shape[1], 1) routed_in = routed_in * router_scores.reshape(-1, 1) routed_out = self.experts(routed_in) - out = self.shared_expert(hidden_states) - out.add_(routed_out.reshape(self.num_experts, -1, self.hidden_dim).sum(dim=0)) - - return out, router_scores + out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0)) + return out, router_logits class Llama4TextRotaryEmbedding(nn.Module): @@ -383,8 +388,6 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -395,12 +398,11 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - attention_states, self_attn_weights = self.self_attn( + attention_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -409,23 +411,12 @@ def forward( # Fully Connected residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.feed_forward(hidden_states) if self.is_moe_layer: - hidden_states, router_logits = hidden_states - else: - router_logits = None + hidden_states, _ = hidden_states hidden_states = residual + hidden_states.view(residual.shape) - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs + return hidden_states @auto_docstring @@ -472,6 +463,11 @@ class Llama4TextModel(Llama4PreTrainedModel): _no_split_modules = ["Llama4TextDecoderLayer"] base_model_prefix = "model" config: Llama4TextConfig + _can_record_outputs = { + "attentions": Llama4TextAttention, + "hidden_states": Llama4TextDecoderLayer, + "router_logits": Llama4TextMoe, + } def __init__(self, config: Llama4TextConfig): super().__init__(config) @@ -489,7 +485,7 @@ def __init__(self, config: Llama4TextConfig): # Initialize weights and apply final processing self.post_init() - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, @@ -499,28 +495,12 @@ def forward( past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) @@ -558,42 +538,22 @@ def forward( # create position embeddings to be shared across the decoder layers freq_cis = self.rotary_emb(hidden_states, position_ids) - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=freq_cis, **kwargs, ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, ) @@ -630,9 +590,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], @@ -659,13 +616,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -673,9 +623,6 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 380ccbf40655..43cb4b88f2cd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2634,7 +2634,14 @@ def _inner_training_loop( self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - self.optimizer.step() + context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + context = implicit_replication + + with context(): + self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index 980d6fff8d47..1904fc8bd1e7 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -109,7 +109,7 @@ def test_model_forward(self): assert has_dtensor == 1, "TP model must has DTensor" - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=False) prompt = "Can I help" inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)