@@ -2055,6 +2055,80 @@ class joint_matrix {
2055
2055
matrix_accessor x;
2056
2056
const size_t num_elements;
2057
2057
};
2058
+
2059
+ // / Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32
2060
+ // / matrix
2061
+ // / \tparam [in] MulType The type of the multiplication result
2062
+ // / \tparam [in] ABType The type of the input matrices
2063
+ // / \tparam [in] CDType The type of the output matrix
2064
+ // / \param [in] aTrans Indicates whether the 1st matrix to be transposed
2065
+ // / \param [in] bTrans Indicates whether the 2nd matrix to be transposed
2066
+ // / \param [in] d0 The 1st element to be written to the output D matrix
2067
+ // / \param [in] d1 The 2nd element to be written to the output D matrix
2068
+ // / \param [in] d2 The 3rd element to be written to the output D matrix
2069
+ // / \param [in] d3 The 4th element to be written to the output D matrix
2070
+ // / \param [in] a0 The 1st element from A matrix to be multiplied with B matrix
2071
+ // / \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix
2072
+ // / \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix
2073
+ // / \param [in] a3 The 4th element from A matrix to be multiplied with B matrix
2074
+ // / \param [in] b0 The 1st element from B matrix to be multiplied with A matrix
2075
+ // / \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix
2076
+ // / \param [in] c0 The 1st element from C matrix to be added with d0
2077
+ // / \param [in] c1 The 2nd element from C matrix to be added with d1
2078
+ // / \param [in] c2 The 3rd element from C matrix to be added with d2
2079
+ // / \param [in] c3 The 4th element from C matrix to be added with d3
2080
+ // / \param [in] item The sycl::nd_item index space class
2081
+ template <typename MulType, typename ABType, typename CDType, typename ItemT>
2082
+ __attribute__ ((optnone)) void
2083
+ mma (CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2084
+ ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, CDType c2,
2085
+ CDType c3, const ItemT &item) {
2086
+ int lane = item.get_sub_group ().get_local_linear_id ();
2087
+
2088
+ short ROW_LOAD_OFFSET = 4 * (lane / 4 );
2089
+ short COL_LOAD_OFFSET = 8 * (lane % 4 );
2090
+
2091
+ ABType recv_a[4 * 4 ], recv_b[4 * 4 ];
2092
+ for (int i = 0 ; i < 4 ; i++) {
2093
+ recv_a[0 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a0,
2094
+ ROW_LOAD_OFFSET + i);
2095
+ recv_a[1 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a2,
2096
+ ROW_LOAD_OFFSET + i);
2097
+ recv_a[2 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a1,
2098
+ ROW_LOAD_OFFSET + i);
2099
+ recv_a[3 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), a3,
2100
+ ROW_LOAD_OFFSET + i);
2101
+
2102
+ recv_b[0 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b0,
2103
+ COL_LOAD_OFFSET + i);
2104
+ recv_b[1 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b1,
2105
+ COL_LOAD_OFFSET + i);
2106
+ recv_b[2 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b0,
2107
+ COL_LOAD_OFFSET + 4 + i);
2108
+ recv_b[3 * 4 + i] = dpct::select_from_sub_group (item.get_sub_group (), b1,
2109
+ COL_LOAD_OFFSET + 4 + i);
2110
+ }
2111
+
2112
+ auto *a = reinterpret_cast <MulType *>(recv_a);
2113
+ auto *b = reinterpret_cast <MulType *>(recv_b);
2114
+ for (int i = 0 ; i < 16 ; i++) {
2115
+ auto a0 = static_cast <CDType>(a[i]);
2116
+ auto a1 = static_cast <CDType>(a[i + 16 ]);
2117
+ auto b0 = static_cast <CDType>(b[i]);
2118
+ auto b1 = static_cast <CDType>(b[i + 16 ]);
2119
+
2120
+ c0 += a0 * b0;
2121
+ c1 += a0 * b1;
2122
+ c2 += a1 * b0;
2123
+ c3 += a1 * b1;
2124
+ }
2125
+
2126
+ *d0 = c0;
2127
+ *d1 = c1;
2128
+ *d2 = c2;
2129
+ *d3 = c3;
2130
+ }
2131
+
2058
2132
} // namespace matrix
2059
2133
} // namespace experimental
2060
2134
0 commit comments