60
60
61
61
from fbgemm_gpu .utils .loader import load_torch_module , load_torch_module_bc
62
62
63
+ from torch .autograd .profiler import record_function
64
+
63
65
try :
64
66
load_torch_module (
65
67
"//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu" ,
@@ -626,6 +628,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
626
628
lxu_cache_locations_list : List [Tensor ]
627
629
lxu_cache_locations_empty : Tensor
628
630
timesteps_prefetched : List [int ]
631
+ prefetched_info : List [Tuple [Tensor , Tensor ]]
629
632
record_cache_metrics : RecordCacheMetrics
630
633
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
631
634
uvm_cache_stats : torch .Tensor
@@ -690,6 +693,8 @@ def __init__( # noqa C901
690
693
embedding_table_index_type : torch .dtype = torch .int64 ,
691
694
embedding_table_offset_type : torch .dtype = torch .int64 ,
692
695
embedding_shard_info : Optional [List [Tuple [int , int , int , int ]]] = None ,
696
+ enable_raw_embedding_streaming : bool = False ,
697
+ res_params : Optional [RESParams ] = None ,
693
698
) -> None :
694
699
super (SplitTableBatchedEmbeddingBagsCodegen , self ).__init__ ()
695
700
self .uuid = str (uuid .uuid4 ())
@@ -700,6 +705,7 @@ def __init__( # noqa C901
700
705
)
701
706
702
707
self .logging_table_name : str = self .get_table_name_for_logging (table_names )
708
+ self .enable_raw_embedding_streaming : bool = enable_raw_embedding_streaming
703
709
self .pooling_mode = pooling_mode
704
710
self .is_nobag : bool = self .pooling_mode == PoolingMode .NONE
705
711
@@ -1460,6 +1466,30 @@ def __init__( # noqa C901
1460
1466
)
1461
1467
self .embedding_table_offset_type : torch .dtype = embedding_table_offset_type
1462
1468
1469
+ self .prefetched_info : List [Tuple [Tensor , Tensor ]] = torch .jit .annotate (
1470
+ List [Tuple [Tensor , Tensor ]], []
1471
+ )
1472
+ if self .enable_raw_embedding_streaming :
1473
+ self .res_params : RESParams = res_params or RESParams ()
1474
+ self .res_params .table_sizes = [0 ] + list (accumulate (rows ))
1475
+ res_port_from_env = os .getenv ("LOCAL_RES_PORT" )
1476
+ self .res_params .res_server_port = (
1477
+ int (res_port_from_env ) if res_port_from_env else 0
1478
+ )
1479
+ # pyre-fixme[4]: Attribute must be annotated.
1480
+ self ._raw_embedding_streamer = torch .classes .fbgemm .RawEmbeddingStreamer (
1481
+ self .uuid ,
1482
+ self .enable_raw_embedding_streaming ,
1483
+ self .res_params .res_store_shards ,
1484
+ self .res_params .res_server_port ,
1485
+ self .res_params .table_names ,
1486
+ self .res_params .table_offsets ,
1487
+ self .res_params .table_sizes ,
1488
+ )
1489
+ logging .info (
1490
+ f"{ self .uuid } raw embedding streaming enabled with { self .res_params = } "
1491
+ )
1492
+
1463
1493
@torch .jit .ignore
1464
1494
def log (self , msg : str ) -> None :
1465
1495
"""
@@ -2521,7 +2551,13 @@ def _prefetch(
2521
2551
self .local_uvm_cache_stats .zero_ ()
2522
2552
self ._report_io_size_count ("prefetch_input" , indices )
2523
2553
2554
+ # streaming before updating the cache
2555
+ self .raw_embedding_stream ()
2556
+
2524
2557
final_lxu_cache_locations = torch .empty_like (indices , dtype = torch .int32 )
2558
+ linear_cache_indices_merged = torch .zeros (
2559
+ 0 , dtype = indices .dtype , device = indices .device
2560
+ )
2525
2561
for (
2526
2562
partial_indices ,
2527
2563
partial_lxu_cache_locations ,
@@ -2537,6 +2573,9 @@ def _prefetch(
2537
2573
vbe_metadata .max_B if vbe_metadata is not None else - 1 ,
2538
2574
base_offset ,
2539
2575
)
2576
+ linear_cache_indices_merged = torch .cat (
2577
+ [linear_cache_indices_merged , linear_cache_indices ]
2578
+ )
2540
2579
2541
2580
if (
2542
2581
self .record_cache_metrics .record_cache_miss_counter
@@ -2617,6 +2656,23 @@ def _prefetch(
2617
2656
if self .should_log ():
2618
2657
self .print_uvm_cache_stats (use_local_cache = False )
2619
2658
2659
+ if self .enable_raw_embedding_streaming :
2660
+ with record_function (
2661
+ "## uvm_save_prefetched_rows {} {} ##" .format (self .timestep , self .uuid )
2662
+ ):
2663
+ (
2664
+ linear_unique_indices ,
2665
+ linear_unique_indices_length ,
2666
+ _ ,
2667
+ ) = torch .ops .fbgemm .get_unique_indices (
2668
+ linear_cache_indices_merged ,
2669
+ self .total_cache_hash_size ,
2670
+ compute_count = False ,
2671
+ )
2672
+ self .prefetched_info .append (
2673
+ (linear_unique_indices , linear_unique_indices_length )
2674
+ )
2675
+
2620
2676
def should_log (self ) -> bool :
2621
2677
"""Determines if we should log for this step, using exponentially decreasing frequency.
2622
2678
@@ -3829,6 +3885,55 @@ def _debug_print_input_stats_factory_null(
3829
3885
return _debug_print_input_stats_factory_impl
3830
3886
return _debug_print_input_stats_factory_null
3831
3887
3888
+ @torch .jit .ignore
3889
+ def raw_embedding_stream (self ) -> None :
3890
+ if not self .enable_raw_embedding_streaming :
3891
+ return None
3892
+ # when pipelining is enabled
3893
+ # prefetch in iter i happens before the backward sparse in iter i - 1
3894
+ # so embeddings for iter i - 1's changed ids are not updated.
3895
+ # so we can only fetch the indices from the iter i - 2
3896
+ # when pipelining is disabled
3897
+ # prefetch in iter i happens before forward iter i
3898
+ # so we can get the iter i - 1's changed ids safely.
3899
+ target_prev_iter = 1
3900
+ if self .prefetch_pipeline :
3901
+ target_prev_iter = 2
3902
+ if not len (self .prefetched_info ) > (target_prev_iter - 1 ):
3903
+ return None
3904
+ with record_function (
3905
+ "## uvm_lookup_prefetched_rows {} {} ##" .format (self .timestep , self .uuid )
3906
+ ):
3907
+ (updated_indices , updated_count ) = self .prefetched_info .pop (0 )
3908
+ updated_locations = torch .ops .fbgemm .lxu_cache_lookup (
3909
+ updated_indices ,
3910
+ self .lxu_cache_state ,
3911
+ self .total_cache_hash_size ,
3912
+ gather_cache_stats = False , # not collecting cache stats
3913
+ num_uniq_cache_indices = updated_count ,
3914
+ )
3915
+ updated_weights = torch .empty (
3916
+ [updated_indices .size ()[0 ], self .max_D_cache ],
3917
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
3918
+ dtype = self .lxu_cache_weights .dtype ,
3919
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
3920
+ device = self .lxu_cache_weights .device ,
3921
+ )
3922
+ torch .ops .fbgemm .masked_index_select (
3923
+ updated_weights ,
3924
+ updated_locations ,
3925
+ self .lxu_cache_weights ,
3926
+ updated_count ,
3927
+ )
3928
+ # stream weights
3929
+ self ._raw_embedding_streamer .stream (
3930
+ updated_indices .to (device = torch .device ("cpu" )),
3931
+ updated_weights .to (device = torch .device ("cpu" )),
3932
+ updated_count .to (device = torch .device ("cpu" )),
3933
+ False , # require_tensor_copy
3934
+ False , # blocking_tensor_copy
3935
+ )
3936
+
3832
3937
3833
3938
class DenseTableBatchedEmbeddingBagsCodegen (nn .Module ):
3834
3939
"""
0 commit comments