diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index a3d7469d58..0197c99667 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -20,6 +20,7 @@ from math import floor, log2 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch # usort:skip +from collections import defaultdict # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -375,6 +376,9 @@ def __init__( self.step = 0 self.last_flush_step = -1 + self.state_dict_seq_num = 0 + # pyre-ignore [4]: Attribute `_step_to_ssd_checkpoint_handle` of class `SSDTableBatchedEmbeddingBags` has no type specified. + self._step_to_ssd_checkpoint_handle = defaultdict(set) # Set prefetch pipeline self.prefetch_pipeline: bool = prefetch_pipeline @@ -3479,11 +3483,17 @@ def create_rocksdb_hard_link_snapshot(self) -> None: """ if self.backend_type == BackendType.SSD: self.ssd_db.create_rocksdb_hard_link_snapshot(self.step) + checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step) + self._step_to_ssd_checkpoint_handle[self.step].add(checkpoint_handle) else: logging.warning( "create_rocksdb_hard_link_snapshot is only supported for SSD backend" ) + def clear_rocksdb_hard_link_snapshot(self) -> None: + if self.backend_type == BackendType.SSD: + self._step_to_ssd_checkpoint_handle.clear() + def prepare_inputs( self, indices: Tensor, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 65d45fb568..2e8ac69694 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -218,15 +218,37 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { impl_->create_checkpoint(global_step); } - c10::intrusive_ptr get_active_checkpoint_uuid( - int64_t global_step) { + std::optional> + get_active_checkpoint_uuid(int64_t global_step) { auto uuid_opt = impl_->get_active_checkpoint_uuid(global_step); - if (uuid_opt.has_value()) { - return c10::make_intrusive( - uuid_opt.value(), impl_); - } else { - return nullptr; + if (!uuid_opt.has_value()) { + return std::nullopt; } + std::lock_guard _lk(g_mu); + for (auto it = g_cache.begin(); it != g_cache.end();) { + if (impl_->query_checkpoint_by_uuid(it->first)) { + ++it; + } else { + // handle already destroyed, remove its entry from cache + g_cache.erase(it++); + } + } + if (auto it = g_cache.find(uuid_opt.value()); it != g_cache.end()) { + if (auto sp = it->second.lock()) { + return sp; + } else { + // reference count is 0, handle is scheduled for destruction. + return std::nullopt; + } + } + + auto obj = c10::make_intrusive( + uuid_opt.value(), impl_); + g_cache.emplace( + uuid_opt.value(), + c10::weak_intrusive_ptr(obj)); + + return obj; } void set_backend_return_whole_row(bool backend_return_whole_row) { @@ -238,6 +260,12 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { // shared pointer since we use shared_from_this() in callbacks. std::shared_ptr impl_; + // cache of RocksdbCheckpointHandleWrapper + std::unordered_map< + std::string, + c10::weak_intrusive_ptr> + g_cache; + std::mutex g_mu; }; } // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 8bf1d5d3a2..9af3b44da8 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -617,6 +617,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { return checkpoints_.at(ckpt_uuid)->get_shard_checkpoints(); } + bool query_checkpoint_by_uuid(const std::string& ckpt_uuid) const { + return checkpoints_.find(ckpt_uuid) != checkpoints_.end(); + } + std::string get_tbe_uuid() const { return tbe_uuid_; } diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index 3fe04bfe86..cc32ba9b9b 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -394,6 +394,37 @@ def test_rocksdb_get_discrete_ids( assert torch.equal(ids_in_range_ordered, id_tensor_ordered) + @given(**default_st) + @settings(**default_settings) + def test_ssd_rocksdb_checkpoint_handle( + self, + T: int, + D: int, + log_E: int, + mixed: bool, + weights_precision: SparseType, + ) -> None: + emb, _, _ = self.generate_fbgemm_kv_tbe( + T, D, log_E, weights_precision, mixed, False, 8 + ) + + # expect no checkpoint handle + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is None, f"{checkpoint_handle=}" + # create a checkpoint + emb.create_rocksdb_hard_link_snapshot() + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is not None, f"{checkpoint_handle=}" + # delete checkpoint_handle, handle should still exist because emb holds a reference + del checkpoint_handle + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is not None, f"{checkpoint_handle=}" + # delete checkpoint_handle, and release emb's reference, handle should be destroyed + del checkpoint_handle + emb.clear_rocksdb_hard_link_snapshot() + checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0) + assert checkpoint_handle is None, f"{checkpoint_handle=}" + @given( E=st.integers(min_value=1000, max_value=10000), num_buckets=st.integers(min_value=10, max_value=15),