Skip to content

Commit 973e334

Browse files
authored
feat: support Ulysses Anything Attention (#12996)
* feat: support Ulysses Anything Attention * feat: support Ulysses Anything Attention * feat: support Ulysses Anything Attention * feat: support Ulysses Anything Attention * fix UAA broken while using joint attn * update * post check * add docs * add docs * remove lru cache * move codes * update
1 parent 769a1f3 commit 973e334

File tree

4 files changed

+417
-18
lines changed

4 files changed

+417
-18
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,34 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](
343343

344344
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
345345

346+
347+
### Ulysses Anything Attention
348+
349+
The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.
350+
351+
[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].
352+
353+
```py
354+
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
355+
```
356+
357+
> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.
358+
359+
We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:
360+
361+
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
362+
|--------------------|------------------|-------------|------------------|------------|
363+
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
364+
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
365+
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
366+
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
367+
| ulysses | failed | failed | failed | 1008x1008 |
368+
| ring | failed | failed | failed | 1008x1008 |
369+
| unified_balanced | failed | failed | failed | 1008x1008 |
370+
| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |
371+
372+
From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.
373+
346374
### parallel_config
347375

348376
Pass `parallel_config` during model initialization to enable context parallelism.

src/diffusers/hooks/context_parallel.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import copy
15+
import functools
1516
import inspect
1617
from dataclasses import dataclass
17-
from typing import Dict, List, Type, Union
18+
from typing import Dict, List, Tuple, Type, Union
1819

1920
import torch
21+
import torch.distributed as dist
2022

2123

2224
if torch.distributed.is_available():
@@ -27,9 +29,10 @@
2729
ContextParallelInput,
2830
ContextParallelModelPlan,
2931
ContextParallelOutput,
32+
gather_size_by_comm,
3033
)
3134
from ..utils import get_logger
32-
from ..utils.torch_utils import unwrap_module
35+
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
3336
from .hooks import HookRegistry, ModelHook
3437

3538

@@ -208,6 +211,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
208211
)
209212
return x
210213
else:
214+
if self.parallel_config.ulysses_anything:
215+
return PartitionAnythingSharder.shard_anything(
216+
x, cp_input.split_dim, self.parallel_config._flattened_mesh
217+
)
211218
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
212219

213220

@@ -233,7 +240,14 @@ def post_forward(self, module, output):
233240
for i, cpm in enumerate(self.metadata):
234241
if cpm is None:
235242
continue
236-
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
243+
if self.parallel_config.ulysses_anything:
244+
output[i] = PartitionAnythingSharder.unshard_anything(
245+
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
246+
)
247+
else:
248+
output[i] = EquipartitionSharder.unshard(
249+
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
250+
)
237251

238252
return output[0] if is_tensor else tuple(output)
239253

@@ -274,6 +288,73 @@ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_
274288
return tensor
275289

276290

291+
class AllGatherAnythingFunction(torch.autograd.Function):
292+
@staticmethod
293+
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
294+
ctx.dim = dim
295+
ctx.group = group
296+
ctx.world_size = dist.get_world_size(group)
297+
ctx.rank = dist.get_rank(group)
298+
gathered_tensor = _all_gather_anything(tensor, dim, group)
299+
return gathered_tensor
300+
301+
@staticmethod
302+
def backward(ctx, grad_output):
303+
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
304+
# function may return fewer than the specified number of chunks!
305+
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
306+
return grad_splits[ctx.rank], None, None
307+
308+
309+
class PartitionAnythingSharder:
310+
@classmethod
311+
def shard_anything(
312+
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
313+
) -> torch.Tensor:
314+
assert tensor.size()[dim] >= mesh.size(), (
315+
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
316+
)
317+
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
318+
# function may return fewer than the specified number of chunks!
319+
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
320+
321+
@classmethod
322+
def unshard_anything(
323+
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
324+
) -> torch.Tensor:
325+
tensor = tensor.contiguous()
326+
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
327+
return tensor
328+
329+
330+
@functools.lru_cache(maxsize=64)
331+
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
332+
gather_shapes = []
333+
for i in range(world_size):
334+
rank_shape = list(copy.deepcopy(shape))
335+
rank_shape[dim] = gather_dims[i]
336+
gather_shapes.append(rank_shape)
337+
return gather_shapes
338+
339+
340+
@maybe_allow_in_graph
341+
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
342+
world_size = dist.get_world_size(group=group)
343+
344+
tensor = tensor.contiguous()
345+
shape = tensor.shape
346+
rank_dim = shape[dim]
347+
gather_dims = gather_size_by_comm(rank_dim, group)
348+
349+
gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)
350+
351+
gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]
352+
353+
dist.all_gather(gathered_tensors, tensor, group=group)
354+
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
355+
return gathered_tensor
356+
357+
277358
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
278359
if name.count("*") > 1:
279360
raise ValueError("Wildcard '*' can only be used once in the name")

src/diffusers/models/_modeling_parallel.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
2020

2121
import torch
22+
import torch.distributed as dist
2223

2324
from ..utils import get_logger
2425

@@ -67,6 +68,9 @@ class ContextParallelConfig:
6768
convert_to_fp32: bool = True
6869
# TODO: support alltoall
6970
rotate_method: Literal["allgather", "alltoall"] = "allgather"
71+
# Whether to enable ulysses anything attention to support
72+
# any sequence lengths and any head numbers.
73+
ulysses_anything: bool = False
7074

7175
_rank: int = None
7276
_world_size: int = None
@@ -94,6 +98,11 @@ def __post_init__(self):
9498
raise NotImplementedError(
9599
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
96100
)
101+
if self.ulysses_anything:
102+
if self.ulysses_degree == 1:
103+
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
104+
if self.ring_degree > 1:
105+
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
97106

98107
@property
99108
def mesh_shape(self) -> Tuple[int, int]:
@@ -257,3 +266,39 @@ def __repr__(self):
257266
#
258267
# ContextParallelOutput:
259268
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
269+
270+
271+
# Below are utility functions for distributed communication in context parallelism.
272+
def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
273+
r"""Gather the local size from all ranks.
274+
size: int, local size return: List[int], list of size from all ranks
275+
"""
276+
# NOTE(Serving/CP Safety):
277+
# Do NOT cache this collective result.
278+
#
279+
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
280+
# may legitimately differ across ranks. If we cache based on the *local* `size`,
281+
# different ranks can have different cache hit/miss patterns across time.
282+
#
283+
# That can lead to a catastrophic distributed hang:
284+
# - some ranks hit cache and *skip* dist.all_gather()
285+
# - other ranks miss cache and *enter* dist.all_gather()
286+
# This mismatched collective participation will stall the process group and
287+
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
288+
# timeouts in Ulysses attention).
289+
world_size = dist.get_world_size(group=group)
290+
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
291+
comm_backends = str(dist.get_backend(group=group))
292+
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
293+
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
294+
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
295+
dist.all_gather(
296+
gathered_sizes,
297+
torch.tensor([size], device=gather_device, dtype=torch.int64),
298+
group=group,
299+
)
300+
301+
gathered_sizes = [s[0].item() for s in gathered_sizes]
302+
# NOTE: DON'T use tolist here due to graph break - Explanation:
303+
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
304+
return gathered_sizes

0 commit comments

Comments
 (0)