@@ -25,6 +25,7 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel(
25
25
int32_t B,
26
26
const indices_t * __restrict__ indices,
27
27
const weights_t * __restrict__ weights,
28
+ const int32_t weights_columns,
28
29
const int32_t * __restrict__ permute,
29
30
const offsets_t * __restrict__ input_offsets,
30
31
const offsets_t * __restrict__ output_offsets,
@@ -46,7 +47,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel(
46
47
for (auto i = threadIdx .x ; i < segment_length; i += blockDim .x ) {
47
48
permuted_indices[output_start + i] = indices[input_start + i];
48
49
if (has_weight) {
49
- permuted_weights[output_start + i] = weights[input_start + i];
50
+ for (auto w_col = 0 ; w_col < weights_columns; ++w_col) {
51
+ permuted_weights[(output_start + i) * weights_columns + w_col] =
52
+ weights[(input_start + i) * weights_columns + w_col];
53
+ }
50
54
}
51
55
}
52
56
}
@@ -143,8 +147,16 @@ permute_2D_sparse_data_cuda(
143
147
if (weights.has_value ()) {
144
148
const Tensor weights_value = weights.value ();
145
149
const auto weights_value_contig = weights_value.contiguous ();
146
- permuted_weights =
147
- at::empty (permuted_indices_size, weights_value.options ());
150
+ int32_t weights_columns = 1 ;
151
+ if (weights_value.dense_dim () > 1 ) {
152
+ weights_columns = weights_value.size (1 );
153
+ permuted_weights = at::empty (
154
+ {permuted_indices_size, weights_columns},
155
+ weights_value.options ());
156
+ } else {
157
+ permuted_weights =
158
+ at::empty (permuted_indices_size, weights_value.options ());
159
+ }
148
160
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE (
149
161
weights_value.scalar_type (),
150
162
" permute_2D_data_kernel_3" ,
@@ -164,6 +176,7 @@ permute_2D_sparse_data_cuda(
164
176
B,
165
177
indices_contig.data_ptr<indices_t >(),
166
178
weights_value_contig.data_ptr<weights_t>(),
179
+ weights_columns,
167
180
permute_contig.data_ptr<int32_t>(),
168
181
input_offsets.data_ptr<offsets_t>(),
169
182
output_offsets.data_ptr<offsets_t>(),
@@ -186,6 +199,7 @@ permute_2D_sparse_data_cuda(
186
199
B,
187
200
indices_contig.data_ptr<indices_t >(),
188
201
nullptr,
202
+ 1,
189
203
permute_contig.data_ptr<int32_t>(),
190
204
input_offsets.data_ptr<offsets_t>(),
191
205
output_offsets.data_ptr<offsets_t>(),
0 commit comments