Skip to content

Commit 3ef1044

Browse files
fix and enhance all_gather (#1228)
1 parent 1f8d6d9 commit 3ef1044

File tree

6 files changed

+60
-13
lines changed

6 files changed

+60
-13
lines changed

ppsci/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ppsci.utils.checker import run_check # isort:skip
3131
from ppsci.utils.checker import run_check_mesh # isort:skip
3232
from ppsci.utils import lambdify # isort:skip
33+
from ppsci.utils import misc # isort:skip
3334

3435

3536
try:
@@ -58,6 +59,7 @@
5859
"run_check",
5960
"run_check_mesh",
6061
"lambdify",
62+
"misc",
6163
]
6264

6365

ppsci/arch/cuboid_transformer_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,9 @@ def compute_cuboid_self_attention_mask(
313313
"""Compute the shift window attention mask
314314
315315
Args:
316-
data_shape (Tuple[int,....]): Should be (T, H, W).
317-
cuboid_size (Tuple[int,....]): Size of the cuboid.
318-
shift_size (Tuple[int,....]): The shift size.
316+
data_shape (Tuple[int, ...]): Should be (T, H, W).
317+
cuboid_size (Tuple[int, ...]): Size of the cuboid.
318+
shift_size (Tuple[int, ...]): The shift size.
319319
strategy (str): The decomposition strategy.
320320
padding_type (str): Type of the padding.
321321
device (str): The device.

ppsci/arch/extformer_moe_cuboid_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,9 @@ def compute_cuboid_self_attention_mask(
337337
"""Compute the shift window attention mask
338338
339339
Args:
340-
data_shape (Tuple[int,....]): Should be (T, H, W).
341-
cuboid_size (Tuple[int,....]): Size of the cuboid.
342-
shift_size (Tuple[int,....]): The shift size.
340+
data_shape (Tuple[int, ...]): Should be (T, H, W).
341+
cuboid_size (Tuple[int, ...]): Size of the cuboid.
342+
shift_size (Tuple[int, ...]): The shift size.
343343
strategy (str): The decomposition strategy.
344344
padding_type (str): Type of the padding.
345345
device (str): The device.

ppsci/metric/func.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,10 @@ def __init__(
6363
self.metric_expr = metric_expr
6464

6565
def forward(self, output_dict, label_dict=None) -> Dict[str, "paddle.Tensor"]:
66-
return self.metric_expr(output_dict, label_dict)
66+
metric: "paddle.Tensor" = self.metric_expr(output_dict, label_dict)
67+
if self.keep_batch:
68+
assert metric.ndim >= 1, (
69+
f"metric.shape should be like [batch_size, ...], but got {metric.shape} when keep_batch is True, "
70+
"please check the return value of your metric_expr function."
71+
)
72+
return metric

ppsci/solver/solver.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,27 @@ def __init__(
280280
*[_v.metric.values() for _v in self.validator.values()]
281281
):
282282
if metric.keep_batch ^ self.compute_metric_by_batch:
283+
"""
284+
Evaluation has two modes:
285+
1. compute_metric_by_batch=True:
286+
- The metric is computed for each batch separately, and the results
287+
are averaged across all batches.
288+
- Suitable for metrics that support additive aggregation (e.g. accuracy).
289+
- Saves memory since batch outputs are not stored.
290+
- In this mode, metric.keep_batch should be True.
291+
292+
2. compute_metric_by_batch=False:
293+
- The outputs and labels of all batches are cached.
294+
- Metric is computed once on the concatenated results at the end.
295+
- Needed for metrics that cannot be computed additively (e.g. L2 relative error).
296+
- In this mode, metric.keep_batch should be False.
297+
"""
283298
raise ValueError(
284299
f"{misc.typename(metric)}.keep_batch should be "
285-
f"{self.compute_metric_by_batch} when compute_metric_by_batch="
300+
f"{self.compute_metric_by_batch} when compute_metric_by_batch is "
286301
f"{self.compute_metric_by_batch}."
287302
)
303+
288304
# check metric name uniqueness over all validators
289305
_count = {}
290306
for _validator in validator.values():

ppsci/utils/misc.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import random
2121
import time
22+
import warnings
2223
from contextlib import ContextDecorator
2324
from typing import Callable
2425
from typing import Dict
@@ -32,6 +33,7 @@
3233
import paddle
3334
from matplotlib import pyplot as plt
3435
from paddle import distributed as dist
36+
from paddle.incubate.distributed.models.moe.moe_layer import AllGather
3537

3638
from ppsci.utils import logger
3739

@@ -326,14 +328,18 @@ def convert_to_dict(array: np.ndarray, keys: Tuple[str, ...]) -> Dict[str, np.nd
326328

327329

328330
def all_gather(
329-
tensor: paddle.Tensor, concat: bool = True, axis: int = 0
331+
tensor: paddle.Tensor,
332+
concat: bool = True,
333+
axis: int = 0,
334+
requires_grad: bool = False,
330335
) -> Union[paddle.Tensor, List[paddle.Tensor]]:
331336
"""Gather tensor from all devices, concatenate them along given axis if specified.
332337
333338
Args:
334339
tensor (paddle.Tensor): Tensor to be gathered from all GPUs.
335340
concat (bool, optional): Whether to concatenate gathered Tensors. Defaults to True.
336341
axis (int, optional): Axis which concatenated along. Defaults to 0.
342+
requires_grad (bool, optional): Whether to require gradient. Defaults to False.
337343
338344
Returns:
339345
Union[paddle.Tensor, List[paddle.Tensor]]: Gathered Tensors.
@@ -354,7 +360,7 @@ def all_gather(
354360
[ 7 8 9]
355361
[10 11 12]]
356362
"""
357-
result: List[paddle.Tensor] = []
363+
result: Union[paddle.Tensor, List[paddle.Tensor]] = []
358364

359365
# NOTE: Put tensor to CUDAPlace from CUDAPinnedPlace to use communication.
360366
if tensor.place.is_cuda_pinned_place():
@@ -363,10 +369,27 @@ def all_gather(
363369
# TODO(HydrogenSulfate): As non-contiguous(strided) tensor is not supported in
364370
# dist.all_gather, manually convert given Tensor to contiguous below. Strided tensor
365371
# will be supported in future.
366-
dist.all_gather(result, tensor.contiguous())
372+
if not requires_grad:
373+
dist.all_gather(result, tensor.contiguous())
374+
if concat:
375+
if tensor.ndim == 0:
376+
warnings.warn(
377+
"given tensor is a 0-dim tensor, so we use `paddle.stack` to replace `paddle.concat`",
378+
category=UserWarning,
379+
stacklevel=2,
380+
)
381+
result = paddle.stack(result, axis)
382+
else:
383+
result = paddle.concat(result, axis)
384+
else:
385+
assert (
386+
tensor.ndim > 0
387+
), "`all_gather` is not supported for 0-dim tensor when requires_grad=True"
388+
assert concat is True, "`requires_grad=True` only support `concat=True`"
389+
result = AllGather.apply(
390+
tensor.contiguous(), dist.get_rank(), dist.get_world_size(), None
391+
)
367392

368-
if concat:
369-
return paddle.concat(result, axis)
370393
return result
371394

372395

0 commit comments

Comments
 (0)