Skip to content

Commit 4b83779

Browse files
chouxifacebook-github-bot
authored andcommitted
Add tracking and streaming logic to SplitTableBatchedEmbeddingBagsCodegen (#4741)
Summary: Pull Request resolved: #4741 X-link: facebookresearch/FBGEMM#1762 It follows similar logic to SSD TBE https://fburl.com/code/fxdcxma3 It tries to 1. store the updated ids/count 2. next iteration streams out the updated embeddings and ids, before the embedding cache are populated again. the prefetch pipeline logic also the same to SSDTBE. Differential Revision: D78438757
1 parent 21cbfbb commit 4b83779

File tree

2 files changed

+215
-0
lines changed

2 files changed

+215
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import torch # usort:skip
2424
from torch import nn, Tensor # usort:skip
25+
from torch.autograd.profiler import record_function # usort:skip
2526

2627
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
2728
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
@@ -626,6 +627,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
626627
lxu_cache_locations_list: List[Tensor]
627628
lxu_cache_locations_empty: Tensor
628629
timesteps_prefetched: List[int]
630+
prefetched_info: List[Tuple[Tensor, Tensor]]
629631
record_cache_metrics: RecordCacheMetrics
630632
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
631633
uvm_cache_stats: torch.Tensor
@@ -690,6 +692,8 @@ def __init__( # noqa C901
690692
embedding_table_index_type: torch.dtype = torch.int64,
691693
embedding_table_offset_type: torch.dtype = torch.int64,
692694
embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None,
695+
enable_raw_embedding_streaming: bool = False,
696+
res_params: Optional[RESParams] = None,
693697
) -> None:
694698
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
695699
self.uuid = str(uuid.uuid4())
@@ -700,6 +704,7 @@ def __init__( # noqa C901
700704
)
701705

702706
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
707+
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
703708
self.pooling_mode = pooling_mode
704709
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE
705710

@@ -1460,6 +1465,30 @@ def __init__( # noqa C901
14601465
)
14611466
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type
14621467

1468+
self.prefetched_info: List[Tuple[Tensor, Tensor]] = torch.jit.annotate(
1469+
List[Tuple[Tensor, Tensor]], []
1470+
)
1471+
if self.enable_raw_embedding_streaming:
1472+
self.res_params: RESParams = res_params or RESParams()
1473+
self.res_params.table_sizes = [0] + list(accumulate(rows))
1474+
res_port_from_env = os.getenv("LOCAL_RES_PORT")
1475+
self.res_params.res_server_port = (
1476+
int(res_port_from_env) if res_port_from_env else 0
1477+
)
1478+
# pyre-fixme[4]: Attribute must be annotated.
1479+
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
1480+
self.uuid,
1481+
self.enable_raw_embedding_streaming,
1482+
self.res_params.res_store_shards,
1483+
self.res_params.res_server_port,
1484+
self.res_params.table_names,
1485+
self.res_params.table_offsets,
1486+
self.res_params.table_sizes,
1487+
)
1488+
logging.info(
1489+
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
1490+
)
1491+
14631492
@torch.jit.ignore
14641493
def log(self, msg: str) -> None:
14651494
"""
@@ -2521,7 +2550,13 @@ def _prefetch(
25212550
self.local_uvm_cache_stats.zero_()
25222551
self._report_io_size_count("prefetch_input", indices)
25232552

2553+
# streaming before updating the cache
2554+
self.raw_embedding_stream()
2555+
25242556
final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
2557+
linear_cache_indices_merged = torch.zeros(
2558+
0, dtype=indices.dtype, device=indices.device
2559+
)
25252560
for (
25262561
partial_indices,
25272562
partial_lxu_cache_locations,
@@ -2537,6 +2572,9 @@ def _prefetch(
25372572
vbe_metadata.max_B if vbe_metadata is not None else -1,
25382573
base_offset,
25392574
)
2575+
linear_cache_indices_merged = torch.cat(
2576+
[linear_cache_indices_merged, linear_cache_indices]
2577+
)
25402578

25412579
if (
25422580
self.record_cache_metrics.record_cache_miss_counter
@@ -2617,6 +2655,23 @@ def _prefetch(
26172655
if self.should_log():
26182656
self.print_uvm_cache_stats(use_local_cache=False)
26192657

2658+
if self.enable_raw_embedding_streaming:
2659+
with record_function(
2660+
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
2661+
):
2662+
(
2663+
linear_unique_indices,
2664+
linear_unique_indices_length,
2665+
_,
2666+
) = torch.ops.fbgemm.get_unique_indices(
2667+
linear_cache_indices_merged,
2668+
self.total_cache_hash_size,
2669+
compute_count=False,
2670+
)
2671+
self.prefetched_info.append(
2672+
(linear_unique_indices, linear_unique_indices_length)
2673+
)
2674+
26202675
def should_log(self) -> bool:
26212676
"""Determines if we should log for this step, using exponentially decreasing frequency.
26222677
@@ -3829,6 +3884,55 @@ def _debug_print_input_stats_factory_null(
38293884
return _debug_print_input_stats_factory_impl
38303885
return _debug_print_input_stats_factory_null
38313886

3887+
@torch.jit.ignore
3888+
def raw_embedding_stream(self) -> None:
3889+
if not self.enable_raw_embedding_streaming:
3890+
return None
3891+
# when pipelining is enabled
3892+
# prefetch in iter i happens before the backward sparse in iter i - 1
3893+
# so embeddings for iter i - 1's changed ids are not updated.
3894+
# so we can only fetch the indices from the iter i - 2
3895+
# when pipelining is disabled
3896+
# prefetch in iter i happens before forward iter i
3897+
# so we can get the iter i - 1's changed ids safely.
3898+
target_prev_iter = 1
3899+
if self.prefetch_pipeline:
3900+
target_prev_iter = 2
3901+
if not len(self.prefetched_info) > (target_prev_iter - 1):
3902+
return None
3903+
with record_function(
3904+
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
3905+
):
3906+
(updated_indices, updated_count) = self.prefetched_info.pop(0)
3907+
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
3908+
updated_indices,
3909+
self.lxu_cache_state,
3910+
self.total_cache_hash_size,
3911+
gather_cache_stats=False, # not collecting cache stats
3912+
num_uniq_cache_indices=updated_count,
3913+
)
3914+
updated_weights = torch.empty(
3915+
[updated_indices.size()[0], self.max_D_cache],
3916+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
3917+
dtype=self.lxu_cache_weights.dtype,
3918+
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
3919+
device=self.lxu_cache_weights.device,
3920+
)
3921+
torch.ops.fbgemm.masked_index_select(
3922+
updated_weights,
3923+
updated_locations,
3924+
self.lxu_cache_weights,
3925+
updated_count,
3926+
)
3927+
# stream weights
3928+
self._raw_embedding_streamer.stream(
3929+
updated_indices.to(device=torch.device("cpu")),
3930+
updated_weights.to(device=torch.device("cpu")),
3931+
updated_count.to(device=torch.device("cpu")),
3932+
False, # require_tensor_copy
3933+
False, # blocking_tensor_copy
3934+
)
3935+
38323936

