Skip to content

Commit 2f02195

Browse files
committed
delay destructiuon of rocksdb checkpoint handle (#4740)
Summary: Pull Request resolved: #4740 X-link: facebookresearch/FBGEMM#1761 delay the destruction of rocksdb checkpoint handle, so that rdb checkpoint handle can be used multiple times. this is a stopgap measure for the silvertorch publishing use case, until the rdb checkpoint lifecycle unification effort is ready. Reviewed By: duduyi2013 Differential Revision: D80564473
1 parent 87f413c commit 2f02195

File tree

5 files changed

+58
-7
lines changed

5 files changed

+58
-7
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
218218
impl_->create_checkpoint(global_step);
219219
}
220220

221-
c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> get_active_checkpoint_uuid(
222-
int64_t global_step) {
221+
std::optional<c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>
222+
get_active_checkpoint_uuid(int64_t global_step) {
223223
auto uuid_opt = impl_->get_active_checkpoint_uuid(global_step);
224224
if (uuid_opt.has_value()) {
225225
return c10::make_intrusive<RocksdbCheckpointHandleWrapper>(
226226
uuid_opt.value(), impl_);
227227
} else {
228-
return nullptr;
228+
return std::nullopt;
229229
}
230230
}
231231

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct RocksdbCheckpointHandleWrapper : public torch::jit::CustomClassHolder {
4747
const std::string& checkpoint_uuid,
4848
std::shared_ptr<EmbeddingRocksDB> db);
4949

50-
~RocksdbCheckpointHandleWrapper();
50+
// ~RocksdbCheckpointHandleWrapper();
5151

5252
std::string uuid;
5353
std::shared_ptr<EmbeddingRocksDB> db;

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,12 @@ RocksdbCheckpointHandleWrapper::RocksdbCheckpointHandleWrapper(
357357
std::shared_ptr<EmbeddingRocksDB> db)
358358
: uuid(checkpoint_uuid), db(std::move(db)) {}
359359

360-
RocksdbCheckpointHandleWrapper::~RocksdbCheckpointHandleWrapper() {
361-
db->release_checkpoint(uuid);
362-
}
360+
// do not release uuid when RocksdbCheckpointHandleWrapper is destroyed
361+
// subsequent get_active_checkpoint_uuid() calls need to retrieve
362+
// the checkpoint uuid
363+
// RocksdbCheckpointHandleWrapper::~RocksdbCheckpointHandleWrapper() {
364+
// db->release_checkpoint(uuid);
365+
// }
363366

364367
KVTensorWrapper::KVTensorWrapper(
365368
std::vector<int64_t> shape,

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,29 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
593593
"removing the prev rdb ckpt, please make sure it has fullfilled "
594594
"its use case, e.g. checkpoint and publish";
595595
}
596+
597+
// remove old checkoint handler as they are likely not needed anymore
598+
// this can not be done the same way as rdb snapshot, because we only
599+
// create 1 rdb checkpoint per use cases(ckpt/publish), however publish
600+
// calls state_dict multiple times, most of them to just get the fqn, they
601+
// throw away pmt immediately if such state_dict is called at the beginning,
602+
// it will destroy the checkpoint handler and when the real use case needs
603+
// rdb ckpt, it is removed already
604+
if (global_step - 10000 > 0) {
605+
std::vector<std::string> ckpt_uuids_to_purge;
606+
for (const auto& [glb_step, ckpt_uuid] : global_step_to_ckpt_uuid_) {
607+
if (glb_step < global_step - 10000) {
608+
ckpt_uuids_to_purge.push_back(ckpt_uuid);
609+
}
610+
}
611+
for (auto& ckpt_uuid : ckpt_uuids_to_purge) {
612+
release_checkpoint(ckpt_uuid);
613+
}
614+
}
615+
596616
auto ckpt_uuid = facebook::strings::generateUUID();
617+
618+
LOG(INFO) << "creating new rocksdb checkpoint, uuid:" << ckpt_uuid;
597619
auto handle = std::make_unique<CheckpointHandle>(
598620
this, tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_);
599621
checkpoints_[ckpt_uuid] = std::move(handle);

fbgemm_gpu/test/tbe/ssd/kv_backend_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,32 @@ def test_rocksdb_get_discrete_ids(
394394

395395
assert torch.equal(ids_in_range_ordered, id_tensor_ordered)
396396

397+
@given(**default_st)
398+
@settings(**default_settings)
399+
def test_ssd_rocksdb_checkpoint_handle(
400+
self,
401+
T: int,
402+
D: int,
403+
log_E: int,
404+
mixed: bool,
405+
weights_precision: SparseType,
406+
) -> None:
407+
emb, _, _ = self.generate_fbgemm_kv_tbe(
408+
T, D, log_E, weights_precision, mixed, False, 8
409+
)
410+
411+
# expect no checkpoint handle
412+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
413+
assert checkpoint_handle is None, f"{checkpoint_handle=}"
414+
# create a checkpoint
415+
emb.create_rocksdb_hard_link_snapshot()
416+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
417+
assert checkpoint_handle is not None, f"{checkpoint_handle=}"
418+
# delete checkpoint_handle, handle should still exist because emb holds a reference
419+
del checkpoint_handle
420+
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
421+
assert checkpoint_handle is not None, f"{checkpoint_handle=}"
422+
397423
@given(
398424
E=st.integers(min_value=1000, max_value=10000),
399425
num_buckets=st.integers(min_value=10, max_value=15),

0 commit comments

Comments
 (0)