Skip to content

Commit cf49291

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Model wrapper for DeepFM (#3115)
Summary: Pull Request resolved: #3115 * Added model wrapper for DeepFM. The wrapper will take ModelInput as an only parameter in the forward method. The forward method will return just the prediction if it's in inference mode and losses with prediction if it's in training mode. (Because training pipeline expects loss and prediction. See https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py#L670) * Added the parameterized unit tests to cover the model's wrapper Reviewed By: aliafzal Differential Revision: D76916471 fbshipit-source-id: 36505e2ec0b367f747e7769eb4a10e2e84f70603
1 parent a77e054 commit cf49291

File tree

2 files changed

+146
-2
lines changed

2 files changed

+146
-2
lines changed

torchrec/models/deepfm.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
# pyre-strict
99

10-
from typing import List
10+
from typing import List, Tuple, Union
1111

1212
import torch
1313
from torch import nn
14+
from torchrec.distributed.test_utils.test_input import ModelInput
1415
from torchrec.modules.deepfm import DeepFM, FactorizationMachine
1516
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1617
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
@@ -350,3 +351,34 @@ def forward(
350351
)
351352
logits = self.over_arch(concatenated_dense)
352353
return logits
354+
355+
356+
class SimpleDeepFMNNWrapper(SimpleDeepFMNN):
357+
# pyre-ignore[14, 15]
358+
def forward(
359+
self, model_input: ModelInput
360+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
361+
"""
362+
Forward pass for the SimpleDeepFMNNWrapper.
363+
364+
Args:
365+
model_input (ModelInput): Contains dense and sparse features.
366+
367+
Returns:
368+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
369+
If training, returns (loss, prediction). Otherwise, returns prediction.
370+
"""
371+
pred = super().forward(
372+
dense_features=model_input.float_features,
373+
sparse_features=model_input.idlist_features, # pyre-ignore[6]
374+
)
375+
376+
if self.training:
377+
# Calculate loss and return both loss and prediction
378+
loss = torch.nn.functional.binary_cross_entropy_with_logits(
379+
pred.squeeze(), model_input.label
380+
)
381+
return (loss, pred)
382+
else:
383+
# Return just the prediction
384+
return pred

torchrec/models/tests/test_deepfm.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@
88
# pyre-strict
99

1010
import unittest
11+
from dataclasses import dataclass
12+
from typing import List
1113

1214
import torch
15+
from parameterized import parameterized
1316
from torch.testing import FileCheck # @manual
17+
from torchrec.distributed.test_utils.test_input import ModelInput
1418
from torchrec.fx import symbolic_trace, Tracer
15-
from torchrec.models.deepfm import DenseArch, FMInteractionArch, SimpleDeepFMNN
19+
from torchrec.models.deepfm import (
20+
DenseArch,
21+
FMInteractionArch,
22+
SimpleDeepFMNN,
23+
SimpleDeepFMNNWrapper,
24+
)
1625
from torchrec.modules.embedding_configs import EmbeddingBagConfig
1726
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1827
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
@@ -210,5 +219,108 @@ def test_fx_script(self) -> None:
210219
self.assertEqual(logits.size(), (B, 1))
211220

212221

222+
class SimpleDeepFMNNWrapperTest(unittest.TestCase):
223+
@dataclass
224+
class WrapperTestParams:
225+
# input parameters
226+
embedding_configs: List[EmbeddingBagConfig]
227+
sparse_feature_keys: List[str]
228+
sparse_feature_values: List[int]
229+
sparse_feature_offsets: List[int]
230+
# expected output parameters
231+
expected_output_size: tuple[int, ...]
232+
233+
@parameterized.expand(
234+
[
235+
(
236+
"basic_with_multiple_features",
237+
WrapperTestParams(
238+
embedding_configs=[
239+
EmbeddingBagConfig(
240+
name="t1",
241+
embedding_dim=8,
242+
num_embeddings=100,
243+
feature_names=["f1", "f3"],
244+
),
245+
EmbeddingBagConfig(
246+
name="t2",
247+
embedding_dim=8,
248+
num_embeddings=100,
249+
feature_names=["f2"],
250+
),
251+
],
252+
sparse_feature_keys=["f1", "f3", "f2"],
253+
sparse_feature_values=[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3],
254+
sparse_feature_offsets=[0, 2, 4, 6, 8, 10, 11],
255+
expected_output_size=(2, 1),
256+
),
257+
),
258+
(
259+
"empty_sparse_features",
260+
WrapperTestParams(
261+
embedding_configs=[
262+
EmbeddingBagConfig(
263+
name="t1",
264+
embedding_dim=8,
265+
num_embeddings=100,
266+
feature_names=["f1"],
267+
),
268+
],
269+
sparse_feature_keys=["f1"],
270+
sparse_feature_values=[],
271+
sparse_feature_offsets=[0, 0, 0],
272+
expected_output_size=(2, 1),
273+
),
274+
),
275+
]
276+
)
277+
def test_wrapper_functionality(
278+
self, _test_name: str, test_params: WrapperTestParams
279+
) -> None:
280+
B = 2
281+
num_dense_features = 100
282+
283+
ebc = EmbeddingBagCollection(tables=test_params.embedding_configs)
284+
285+
deepfm_wrapper = SimpleDeepFMNNWrapper(
286+
num_dense_features=num_dense_features,
287+
embedding_bag_collection=ebc,
288+
hidden_layer_size=20,
289+
deep_fm_dimension=5,
290+
)
291+
292+
# Create ModelInput
293+
dense_features = torch.rand((B, num_dense_features))
294+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
295+
keys=test_params.sparse_feature_keys,
296+
values=torch.tensor(test_params.sparse_feature_values, dtype=torch.long),
297+
offsets=torch.tensor(test_params.sparse_feature_offsets, dtype=torch.long),
298+
)
299+
300+
model_input = ModelInput(
301+
float_features=dense_features,
302+
idlist_features=sparse_features,
303+
idscore_features=None,
304+
label=torch.rand((B,)),
305+
)
306+
307+
# Test eval mode - should return just logits
308+
deepfm_wrapper.eval()
309+
logits = deepfm_wrapper(model_input)
310+
self.assertIsInstance(logits, torch.Tensor)
311+
self.assertEqual(logits.size(), test_params.expected_output_size)
312+
313+
# Test training mode - should return (loss, logits) tuple
314+
deepfm_wrapper.train()
315+
result = deepfm_wrapper(model_input)
316+
self.assertIsInstance(result, tuple)
317+
self.assertEqual(len(result), 2)
318+
loss, pred = result
319+
self.assertIsInstance(loss, torch.Tensor)
320+
self.assertIsInstance(pred, torch.Tensor)
321+
self.assertEqual(loss.size(), ()) # scalar loss
322+
self.assertEqual(pred.size(), test_params.expected_output_size)
323+
324+
213325
if __name__ == "__main__":
214326
unittest.main()

0 commit comments

Comments
 (0)