@@ -369,89 +369,6 @@ struct ConvertLayoutOpConversion
369
369
}
370
370
}
371
371
372
- LogicalResult
373
- lowerDistToDistWithDistSmem (triton::gpu::ConvertLayoutOp op,
374
- OpAdaptor adaptor,
375
- ConversionPatternRewriter &rewriter) const {
376
- auto loc = op.getLoc ();
377
- auto typeConverter = getTypeConverter ();
378
- auto srcTy = op.getSrc ().getType ();
379
- auto dstTy = op.getType ();
380
- auto srcLayout = srcTy.getEncoding ();
381
- auto dstLayout = dstTy.getEncoding ();
382
- auto srcShapePerCTA = getShapePerCTA (srcTy);
383
- auto srcCTAsPerCGA = triton::gpu::getCTAsPerCGA (srcLayout);
384
- auto srcCTAOrder = triton::gpu::getCTAOrder (srcLayout);
385
- unsigned rank = srcShapePerCTA.size ();
386
-
387
- auto llvmElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
388
- auto elemPtrTy = ptr_ty (rewriter.getContext (), 3 );
389
-
390
- Value smemBase =
391
- LLVM::intel::getSharedMemoryBase (loc, rewriter, op.getOperation ());
392
- smemBase = bitcast (smemBase, elemPtrTy);
393
- auto smemShape = convertType<unsigned , int64_t >(srcShapePerCTA);
394
-
395
- // Store to local shared memory
396
- {
397
- auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
398
- auto inIndices = ::intel::emitIndices (loc, rewriter, srcLayout, srcTy,
399
- /* withCTAOffset*/ false );
400
-
401
- assert (inIndices.size () == inVals.size () &&
402
- " Unexpected number of indices emitted" );
403
-
404
- for (unsigned i = 0 ; i < inIndices.size (); ++i) {
405
- Value offset = linearize (rewriter, loc, inIndices[i], smemShape);
406
- Value ptr = gep (elemPtrTy, llvmElemTy, smemBase, offset);
407
- store (inVals[i], ptr);
408
- }
409
- }
410
-
411
- // Cluster barrier
412
- rewriter.create <triton::nvidia_gpu::ClusterArriveOp>(loc, false );
413
- rewriter.create <triton::nvidia_gpu::ClusterWaitOp>(loc);
414
-
415
- // Load from remote shared memory
416
- {
417
- SmallVector<Value> srcShapePerCTACache;
418
- for (unsigned i = 0 ; i < rank; ++i)
419
- srcShapePerCTACache.push_back (i32_val (srcShapePerCTA[i]));
420
-
421
- SmallVector<Value> outVals;
422
- auto outIndices = ::intel::emitIndices (loc, rewriter, dstLayout, dstTy,
423
- /* withCTAOffset*/ true );
424
-
425
- for (unsigned i = 0 ; i < outIndices.size (); ++i) {
426
- auto coord = outIndices[i];
427
- assert (coord.size () == rank && " Unexpected rank of index emitted" );
428
-
429
- SmallVector<Value> multiDimCTAId, localCoord;
430
- for (unsigned d = 0 ; d < rank; ++d) {
431
- multiDimCTAId.push_back (udiv (coord[d], srcShapePerCTACache[d]));
432
- localCoord.push_back (urem (coord[d], srcShapePerCTACache[d]));
433
- }
434
-
435
- Value remoteCTAId =
436
- linearize (rewriter, loc, multiDimCTAId, srcCTAsPerCGA, srcCTAOrder);
437
- Value localOffset = linearize (rewriter, loc, localCoord, smemShape);
438
-
439
- Value ptr = gep (elemPtrTy, llvmElemTy, smemBase, localOffset);
440
- outVals.push_back (load_dsmem (ptr, remoteCTAId, llvmElemTy));
441
- }
442
-
443
- Value result =
444
- packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
445
- rewriter.replaceOp (op, result);
446
- }
447
-
448
- // Cluster barrier
449
- rewriter.create <triton::nvidia_gpu::ClusterArriveOp>(loc, false );
450
- rewriter.create <triton::nvidia_gpu::ClusterWaitOp>(loc);
451
-
452
- return success ();
453
- }
454
-
455
372
// blocked/dpas -> blocked/dpas.
456
373
// Data padding in shared memory to avoid bank conflict.
457
374
LogicalResult
@@ -465,8 +382,6 @@ struct ConvertLayoutOpConversion
465
382
Attribute srcLayout = srcTy.getEncoding ();
466
383
Attribute dstLayout = dstTy.getEncoding ();
467
384
468
- if (shouldUseDistSmem (srcLayout, dstLayout))
469
- return lowerDistToDistWithDistSmem (op, adaptor, rewriter);
470
385
Value smemBase =
471
386
LLVM::intel::getSharedMemoryBase (loc, rewriter, op.getOperation ());
472
387
auto elemPtrTy = ptr_ty (rewriter.getContext (), 3 );
0 commit comments