Skip to content

Commit 1487f36

Browse files
jiayulufacebook-github-bot
authored andcommitted
reuse checkpont handle
Summary: cache sdd rocksdb checkpoint handle so it can be used multiple times. Differential Revision: D80548370
1 parent 6b49418 commit 1487f36

File tree

4 files changed

+88
-7
lines changed

4 files changed

+88
-7
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from math import floor, log2
2121
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2222
import torch # usort:skip
23+
from collections import defaultdict
2324

2425
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
2526
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
@@ -375,6 +376,9 @@ def __init__(
375376

376377
self.step = 0
377378
self.last_flush_step = -1
379+
self.state_dict_seq_num = 0
380+
# pyre-ignore [4]: Attribute `_step_to_ssd_checkpoint_handle` of class `SSDTableBatchedEmbeddingBags` has no type specified.
381+
self._step_to_ssd_checkpoint_handle = defaultdict(set)
378382

379383
# Set prefetch pipeline
380384
self.prefetch_pipeline: bool = prefetch_pipeline
@@ -3479,11 +3483,17 @@ def create_rocksdb_hard_link_snapshot(self) -> None:
34793483
"""
34803484
if self.backend_type == BackendType.SSD:
34813485
self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
3486+
checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)
3487+
self._step_to_ssd_checkpoint_handle[self.step].add(checkpoint_handle)
34823488
else:
34833489
logging.warning(
34843490
"create_rocksdb_hard_link_snapshot is only supported for SSD backend"
34853491
)
34863492

3493+
def clear_rocksdb_hard_link_snapshot(self) -> None:
3494+
if self.backend_type == BackendType.SSD:
3495+
self._step_to_ssd_checkpoint_handle.clear()
3496+
34873497
def prepare_inputs(
34883498
self,
34893499
indices: Tensor,

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,37 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
218218
impl_->create_checkpoint(global_step);
219219
}
220220

221-
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> get_active_checkpoint_uuid(
222-
int64_t global_step) {
221+
std::optional<c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>
222+
get_active_checkpoint_uuid(int64_t global_step) {
223223
auto uuid_opt = impl_->get_active_checkpoint_uuid(global_step);
224-
if (uuid_opt.has_value()) {
225-
return c10::make_intrusive<RocksdbCheckpointHandleWrapper>(
226-
uuid_opt.value(), impl_);
227-
} else {
228-
return nullptr;
224+
if (!uuid_opt.has_value()) {
225+
return std::nullopt;
229226
}
227+
std::lock_guard<std::mutex> _lk(g_mu);
228+
for (auto it = g_cache.begin(); it != g_cache.end();) {
229+
if (impl_->query_checkpoint_by_uuid(it->first)) {
230+
++it;
231+
} else {
232+
// handle already destroyed, remove its entry from cache
233+
g_cache.erase(it++);
234+
}
235+
}
236+
if (auto it = g_cache.find(uuid_opt.value()); it != g_cache.end()) {
237+
if (auto sp = it->second.lock()) {
238+
return sp;
239+
} else {
240+
// reference count is 0, handle is scheduled for destruction.
241+
return std::nullopt;
242+
}
243+
}
244+
245+
auto obj = c10::make_intrusive<RocksdbCheckpointHandleWrapper>(
246+
uuid_opt.value(), impl_);
247+
g_cache.emplace(
248+
uuid_opt.value(),
249+
c10::weak_intrusive_ptr<RocksdbCheckpointHandleWrapper>(obj));
250+
251+
return obj;
230252
}
231253

232254
void set_backend_return_whole_row(bool backend_return_whole_row) {
@@ -238,6 +260,12 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
238260

239261
// shared pointer since we use shared_from_this() in callbacks.
240262
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
263+
// cache of RocksdbCheckpointHandleWrapper
264+
std::unordered_map<
265+
std::string,
266+
c10::weak_intrusive_ptr<RocksdbCheckpointHandleWrapper>>
267+
g_cache;
268+
std::mutex g_mu;
241269
};
242270

243271
} // namespace ssd

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
617617
return checkpoints_.at(ckpt_uuid)->get_shard_checkpoints();
618618
}
619619

620+
bool query_checkpoint_by_uuid(const std::string& ckpt_uuid) const {
621+
return checkpoints_.find(ckpt_uuid) != checkpoints_.end();
622+
}
623+
620624
std::string get_tbe_uuid() const {
621625
return tbe_uuid_;
622626
}

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3136,3 +3136,42 @@ def copy_opt_states_hook(
31363136
atol=tolerance,
31373137
rtol=tolerance,
31383138
)
3139+
3140+
@given(
3141+
weights_precision=st.sampled_from([SparseType.FP32]),
3142+
)
3143+
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
3144+
def test_ssd_rocksdb_checkpoint(self, weights_precision: SparseType) -> None:
3145+
import tempfile
3146+
3147+
E = int(1e4)
3148+
D = 128
3149+
3150+
feature_table_map = list(range(1))
3151+
emb = SSDTableBatchedEmbeddingBags(
3152+
embedding_specs=[(E, D)],
3153+
feature_table_map=feature_table_map,
3154+
ssd_storage_directory=tempfile.mkdtemp(),
3155+
cache_sets=1,
3156+
ssd_uniform_init_lower=-0.1,
3157+
ssd_uniform_init_upper=0.1,
3158+
weights_precision=weights_precision,
3159+
l2_cache_size=8,
3160+
)
3161+
3162+
# expect no checkpoint handle
3163+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
3164+
assert checkpoint_handle is None, f"{checkpoint_handle=}"
3165+
# create a checkpoint
3166+
emb.create_rocksdb_hard_link_snapshot()
3167+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
3168+
assert checkpoint_handle is not None, f"{checkpoint_handle=}"
3169+
# delete checkpoint_handle, handle should still exist because emb holds a reference
3170+
del checkpoint_handle
3171+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
3172+
assert checkpoint_handle is not None, f"{checkpoint_handle=}"
3173+
# delete checkpoint_handle, and release emb's reference, handle should be destroyed
3174+
del checkpoint_handle
3175+
emb.clear_rocksdb_hard_link_snapshot()
3176+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
3177+
assert checkpoint_handle is None, f"{checkpoint_handle=}"

0 commit comments

Comments
 (0)