@@ -201,29 +201,35 @@ struct ExampleProgram : public TestBase
201201 // Copy weight gradients from training-optimal layout to row-major layout,
202202 // so we can read them in the `adjustParameters` kernel.
203203 {
204- std::vector<rhi::ConvertCooperativeVectorMatrixDesc> matrixDescs;
204+ rhi::CooperativeVectorMatrixDesc srcDescs[kLayerCount ];
205+ rhi::CooperativeVectorMatrixDesc dstDescs[kLayerCount ];
205206 for (int i = 0 ; i < kLayerCount ; i++)
206207 {
207- rhi::ConvertCooperativeVectorMatrixDesc desc = {};
208- desc.rowCount = kLayerSizes [i + 1 ];
209- desc.colCount = kLayerSizes [i];
210- desc.dstComponentType = rhi::CooperativeVectorComponentType::Float16;
211- desc.dstSize = &layerAllocations[i].weightsSize ;
212- desc.dstData .deviceAddress = networkParamsBuffer->getDeviceAddress () +
213- layerAllocations[i].weightsGradOffset ;
214- desc.dstLayout = rhi::CooperativeVectorMatrixLayout::RowMajor;
215- desc.dstStride = getNetworkLayerWeightStride (i);
216- desc.srcComponentType = rhi::CooperativeVectorComponentType::Float16;
217- desc.srcSize = layerAllocations[i].weightsGradTrainingSize ;
218- desc.srcData .deviceAddress = networkParamsBuffer->getDeviceAddress () +
219- layerAllocations[i].weightsGradTrainingOffset ;
220- desc.srcLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
221- matrixDescs.push_back (desc);
208+ rhi::CooperativeVectorMatrixDesc& srcDesc = srcDescs[i];
209+ srcDesc = {};
210+ srcDesc.rowCount = kLayerSizes [i + 1 ];
211+ srcDesc.colCount = kLayerSizes [i];
212+ srcDesc.componentType = rhi::CooperativeVectorComponentType::Float16;
213+ srcDesc.layout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
214+ srcDesc.size = layerAllocations[i].weightsGradTrainingSize ;
215+ srcDesc.offset = layerAllocations[i].weightsGradTrainingOffset ;
216+ rhi::CooperativeVectorMatrixDesc& dstDesc = dstDescs[i];
217+ dstDesc = {};
218+ dstDesc.rowCount = kLayerSizes [i + 1 ];
219+ dstDesc.colCount = kLayerSizes [i];
220+ dstDesc.componentType = rhi::CooperativeVectorComponentType::Float16;
221+ dstDesc.layout = rhi::CooperativeVectorMatrixLayout::RowMajor;
222+ dstDesc.size = layerAllocations[i].weightsSize ;
223+ dstDesc.offset = layerAllocations[i].weightsGradOffset ;
224+ dstDesc.rowColumnStride = getNetworkLayerWeightStride (i);
222225 }
223226 auto encoder = queue->createCommandEncoder ();
224227 encoder->convertCooperativeVectorMatrix (
225- matrixDescs.data (),
226- (uint32_t )matrixDescs.size ());
228+ networkParamsBuffer,
229+ dstDescs,
230+ networkParamsBuffer,
231+ srcDescs,
232+ kLayerCount );
227233 ComPtr<rhi::ICommandBuffer> commandBuffer;
228234 encoder->finish (commandBuffer.writeRef ());
229235 queue->submit (commandBuffer);
@@ -296,20 +302,13 @@ struct ExampleProgram : public TestBase
296302 for (int i = 0 ; i < kLayerCount ; i++)
297303 {
298304 // Allocate space for gradients in training-optimal layout.
299- rhi::ConvertCooperativeVectorMatrixDesc matrixDesc = {};
300- matrixDesc.srcComponentType = rhi::CooperativeVectorComponentType::Float16;
301- matrixDesc.srcSize = paramStorage[i].weightsSize ;
302- matrixDesc.srcData .hostAddress = nullptr ;
303- matrixDesc.srcLayout = rhi::CooperativeVectorMatrixLayout::RowMajor;
304- matrixDesc.srcStride = getNetworkLayerWeightStride (i);
305- matrixDesc.dstComponentType = rhi::CooperativeVectorComponentType::Float16;
306- matrixDesc.dstSize = ¶mStorage[i].weightsGradTrainingSize ;
307- matrixDesc.dstData .hostAddress = nullptr ;
308- matrixDesc.dstLayout = rhi::CooperativeVectorMatrixLayout::TrainingOptimal;
309- matrixDesc.dstStride = 0 ;
310- matrixDesc.rowCount = kLayerSizes [i + 1 ];
311- matrixDesc.colCount = kLayerSizes [i];
312- gDevice ->convertCooperativeVectorMatrix (&matrixDesc, 1 );
305+ gDevice ->getCooperativeVectorMatrixSize (
306+ kLayerSizes [i + 1 ],
307+ kLayerSizes [i],
308+ rhi::CooperativeVectorComponentType::Float16,
309+ rhi::CooperativeVectorMatrixLayout::TrainingOptimal,
310+ 0 ,
311+ ¶mStorage[i].weightsGradTrainingSize );
313312 paramStorage[i].weightsGradTrainingOffset =
314313 allocRowMajorStorage (paramStorage[i].weightsGradTrainingSize );
315314 }
0 commit comments