|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import unittest
|
| 11 | +from dataclasses import dataclass |
| 12 | +from typing import List |
11 | 13 |
|
12 | 14 | import torch
|
| 15 | +from parameterized import parameterized |
13 | 16 | from torch.testing import FileCheck # @manual
|
| 17 | +from torchrec.distributed.test_utils.test_input import ModelInput |
14 | 18 | from torchrec.fx import symbolic_trace, Tracer
|
15 |
| -from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN |
| 19 | +from torchrec.models.deepfm import ( |
| 20 | + DenseArch, |
| 21 | + FMInteractionArch, |
| 22 | + SimpleDeepFMNN, |
| 23 | + SimpleDeepFMNNWrapper, |
| 24 | +) |
16 | 25 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
17 | 26 | from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
18 | 27 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
@@ -210,5 +219,108 @@ def test_fx_script(self) -> None:
|
210 | 219 | self.assertEqual(logits.size(), (B, 1))
|
211 | 220 |
|
212 | 221 |
|
| 222 | +class SimpleDeepFMNNWrapperTest(unittest.TestCase): |
| 223 | + @dataclass |
| 224 | + class WrapperTestParams: |
| 225 | + # input parameters |
| 226 | + embedding_configs: List[EmbeddingBagConfig] |
| 227 | + sparse_feature_keys: List[str] |
| 228 | + sparse_feature_values: List[int] |
| 229 | + sparse_feature_offsets: List[int] |
| 230 | + # expected output parameters |
| 231 | + expected_output_size: tuple[int, ...] |
| 232 | + |
| 233 | + @parameterized.expand( |
| 234 | + [ |
| 235 | + ( |
| 236 | + "basic_with_multiple_features", |
| 237 | + WrapperTestParams( |
| 238 | + embedding_configs=[ |
| 239 | + EmbeddingBagConfig( |
| 240 | + name="t1", |
| 241 | + embedding_dim=8, |
| 242 | + num_embeddings=100, |
| 243 | + feature_names=["f1", "f3"], |
| 244 | + ), |
| 245 | + EmbeddingBagConfig( |
| 246 | + name="t2", |
| 247 | + embedding_dim=8, |
| 248 | + num_embeddings=100, |
| 249 | + feature_names=["f2"], |
| 250 | + ), |
| 251 | + ], |
| 252 | + sparse_feature_keys=["f1", "f3", "f2"], |
| 253 | + sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3], |
| 254 | + sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11], |
| 255 | + expected_output_size=(2, 1), |
| 256 | + ), |
| 257 | + ), |
| 258 | + ( |
| 259 | + "empty_sparse_features", |
| 260 | + WrapperTestParams( |
| 261 | + embedding_configs=[ |
| 262 | + EmbeddingBagConfig( |
| 263 | + name="t1", |
| 264 | + embedding_dim=8, |
| 265 | + num_embeddings=100, |
| 266 | + feature_names=["f1"], |
| 267 | + ), |
| 268 | + ], |
| 269 | + sparse_feature_keys=["f1"], |
| 270 | + sparse_feature_values=[], |
| 271 | + sparse_feature_offsets=[0, 0, 0], |
| 272 | + expected_output_size=(2, 1), |
| 273 | + ), |
| 274 | + ), |
| 275 | + ] |
| 276 | + ) |
| 277 | + def test_wrapper_functionality( |
| 278 | + self, _test_name: str, test_params: WrapperTestParams |
| 279 | + ) -> None: |
| 280 | + B = 2 |
| 281 | + num_dense_features = 100 |
| 282 | + |
| 283 | + ebc = EmbeddingBagCollection(tables=test_params.embedding_configs) |
| 284 | + |
| 285 | + deepfm_wrapper = SimpleDeepFMNNWrapper( |
| 286 | + num_dense_features=num_dense_features, |
| 287 | + embedding_bag_collection=ebc, |
| 288 | + hidden_layer_size=20, |
| 289 | + deep_fm_dimension=5, |
| 290 | + ) |
| 291 | + |
| 292 | + # Create ModelInput |
| 293 | + dense_features = torch.rand((B, num_dense_features)) |
| 294 | + sparse_features = KeyedJaggedTensor.from_offsets_sync( |
| 295 | + keys=test_params.sparse_feature_keys, |
| 296 | + values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long), |
| 297 | + offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long), |
| 298 | + ) |
| 299 | + |
| 300 | + model_input = ModelInput( |
| 301 | + float_features=dense_features, |
| 302 | + idlist_features=sparse_features, |
| 303 | + idscore_features=None, |
| 304 | + label=torch.rand((B,)), |
| 305 | + ) |
| 306 | + |
| 307 | + # Test eval mode - should return just logits |
| 308 | + deepfm_wrapper.eval() |
| 309 | + logits = deepfm_wrapper(model_input) |
| 310 | + self.assertIsInstance(logits, torch.Tensor) |
| 311 | + self.assertEqual(logits.size(), test_params.expected_output_size) |
| 312 | + |
| 313 | + # Test training mode - should return (loss, logits) tuple |
| 314 | + deepfm_wrapper.train() |
| 315 | + result = deepfm_wrapper(model_input) |
| 316 | + self.assertIsInstance(result, tuple) |
| 317 | + self.assertEqual(len(result), 2) |
| 318 | + loss, pred = result |
| 319 | + self.assertIsInstance(loss, torch.Tensor) |
| 320 | + self.assertIsInstance(pred, torch.Tensor) |
| 321 | + self.assertEqual(loss.size(), ()) # scalar loss |
| 322 | + self.assertEqual(pred.size(), test_params.expected_output_size) |
| 323 | + |
| 324 | + |
213 | 325 | if __name__ == "__main__":
|
214 | 326 | unittest.main()
|
0 commit comments