Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3954,6 +3954,7 @@ def __len__(self):
return len(self._dataloader)

def __iter__(self):
self.iter = None
return self

def _get_mesh_and_placement(self, index):
Expand Down Expand Up @@ -4007,7 +4008,9 @@ def _dtensors_from_list_input(
):
dist_data = []
for j in range(len(list_tensors)):
if dense_tensor_idx is not None and j in dense_tensor_idx:
if (
dense_tensor_idx is not None and j in dense_tensor_idx
) or not isinstance(list_tensors[j], paddle.Tensor):
dist_data.append(list_tensors[j])
else:
dist_data.append(
Expand Down Expand Up @@ -4095,9 +4098,7 @@ def _get_batch(self, batch_data):
batch_data[key], mesh, placements
)
else:
raise ValueError(
f"Unsupported input_data type {type(input_data)}"
)
dist_batch_data[key] = input_data
return dist_batch_data
elif isinstance(batch_data, paddle.Tensor):
mesh, placements = self._get_mesh_and_placement(0)
Expand All @@ -4112,7 +4113,7 @@ def __next__(self):
return self._get_batch(batch_data)

def __call__(self):
self.iter = self._dataloader.__iter__()
self.iter = None
return self


Expand Down
8 changes: 8 additions & 0 deletions test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_semi_auto_parallel_multi_inputs
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_no_tensor_inputs MODULES
test_semi_auto_parallel_no_tensor_inputs ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_no_tensor_inputs
PROPERTIES TIMEOUT "50" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_llama_model_vpp MODULES
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.io import BatchSampler, DataLoader, Dataset

SEQ_LEN = 4
HIDDEN_SIZE = 8
global_mesh = dist.ProcessMesh(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=['pp', 'dp', 'mp']
)
mesh0 = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['dp', 'mp'])
mesh1 = dist.ProcessMesh([[4, 5], [6, 7]], dim_names=['dp', 'mp'])


class MlpModel(paddle.nn.Layer):
def __init__(self, variable_initial_values, run_single_process=False):
super().__init__()
self.w0 = self.create_parameter(
shape=[HIDDEN_SIZE, HIDDEN_SIZE],
default_initializer=paddle.nn.initializer.Assign(
variable_initial_values[0]
),
)
self.w1 = self.create_parameter(
shape=[HIDDEN_SIZE, HIDDEN_SIZE],
default_initializer=paddle.nn.initializer.Assign(
variable_initial_values[1]
),
)
if run_single_process is False:
self.w0 = dist.shard_tensor(
self.w0,
mesh0,
[dist.Replicate(), dist.Shard(1)],
)
self.w1 = dist.shard_tensor(
self.w1,
mesh1,
[dist.Replicate(), dist.Shard(0)],
)
self.run_single_process = run_single_process

def forward(self, input1, input2, extra_input1=None, extra_input2=None):
# extra_input1 and extra_input2 only used for test non_tensor input in shard_dataloader
x = input1 + input2
# x: [bs, seq_len, hidden]
# forward on mesh0
y = x @ self.w0
# forward on mesh1
if self.run_single_process is False:
y = dist.reshard(y, mesh1, [dist.Shard(0), dist.Shard(2)])
z = y @ self.w1
return z


class RandomDataset(Dataset):
def __init__(self, seq_len, hidden, num_samples=8):
super().__init__()
self.seq_len = seq_len
self.hidden = hidden
self.num_samples = num_samples
self.inputs1 = [
np.random.uniform(size=[self.seq_len, self.hidden]).astype(
"float32"
)
for _ in range(num_samples)
]
self.inputs2 = [
np.random.uniform(size=[self.seq_len, self.hidden]).astype(
"float32"
)
for _ in range(num_samples)
]
self.labels = [
np.array(index, dtype="float32") for index in range(num_samples)
]

def __getitem__(self, index):
return {
"inputs": [self.inputs1[index], self.inputs2[index]],
"label": self.labels[index],
}

def __len__(self):
return self.num_samples


def create_dataloader(collate_fn=None):
dataset = RandomDataset(SEQ_LEN, HIDDEN_SIZE)
sampler = BatchSampler(
dataset,
batch_size=2,
)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=collate_fn,
)
return dataloader


def get_variable_initial_value(var_num=2):
res = []
for i in range(var_num):
res.append(
paddle.uniform(
shape=[HIDDEN_SIZE, HIDDEN_SIZE],
dtype=paddle.float32,
min=-0.0001,
max=0.0001,
)
)
return res


def loss_fn(logits, label):
# logits: [bs, seq_len, hidden], label: [bs]
loss = paddle.nn.MSELoss(reduction="sum")
logits = paddle.sum(logits, axis=[1, 2])
return loss(logits, label)


class TestSemiAutoParallelMultiInputs:
def __init__(self):
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._run_static = eval(os.getenv("run_static"))
paddle.seed(self._seed)
np.random.seed(self._seed)
paddle.set_device(self._backend)
self.dataloader = create_dataloader()
self.variable_initial_values = get_variable_initial_value()

def test_non_tensor_input(self):
model = MlpModel(variable_initial_values=self.variable_initial_values)
opt = paddle.optimizer.AdamW(
learning_rate=0.001, parameters=model.parameters()
)

def custom_collate_fn(batch):
collated_batch = {
"inputs": [
paddle.to_tensor([item["inputs"][0] for item in batch]),
paddle.to_tensor([item["inputs"][1] for item in batch]),
12.0,
],
"extra_input": 12,
"label": paddle.to_tensor([item["label"] for item in batch]),
}
return collated_batch

self.dataloader = create_dataloader(custom_collate_fn)

dist_dataloader = dist.shard_dataloader(
dataloader=self.dataloader,
meshes=[mesh0, mesh0, mesh1],
shard_dims="dp",
input_keys=["inputs", "extra_input", "label"],
)

dist_opt = dist.shard_optimizer(opt)
for step, data in enumerate(dist_dataloader()):
input1, input2, extra_input1 = data["inputs"]
extra_input2 = data["extra_input"]
logits = model(input1, input2, extra_input1, extra_input2)
label = data["label"]
loss = loss_fn(logits, label)
loss.backward()
dist_opt.step()
dist_opt.clear_grad()

def run_test_case(self):
self.test_non_tensor_input()


if __name__ == '__main__':
TestSemiAutoParallelMultiInputs().run_test_case()
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import collective.test_communication_api_base as test_base

os.environ['FLAGS_enable_pir_api'] = '1'


class TestSemiAutoParallelMultiInputs(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(
num_of_devices=8,
timeout=120,
nnode=1,
)
self._default_envs = {
"dtype": "float32",
"seed": "1024",
}
self._changeable_envs = {"backend": ["gpu"]}

def test_dynamic(self):
self._default_envs.update({"run_static": "0"})
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_no_tensor_inputs.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/auto_parallel/hybrid_strategy/testslist.csv
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ test_semi_auto_parallel_hybrid_sharding_strategy,LINUX,GPU,120,HYBRID,test_runne
test_global_mesh_reshard,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_semi_auto_parallel_global_input,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_semi_auto_parallel_multi_inputs,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_semi_auto_parallel_no_tensor_inputs,LINUX,GPU,50,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_semi_auto_parallel_llama_model_vpp,LINUX,GPU,180,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_semi_auto_parallel_llama_model_pir,LINUX,GPU,180,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..;FLAGS_enable_pir_api=1,
test_pir_reshard_nd_mesh_func,LINUX,GPU,60,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
Expand Down
2 changes: 1 addition & 1 deletion third_party/openblas
Loading