Skip to content
114 changes: 63 additions & 51 deletions onnxruntime/core/providers/webgpu/tensor/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@
WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12)
WEBGPU_CONCAT_KERNEL(13)

void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) {
os << "fn calculate_input_index(index: u32) -> u32 {\n"
<< " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n"
<< " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n"
<< " return i;\n"
void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) {
os << "fn calculate_input_index(global_idx: u32) -> u32 {\n"
<< " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n"
<< " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n"
<< " return i - 1;\n"
<< " }\n"
<< " }\n"
<< " return " << input_count << ";\n"
<< " return " << input_count - 1 << ";\n"
<< "}\n";
}

void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output) {
os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n";
void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) {
os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i == 0) {
os << " if (input_index == 0u) {\n";
Expand All @@ -59,7 +59,12 @@
} else {
os << " } else if (input_index == " << i << "u) {\n";
}
os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n";
std::string offset = GetElementAt("uniforms.offsets", "input_index", input_count);
std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count);
std::string output_indices_axis = "output_indices" + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis) + "]" : "");

Check warning on line 64 in onnxruntime/core/providers/webgpu/tensor/concat.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/concat.cc:64: Add #include <string> for string [build/include_what_you_use] [4]
os << " var output_indices = " << inputs[i]->OffsetToIndices("global_idx - " + offset) << ";\n"
<< " " << output_indices_axis << " += " << concat_axis_offset << ";\n"
<< " " << output.SetByIndices("output_indices", inputs[i]->GetByOffset("global_idx - " + offset)) << "\n";
}
os << " }\n"
"}\n";
Expand All @@ -74,27 +79,21 @@
}
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

// add implementation of fn calculate_input_index
AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count);
// add implementation of fn assign_output_data
AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output);
const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count);
AppendCalculateInputIndexFunction(shader.AdditionalImplementation(), input_count);
AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output, axis_, input_count);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " var indices = " << output.OffsetToIndices("global_idx") << ";\n"
<< " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n"
<< " let input_index = calculate_input_index(indices_axis);\n"
<< " if (input_index != 0u) {\n"
<< " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n"
<< " }\n"
" assign_output_data(global_idx, input_index, indices);\n";
<< "let input_index = calculate_input_index(global_idx);\n"
<< "assign_output_data(global_idx, input_index);\n";

return Status::OK();
}

Status Concat::ComputeInternal(ComputeContext& context) const {
int input_count = context.InputCount();
uint32_t input_count = context.InputCount();
InlinedTensorsVector input_tensors;
input_tensors.reserve(input_count);
for (int i = 0; i < input_count; ++i) {
for (uint32_t i = 0; i < input_count; ++i) {
input_tensors.push_back(context.Input<Tensor>(i));
}

Expand All @@ -104,42 +103,55 @@
return Status::OK();
}

uint32_t output_size = onnxruntime::narrow<int32_t>(prepare.output_tensor->Shape().Size());
uint32_t axis = static_cast<uint32_t>(prepare.axis);
uint32_t max_inputs_per_concat = context.DeviceLimits().maxStorageBuffersPerShaderStage - 1;

uint32_t input_index = 0;
uint32_t cumulative_size_in_concat_axis = 0;

while (input_index < input_count) {
ConcatProgram program{axis};
uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index);

Check warning on line 114 in onnxruntime/core/providers/webgpu/tensor/concat.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/concat.cc:114: Add #include <algorithm> for min [build/include_what_you_use] [4]

std::vector<uint32_t> offsets;
offsets.reserve(num_inputs_this_concat + 1);
offsets.push_back(0);

size_t axis = static_cast<size_t>(prepare.axis);
ConcatProgram program{axis};
std::vector<uint32_t> sizes_in_concat_axis;

Check warning on line 120 in onnxruntime/core/providers/webgpu/tensor/concat.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/tensor/concat.cc:120: Add #include <vector> for vector<> [build/include_what_you_use] [4]
sizes_in_concat_axis.reserve(num_inputs_this_concat + 1);
sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis);

std::vector<uint32_t> sizes_in_concat_axis;
sizes_in_concat_axis.reserve(input_count);
uint32_t sum = 0;
for (int i = 0; i < input_count; ++i) {
const auto& input = prepare.inputs[i];
if (input.tensor->Shape().Size() == 0) {
continue;
uint32_t output_size = 0;
for (uint32_t i = 0; i < num_inputs_this_concat; i++) {
auto& input = prepare.inputs[input_index + i];
if (input.tensor->Shape().Size() == 0) {
continue;
}
program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank});

uint32_t size = onnxruntime::narrow<int32_t>(input.tensor->Shape().Size());
uint32_t axis_size = static_cast<uint32_t>(input.tensor->Shape()[axis]);

output_size += size;
offsets.push_back(output_size);
cumulative_size_in_concat_axis += axis_size;
sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis);
}
program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank});

auto axis_size = input.tensor->Shape()[axis];
sum += static_cast<uint32_t>(axis_size);
sizes_in_concat_axis.push_back(sum);
}
offsets.pop_back();
sizes_in_concat_axis.pop_back();

