Skip to content

Add tracking and streaming logic to SplitTableBatchedEmbeddingBagsCodegen #4741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/cmake/TbeInference.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ gpu_cpp_library(
${FBGEMM_GPU}/src/split_embeddings_cache/lru_cache_populate_byte.cpp
${FBGEMM_GPU}/src/split_embeddings_cache/lxu_cache.cpp
${FBGEMM_GPU}/src/split_embeddings_cache/split_embeddings_cache_ops.cpp
${FBGEMM_GPU}/src/split_embeddings_cache/raw_embedding_streamer.cpp
GPU_SRCS
${FBGEMM_GPU}/src/split_embeddings_cache/lfu_cache_find.cu
${FBGEMM_GPU}/src/split_embeddings_cache/lfu_cache_populate.cu
Expand Down
104 changes: 104 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch # usort:skip
from torch import nn, Tensor # usort:skip
from torch.autograd.profiler import record_function # usort:skip

# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
Expand Down Expand Up @@ -626,6 +627,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_list: List[Tensor]
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
prefetched_info: List[Tuple[Tensor, Tensor]]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
Expand Down Expand Up @@ -690,6 +692,8 @@ def __init__( # noqa C901
embedding_table_index_type: torch.dtype = torch.int64,
embedding_table_offset_type: torch.dtype = torch.int64,
embedding_shard_info: Optional[List[Tuple[int, int, int, int]]] = None,
enable_raw_embedding_streaming: bool = False,
res_params: Optional[RESParams] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()
self.uuid = str(uuid.uuid4())
Expand All @@ -700,6 +704,7 @@ def __init__( # noqa C901
)

self.logging_table_name: str = self.get_table_name_for_logging(table_names)
self.enable_raw_embedding_streaming: bool = enable_raw_embedding_streaming
self.pooling_mode = pooling_mode
self.is_nobag: bool = self.pooling_mode == PoolingMode.NONE

Expand Down Expand Up @@ -1460,6 +1465,30 @@ def __init__( # noqa C901
)
self.embedding_table_offset_type: torch.dtype = embedding_table_offset_type

self.prefetched_info: List[Tuple[Tensor, Tensor]] = torch.jit.annotate(
List[Tuple[Tensor, Tensor]], []
)
if self.enable_raw_embedding_streaming:
self.res_params: RESParams = res_params or RESParams()
self.res_params.table_sizes = [0] + list(accumulate(rows))
res_port_from_env = os.getenv("LOCAL_RES_PORT")
self.res_params.res_server_port = (
int(res_port_from_env) if res_port_from_env else 0
)
# pyre-fixme[4]: Attribute must be annotated.
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
self.uuid,
self.enable_raw_embedding_streaming,
self.res_params.res_store_shards,
self.res_params.res_server_port,
self.res_params.table_names,
self.res_params.table_offsets,
self.res_params.table_sizes,
)
logging.info(
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
)

@torch.jit.ignore
def log(self, msg: str) -> None:
"""
Expand Down Expand Up @@ -2521,7 +2550,13 @@ def _prefetch(
self.local_uvm_cache_stats.zero_()
self._report_io_size_count("prefetch_input", indices)

# streaming before updating the cache
self.raw_embedding_stream()

final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32)
linear_cache_indices_merged = torch.zeros(
0, dtype=indices.dtype, device=indices.device
)
for (
partial_indices,
partial_lxu_cache_locations,
Expand All @@ -2537,6 +2572,9 @@ def _prefetch(
vbe_metadata.max_B if vbe_metadata is not None else -1,
base_offset,
)
linear_cache_indices_merged = torch.cat(
[linear_cache_indices_merged, linear_cache_indices]
)

if (
self.record_cache_metrics.record_cache_miss_counter
Expand Down Expand Up @@ -2617,6 +2655,23 @@ def _prefetch(
if self.should_log():
self.print_uvm_cache_stats(use_local_cache=False)

if self.enable_raw_embedding_streaming:
with record_function(
"## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
):
(
linear_unique_indices,
linear_unique_indices_length,
_,
) = torch.ops.fbgemm.get_unique_indices(
linear_cache_indices_merged,
self.total_cache_hash_size,
compute_count=False,
)
self.prefetched_info.append(
(linear_unique_indices, linear_unique_indices_length)
)

def should_log(self) -> bool:
"""Determines if we should log for this step, using exponentially decreasing frequency.

