@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
6969                          Sharding sourceSharding,
7070                          TypedValue<ShapedType> sourceShard, GridOp grid,
7171                          int64_t  splitTensorAxis, GridAxis splitGridAxis) {
72-   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 
72+   TypedValue<ShapedType> targetShard =
7373      AllSliceOp::create (builder, sourceShard, grid,
7474                         ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
75-           .getResult ()) ;
75+           .getResult ();
7676  Sharding targetSharding = targetShardingInSplitLastAxis (
7777      builder.getContext (), sourceSharding, splitTensorAxis, splitGridAxis);
7878  return  {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
204204      APInt (64 , splitTensorAxis));
205205  ShapedType targetShape =
206206      shardShapedType (sourceUnshardedShape, grid, targetSharding);
207-   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
208-       tensor::CastOp::create (builder, targetShape, allGatherResult)
209-           .getResult ());
207+   TypedValue<ShapedType> targetShard =
208+       tensor::CastOp::create (builder, targetShape, allGatherResult).getResult ();
210209  return  {targetShard, targetSharding};
211210}
212211
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
336335      APInt (64 , targetTensorAxis), APInt (64 , sourceTensorAxis));
337336  ShapedType targetShape =
338337      shardShapedType (sourceUnshardedShape, grid, targetSharding);
339-   TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( 
340-       tensor::CastOp::create (builder, targetShape, allToAllResult).getResult ()) ;
338+   TypedValue<ShapedType> targetShard =
339+       tensor::CastOp::create (builder, targetShape, allToAllResult).getResult ();
341340  return  {targetShard, targetSharding};
342341}
343342
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
510509  auto  targetSharding = target.getSharding ();
511510  ImplicitLocOpBuilder implicitLocOpBuilder (target->getLoc (), builder);
512511  return  reshard (implicitLocOpBuilder, grid, sourceSharding, targetSharding,
513-                  cast<TypedValue<ShapedType>>(source.getSrc ()),
514-                  sourceShardValue);
512+                  source.getSrc (), sourceShardValue);
515513}
516514
517515TypedValue<ShapedType> reshard (OpBuilder &builder, ShardOp source,
0 commit comments