size_t non_empty_input_count = sizes_in_concat_axis.size();
program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ","))
.AddOutputs({prepare.output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({gsl::span<const uint32_t>(offsets.data(), offsets.size()), gsl::span<const uint32_t>(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size});
ORT_RETURN_IF_ERROR(context.RunProgram(program));

if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) {
// TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes.
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=",
input_count, ", output=1) exceeds the limit (",
context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device.");
input_index += num_inputs_this_concat;
}

program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ","))
.AddOutputs({prepare.output_tensor})
.SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({gsl::span<const uint32_t>(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()),
output_size});
return context.RunProgram(program);
return Status::OK();
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/tensor/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class ConcatProgram final : public Program<ConcatProgram> {

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32},
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"offsets", ProgramUniformVariableDataType::Uint32},
{"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32},
{"output_size", ProgramUniformVariableDataType::Uint32});

private:
Expand All @@ -33,4 +34,4 @@ class Concat final : public WebGpuKernel, public ConcatBase {
};

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
105 changes: 105 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/concat_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,5 +434,110 @@ TEST(ConcatOpTest, Concat4D_2) {
test.Run();
}

#ifdef USE_WEBGPU
TEST(ConcatOpTest, Concat1D_int32_4inputs) {
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});

test.AddInput<int32_t>("input1", {1}, {1});
test.AddInput<int32_t>("input2", {2}, {2, 3});
test.AddInput<int32_t>("input3", {4}, {4, 5, 6, 7});
test.AddInput<int32_t>("input4", {2}, {8, 9});
test.AddOutput<int32_t>("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
test.Run();
}

TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});

test.AddInput<int32_t>("input1", {1}, {1});
test.AddInput<int32_t>("input2", {1}, {2});
test.AddInput<int32_t>("input3", {1}, {3});
test.AddInput<int32_t>("input4", {1}, {4});
test.AddInput<int32_t>("input5", {1}, {5});
test.AddInput<int32_t>("input6", {1}, {6});
test.AddInput<int32_t>("input7", {1}, {7});
test.AddInput<int32_t>("input8", {1}, {8});
test.AddInput<int32_t>("input9", {1}, {9});
test.AddOutput<int32_t>("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
test.Run();
}

TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{0});

test.AddInput<int32_t>("input1", {1, 2}, {1, 2});
test.AddInput<int32_t>("input2", {1, 2}, {3, 4});
test.AddInput<int32_t>("input3", {1, 2}, {5, 6});
test.AddInput<int32_t>("input4", {1, 2}, {7, 8});
test.AddInput<int32_t>("input5", {1, 2}, {9, 10});
test.AddInput<int32_t>("input6", {1, 2}, {11, 12});
test.AddInput<int32_t>("input7", {1, 2}, {13, 14});
test.AddInput<int32_t>("input8", {1, 2}, {15, 16});
test.AddInput<int32_t>("input9", {1, 2}, {17, 18});
test.AddOutput<int32_t>("concat_result", {9, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
test.Run();
}

TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});

test.AddInput<int32_t>("input1", {1, 2}, {1, 2});
test.AddInput<int32_t>("input2", {1, 2}, {3, 4});
test.AddInput<int32_t>("input3", {1, 2}, {5, 6});
test.AddInput<int32_t>("input4", {1, 2}, {7, 8});
test.AddInput<int32_t>("input5", {1, 2}, {9, 10});
test.AddInput<int32_t>("input6", {1, 2}, {11, 12});
test.AddInput<int32_t>("input7", {1, 2}, {13, 14});
test.AddInput<int32_t>("input8", {1, 2}, {15, 16});
test.AddInput<int32_t>("input9", {1, 2}, {17, 18});
test.AddOutput<int32_t>("concat_result", {1, 18}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
test.Run();
}

TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});

test.AddInput<int32_t>("input1", {2, 1, 1}, {1, 2});
test.AddInput<int32_t>("input2", {2, 1, 1}, {3, 4});
test.AddInput<int32_t>("input3", {2, 1, 1}, {5, 6});
test.AddInput<int32_t>("input4", {2, 1, 1}, {7, 8});
test.AddInput<int32_t>("input5", {2, 1, 1}, {9, 10});
test.AddInput<int32_t>("input6", {2, 1, 1}, {11, 12});
test.AddInput<int32_t>("input7", {2, 1, 1}, {13, 14});
test.AddInput<int32_t>("input8", {2, 1, 1}, {15, 16});
test.AddInput<int32_t>("input9", {2, 1, 1}, {17, 18});
test.AddOutput<int32_t>("concat_result", {2, 9, 1}, {// batch 0
1, 3, 5, 7, 9, 11, 13, 15, 17,
// batch 1
2, 4, 6, 8, 10, 12, 14, 16, 18});
test.Run();
}

TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_mixed_sizes) {
// maxStorageBuffersPerShaderStage==8
OpTester test("Concat");
test.AddAttribute("axis", int64_t{1});

test.AddInput<int32_t>("input1", {2, 1, 1}, {1, 2});
test.AddInput<int32_t>("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8});
test.AddInput<int32_t>("input3", {2, 2, 1}, {9, 10, 11, 12});
test.AddInput<int32_t>("input4", {2, 1, 1}, {13, 14});
test.AddOutput<int32_t>("concat_result", {2, 7, 1}, {// batch 0
1, 3, 4, 5, 9, 10, 13,
// batch 1
2, 6, 7, 8, 11, 12, 14});
test.Run();
}
#endif // USE_WEBGPU

} // namespace test
} // namespace onnxruntime
Loading