diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 65d45fb568..e69a02854c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -218,14 +218,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { impl_->create_checkpoint(global_step); } - c10::intrusive_ptr get_active_checkpoint_uuid( - int64_t global_step) { + std::optional> + get_active_checkpoint_uuid(int64_t global_step) { auto uuid_opt = impl_->get_active_checkpoint_uuid(global_step); if (uuid_opt.has_value()) { return c10::make_intrusive( uuid_opt.value(), impl_); } else { - return nullptr; + return std::nullopt; } } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h index e9bd6864ae..2a9dfdd617 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h @@ -47,7 +47,7 @@ struct RocksdbCheckpointHandleWrapper : public torch::jit::CustomClassHolder { const std::string& checkpoint_uuid, std::shared_ptr db); - ~RocksdbCheckpointHandleWrapper(); + // ~RocksdbCheckpointHandleWrapper(); std::string uuid; std::shared_ptr db; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index db7d489836..f96580afe3 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -357,9 +357,12 @@ RocksdbCheckpointHandleWrapper::RocksdbCheckpointHandleWrapper( std::shared_ptr db) : uuid(checkpoint_uuid), db(std::move(db)) {} -RocksdbCheckpointHandleWrapper::~RocksdbCheckpointHandleWrapper() { - db->release_checkpoint(uuid); -} +// do not release uuid when RocksdbCheckpointHandleWrapper is destroyed +// subsequent get_active_checkpoint_uuid() calls need to retrieve +// the checkpoint uuid +// RocksdbCheckpointHandleWrapper::~RocksdbCheckpointHandleWrapper() { +// db->release_checkpoint(uuid); +// } KVTensorWrapper::KVTensorWrapper( std::vector shape, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 8bf1d5d3a2..e1ce66017a 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -593,7 +593,29 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { "removing the prev rdb ckpt, please make sure it has fullfilled " "its use case, e.g. checkpoint and publish"; } + + // remove old checkoint handler as they are likely not needed anymore + // this can not be done the same way as rdb snapshot, because we only + // create 1 rdb checkpoint per use cases(ckpt/publish), however publish + // calls state_dict multiple times, most of them to just get the fqn, they + // throw away pmt immediately if such state_dict is called at the beginning, + // it will destroy the checkpoint handler and when the real use case needs + // rdb ckpt, it is removed already + if (global_step - 10000 > 0) { + std::vector ckpt_uuids_to_purge; + for (const auto& [glb_step, ckpt_uuid] : global_step_to_ckpt_uuid_) { + if (glb_step < global_step - 10000) { + ckpt_uuids_to_purge.push_back(ckpt_uuid); + } + } + for (auto& ckpt_uuid : ckpt_uuids_to_purge) { + release_checkpoint(ckpt_uuid); + } + } + auto ckpt_uuid = facebook::strings::generateUUID(); + + LOG(INFO) << "creating new rocksdb checkpoint, uuid:" << ckpt_uuid; auto handle = std::make_unique( this, tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_); checkpoints_[ckpt_uuid] = std::move(handle); diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index 3fe04bfe86..f34c1a2f06 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -394,6 +394,32 @@ def test_rocksdb_get_discrete_ids( assert torch.equal(ids_in_range_ordered, id_tensor_ordered) + @given(**default_st) + @settings(**default_settings) + def test_ssd_rocksdb_checkpoint_handle( + self, + T: int, + D: int, + log_E: int, + mixed: bool, + weights_precision: SparseType, + ) -> None: + emb, _, _ = self.generate_fbgemm_kv_tbe( + T, D, log_E, weights_precision, mixed, False, 8 + ) + + # expect no checkpoint handle + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is None, f"{checkpoint_handle=}" + # create a checkpoint + emb.create_rocksdb_hard_link_snapshot() + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is not None, f"{checkpoint_handle=}" + # delete checkpoint_handle, handle should still exist because emb holds a reference + del checkpoint_handle + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is not None, f"{checkpoint_handle=}" + @given( E=st.integers(min_value=1000, max_value=10000), num_buckets=st.integers(min_value=10, max_value=15),