Skip to content

Commit f02bc4c

Browse files
chouxifacebook-github-bot
authored andcommitted
Extract the res backend to a separate class and export to python side (#4714)
Summary: Pull Request resolved: #4714 X-link: facebookresearch/FBGEMM#1738 We're extending the raw embedding streaming to tables with UVM_CACHING. Extract the backend out to a standalone class under deeplearning/fbgemm/fbgemm_gpu/src/split_embeddings_cache folder, so the logic could be reused by [SplitTableBatchedEmbeddingBagsCodegen](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py?lines=1349) that supports the UVM_CACHING tables. Differential Revision: D79192671
1 parent 8ecece6 commit f02bc4c

File tree

7 files changed

+750
-443
lines changed

7 files changed

+750
-443
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <ATen/ATen.h>
11+
#ifdef FBGEMM_FBCODE
12+
#include <folly/coro/Task.h>
13+
#endif
14+
15+
#include <utility>
16+
17+
namespace fbgemm_gpu {
18+
19+
struct StreamQueueItem {
20+
at::Tensor indices;
21+
at::Tensor weights;
22+
at::Tensor count;
23+
StreamQueueItem(
24+
at::Tensor src_indices,
25+
at::Tensor src_weights,
26+
at::Tensor src_count) {
27+
indices = std::move(src_indices);
28+
weights = std::move(src_weights);
29+
count = std::move(src_count);
30+
}
31+
};
32+
33+
class RawEmbeddingStreamer : public torch::jit::CustomClassHolder {
34+
public:
35+
explicit RawEmbeddingStreamer(
36+
std::string unique_id,
37+
bool enable_raw_embedding_streaming,
38+
int64_t res_store_shards,
39+
int64_t res_server_port,
40+
std::vector<std::string> table_names,
41+
std::vector<int64_t> table_offsets,
42+
const std::vector<int64_t>& table_sizes);
43+
44+
virtual ~RawEmbeddingStreamer();
45+
46+
/// Stream out non-negative elements in <indices> and its paired embeddings
47+
/// from <weights> for the first <count> elements in the tensor.
48+
/// It spins up a thread that will copy all 3 tensors to CPU and inject them
49+
/// into the background queue which will be picked up by another set of thread
50+
/// pools for streaming out to the thrift server (co-located on same host
51+
/// now).
52+
///
53+
/// This is used in cuda stream callback, which doesn't require to be
54+
/// serialized with other callbacks, thus a separate thread is used to
55+
/// maximize the overlapping with other callbacks.
56+
///
57+
/// @param indices The 1D embedding index tensor, should skip on negative
58+
/// value
59+
/// @param weights The 2D tensor that each row(embeddings) is paired up with
60+
/// relative element in <indices>
61+
/// @param count A single element tensor that contains the number of indices
62+
/// to be processed
63+
/// @param blocking_tensor_copy whether to copy the tensors to be streamed in
64+
/// a blocking manner
65+
///
66+
/// @return None
67+
void stream(
68+
const at::Tensor& indices,
69+
const at::Tensor& weights,
70+
const at::Tensor& count,
71+
bool require_tensor_copy,
72+
bool blocking_tensor_copy = true);
73+
74+
#ifdef FBGEMM_FBCODE
75+
folly::coro::Task<void> tensor_stream(
76+
const at::Tensor& indices,
77+
const at::Tensor& weights);
78+
/*
79+
* Copy the indices, weights and count tensors and enqueue them for
80+
* asynchronous stream.
81+
*/
82+
void copy_and_enqueue_stream_tensors(
83+
const at::Tensor& indices,
84+
const at::Tensor& weights,
85+
const at::Tensor& count);
86+
87+
/*
88+
* Join the stream tensor copy thread, make sure the thread is properly
89+
* finished before creating new.
90+
*/
91+
void join_stream_tensor_copy_thread();
92+
93+
/*
94+
* FOR TESTING: Join the weight stream thread, make sure the thread is
95+
* properly finished for destruction and testing.
96+
*/
97+
void join_weights_stream_thread();
98+
// FOR TESTING: get queue size.
99+
uint64_t get_weights_to_stream_queue_size();
100+
#endif
101+
private:
102+
std::atomic<bool> stop_{false};
103+
std::string unique_id_;
104+
bool enable_raw_embedding_streaming_;
105+
int64_t res_store_shards_;
106+
int64_t res_server_port_;
107+
std::vector<std::string> table_names_;
108+
std::vector<int64_t> table_offsets_;
109+
at::Tensor table_sizes_;
110+
std::unique_ptr<std::thread> weights_stream_thread_;
111+
folly::UMPSCQueue<StreamQueueItem, true> weights_to_stream_queue_;
112+
std::unique_ptr<std::thread> stream_tensor_copy_thread_;
113+
};
114+
115+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)