Skip to content

Commit cb269f0

Browse files
kausvfacebook-github-bot
authored andcommitted
support 2D weights in permute kernel (#4723)
Summary: Rollback Plan: Differential Revision: D80466458
1 parent 2d4e455 commit cb269f0

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel(
2525
int32_t B,
2626
const indices_t* __restrict__ indices,
2727
const weights_t* __restrict__ weights,
28+
const int32_t weights_columns,
2829
const int32_t* __restrict__ permute,
2930
const offsets_t* __restrict__ input_offsets,
3031
const offsets_t* __restrict__ output_offsets,
@@ -46,7 +47,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel(
4647
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
4748
permuted_indices[output_start + i] = indices[input_start + i];
4849
if (has_weight) {
49-
permuted_weights[output_start + i] = weights[input_start + i];
50+
for (auto w_col = 0; w_col < weights_columns; ++w_col) {
51+
permuted_weights[(output_start + i) * weights_columns + w_col] =
52+
weights[(input_start + i) * weights_columns + w_col];
53+
}
5054
}
5155
}
5256
}
@@ -143,8 +147,16 @@ permute_2D_sparse_data_cuda(
143147
if (weights.has_value()) {
144148
const Tensor weights_value = weights.value();
145149
const auto weights_value_contig = weights_value.contiguous();
146-
permuted_weights =
147-
at::empty(permuted_indices_size, weights_value.options());
150+
int32_t weights_columns = 1;
151+
if (weights_value.dense_dim() > 1) {
152+
weights_columns = weights_value.size(1);
153+
permuted_weights = at::empty(
154+
{permuted_indices_size, weights_columns},
155+
weights_value.options());
156+
} else {
157+
permuted_weights =
158+
at::empty(permuted_indices_size, weights_value.options());
159+
}
148160
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
149161
weights_value.scalar_type(),
150162
"permute_2D_data_kernel_3",
@@ -164,6 +176,7 @@ permute_2D_sparse_data_cuda(
164176
B,
165177
indices_contig.data_ptr<indices_t>(),
166178
weights_value_contig.data_ptr<weights_t>(),
179+
weights_columns,
167180
permute_contig.data_ptr<int32_t>(),
168181
input_offsets.data_ptr<offsets_t>(),
169182
output_offsets.data_ptr<offsets_t>(),
@@ -186,6 +199,7 @@ permute_2D_sparse_data_cuda(
186199
B,
187200
indices_contig.data_ptr<indices_t>(),
188201
nullptr,
202+
1,
189203
permute_contig.data_ptr<int32_t>(),
190204
input_offsets.data_ptr<offsets_t>(),
191205
output_offsets.data_ptr<offsets_t>(),

fbgemm_gpu/test/sparse/permute_sparse_features_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,39 @@ def test_permute_sparse_features_with_repeats(
192192
assert permuted_weights_cpu is None
193193

194194

195+
class Permute2DSparseFeaturesTest(unittest.TestCase):
196+
def test_permute_2D_sparse_data(self) -> None:
197+
lengths = torch.tensor(
198+
[[0, 0, 1], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]],
199+
dtype=torch.int32,
200+
device="cuda",
201+
)
202+
indices = torch.tensor([500, 1000, 1999], dtype=torch.int32, device="cuda")
203+
permute = torch.tensor([0, 3, 1, 4, 2, 5], dtype=torch.int32, device="cuda")
204+
weights = torch.rand((3, 64), device="cuda")
205+
print(f"expected: {weights.dtype} {weights}")
206+
(
207+
lengths_actual,
208+
values_actual,
209+
weights_actual,
210+
) = torch.ops.fbgemm.permute_2D_sparse_data(
211+
permute, lengths, indices, weights, indices.numel()
212+
)
213+
print(f"actual: {weights_actual.dtype} {weights_actual}")
214+
self.assertTrue(
215+
torch.equal(
216+
lengths_actual,
217+
torch.tensor(
218+
[[0, 0, 1], [0, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]],
219+
dtype=torch.int32,
220+
device="cuda",
221+
),
222+
)
223+
)
224+
self.assertTrue(torch.equal(values_actual, indices))
225+
self.assertTrue(torch.equal(weights_actual, weights))
226+
227+
195228
extend_test_class(PermuteSparseFeaturesTest)
196229

197230
if __name__ == "__main__":

0 commit comments

Comments
 (0)