38333937
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
38343938
"""

fbgemm_gpu/test/tbe/training/forward_test.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import math
1313
import random
1414
import unittest
15+
from unittest.mock import MagicMock, patch
1516

1617
import hypothesis.strategies as st
1718
import numpy as np
@@ -24,6 +25,7 @@
2425
)
2526
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
2627
ComputeDevice,
28+
RESParams,
2729
SplitTableBatchedEmbeddingBagsCodegen,
2830
)
2931
from fbgemm_gpu.tbe.utils import (
@@ -129,6 +131,8 @@ def execute_forward_( # noqa C901
129131
use_cpu: bool,
130132
output_dtype: SparseType,
131133
use_experimental_tbe: bool,
134+
enable_raw_embedding_streaming: bool = False,
135+
prefetch_pipeline: bool = False,
132136
) -> None:
133137
# NOTE: cache is not applicable to CPU version.
134138
assume(not use_cpu or not use_cache)
@@ -158,6 +162,10 @@ def execute_forward_( # noqa C901
158162
and pooling_mode != PoolingMode.NONE
159163
)
160164
)
165+
# NOTE: Raw embedding streaming requires UVM cache
166+
assume(not enable_raw_embedding_streaming or use_cache)
167+
# NOTE: Raw embedding streaming not supported on CPU
168+
assume(not enable_raw_embedding_streaming or not use_cpu)
161169

162170
emb_op = SplitTableBatchedEmbeddingBagsCodegen
163171
if pooling_mode == PoolingMode.SUM:
@@ -285,6 +293,16 @@ def execute_forward_( # noqa C901
285293
else:
286294
f = torch.cat(fs, dim=0).view(-1, D)
287295

296+
# Create RES parameters if raw embedding streaming is enabled
297+
res_params = None
298+
if enable_raw_embedding_streaming:
299+
res_params = RESParams(
300+
res_store_shards=1,
301+
table_names=[f"table_{i}" for i in range(T)],
302+
table_offsets=[sum(Es[:i]) for i in range(T + 1)],
303+
table_sizes=Es,
304+
)
305+
288306
# Create a TBE op
289307
cc = emb_op(
290308
embedding_specs=[
@@ -305,6 +323,9 @@ def execute_forward_( # noqa C901
305323
pooling_mode=pooling_mode,
306324
output_dtype=output_dtype,
307325
use_experimental_tbe=use_experimental_tbe,
326+
prefetch_pipeline=prefetch_pipeline,
327+
enable_raw_embedding_streaming=enable_raw_embedding_streaming,
328+
res_params=res_params,
308329
)
309330
# Test torch JIT script compatibility
310331
if not use_cpu:
@@ -1158,6 +1179,96 @@ def test_forward_fused_pooled_emb_quant(
11581179
cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output
11591180
)
11601181

1182+
def _check_raw_embedding_stream_call_counts(
1183+
self,
1184+
mock_raw_embedding_stream: unittest.mock.Mock,
1185+
num_iterations: int,
1186+
prefetch_pipeline: bool,
1187+
L: int,
1188+
) -> None:
1189+
# For TBE (not SSD), raw_embedding_stream is called once per prefetch
1190+
# when there's data to stream
1191+
expected_calls = num_iterations if L > 0 else 0
1192+
if prefetch_pipeline:
1193+
# With prefetch pipeline, there might be fewer calls initially
1194+
expected_calls = max(0, expected_calls - 1)
1195+
1196+
self.assertGreaterEqual(mock_raw_embedding_stream.call_count, 0)
1197+
# Allow some flexibility in call count due to caching behavior
1198+
self.assertLessEqual(mock_raw_embedding_stream.call_count, expected_calls + 2)
1199+
1200+
@unittest.skipIf(*gpu_unavailable)
1201+
@given(
1202+
T=st.integers(min_value=1, max_value=5),
1203+
D=st.integers(min_value=2, max_value=64),
1204+
B=st.integers(min_value=1, max_value=32),
1205+
log_E=st.integers(min_value=3, max_value=4),
1206+
L=st.integers(min_value=1, max_value=10),
1207+
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
1208+
cache_algorithm=st.sampled_from(CacheAlgorithm),
1209+
pooling_mode=st.sampled_from([PoolingMode.SUM, PoolingMode.MEAN]),
1210+
weighted=st.booleans(),
1211+
mixed=st.booleans(),
1212+
prefetch_pipeline=st.booleans(),
1213+
)
1214+
@settings(
1215+
verbosity=VERBOSITY,
1216+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
1217+
deadline=None,
1218+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
1219+
)
1220+
def test_forward_raw_embedding_streaming(
1221+
self,
1222+
T: int,
1223+
D: int,
1224+
B: int,
1225+
log_E: int,
1226+
L: int,
1227+
weights_precision: SparseType,
1228+
cache_algorithm: CacheAlgorithm,
1229+
pooling_mode: PoolingMode,
1230+
weighted: bool,
1231+
mixed: bool,
1232+
prefetch_pipeline: bool,
1233+
) -> None:
1234+
"""Test raw embedding streaming functionality integrated with forward pass."""
1235+
num_iterations = 5
1236+
# only LRU supports prefetch_pipeline
1237+
assume(not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU)
1238+
1239+
with patch(
1240+
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
1241+
) as mock_streamer_class:
1242+
# Mock the RawEmbeddingStreamer class
1243+
mock_streamer_instance = MagicMock()
1244+
mock_streamer_class.return_value = mock_streamer_instance
1245+
1246+
# Run multiple iterations to test streaming behavior
1247+
for _ in range(num_iterations):
1248+
self.execute_forward_(
1249+
T=T,
1250+
D=D,
1251+
B=B,
1252+
log_E=log_E,
1253+
L=L,
1254+
weights_precision=weights_precision,
1255+
weighted=weighted,
1256+
mixed=mixed,
1257+
mixed_B=False, # Keep simple for streaming tests
1258+
use_cache=True, # Required for streaming
1259+
cache_algorithm=cache_algorithm,
1260+
pooling_mode=pooling_mode,
1261+
use_cpu=False, # Streaming not supported on CPU
1262+
output_dtype=SparseType.FP32,
1263+
use_experimental_tbe=False,
1264+
enable_raw_embedding_streaming=True,
1265+
prefetch_pipeline=prefetch_pipeline,
1266+
)
1267+
1268+
self._check_raw_embedding_stream_call_counts(
1269+
mock_streamer_instance, num_iterations, prefetch_pipeline, L
1270+
)
1271+
11611272

11621273
if __name__ == "__main__":
11631274
unittest.main()

0 commit comments

Comments
 (0)