22
22
23
23
import torch # usort:skip
24
24
from torch import nn , Tensor # usort:skip
25
+ from torch .autograd .profiler import record_function # usort:skip
25
26
26
27
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
27
28
import fbgemm_gpu .split_embedding_codegen_lookup_invokers as invokers
@@ -626,6 +627,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
626
627
lxu_cache_locations_list : List [Tensor ]
627
628
lxu_cache_locations_empty : Tensor
628
629
timesteps_prefetched : List [int ]
630
+ prefetched_info : List [Tuple [Tensor , Tensor ]]
629
631
record_cache_metrics : RecordCacheMetrics
630
632
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
631
633
uvm_cache_stats : torch .Tensor
@@ -690,6 +692,8 @@ def __init__( # noqa C901
690
692
embedding_table_index_type : torch .dtype = torch .int64 ,
691
693
embedding_table_offset_type : torch .dtype = torch .int64 ,
692
694
embedding_shard_info : Optional [List [Tuple [int , int , int , int ]]] = None ,
695
+ enable_raw_embedding_streaming : bool = False ,
696
+ res_params : Optional [RESParams ] = None ,
693
697
) -> None :
694
698
super (SplitTableBatchedEmbeddingBagsCodegen , self ).__init__ ()
695
699
self .uuid = str (uuid .uuid4 ())
@@ -700,6 +704,7 @@ def __init__( # noqa C901
700
704
)
701
705
702
706
self .logging_table_name : str = self .get_table_name_for_logging (table_names )
707
+ self .enable_raw_embedding_streaming : bool = enable_raw_embedding_streaming
703
708
self .pooling_mode = pooling_mode
704
709
self .is_nobag : bool = self .pooling_mode == PoolingMode .NONE
705
710
@@ -1460,6 +1465,30 @@ def __init__( # noqa C901
1460
1465
)
1461
1466
self .embedding_table_offset_type : torch .dtype = embedding_table_offset_type
1462
1467
1468
+ self .prefetched_info : List [Tuple [Tensor , Tensor ]] = torch .jit .annotate (
1469
+ List [Tuple [Tensor , Tensor ]], []
1470
+ )
1471
+ if self .enable_raw_embedding_streaming :
1472
+ self .res_params : RESParams = res_params or RESParams ()
1473
+ self .res_params .table_sizes = [0 ] + list (accumulate (rows ))
1474
+ res_port_from_env = os .getenv ("LOCAL_RES_PORT" )
1475
+ self .res_params .res_server_port = (
1476
+ int (res_port_from_env ) if res_port_from_env else 0
1477
+ )
1478
+ # pyre-fixme[4]: Attribute must be annotated.
1479
+ self ._raw_embedding_streamer = torch .classes .fbgemm .RawEmbeddingStreamer (
1480
+ self .uuid ,
1481
+ self .enable_raw_embedding_streaming ,
1482
+ self .res_params .res_store_shards ,
1483
+ self .res_params .res_server_port ,
1484
+ self .res_params .table_names ,
1485
+ self .res_params .table_offsets ,
1486
+ self .res_params .table_sizes ,
1487
+ )
1488
+ logging .info (
1489
+ f"{ self .uuid } raw embedding streaming enabled with { self .res_params = } "
1490
+ )
1491
+
1463
1492
@torch .jit .ignore
1464
1493
def log (self , msg : str ) -> None :
1465
1494
"""
@@ -2521,7 +2550,13 @@ def _prefetch(
2521
2550
self .local_uvm_cache_stats .zero_ ()
2522
2551
self ._report_io_size_count ("prefetch_input" , indices )
2523
2552
2553
+ # streaming before updating the cache
2554
+ self .raw_embedding_stream ()
2555
+
2524
2556
final_lxu_cache_locations = torch .empty_like (indices , dtype = torch .int32 )
2557
+ linear_cache_indices_merged = torch .zeros (
2558
+ 0 , dtype = indices .dtype , device = indices .device
2559
+ )
2525
2560
for (
2526
2561
partial_indices ,
2527
2562
partial_lxu_cache_locations ,
@@ -2537,6 +2572,9 @@ def _prefetch(
2537
2572
vbe_metadata .max_B if vbe_metadata is not None else - 1 ,
2538
2573
base_offset ,
2539
2574
)
2575
+ linear_cache_indices_merged = torch .cat (
2576
+ [linear_cache_indices_merged , linear_cache_indices ]
2577
+ )
2540
2578
2541
2579
if (
2542
2580
self .record_cache_metrics .record_cache_miss_counter
@@ -2617,6 +2655,23 @@ def _prefetch(
2617
2655
if self .should_log ():
2618
2656
self .print_uvm_cache_stats (use_local_cache = False )
2619
2657
2658
+ if self .enable_raw_embedding_streaming :
2659
+ with record_function (
2660
+ "## uvm_save_prefetched_rows {} {} ##" .format (self .timestep , self .uuid )
2661
+ ):
2662
+ (
2663
+ linear_unique_indices ,
2664
+ linear_unique_indices_length ,
2665
+ _ ,
2666
+ ) = torch .ops .fbgemm .get_unique_indices (
2667
+ linear_cache_indices_merged ,
2668
+ self .total_cache_hash_size ,
2669
+ compute_count = False ,
2670
+ )
2671
+ self .prefetched_info .append (
2672
+ (linear_unique_indices , linear_unique_indices_length )
2673
+ )
2674
+
2620
2675
def should_log (self ) -> bool :
2621
2676
"""Determines if we should log for this step, using exponentially decreasing frequency.
2622
2677
@@ -3829,6 +3884,55 @@ def _debug_print_input_stats_factory_null(
3829
3884
return _debug_print_input_stats_factory_impl
3830
3885
return _debug_print_input_stats_factory_null
3831
3886
3887
+ @torch .jit .ignore
3888
+ def raw_embedding_stream (self ) -> None :
3889
+ if not self .enable_raw_embedding_streaming :
3890
+ return None
3891
+ # when pipelining is enabled
3892
+ # prefetch in iter i happens before the backward sparse in iter i - 1
3893
+ # so embeddings for iter i - 1's changed ids are not updated.
3894
+ # so we can only fetch the indices from the iter i - 2
3895
+ # when pipelining is disabled
3896
+ # prefetch in iter i happens before forward iter i
3897
+ # so we can get the iter i - 1's changed ids safely.
3898
+ target_prev_iter = 1
3899
+ if self .prefetch_pipeline :
3900
+ target_prev_iter = 2
3901
+ if not len (self .prefetched_info ) > (target_prev_iter - 1 ):
3902
+ return None
3903
+ with record_function (
3904
+ "## uvm_lookup_prefetched_rows {} {} ##" .format (self .timestep , self .uuid )
3905
+ ):
3906
+ (updated_indices , updated_count ) = self .prefetched_info .pop (0 )
3907
+ updated_locations = torch .ops .fbgemm .lxu_cache_lookup (
3908
+ updated_indices ,
3909
+ self .lxu_cache_state ,
3910
+ self .total_cache_hash_size ,
3911
+ gather_cache_stats = False , # not collecting cache stats
3912
+ num_uniq_cache_indices = updated_count ,
3913
+ )
3914
+ updated_weights = torch .empty (
3915
+ [updated_indices .size ()[0 ], self .max_D_cache ],
3916
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
3917
+ dtype = self .lxu_cache_weights .dtype ,
3918
+ # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
3919
+ device = self .lxu_cache_weights .device ,
3920
+ )
3921
+ torch .ops .fbgemm .masked_index_select (
3922
+ updated_weights ,
3923
+ updated_locations ,
3924
+ self .lxu_cache_weights ,
3925
+ updated_count ,
3926
+ )
3927
+ # stream weights
3928
+ self ._raw_embedding_streamer .stream (
3929
+ updated_indices .to (device = torch .device ("cpu" )),
3930
+ updated_weights .to (device = torch .device ("cpu" )),
3931
+ updated_count .to (device = torch .device ("cpu" )),
3932
+ False , # require_tensor_copy
3933
+ False , # blocking_tensor_copy
3934
+ )
3935
+
3832
3936
3833
3937
class DenseTableBatchedEmbeddingBagsCodegen (nn .Module ):
3834
3938
"""
0 commit comments