Skip to content

Commit 1184252

Browse files
authored
Update slang-rhi and mlp-training-coopvec example (#9036)
- Update to latest slang-rhi with refactored coopvec APIs for added CUDA/OptiX compatibility - Update `mlp-training-coopvec` example to new API
1 parent ce376e1 commit 1184252

File tree

2 files changed

+32
-33
lines changed

2 files changed

+32
-33
lines changed

examples/mlp-training-coopvec/mlp-training-coopvec.cpp

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -201,29 +201,35 @@ struct ExampleProgram : public TestBase
201201
// Copy weight gradients from training-optimal layout to row-major layout,
202202
// so we can read them in the `adjustParameters` kernel.
203203
{
204-
std::vector<rhi::ConvertCooperativeVectorMatrixDesc> matrixDescs;
204+
rhi::CooperativeVectorMatrixDesc srcDescs[kLayerCount];
205+
rhi::CooperativeVectorMatrixDesc dstDescs[kLayerCount];
205206
for (int i = 0; i < kLayerCount; i++)
206207
{
207-
rhi::ConvertCooperativeVectorMatrixDesc desc = {};
208-
desc.rowCount = kLayerSizes[i + 1];
209-
desc.colCount = kLayerSizes[i];
210-
desc.dstComponentType = rhi::CooperativeVectorComponentType::Float16;
211-
desc.dstSize = &layerAllocations[i].weightsSize;
212-
desc.dstData.deviceAddress = networkParamsBuffer->getDeviceAddress() +
213-
layerAllocations[i].weightsGradOffset;
214-
desc.dstLayout = rhi::CooperativeVectorMatrixLayout::RowMajor;
215-
desc.dstStride = getNetworkLayerWeightStride(i);
216-
desc.srcComponentType = rhi::CooperativeVectorComponentType::Float16;
217-
desc.srcSize = layerAllocations[i].weightsGradTrainingSize;
218-
desc.srcData.deviceAddress = networkParamsBuffer->getDeviceAddress() +
219-
layerAllocations[i].weightsGradTrainingOffset;
220-
desc.srcLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
221-
matrixDescs.push_back(desc);
208+
rhi::CooperativeVectorMatrixDesc& srcDesc = srcDescs[i];
209+
srcDesc = {};
210+
srcDesc.rowCount = kLayerSizes[i + 1];
211+
srcDesc.colCount = kLayerSizes[i];
212+
srcDesc.componentType = rhi::CooperativeVectorComponentType::Float16;
213+
srcDesc.layout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
214+
srcDesc.size = layerAllocations[i].weightsGradTrainingSize;
215+
srcDesc.offset = layerAllocations[i].weightsGradTrainingOffset;
216+
rhi::CooperativeVectorMatrixDesc& dstDesc = dstDescs[i];
217+
dstDesc = {};
218+
dstDesc.rowCount = kLayerSizes[i + 1];
219+
dstDesc.colCount = kLayerSizes[i];
220+
dstDesc.componentType = rhi::CooperativeVectorComponentType::Float16;
221+
dstDesc.layout = rhi::CooperativeVectorMatrixLayout::RowMajor;
222+
dstDesc.size = layerAllocations[i].weightsSize;
223+
dstDesc.offset = layerAllocations[i].weightsGradOffset;
224+
dstDesc.rowColumnStride = getNetworkLayerWeightStride(i);
222225
}
223226
auto encoder = queue->createCommandEncoder();
224227
encoder->convertCooperativeVectorMatrix(
225-
matrixDescs.data(),
226-
(uint32_t)matrixDescs.size());
228+
networkParamsBuffer,
229+
dstDescs,
230+
networkParamsBuffer,
231+
srcDescs,
232+
kLayerCount);
227233
ComPtr<rhi::ICommandBuffer> commandBuffer;
228234
encoder->finish(commandBuffer.writeRef());
229235
queue->submit(commandBuffer);
@@ -296,20 +302,13 @@ struct ExampleProgram : public TestBase
296302
for (int i = 0; i < kLayerCount; i++)
297303
{
298304
// Allocate space for gradients in training-optimal layout.
299-
rhi::ConvertCooperativeVectorMatrixDesc matrixDesc = {};
300-
matrixDesc.srcComponentType = rhi::CooperativeVectorComponentType::Float16;
301-
matrixDesc.srcSize = paramStorage[i].weightsSize;
302-
matrixDesc.srcData.hostAddress = nullptr;
303-
matrixDesc.srcLayout = rhi::CooperativeVectorMatrixLayout::RowMajor;
304-
matrixDesc.srcStride = getNetworkLayerWeightStride(i);
305-
matrixDesc.dstComponentType = rhi::CooperativeVectorComponentType::Float16;
306-
matrixDesc.dstSize = &paramStorage[i].weightsGradTrainingSize;
307-
matrixDesc.dstData.hostAddress = nullptr;
308-
matrixDesc.dstLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
309-
matrixDesc.dstStride = 0;
310-
matrixDesc.rowCount = kLayerSizes[i + 1];
311-
matrixDesc.colCount = kLayerSizes[i];
312-
gDevice->convertCooperativeVectorMatrix(&matrixDesc, 1);
305+
gDevice->getCooperativeVectorMatrixSize(
306+
kLayerSizes[i + 1],
307+
kLayerSizes[i],
308+
rhi::CooperativeVectorComponentType::Float16,
309+
rhi::CooperativeVectorMatrixLayout::TrainingOptimal,
310+
0,
311+
&paramStorage[i].weightsGradTrainingSize);
313312
paramStorage[i].weightsGradTrainingOffset =
314313
allocRowMajorStorage(paramStorage[i].weightsGradTrainingSize);
315314
}

external/slang-rhi

Submodule slang-rhi updated 106 files

0 commit comments

Comments
 (0)