1919import os
2020import random
2121import time
22+ import warnings
2223from contextlib import ContextDecorator
2324from typing import Callable
2425from typing import Dict
3233import paddle
3334from matplotlib import pyplot as plt
3435from paddle import distributed as dist
36+ from paddle .incubate .distributed .models .moe .moe_layer import AllGather
3537
3638from ppsci .utils import logger
3739
@@ -326,14 +328,18 @@ def convert_to_dict(array: np.ndarray, keys: Tuple[str, ...]) -> Dict[str, np.nd
326328
327329
328330def all_gather (
329- tensor : paddle .Tensor , concat : bool = True , axis : int = 0
331+ tensor : paddle .Tensor ,
332+ concat : bool = True ,
333+ axis : int = 0 ,
334+ requires_grad : bool = False ,
330335) -> Union [paddle .Tensor , List [paddle .Tensor ]]:
331336 """Gather tensor from all devices, concatenate them along given axis if specified.
332337
333338 Args:
334339 tensor (paddle.Tensor): Tensor to be gathered from all GPUs.
335340 concat (bool, optional): Whether to concatenate gathered Tensors. Defaults to True.
336341 axis (int, optional): Axis which concatenated along. Defaults to 0.
342+ requires_grad (bool, optional): Whether to require gradient. Defaults to False.
337343
338344 Returns:
339345 Union[paddle.Tensor, List[paddle.Tensor]]: Gathered Tensors.
@@ -354,7 +360,7 @@ def all_gather(
354360 [ 7 8 9]
355361 [10 11 12]]
356362 """
357- result : List [paddle .Tensor ] = []
363+ result : Union [ paddle . Tensor , List [paddle .Tensor ] ] = []
358364
359365 # NOTE: Put tensor to CUDAPlace from CUDAPinnedPlace to use communication.
360366 if tensor .place .is_cuda_pinned_place ():
@@ -363,10 +369,27 @@ def all_gather(
363369 # TODO(HydrogenSulfate): As non-contiguous(strided) tensor is not supported in
364370 # dist.all_gather, manually convert given Tensor to contiguous below. Strided tensor
365371 # will be supported in future.
366- dist .all_gather (result , tensor .contiguous ())
372+ if not requires_grad :
373+ dist .all_gather (result , tensor .contiguous ())
374+ if concat :
375+ if tensor .ndim == 0 :
376+ warnings .warn (
377+ "given tensor is a 0-dim tensor, so we use `paddle.stack` to replace `paddle.concat`" ,
378+ category = UserWarning ,
379+ stacklevel = 2 ,
380+ )
381+ result = paddle .stack (result , axis )
382+ else :
383+ result = paddle .concat (result , axis )
384+ else :
385+ assert (
386+ tensor .ndim > 0
387+ ), "`all_gather` is not supported for 0-dim tensor when requires_grad=True"
388+ assert concat is True , "`requires_grad=True` only support `concat=True`"
389+ result = AllGather .apply (
390+ tensor .contiguous (), dist .get_rank (), dist .get_world_size (), None
391+ )
367392
368- if concat :
369- return paddle .concat (result , axis )
370393 return result
371394
372395
0 commit comments