Skip to content

Commit 5e66106

Browse files
chouxifacebook-github-bot
authored andcommitted
Add tracking and streaming logic to SplitTableBatchedEmbeddingBagsCodegen (#4741)
Summary: Pull Request resolved: #4741 X-link: facebookresearch/FBGEMM#1762 It follows similar logic to SSD TBE https://fburl.com/code/fxdcxma3 It tries to 1. store the updated ids/count 2. next iteration streams out the updated embeddings and ids, before the embedding cache are populated again. the prefetch pipeline logic also the same to SSDTBE. Differential Revision: D78438757
1 parent 7f89b52 commit 5e66106

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060

6161
from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
6262

63+
from torch.autograd.profiler import record_function
64+
6365
try:
6466
load_torch_module(
6567
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu",
@@ -626,6 +628,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
626628
lxu_cache_locations_list: List[Tensor]
627629
lxu_cache_locations_empty: Tensor
628630
timesteps_prefetched: List[int]
631+
prefetched_info: List[Tuple[Tensor, Tensor]]
629632
record_cache_metrics: RecordCacheMetrics
630633
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
631634
uvm_cache_stats: torch.Tensor
@@ -690,6 +693,8 @@ def __init__( # noqa C901
690693
embedding_table_index_type: torch.dtype = torch.int64,
691694
embedding_table_offset_type: torch.dtype = torch.int64,
692695
embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None,
696+
enable_raw_embedding_streaming: bool = False,
697+
res_params: Optional[RESParams] = None,
693698
) -> None:
694699
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
695700
self.uuid = str(uuid.uuid4())
@@ -700,6 +705,7 @@ def __init__( # noqa C901
700705
)
701706

702707
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
708+
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
703709
self.pooling_mode = pooling_mode
704710
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
705711

@@ -1460,6 +1466,30 @@ def __init__( # noqa C901
14601466
)
14611467
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
14621468

1469+
self.prefetched_info: List[Tuple[Tensor, Tensor]] = torch.jit.annotate(
1470+
List[Tuple[Tensor, Tensor]], []
1471+
)
1472+
if self.enable_raw_embedding_streaming:
1473+
self.res_params: RESParams = res_params or RESParams()
1474+
self.res_params.table_sizes = [0] + list(accumulate(rows))
1475+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
1476+
self.res_params.res_server_port = (
1477+
int(res_port_from_env) if res_port_from_env else 0
1478+
)
1479+
# pyre-fixme[4]: Attribute must be annotated.
1480+
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
1481+
self.uuid,
1482+
self.enable_raw_embedding_streaming,
1483+
self.res_params.res_store_shards,
1484+
self.res_params.res_server_port,
1485+
self.res_params.table_names,
1486+
self.res_params.table_offsets,
1487+
self.res_params.table_sizes,
1488+
)
1489+
logging.info(
1490+
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
1491+
)
1492+
14631493
@torch.jit.ignore
14641494
def log(self, msg: str) -> None:
14651495
"""
@@ -2521,7 +2551,13 @@ def _prefetch(
25212551
self.local_uvm_cache_stats.zero_()
25222552
self._report_io_size_count("prefetch_input", indices)
25232553

2554+
# streaming before updating the cache
2555+
self.raw_embedding_stream()
2556+
25242557
final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
2558+
linear_cache_indices_merged = torch.zeros(
2559+
0, dtype=indices.dtype, device=indices.device
2560+
)
25252561
for (
25262562
partial_indices,
25272563
partial_lxu_cache_locations,
@@ -2537,6 +2573,9 @@ def _prefetch(
25372573
vbe_metadata.max_B if vbe_metadata is not None else -1,
25382574
base_offset,
25392575
)
2576+
linear_cache_indices_merged = torch.cat(
2577+
[linear_cache_indices_merged, linear_cache_indices]
2578+
)
25402579

25412580
if (
25422581
self.record_cache_metrics.record_cache_miss_counter
@@ -2617,6 +2656,23 @@ def _prefetch(
26172656
if self.should_log():
26182657
self.print_uvm_cache_stats(use_local_cache=False)
26192658

2659+
if self.enable_raw_embedding_streaming:
2660+
with record_function(
2661+
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
2662+
):
2663+
(
2664+
linear_unique_indices,
2665+
linear_unique_indices_length,
2666+
_,
2667+
) = torch.ops.fbgemm.get_unique_indices(
2668+
linear_cache_indices_merged,
2669+
self.total_cache_hash_size,
2670+
compute_count=False,
2671+
)
2672+
self.prefetched_info.append(
2673+
(linear_unique_indices, linear_unique_indices_length)
2674+
)
2675+
26202676
def should_log(self) -> bool:
26212677
"""Determines if we should log for this step, using exponentially decreasing frequency.
26222678
@@ -3829,6 +3885,55 @@ def _debug_print_input_stats_factory_null(
38293885
return _debug_print_input_stats_factory_impl
38303886
return _debug_print_input_stats_factory_null
38313887

3888+
@torch.jit.ignore
3889+
def raw_embedding_stream(self) -> None:
3890+
if not self.enable_raw_embedding_streaming:
3891+
return None
3892+
# when pipelining is enabled
3893+
# prefetch in iter i happens before the backward sparse in iter i - 1
3894+
# so embeddings for iter i - 1's changed ids are not updated.
3895+
# so we can only fetch the indices from the iter i - 2
3896+
# when pipelining is disabled
3897+
# prefetch in iter i happens before forward iter i
3898+
# so we can get the iter i - 1's changed ids safely.
3899+
target_prev_iter = 1
3900+
if self.prefetch_pipeline:
3901+
target_prev_iter = 2
3902+
if not len(self.prefetched_info) > (target_prev_iter - 1):
3903+
return None
3904+
with record_function(
3905+
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
3906+
):
3907+
(updated_indices, updated_count) = self.prefetched_info.pop(0)
3908+
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
3909+
updated_indices,
3910+
self.lxu_cache_state,
3911+
self.total_cache_hash_size,
3912+
gather_cache_stats=False, # not collecting cache stats
3913+
num_uniq_cache_indices=updated_count,
3914+
)
3915+
updated_weights = torch.empty(
3916+
[updated_indices.size()[0], self.max_D_cache],
3917+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
3918+
dtype=self.lxu_cache_weights.dtype,
3919+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
3920+
device=self.lxu_cache_weights.device,
3921+
)
3922+
torch.ops.fbgemm.masked_index_select(
3923+
updated_weights,
3924+
updated_locations,
3925+
self.lxu_cache_weights,
3926+
updated_count,
3927+
)
3928+
# stream weights
3929+
self._raw_embedding_streamer.stream(
3930+
updated_indices.to(device=torch.device("cpu")),
3931+
updated_weights.to(device=torch.device("cpu")),
3932+
updated_count.to(device=torch.device("cpu")),
3933+
False, # require_tensor_copy
3934+
False, # blocking_tensor_copy
3935+
)
3936+
38323937

38333938
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
38343939
"""

fbgemm_gpu/test/tbe/training/forward_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import math
1313
import random
1414
import unittest
15+
from unittest.mock import MagicMock, patch
1516

1617
import hypothesis.strategies as st
1718
import numpy as np
@@ -24,6 +25,7 @@
2425
)
2526
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
2627
ComputeDevice,
28+
RESParams,
2729
SplitTableBatchedEmbeddingBagsCodegen,
2830
)
2931
from fbgemm_gpu.tbe.utils import (
@@ -129,6 +131,8 @@ def execute_forward_( # noqa C901
129131
use_cpu: bool,
130132
output_dtype: SparseType,
131133
use_experimental_tbe: bool,
134+
enable_raw_embedding_streaming: bool = False,
135+
prefetch_pipeline: bool = False,
132136
) -> None:
133137
# NOTE: cache is not applicable to CPU version.
134138
assume(not use_cpu or not use_cache)
@@ -158,6 +162,10 @@ def execute_forward_( # noqa C901
158162
and pooling_mode != PoolingMode.NONE
159163
)
160164
)
165+
# NOTE: Raw embedding streaming requires UVM cache
166+
assume(not enable_raw_embedding_streaming or use_cache)
167+
# NOTE: Raw embedding streaming not supported on CPU
168+
assume(not enable_raw_embedding_streaming or not use_cpu)
161169

162170
emb_op = SplitTableBatchedEmbeddingBagsCodegen
163171
if pooling_mode == PoolingMode.SUM:
@@ -285,6 +293,16 @@ def execute_forward_( # noqa C901
285293
else:
286294
f = torch.cat(fs, dim=0).view(-1, D)
287295

296+
# Create RES parameters if raw embedding streaming is enabled
297+
res_params = None
298+
if enable_raw_embedding_streaming:
299+
res_params = RESParams(
300+
res_store_shards=1,
301+
table_names=[f"table_{i}" for i in range(T)],
302+
table_offsets=[sum(Es[:i]) for i in range(T + 1)],
303+
table_sizes=Es,
304+
)
305+
288306
# Create a TBE op
289307
cc = emb_op(
290308
embedding_specs=[
@@ -305,6 +323,9 @@ def execute_forward_( # noqa C901
305323
pooling_mode=pooling_mode,
306324
output_dtype=output_dtype,
307325
use_experimental_tbe=use_experimental_tbe,
326+
prefetch_pipeline=prefetch_pipeline,
327+
enable_raw_embedding_streaming=enable_raw_embedding_streaming,
328+
res_params=res_params,
308329
)
309330
# Test torch JIT script compatibility
310331
if not use_cpu:
@@ -1158,6 +1179,94 @@ def test_forward_fused_pooled_emb_quant(
11581179
cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output
11591180
)
11601181

1182+
def _check_raw_embedding_stream_call_counts(
1183+
self,
1184+
mock_raw_embedding_stream: unittest.mock.Mock,
1185+
num_iterations: int,
1186+
prefetch_pipeline: bool,
1187+
L: int,
1188+
) -> None:
1189+
# For TBE (not SSD), raw_embedding_stream is called once per prefetch
1190+
# when there's data to stream
1191+
expected_calls = num_iterations if L > 0 else 0
1192+
if prefetch_pipeline:
1193+
# With prefetch pipeline, there might be fewer calls initially
1194+
expected_calls = max(0, expected_calls - 1)
1195+
1196+
self.assertGreaterEqual(mock_raw_embedding_stream.call_count, 0)
1197+
# Allow some flexibility in call count due to caching behavior
1198+
self.assertLessEqual(mock_raw_embedding_stream.call_count, expected_calls + 2)
1199+
1200+
@unittest.skipIf(*gpu_unavailable)
1201+
@given(
1202+
T=st.integers(min_value=1, max_value=5),
1203+
D=st.integers(min_value=2, max_value=64),
1204+
B=st.integers(min_value=1, max_value=32),
1205+
log_E=st.integers(min_value=3, max_value=4),
1206+
L=st.integers(min_value=1, max_value=10),
1207+
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
1208+
cache_algorithm=st.sampled_from(CacheAlgorithm),
1209+
pooling_mode=st.sampled_from([PoolingMode.SUM, PoolingMode.MEAN]),
1210+
weighted=st.booleans(),
1211+
mixed=st.booleans(),
1212+
prefetch_pipeline=st.booleans(),
1213+
)
1214+
@settings(
1215+
verbosity=VERBOSITY,
1216+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
1217+
deadline=None,
1218+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
1219+
)
1220+
def test_forward_raw_embedding_streaming(
1221+
self,
1222+
T: int,
1223+
D: int,
1224+
B: int,
1225+
log_E: int,
1226+
L: int,
1227+
weights_precision: SparseType,
1228+
cache_algorithm: CacheAlgorithm,
1229+
pooling_mode: PoolingMode,
1230+
weighted: bool,
1231+
mixed: bool,
1232+
prefetch_pipeline: bool,
1233+
) -> None:
1234+
"""Test raw embedding streaming functionality integrated with forward pass."""
1235+
num_iterations = 5
1236+
1237+
with patch(
1238+
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
1239+
) as mock_streamer_class:
1240+
# Mock the RawEmbeddingStreamer class
1241+
mock_streamer_instance = MagicMock()
1242+
mock_streamer_class.return_value = mock_streamer_instance
1243+
1244+
# Run multiple iterations to test streaming behavior
1245+
for _ in range(num_iterations):
1246+
self.execute_forward_(
1247+
T=T,
1248+
D=D,
1249+
B=B,
1250+
log_E=log_E,
1251+
L=L,
1252+
weights_precision=weights_precision,
1253+
weighted=weighted,
1254+
mixed=mixed,
1255+
mixed_B=False, # Keep simple for streaming tests
1256+
use_cache=True, # Required for streaming
1257+
cache_algorithm=cache_algorithm,
1258+
pooling_mode=pooling_mode,
1259+
use_cpu=False, # Streaming not supported on CPU
1260+
output_dtype=SparseType.FP32,
1261+
use_experimental_tbe=False,
1262+
enable_raw_embedding_streaming=True,
1263+
prefetch_pipeline=prefetch_pipeline,
1264+
)
1265+
1266+
self._check_raw_embedding_stream_call_counts(
1267+
mock_streamer_instance, num_iterations, prefetch_pipeline, L
1268+
)
1269+
11611270

11621271
if __name__ == "__main__":
11631272
unittest.main()

0 commit comments

Comments
 (0)