Expand Down Expand Up @@ -3829,6 +3884,55 @@ def _debug_print_input_stats_factory_null(
return _debug_print_input_stats_factory_impl
return _debug_print_input_stats_factory_null

@torch.jit.ignore
def raw_embedding_stream(self) -> None:
if not self.enable_raw_embedding_streaming:
return None
# when pipelining is enabled
# prefetch in iter i happens before the backward sparse in iter i - 1
# so embeddings for iter i - 1's changed ids are not updated.
# so we can only fetch the indices from the iter i - 2
# when pipelining is disabled
# prefetch in iter i happens before forward iter i
# so we can get the iter i - 1's changed ids safely.
target_prev_iter = 1
if self.prefetch_pipeline:
target_prev_iter = 2
if not len(self.prefetched_info) > (target_prev_iter - 1):
return None
with record_function(
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
):
(updated_indices, updated_count) = self.prefetched_info.pop(0)
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
updated_indices,
self.lxu_cache_state,
self.total_cache_hash_size,
gather_cache_stats=False, # not collecting cache stats
num_uniq_cache_indices=updated_count,
)
updated_weights = torch.empty(
[updated_indices.size()[0], self.max_D_cache],
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
dtype=self.lxu_cache_weights.dtype,
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
device=self.lxu_cache_weights.device,
)
torch.ops.fbgemm.masked_index_select(
updated_weights,
updated_locations,
self.lxu_cache_weights,
updated_count,
)
# stream weights
self._raw_embedding_streamer.stream(
updated_indices.to(device=torch.device("cpu")),
updated_weights.to(device=torch.device("cpu")),
updated_count.to(device=torch.device("cpu")),
False, # require_tensor_copy
False, # blocking_tensor_copy
)


class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <ATen/ATen.h>
#ifdef FBGEMM_FBCODE
#include <folly/coro/Task.h>
#endif

#include <utility>

namespace fbgemm_gpu {

struct StreamQueueItem {
at::Tensor indices;
at::Tensor weights;
at::Tensor count;
StreamQueueItem(
at::Tensor src_indices,
at::Tensor src_weights,
at::Tensor src_count) {
indices = std::move(src_indices);
weights = std::move(src_weights);
count = std::move(src_count);
}
};

class RawEmbeddingStreamer : public torch::jit::CustomClassHolder {
public:
explicit RawEmbeddingStreamer(
std::string unique_id,
bool enable_raw_embedding_streaming,
int64_t res_store_shards,
int64_t res_server_port,
std::vector<std::string> table_names,
std::vector<int64_t> table_offsets,
const std::vector<int64_t>& table_sizes);

virtual ~RawEmbeddingStreamer();

/// Stream out non-negative elements in <indices> and its paired embeddings
/// from <weights> for the first <count> elements in the tensor.
/// It spins up a thread that will copy all 3 tensors to CPU and inject them
/// into the background queue which will be picked up by another set of thread
/// pools for streaming out to the thrift server (co-located on same host
/// now).
///
/// This is used in cuda stream callback, which doesn't require to be
/// serialized with other callbacks, thus a separate thread is used to
/// maximize the overlapping with other callbacks.
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights The 2D tensor that each row(embeddings) is paired up with
/// relative element in <indices>
/// @param count A single element tensor that contains the number of indices
/// to be processed
/// @param blocking_tensor_copy whether to copy the tensors to be streamed in
/// a blocking manner
///
/// @return None
void stream(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count,
bool require_tensor_copy,
bool blocking_tensor_copy = true);

#ifdef FBGEMM_FBCODE
folly::coro::Task<void> tensor_stream(
const at::Tensor& indices,
const at::Tensor& weights);
/*
* Copy the indices, weights and count tensors and enqueue them for
* asynchronous stream.
*/
void copy_and_enqueue_stream_tensors(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count);

/*
* Join the stream tensor copy thread, make sure the thread is properly
* finished before creating new.
*/
void join_stream_tensor_copy_thread();

/*
* FOR TESTING: Join the weight stream thread, make sure the thread is
* properly finished for destruction and testing.
*/
void join_weights_stream_thread();
// FOR TESTING: get queue size.
uint64_t get_weights_to_stream_queue_size();
#endif
private:
std::atomic<bool> stop_{false};
std::string unique_id_;
bool enable_raw_embedding_streaming_;
int64_t res_store_shards_;
int64_t res_server_port_;
std::vector<std::string> table_names_;
std::vector<int64_t> table_offsets_;
at::Tensor table_sizes_;
#ifdef FBGEMM_FBCODE
std::unique_ptr<std::thread> weights_stream_thread_;
folly::UMPSCQueue<StreamQueueItem, true> weights_to_stream_queue_;
std::unique_ptr<std::thread> stream_tensor_copy_thread_;
#endif
};

} // namespace fbgemm_gpu
Loading
Loading