@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
370370 result.operands )))
371371 return failure ();
372372
373- result.addTypes (fnTy.getResult ( 0 ));
373+ result.addTypes (fnTy.getResults ( ));
374374 result.addAttributes (attrs);
375375
376376 return success ();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
532532 printWithEnumHandling (parser, *this );
533533}
534534
535+ ParseResult CastFromBlockScaledOp::parse (OpAsmParser &parser,
536+ OperationState &result) {
537+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
538+ }
539+
540+ void CastFromBlockScaledOp::print (OpAsmPrinter &parser) {
541+ printWithEnumHandling (parser, *this );
542+ }
543+
544+ ParseResult CastToBlockScaledOp::parse (OpAsmParser &parser,
545+ OperationState &result) {
546+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
547+ }
548+
549+ void CastToBlockScaledOp::print (OpAsmPrinter &parser) {
550+ printWithEnumHandling (parser, *this );
551+ }
552+
535553// ===----------------------------------------------------------------------===//
536554// Tosa utilities.
537555// ===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
39443962 return success ();
39453963}
39463964
3965+ LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents (
3966+ MLIRContext *context, ::std::optional<Location> location,
3967+ CastFromBlockScaledOp::Adaptor adaptor,
3968+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3969+ const ShapeAdaptor inputShape (adaptor.getInputData ().getType ());
3970+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
3971+ return success ();
3972+ }
3973+
3974+ LogicalResult CastFromBlockScaledOp::verify () {
3975+ const Type inputDataType = getInputData ().getType ();
3976+ const Type outputDataType = getResult ().getType ();
3977+ if (failed (verifyCompatibleShape (inputDataType, outputDataType)))
3978+ return emitOpError () << " require compatible shapes for input_data ("
3979+ << inputDataType << " ) and "
3980+ << " output_data (" << outputDataType << " )" ;
3981+
3982+ const ShapeAdaptor inputDataShape = ShapeAdaptor (inputDataType);
3983+
3984+ if (inputDataShape.hasRank ()) {
3985+ const unsigned int blockSize =
3986+ BlockSizeAttr::getBlockSizeValue (getBlockSize ());
3987+ const int64_t inputDataLastDim =
3988+ inputDataShape.getDimSize (inputDataShape.getRank () - 1 );
3989+ if (inputDataLastDim % blockSize != 0 )
3990+ return emitOpError () << " expect last dimension of input_data ("
3991+ << inputDataLastDim
3992+ << " ) to be divisible by block_size (" << blockSize
3993+ << " )" ;
3994+
3995+ const Type inputScaleType = getInputScale ().getType ();
3996+ const ShapeAdaptor inputScaleShape = ShapeAdaptor (inputScaleType);
3997+
3998+ if (inputScaleShape.hasRank ()) {
3999+ SmallVector<int64_t > inputDataDims, inputScaleDims;
4000+ inputDataShape.getDims (inputDataDims);
4001+ inputScaleShape.getDims (inputScaleDims);
4002+
4003+ if (inputDataDims.size () != inputScaleDims.size () ||
4004+ failed (verifyCompatibleShape (
4005+ ArrayRef<int64_t >(inputDataDims).drop_back (1 ),
4006+ ArrayRef<int64_t >(inputScaleDims).drop_back (1 ))))
4007+ return emitOpError () << " require compatible shapes for input_data ("
4008+ << inputDataType << " ) and "
4009+ << " input_scale (" << inputScaleType
4010+ << " ) except for the last dimension" ;
4011+
4012+ const SmallVector<int64_t , 2 > dimsToCheck{inputDataLastDim / blockSize,
4013+ inputScaleDims.back ()};
4014+ if (ShapedType::isStatic (inputDataLastDim) &&
4015+ failed (verifyCompatibleDims (dimsToCheck)))
4016+ return emitOpError ()
4017+ << " expect last dimension of input_scale ("
4018+ << inputScaleDims.back ()
4019+ << " ) to be equal to last dimension of input_data / block_size ("
4020+ << inputDataDims.back () / blockSize << " )" ;
4021+ }
4022+ }
4023+
4024+ return success ();
4025+ }
4026+
4027+ LogicalResult CastToBlockScaledOp::inferReturnTypeComponents (
4028+ MLIRContext *context, ::std::optional<Location> location,
4029+ CastToBlockScaledOp::Adaptor adaptor,
4030+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4031+ const ShapeAdaptor inputShape (adaptor.getInputData ().getType ());
4032+ inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
4033+ if (!inputShape.hasRank ())
4034+ return success ();
4035+
4036+ // Calculate output_scale shape if ranked input provided
4037+ SmallVector<int64_t > outputScaleShape;
4038+ inputShape.getDims (outputScaleShape);
4039+ const int64_t lastDimLoc = inputShape.getRank () - 1 ;
4040+ const int64_t lastDimSize = inputShape.getDimSize (lastDimLoc);
4041+ if (ShapedType::isStatic (lastDimSize)) {
4042+ const unsigned int blockSize =
4043+ BlockSizeAttr::getBlockSizeValue (adaptor.getBlockSize ());
4044+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4045+ }
4046+ inferredReturnShapes.push_back (ShapedTypeComponents (outputScaleShape));
4047+ return success ();
4048+ }
4049+
4050+ LogicalResult CastToBlockScaledOp::verify () {
4051+ const Type inputDataType = getInputData ().getType ();
4052+ const Type outputDataType = getResult (0 ).getType ();
4053+ if (failed (verifyCompatibleShape (inputDataType, outputDataType)))
4054+ return emitOpError () << " require compatible shapes for input_data ("
4055+ << inputDataType << " ) and "
4056+ << " output_data (" << outputDataType << " )" ;
4057+
4058+ const unsigned int blockSize =
4059+ BlockSizeAttr::getBlockSizeValue (getBlockSize ());
4060+ const ShapeAdaptor inputDataShape = ShapeAdaptor (inputDataType);
4061+ if (inputDataShape.hasRank ()) {
4062+ const int64_t inputDataLastDim =
4063+ inputDataShape.getDimSize (inputDataShape.getRank () - 1 );
4064+ if (ShapedType::isStatic (inputDataLastDim) &&
4065+ inputDataLastDim % blockSize != 0 )
4066+ return emitOpError () << " expect last dimension of input_data ("
4067+ << inputDataLastDim
4068+ << " ) to be divisible by block_size (" << blockSize
4069+ << " )" ;
4070+ }
4071+
4072+ const ShapeAdaptor outputDataShape = ShapeAdaptor (outputDataType);
4073+ const Type outputScaleType = getResult (1 ).getType ();
4074+ const ShapeAdaptor outputScaleShape = ShapeAdaptor (outputScaleType);
4075+ if (outputDataShape.hasRank () && outputScaleShape.hasRank ()) {
4076+ SmallVector<int64_t > outputDataDims, outputScaleDims;
4077+ outputDataShape.getDims (outputDataDims);
4078+ outputScaleShape.getDims (outputScaleDims);
4079+
4080+ if (outputDataDims.size () != outputScaleDims.size () ||
4081+ failed (verifyCompatibleShape (
4082+ ArrayRef<int64_t >(outputDataDims).drop_back (1 ),
4083+ ArrayRef<int64_t >(outputScaleDims).drop_back (1 ))))
4084+ return emitOpError () << " require compatible shapes for output_data ("
4085+ << outputDataType << " ) and "
4086+ << " output_scale (" << outputScaleType
4087+ << " ) except for the last dimension" ;
4088+
4089+ const int64_t outputDataLastDim = outputDataDims.back ();
4090+ const SmallVector<int64_t , 2 > dimsToCheck{outputDataLastDim / blockSize,
4091+ outputScaleDims.back ()};
4092+ if (ShapedType::isStatic (outputDataLastDim) &&
4093+ failed (verifyCompatibleDims (dimsToCheck)))
4094+ return emitOpError ()
4095+ << " expect last dimension of output_scale ("
4096+ << outputScaleDims.back ()
4097+ << " ) to be equal to last dimension of output_data / block_size ("
4098+ << outputDataDims.back () / blockSize << " )" ;
4099+ }
4100+
4101+ return success ();
4102+ }
4103+
39474104LogicalResult IfOp::inferReturnTypeComponents (
39484105 MLIRContext *context, ::std::optional<Location> location,
39494106 IfOp::Adaptor adaptor,
0 commit comments