Skip to content

Commit 2fa3b3c

Browse files
Merge branch 'develop' into quick_tune_code_review
2 parents 70b50ac + 58c991b commit 2fa3b3c

File tree

12 files changed

+975
-245
lines changed

12 files changed

+975
-245
lines changed

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ BroadcastConverter::matchAndRewrite(migraphx::BroadcastOp op, OpAdaptor adaptor,
584584
// because tosa does not have an explicit broadcast op
585585
auto oneTensor = rock::tosa::getOneTensor(rewriter, loc, outType);
586586
auto mulWithOne = rock::tosa::getMulOp(rewriter, loc, sameRankReshapedOp,
587-
oneTensor, elemType);
587+
oneTensor, newOutElementTy);
588588
rewriter.replaceOp(op, mulWithOne);
589589
return success();
590590
}

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 398 additions & 139 deletions
Large diffs are not rendered by default.

mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-causal.mlir

Lines changed: 125 additions & 0 deletions
Large diffs are not rendered by default.

mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-lse.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,46 @@ func.func @mlir_attention_single_token(%arg0: tensor<128xf32>, %arg1: tensor<256
345345
%collapsed_7 = tensor.collapse_shape %20 [[0, 1, 2]] : tensor<8x1x32xf32> into tensor<256xf32>
346346
return %collapsed_7, %collapsed_4 : tensor<256xf32>, tensor<8xf32>
347347
}
348+
349+
// CHECK-LABEL: @mlir_attention_lse_unfolded
350+
// CHECK: %[[lseBuffer:.+]] = bufferization.alloc_tensor() : tensor<8x1xf32>
351+
// CHECK: %{{.*}}, %[[lseOut:.*]] = rock.attention
352+
// CHECK: lse = %[[lseBuffer]] : tensor<8x1xf32>
353+
// CHECK: %[[lseExpanded:.*]] = tensor.expand_shape %[[lseOut]]
354+
// CHECK: %[[lseCollapsed:.*]] = tensor.collapse_shape %[[lseExpanded]]
355+
// CHECK: return %{{.*}}, %[[lseCollapsed]] : tensor<256xf32>, tensor<8xf32>
356+
func.func private @mlir_attention_lse_unfolded(%arg0: tensor<128xf32>, %arg1: tensor<256xf32>, %arg2: tensor<128xf32>) -> (tensor<256xf32>, tensor<8xf32>) attributes {arch = "##TOKEN_ARCH##", kernel} {
357+
%0 = tosa.const_shape {values = dense<256> : tensor<1xindex>} : () -> !tosa.shape<1>
358+
%1 = tosa.const_shape {values = dense<[8, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
359+
%2 = tosa.const_shape {values = dense<8> : tensor<1xindex>} : () -> !tosa.shape<1>
360+
%3 = tosa.const_shape {values = dense<[2, 4, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
361+
%4 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
362+
%5 = tosa.const_shape {values = dense<[8, 32, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
363+
%6 = tosa.const_shape {values = dense<[8, 1, 32]> : tensor<3xindex>} : () -> !tosa.shape<3>
364+
%7 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
365+
%8 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<2x2x2x1x32xf32>}> : () -> tensor<2x2x2x1x32xf32>
366+
%9 = tosa.const_shape {values = dense<[2, 2, 1, 1, 32]> : tensor<5xindex>} : () -> !tosa.shape<5>
367+
%10 = tosa.const_shape {values = dense<[2, 4, 1, 32]> : tensor<4xindex>} : () -> !tosa.shape<4>
368+
%expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 1, 32] : tensor<128xf32> into tensor<2x2x1x1x32xf32>
369+
%11 = tosa.mul %expanded, %8, %7 : (tensor<2x2x1x1x32xf32>, tensor<2x2x2x1x32xf32>, tensor<1xi8>) -> tensor<2x2x2x1x32xf32>
370+
%expanded_0 = tensor.expand_shape %arg2 [[0, 1, 2, 3, 4]] output_shape [2, 2, 1, 1, 32] : tensor<128xf32> into tensor<2x2x1x1x32xf32>
371+
%12 = tosa.mul %expanded_0, %8, %7 : (tensor<2x2x1x1x32xf32>, tensor<2x2x2x1x32xf32>, tensor<1xi8>) -> tensor<2x2x2x1x32xf32>
372+
%collapsed = tensor.collapse_shape %12 [[0], [1, 2], [3], [4]] : tensor<2x2x2x1x32xf32> into tensor<2x4x1x32xf32>
373+
%13 = tosa.transpose %collapsed {perms = array<i32: 0, 1, 3, 2>} : (tensor<2x4x1x32xf32>) -> tensor<2x4x32x1xf32>
374+
%expanded_1 = tensor.expand_shape %arg1 [[0, 1, 2]] output_shape [8, 1, 32] : tensor<256xf32> into tensor<8x1x32xf32>
375+
%collapsed_2 = tensor.collapse_shape %13 [[0, 1], [2], [3]] : tensor<2x4x32x1xf32> into tensor<8x32x1xf32>
376+
%14 = tosa.matmul %expanded_1, %collapsed_2, %4, %4 {acc_type = f32} : (tensor<8x1x32xf32>, tensor<8x32x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x1x1xf32>
377+
%expanded_3 = tensor.expand_shape %14 [[0, 1], [2], [3]] output_shape [2, 4, 1, 1] : tensor<8x1x1xf32> into tensor<2x4x1x1xf32>
378+
%15 = tosa.sub %expanded_3, %expanded_3 : (tensor<2x4x1x1xf32>, tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
379+
%16 = tosa.exp %15 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
380+
%17 = tosa.reciprocal %16 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
381+
%18 = tosa.mul %16, %17, %7 : (tensor<2x4x1x1xf32>, tensor<2x4x1x1xf32>, tensor<1xi8>) -> tensor<2x4x1x1xf32>
382+
%19 = tosa.log %16 : (tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
383+
%20 = tosa.add %19, %expanded_3 : (tensor<2x4x1x1xf32>, tensor<2x4x1x1xf32>) -> tensor<2x4x1x1xf32>
384+
%collapsed_4 = tensor.collapse_shape %20 [[0, 1, 2, 3]] : tensor<2x4x1x1xf32> into tensor<8xf32>
385+
%collapsed_5 = tensor.collapse_shape %18 [[0, 1], [2], [3]] : tensor<2x4x1x1xf32> into tensor<8x1x1xf32>
386+
%collapsed_6 = tensor.collapse_shape %11 [[0, 1, 2], [3], [4]] : tensor<2x2x2x1x32xf32> into tensor<8x1x32xf32>
387+
%21 = tosa.matmul %collapsed_5, %collapsed_6, %4, %4 {acc_type = f32} : (tensor<8x1x1xf32>, tensor<8x1x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x1x32xf32>
388+
%collapsed_7 = tensor.collapse_shape %21 [[0, 1, 2]] : tensor<8x1x32xf32> into tensor<256xf32>
389+
return %collapsed_7, %collapsed_4 : tensor<256xf32>, tensor<8xf32>
390+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand_min_int 0 -rand_max_int 1024 -rand_type_int_for_inputs=3 -rand 1 -rand_type float -fut mlir_attention_wrapper -RMS_threshold 0.01 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
2+
// CHECK: [1 1 1]
3+
4+
module {
5+
func.func @mlir_attention(%arg0: !migraphx.shaped<1x1x1xsi32, 1x1x1>, %arg1: !migraphx.shaped<1x96x1x128xf16, 12288x128x128x1>, %arg2: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>, %arg3: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>) -> !migraphx.shaped<1x1x4096xf16, 4096x4096x1> {
6+
%0 = migraphx.literal(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : tensor<256xsi32>) : <256xsi32, 1>
7+
%1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
8+
%2 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 1>
9+
%3 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 256]} : <256xsi32, 1> -> <1x1x1x256xsi32, 0x0x0x1>
10+
%4 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
11+
%5 = migraphx.broadcast %arg0 {axis = 0 : i64, out_lens = [1, 1, 1, 256]} : <1x1x1xsi32, 1x1x1> -> <1x1x1x256xsi32, 1x1x1x0>
12+
%6 = migraphx.greater %3, %5 : <1x1x1x256xsi32, 0x0x0x1>, <1x1x1x256xsi32, 1x1x1x0> -> <1x1x1x256xsi32, 0x0x0x1>
13+
%7 = migraphx.convert %6 {target_type = 0 : i64} : <1x1x1x256xsi32, 0x0x0x1> to <1x1x1x256xsi8, 0x0x0x1>
14+
%8 = migraphx.multibroadcast %7 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x1x1x256xsi8, 0x0x0x1> -> <1x32x1x256xsi8, 0x0x0x1>
15+
%9 = migraphx.slice %arg1 {axes = [1], ends = [32], starts = [0]} : <1x96x1x128xf16, 12288x128x128x1> -> <1x32x1x128xf16, 12288x128x128x1>
16+
%10 = migraphx.transpose %arg2 {permutation = [0, 1, 3, 2]} : <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x128x256xf16, 1048576x32768x1x128>
17+
%11 = migraphx.dot %9, %10 : <1x32x1x128xf16, 12288x128x128x1>, <1x32x128x256xf16, 1048576x32768x1x128> -> <1x32x1x256xf16, 8192x256x256x1>
18+
%12 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
19+
%13 = migraphx.mul %11, %4 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 0x0x0x0> -> <1x32x1x256xf16, 8192x256x256x1>
20+
%14 = migraphx.where %8, %12, %13 : <1x32x1x256xsi8, 0x0x0x1>, <1x32x1x256xf16, 0x0x0x0>, <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
21+
%15 = migraphx.reshape %14 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
22+
%16 = migraphx.reduce_max %15 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
23+
%17 = migraphx.reshape %16 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
24+
%18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
25+
%19 = migraphx.sub %14, %18 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
26+
%20 = migraphx.exp %19 : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
27+
%21 = migraphx.reshape %20 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
28+
%22 = migraphx.reduce_sum %21 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
29+
%23 = migraphx.reshape %22 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
30+
%24 = migraphx.multibroadcast %23 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
31+
%25 = migraphx.div %20, %24 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
32+
%26 = migraphx.dot %25, %arg3 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x1x128xf16, 4096x128x128x1>
33+
%27 = migraphx.transpose %26 {permutation = [0, 2, 1, 3]} : <1x32x1x128xf16, 4096x128x128x1> -> <1x1x32x128xf16, 4096x128x128x1>
34+
%28 = migraphx.reshape %27 {dims = [1, 1, 4096]} : <1x1x32x128xf16, 4096x128x128x1> -> <1x1x4096xf16, 4096x4096x1>
35+
return %28 : !migraphx.shaped<1x1x4096xf16, 4096x4096x1>
36+
}
37+
}
38+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand_min_int 0 -rand_max_int 1024 -rand_type_int_for_inputs=3 -rand 1 -rand_type float -fut mlir_attention_wrapper -RMS_threshold 0.01 --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full -targets %arch | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
2+
// CHECK: [1 1 1]
3+
4+
module {
5+
func.func @mlir_attention(%arg0: !migraphx.shaped<1x96x1x128xf16, 12288x128x128x1>, %arg1: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>, %arg2: !migraphx.shaped<1x1x1xsi32, 1x1x1>, %arg3: !migraphx.shaped<1x32x256x128xf16, 1048576x32768x128x1>) -> !migraphx.shaped<1x1x4096xf16, 4096x4096x1> {
6+
%0 = migraphx.literal(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : tensor<256xsi32>) : <256xsi32, 1>
7+
%1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
8+
%2 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 1>
9+
%3 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 1, 1, 256]} : <256xsi32, 1> -> <1x1x1x256xsi32, 0x0x0x1>
10+
%4 = migraphx.slice %arg0 {axes = [1], ends = [32], starts = [0]} : <1x96x1x128xf16, 12288x128x128x1> -> <1x32x1x128xf16, 12288x128x128x1>
11+
%5 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x128x256xf16, 1048576x32768x1x128>
12+
%6 = migraphx.dot %4, %5 : <1x32x1x128xf16, 12288x128x128x1>, <1x32x128x256xf16, 1048576x32768x1x128> -> <1x32x1x256xf16, 8192x256x256x1>
13+
%7 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
14+
%8 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1xf16, 1> -> <1x32x1x256xf16, 0x0x0x0>
15+
%9 = migraphx.mul %6, %8 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 0x0x0x0> -> <1x32x1x256xf16, 8192x256x256x1>
16+
%10 = migraphx.broadcast %arg2 {axis = 0 : i64, out_lens = [1, 1, 1, 256]} : <1x1x1xsi32, 1x1x1> -> <1x1x1x256xsi32, 1x1x1x0>
17+
%11 = migraphx.greater %3, %10 : <1x1x1x256xsi32, 0x0x0x1>, <1x1x1x256xsi32, 1x1x1x0> -> <1x1x1x256xsi32, 0x0x0x1>
18+
%12 = migraphx.convert %11 {target_type = 0 : i64} : <1x1x1x256xsi32, 0x0x0x1> to <1x1x1x256xsi8, 0x0x0x1>
19+
%13 = migraphx.multibroadcast %12 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x1x1x256xsi8, 0x0x0x1> -> <1x32x1x256xsi8, 0x0x0x1>
20+
%14 = migraphx.where %13, %7, %9 : <1x32x1x256xsi8, 0x0x0x1>, <1x32x1x256xf16, 0x0x0x0>, <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
21+
%15 = migraphx.reshape %14 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
22+
%16 = migraphx.reduce_max %15 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
23+
%17 = migraphx.reshape %16 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
24+
%18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
25+
%19 = migraphx.sub %14, %18 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
26+
%20 = migraphx.exp %19 : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
27+
%21 = migraphx.reshape %20 {dims = [1, 32, 1, 256]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x256xf16, 8192x256x256x1>
28+
%22 = migraphx.reduce_sum %21 {axes = [3]} : <1x32x1x256xf16, 8192x256x256x1> -> <1x32x1x1xf16, 32x1x1x1>
29+
%23 = migraphx.reshape %22 {dims = [1, 32, 1, 1]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x1xf16, 32x1x1x1>
30+
%24 = migraphx.multibroadcast %23 {out_dyn_dims = [], out_lens = [1, 32, 1, 256]} : <1x32x1x1xf16, 32x1x1x1> -> <1x32x1x256xf16, 32x1x1x0>
31+
%25 = migraphx.div %20, %24 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x1x256xf16, 32x1x1x0> -> <1x32x1x256xf16, 8192x256x256x1>
32+
%26 = migraphx.dot %25, %arg3 : <1x32x1x256xf16, 8192x256x256x1>, <1x32x256x128xf16, 1048576x32768x128x1> -> <1x32x1x128xf16, 4096x128x128x1>
33+
%27 = migraphx.transpose %26 {permutation = [0, 2, 1, 3]} : <1x32x1x128xf16, 4096x128x128x1> -> <1x1x32x128xf16, 4096x128x128x1>
34+
%28 = migraphx.reshape %27 {dims = [1, 1, 4096]} : <1x1x32x128xf16, 4096x128x128x1> -> <1x1x4096xf16, 4096x4096x1>
35+
return %28 : !migraphx.shaped<1x1x4096xf16, 4096x4096x1>
36+
}
37+
}
38+

0 commit comments

Comments
 (0)