diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 05e9cdb8c8..8aeecc910d 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -22,6 +22,7 @@ import torch # usort:skip from torch import nn, Tensor # usort:skip +from torch.autograd.profiler import record_function # usort:skip # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -626,6 +627,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): lxu_cache_locations_list: List[Tensor] lxu_cache_locations_empty: Tensor timesteps_prefetched: List[int] + prefetched_info: List[Tuple[Tensor, Tensor, Optional[Tensor]]] record_cache_metrics: RecordCacheMetrics # pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized. uvm_cache_stats: torch.Tensor @@ -690,6 +692,8 @@ def __init__( # noqa C901 embedding_table_index_type: torch.dtype = torch.int64, embedding_table_offset_type: torch.dtype = torch.int64, embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None, + enable_raw_embedding_streaming: bool = False, + res_params: Optional[RESParams] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() self.uuid = str(uuid.uuid4()) @@ -700,6 +704,7 @@ def __init__( # noqa C901 ) self.logging_table_name: str = self.get_table_name_for_logging(table_names) + self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming self.pooling_mode = pooling_mode self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE @@ -1460,6 +1465,30 @@ def __init__( # noqa C901 ) self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type + self.prefetched_info: List[Tuple[Tensor, Tensor, Optional[Tensor]]] = ( + torch.jit.annotate(List[Tuple[Tensor, Tensor, Optional[Tensor]]], []) + ) + if self.enable_raw_embedding_streaming: + self.res_params: RESParams = res_params or RESParams() + self.res_params.table_sizes = [0] + list(accumulate(rows)) + res_port_from_env = os.getenv("LOCAL_RES_PORT") + self.res_params.res_server_port = ( + int(res_port_from_env) if res_port_from_env else 0 + ) + # pyre-fixme[4]: Attribute must be annotated. + self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer( + self.uuid, + self.enable_raw_embedding_streaming, + self.res_params.res_store_shards, + self.res_params.res_server_port, + self.res_params.table_names, + self.res_params.table_offsets, + self.res_params.table_sizes, + ) + logging.info( + f"{self.uuid} raw embedding streaming enabled with {self.res_params=}" + ) + @torch.jit.ignore def log(self, msg: str) -> None: """ @@ -1979,8 +2008,13 @@ def forward( # noqa: C901 # In forward, we don't enable multi-pass prefetch as we want the process # to be as fast as possible and memory usage doesn't matter (will be recycled # by dense fwd/bwd) + # TODO: Properly pass in the hash_zch_identities self._prefetch( - indices, offsets, vbe_metadata, multipass_prefetch_config=None + indices, + offsets, + vbe_metadata, + multipass_prefetch_config=None, + hash_zch_identities=None, ) if len(self.timesteps_prefetched) > 0: @@ -2503,6 +2537,7 @@ def _prefetch( offsets: Tensor, vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None, multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, + hash_zch_identities: Optional[Tensor] = None, ) -> None: if not is_torchdynamo_compiling(): # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time @@ -2521,7 +2556,13 @@ def _prefetch( self.local_uvm_cache_stats.zero_() self._report_io_size_count("prefetch_input", indices) + # streaming before updating the cache + self.raw_embedding_stream() + final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32) + linear_cache_indices_merged = torch.zeros( + 0, dtype=indices.dtype, device=indices.device + ) for ( partial_indices, partial_lxu_cache_locations, @@ -2537,6 +2578,9 @@ def _prefetch( vbe_metadata.max_B if vbe_metadata is not None else -1, base_offset, ) + linear_cache_indices_merged = torch.cat( + [linear_cache_indices_merged, linear_cache_indices] + ) if ( self.record_cache_metrics.record_cache_miss_counter @@ -2617,6 +2661,36 @@ def _prefetch( if self.should_log(): self.print_uvm_cache_stats(use_local_cache=False) + if self.enable_raw_embedding_streaming: + with record_function( + "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid) + ): + ( + linear_unique_indices, + linear_unique_indices_length, + _, + ) = torch.ops.fbgemm.get_unique_indices( + linear_cache_indices_merged, + self.total_cache_hash_size, + compute_count=False, + ) + linear_unique_indices = linear_unique_indices.narrow( + 0, 0, linear_unique_indices_length[0] + ) + self.prefetched_info.append( + ( + linear_unique_indices, + linear_unique_indices_length, + ( + hash_zch_identities[linear_unique_indices].to( + device=torch.device("cpu") + ) + if hash_zch_identities is not None + else None + ), + ) + ) + def should_log(self) -> bool: """Determines if we should log for this step, using exponentially decreasing frequency. @@ -3829,6 +3903,62 @@ def _debug_print_input_stats_factory_null( return _debug_print_input_stats_factory_impl return _debug_print_input_stats_factory_null + @torch.jit.ignore + def raw_embedding_stream(self) -> None: + if not self.enable_raw_embedding_streaming: + return None + # when pipelining is enabled + # prefetch in iter i happens before the backward sparse in iter i - 1 + # so embeddings for iter i - 1's changed ids are not updated. + # so we can only fetch the indices from the iter i - 2 + # when pipelining is disabled + # prefetch in iter i happens before forward iter i + # so we can get the iter i - 1's changed ids safely. + target_prev_iter = 1 + if self.prefetch_pipeline: + target_prev_iter = 2 + if not len(self.prefetched_info) > (target_prev_iter - 1): + return None + with record_function( + "## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid) + ): + (updated_indices, updated_count, updated_identities) = ( + self.prefetched_info.pop(0) + ) + updated_locations = torch.ops.fbgemm.lxu_cache_lookup( + updated_indices, + self.lxu_cache_state, + self.total_cache_hash_size, + gather_cache_stats=False, # not collecting cache stats + num_uniq_cache_indices=updated_count, + ) + updated_weights = torch.empty( + [updated_indices.size()[0], self.max_D_cache], + # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]` + dtype=self.lxu_cache_weights.dtype, + # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]` + device=self.lxu_cache_weights.device, + ) + torch.ops.fbgemm.masked_index_select( + updated_weights, + updated_locations, + self.lxu_cache_weights, + updated_count, + ) + # stream weights + self._raw_embedding_streamer.stream( + updated_indices.to(device=torch.device("cpu")), + updated_weights.to(device=torch.device("cpu")), + ( + updated_identities.to(device=torch.device("cpu")) + if updated_identities is not None + else None + ), + updated_count.to(device=torch.device("cpu")), + False, # require_tensor_copy + False, # blocking_tensor_copy + ) + class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): """ diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h index e44dc65025..32e363c5c4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h @@ -19,13 +19,16 @@ namespace fbgemm_gpu { struct StreamQueueItem { at::Tensor indices; at::Tensor weights; + std::optional identities; at::Tensor count; StreamQueueItem( at::Tensor src_indices, at::Tensor src_weights, + std::optional src_identities, at::Tensor src_count) { indices = std::move(src_indices); weights = std::move(src_weights); + identities = std::move(src_identities); count = std::move(src_count); } }; @@ -67,6 +70,7 @@ class RawEmbeddingStreamer : public torch::jit::CustomClassHolder { void stream( const at::Tensor& indices, const at::Tensor& weights, + std::optional identities, const at::Tensor& count, bool require_tensor_copy, bool blocking_tensor_copy = true); @@ -74,7 +78,8 @@ class RawEmbeddingStreamer : public torch::jit::CustomClassHolder { #ifdef FBGEMM_FBCODE folly::coro::Task tensor_stream( const at::Tensor& indices, - const at::Tensor& weights); + const at::Tensor& weights, + std::optional identities); /* * Copy the indices, weights and count tensors and enqueue them for * asynchronous stream. @@ -82,6 +87,7 @@ class RawEmbeddingStreamer : public torch::jit::CustomClassHolder { void copy_and_enqueue_stream_tensors( const at::Tensor& indices, const at::Tensor& weights, + std::optional identities, const at::Tensor& count); /* diff --git a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp index ffee3ff672..96d48f9a32 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp @@ -35,8 +35,9 @@ std::unique_ptr< TrainingParameterServerService>> get_res_client(int64_t res_server_port) { auto& factory = facebook::servicerouter::cpp2::getClientFactory(); - auto& params = facebook::servicerouter::ClientParams().setSingleHost( - "::", res_server_port); + auto params = + folly::copy(facebook::servicerouter::ClientParams().setSingleHost( + "::", res_server_port)); return factory.getSRClientUnique< apache::thrift::Client>( @@ -56,6 +57,7 @@ inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { fbgemm_gpu::StreamQueueItem tensor_copy( const at::Tensor& indices, const at::Tensor& weights, + std::optional identities, const at::Tensor& count) { auto num_sets = get_maybe_uvm_scalar(count); auto new_indices = at::empty( @@ -63,6 +65,12 @@ fbgemm_gpu::StreamQueueItem tensor_copy( auto new_weights = at::empty( {num_sets, weights.size(1)}, at::TensorOptions().device(at::kCPU).dtype(weights.dtype())); + std::optional new_identities = std::nullopt; + if (identities.has_value()) { + new_identities = at::empty( + num_sets, + at::TensorOptions().device(at::kCPU).dtype(identities->dtype())); + } auto new_count = at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( @@ -71,23 +79,52 @@ fbgemm_gpu::StreamQueueItem tensor_copy( FBGEMM_DISPATCH_INTEGRAL_TYPES( indices.scalar_type(), "tensor_copy", [&] { using index_t = scalar_t; - auto indices_addr = indices.data_ptr(); - auto new_indices_addr = new_indices.data_ptr(); - std::copy( - indices_addr, - indices_addr + num_sets, - new_indices_addr); // dst_start + if (identities.has_value()) { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + identities->scalar_type(), "tensor_copy", [&] { + using identities_t = scalar_t; + auto identities_addr = + identities->data_ptr(); + auto new_identities_addr = + new_identities->data_ptr(); + std::copy( + identities_addr, + identities_addr + num_sets, + new_identities_addr); // dst_start + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start + + auto weights_addr = weights.data_ptr(); + auto new_weights_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weights_addr); // dst_start + }); + } else { + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start - auto weights_addr = weights.data_ptr(); - auto new_weightss_addr = new_weights.data_ptr(); - std::copy( - weights_addr, - weights_addr + num_sets * weights.size(1), - new_weightss_addr); // dst_start + auto weights_addr = weights.data_ptr(); + auto new_weights_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weights_addr); // dst_start + } }); }); *new_count.data_ptr() = num_sets; - return fbgemm_gpu::StreamQueueItem{new_indices, new_weights, new_count}; + return fbgemm_gpu::StreamQueueItem{ + new_indices, new_weights, new_identities, new_count}; } } // namespace @@ -127,8 +164,9 @@ RawEmbeddingStreamer::RawEmbeddingStreamer( } auto& indices = stream_item_ptr->indices; auto& weights = stream_item_ptr->weights; + auto& identities = stream_item_ptr->identities; folly::stop_watch stop_watch; - folly::coro::blockingWait(tensor_stream(indices, weights)); + folly::coro::blockingWait(tensor_stream(indices, weights, identities)); weights_to_stream_queue_.dequeue(); XLOG_EVERY_MS(INFO, 60000) @@ -154,6 +192,7 @@ RawEmbeddingStreamer::~RawEmbeddingStreamer() { void RawEmbeddingStreamer::stream( const at::Tensor& indices, const at::Tensor& weights, + std::optional identities, const at::Tensor& count, bool require_tensor_copy, bool blocking_tensor_copy) { @@ -164,12 +203,13 @@ void RawEmbeddingStreamer::stream( auto rec = torch::autograd::profiler::record_function_enter_new( "## RawEmbeddingStreamer::stream_callback ##"); if (!require_tensor_copy) { - StreamQueueItem stream_item(indices, weights, count); + StreamQueueItem stream_item(indices, weights, std::move(identities), count); weights_to_stream_queue_.enqueue(stream_item); return; } if (blocking_tensor_copy) { - copy_and_enqueue_stream_tensors(indices, weights, count); + copy_and_enqueue_stream_tensors( + indices, weights, std::move(identities), count); return; } // Make sure the previous thread is done before starting a new one @@ -179,7 +219,8 @@ void RawEmbeddingStreamer::stream( // So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA // can continue executing other host callbacks, eg. get/evict. stream_tensor_copy_thread_ = std::make_unique([=, this]() { - copy_and_enqueue_stream_tensors(indices, weights, count); + copy_and_enqueue_stream_tensors( + indices, weights, std::move(identities), count); }); rec->record.end(); #endif @@ -188,7 +229,8 @@ void RawEmbeddingStreamer::stream( #ifdef FBGEMM_FBCODE folly::coro::Task RawEmbeddingStreamer::tensor_stream( const at::Tensor& indices, - const at::Tensor& weights) { + const at::Tensor& weights, + std::optional identities) { using namespace ::aiplatform::gmpp::experimental::training_ps; if (indices.size(0) != weights.size(0)) { XLOG(ERR) << "[TBE_ID" << unique_id_ @@ -200,13 +242,19 @@ folly::coro::Task RawEmbeddingStreamer::tensor_stream( XLOG_EVERY_MS(INFO, 60000) << "[TBE_ID" << unique_id_ << "] send streaming request: indices = " << indices.size(0) - << ", weights = " << weights.size(0); + << ", weights = " << weights.size(0) << ", identities = " + << (identities.has_value() ? std::to_string(identities->size(0)) + : "none"); auto biggest_idx = table_sizes_.index({table_sizes_.size(0) - 1}); auto mask = at::logical_and(indices >= 0, indices < biggest_idx).nonzero().squeeze(); auto filtered_indices = indices.index_select(0, mask); auto filtered_weights = weights.index_select(0, mask); + std::optional filtered_identities = std::nullopt; + if (identities.has_value()) { + filtered_identities = identities->index_select(0, mask); + } auto num_invalid_indices = indices.size(0) - filtered_indices.size(0); if (num_invalid_indices > 0) { XLOG(INFO) << "[TBE_ID" << unique_id_ @@ -273,6 +321,10 @@ folly::coro::Task RawEmbeddingStreamer::tensor_stream( req.rowIndices() = torch::distributed::wireDumpTensor(global_indices_masked); req.weights() = torch::distributed::wireDumpTensor(weights_masked); + if (filtered_identities.has_value()) { + auto identities_masked = filtered_identities->index_select(0, shrad_mask); + req.identities() = torch::distributed::wireDumpTensor(identities_masked); + } co_await res_client->co_setEmbeddings(req); } co_return; @@ -281,10 +333,12 @@ folly::coro::Task RawEmbeddingStreamer::tensor_stream( void RawEmbeddingStreamer::copy_and_enqueue_stream_tensors( const at::Tensor& indices, const at::Tensor& weights, + std::optional identities, const at::Tensor& count) { auto rec = torch::autograd::profiler::record_function_enter_new( "## RawEmbeddingStreamer::copy_and_enqueue_stream_tensors ##"); - auto stream_item = tensor_copy(indices, weights, count); + auto stream_item = + tensor_copy(indices, weights, std::move(identities), count); weights_to_stream_queue_.enqueue(stream_item); rec->record.end(); } diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index 89b1e9d720..b3c2a726a8 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -104,6 +104,7 @@ static auto raw_embedding_streamer = { torch::arg("indices"), torch::arg("weights"), + torch::arg("identities"), torch::arg("count"), torch::arg("require_tensor_copy"), torch::arg("blocking_tensor_copy"), diff --git a/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp index 3b408f7a93..a07c5e8ac6 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp @@ -80,7 +80,7 @@ TEST(RawEmbeddingStreamerTest, TestStreamWithoutStreaming) { {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); // Should not crash when streaming is disabled - streamer->stream(indices, weights, count, true, true); + streamer->stream(indices, weights, std::nullopt, count, true, true); } #ifdef FBGEMM_FBCODE @@ -112,7 +112,8 @@ TEST(RawEmbeddingStreamerTest, TestTensorStream) { {invalid_indices.size(0), EMBEDDING_DIMENSION}, at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); - folly::coro::blockingWait(streamer->tensor_stream(invalid_indices, weights)); + folly::coro::blockingWait( + streamer->tensor_stream(invalid_indices, weights, std::nullopt)); // Test with valid indices - should call service auto valid_indices = at::tensor( @@ -134,7 +135,8 @@ TEST(RawEmbeddingStreamerTest, TestTensorStream) { aiplatform::gmpp::experimental::training_ps:: SetEmbeddingsResponse>(); })); - folly::coro::blockingWait(streamer->tensor_stream(valid_indices, weights)); + folly::coro::blockingWait( + streamer->tensor_stream(valid_indices, weights, std::nullopt)); } TEST(RawEmbeddingStreamerTest, TestStreamWithCopy) { @@ -171,11 +173,11 @@ TEST(RawEmbeddingStreamerTest, TestStreamWithCopy) { streamer->join_weights_stream_thread(); // Test blocking tensor copy - streamer->stream(indices, weights, count, true, true); + streamer->stream(indices, weights, std::nullopt, count, true, true); EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); // Test non-blocking tensor copy - streamer->stream(indices, weights, count, true, false); + streamer->stream(indices, weights, std::nullopt, count, true, false); EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); streamer->join_stream_tensor_copy_thread(); EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 2); @@ -225,7 +227,7 @@ TEST(RawEmbeddingStreamerTest, TestStreamE2E) { auto count = at::tensor( {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); - streamer->stream(indices, weights, count, true, true); + streamer->stream(indices, weights, std::nullopt, count, true, true); // Make sure dequeue finished std::this_thread::sleep_for(std::chrono::seconds(1)); streamer->join_weights_stream_thread(); @@ -260,6 +262,65 @@ TEST(RawEmbeddingStreamerTest, TestMismatchedIndicesWeights) { at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); EXPECT_CALL(*mock_service, co_setEmbeddings(_)).Times(0); - folly::coro::blockingWait(streamer->tensor_stream(indices, weights)); + folly::coro::blockingWait( + streamer->tensor_stream(indices, weights, std::nullopt)); +} + +TEST(RawEmbeddingStreamerTest, TestStreamWithIdentities) { + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_stream_identities", true, table_names, table_offsets, table_sizes); + + // Mock TrainingParameterServerService + auto mock_service = std::make_shared(); + auto mock_server = + std::make_shared( + mock_service, + "::1", + 0, + facebook::services::TLSConfig::applyDefaultsToThriftServer); + auto& mock_client_factory = + facebook::servicerouter::getMockSRClientFactory(false /* strict */); + mock_client_factory.registerMockService( + "realtime.delta.publish.esr", mock_server); + + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto identities = at::tensor( + {1001, 1002, 1003, 1004, 1005, 1006, 1007}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + // Test that identities are properly handled in tensor_stream + EXPECT_CALL(*mock_service, co_setEmbeddings(_)) + .Times(3) // 3 shards with consistent hashing + .WillRepeatedly(folly::coro::gmock_helpers::CoInvoke( + [](std::unique_ptr< + aiplatform::gmpp::experimental::training_ps::SetEmbeddingsRequest> + request) + -> folly::coro::Task< + std::unique_ptr> { + // Verify that the request is properly formed + EXPECT_GT(request->fqns()->size(), 0); + co_return std::make_unique< + aiplatform::gmpp::experimental::training_ps:: + SetEmbeddingsResponse>(); + })); + folly::coro::blockingWait( + streamer->tensor_stream(indices, weights, identities)); + + // Test streaming with identities using the stream method + streamer->join_weights_stream_thread(); // Stop dequeue thread for testing + streamer->stream(indices, weights, identities, count, true, true); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); } #endif diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 55e548af0f..8548d53e0a 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -301,6 +301,7 @@ void EmbeddingKVDB::stream_cuda( self->raw_embedding_streamer_->stream( indices, weights, + std::nullopt, /*identities*/ count, true, /*require_tensor_copy*/ blocking_tensor_copy); diff --git a/fbgemm_gpu/test/tbe/training/forward_test.py b/fbgemm_gpu/test/tbe/training/forward_test.py index 20bd998b6d..2109ad30c1 100644 --- a/fbgemm_gpu/test/tbe/training/forward_test.py +++ b/fbgemm_gpu/test/tbe/training/forward_test.py @@ -12,6 +12,7 @@ import math import random import unittest +from unittest.mock import MagicMock, patch import hypothesis.strategies as st import numpy as np @@ -24,6 +25,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + RESParams, SplitTableBatchedEmbeddingBagsCodegen, ) from fbgemm_gpu.tbe.utils import ( @@ -129,6 +131,8 @@ def execute_forward_( # noqa C901 use_cpu: bool, output_dtype: SparseType, use_experimental_tbe: bool, + enable_raw_embedding_streaming: bool = False, + prefetch_pipeline: bool = False, ) -> None: # NOTE: cache is not applicable to CPU version. assume(not use_cpu or not use_cache) @@ -158,6 +162,10 @@ def execute_forward_( # noqa C901 and pooling_mode != PoolingMode.NONE ) ) + # NOTE: Raw embedding streaming requires UVM cache + assume(not enable_raw_embedding_streaming or use_cache) + # NOTE: Raw embedding streaming not supported on CPU + assume(not enable_raw_embedding_streaming or not use_cpu) emb_op = SplitTableBatchedEmbeddingBagsCodegen if pooling_mode == PoolingMode.SUM: @@ -285,6 +293,16 @@ def execute_forward_( # noqa C901 else: f = torch.cat(fs, dim=0).view(-1, D) + # Create RES parameters if raw embedding streaming is enabled + res_params = None + if enable_raw_embedding_streaming: + res_params = RESParams( + res_store_shards=1, + table_names=[f"table_{i}" for i in range(T)], + table_offsets=[sum(Es[:i]) for i in range(T + 1)], + table_sizes=Es, + ) + # Create a TBE op cc = emb_op( embedding_specs=[ @@ -305,6 +323,9 @@ def execute_forward_( # noqa C901 pooling_mode=pooling_mode, output_dtype=output_dtype, use_experimental_tbe=use_experimental_tbe, + prefetch_pipeline=prefetch_pipeline, + enable_raw_embedding_streaming=enable_raw_embedding_streaming, + res_params=res_params, ) # Test torch JIT script compatibility if not use_cpu: @@ -1158,6 +1179,96 @@ def test_forward_fused_pooled_emb_quant( cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output ) + def _check_raw_embedding_stream_call_counts( + self, + mock_raw_embedding_stream: unittest.mock.Mock, + num_iterations: int, + prefetch_pipeline: bool, + L: int, + ) -> None: + # For TBE (not SSD), raw_embedding_stream is called once per prefetch + # when there's data to stream + expected_calls = num_iterations if L > 0 else 0 + if prefetch_pipeline: + # With prefetch pipeline, there might be fewer calls initially + expected_calls = max(0, expected_calls - 1) + + self.assertGreaterEqual(mock_raw_embedding_stream.call_count, 0) + # Allow some flexibility in call count due to caching behavior + self.assertLessEqual(mock_raw_embedding_stream.call_count, expected_calls + 2) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=64), + B=st.integers(min_value=1, max_value=32), + log_E=st.integers(min_value=3, max_value=4), + L=st.integers(min_value=1, max_value=10), + weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + cache_algorithm=st.sampled_from(CacheAlgorithm), + pooling_mode=st.sampled_from([PoolingMode.SUM, PoolingMode.MEAN]), + weighted=st.booleans(), + mixed=st.booleans(), + prefetch_pipeline=st.booleans(), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_raw_embedding_streaming( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + cache_algorithm: CacheAlgorithm, + pooling_mode: PoolingMode, + weighted: bool, + mixed: bool, + prefetch_pipeline: bool, + ) -> None: + """Test raw embedding streaming functionality integrated with forward pass.""" + num_iterations = 5 + # only LRU supports prefetch_pipeline + assume(not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU) + + with patch( + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" + ) as mock_streamer_class: + # Mock the RawEmbeddingStreamer class + mock_streamer_instance = MagicMock() + mock_streamer_class.return_value = mock_streamer_instance + + # Run multiple iterations to test streaming behavior + for _ in range(num_iterations): + self.execute_forward_( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + weights_precision=weights_precision, + weighted=weighted, + mixed=mixed, + mixed_B=False, # Keep simple for streaming tests + use_cache=True, # Required for streaming + cache_algorithm=cache_algorithm, + pooling_mode=pooling_mode, + use_cpu=False, # Streaming not supported on CPU + output_dtype=SparseType.FP32, + use_experimental_tbe=False, + enable_raw_embedding_streaming=True, + prefetch_pipeline=prefetch_pipeline, + ) + + self._check_raw_embedding_stream_call_counts( + mock_streamer_instance, num_iterations, prefetch_pipeline, L + ) + if __name__ == "__main__": unittest.main()