Skip to content

reuse checkpont handle #4736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from math import floor, log2
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch # usort:skip
from collections import defaultdict

# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
Expand Down Expand Up @@ -375,6 +376,9 @@ def __init__(

self.step = 0
self.last_flush_step = -1
self.state_dict_seq_num = 0
# pyre-ignore [4]: Attribute `_step_to_ssd_checkpoint_handle` of class `SSDTableBatchedEmbeddingBags` has no type specified.
self._step_to_ssd_checkpoint_handle = defaultdict(set)

# Set prefetch pipeline
self.prefetch_pipeline: bool = prefetch_pipeline
Expand Down Expand Up @@ -3479,11 +3483,17 @@ def create_rocksdb_hard_link_snapshot(self) -> None:
"""
if self.backend_type == BackendType.SSD:
self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
checkpoint_handle = self.ssd_db.get_active_checkpoint_uuid(self.step)
self._step_to_ssd_checkpoint_handle[self.step].add(checkpoint_handle)
else:
logging.warning(
"create_rocksdb_hard_link_snapshot is only supported for SSD backend"
)

def clear_rocksdb_hard_link_snapshot(self) -> None:
if self.backend_type == BackendType.SSD:
self._step_to_ssd_checkpoint_handle.clear()

def prepare_inputs(
self,
indices: Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,37 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
impl_->create_checkpoint(global_step);
}

c10::intrusive_ptr<RocksdbCheckpointHandleWrapper> get_active_checkpoint_uuid(
int64_t global_step) {
std::optional<c10::intrusive_ptr<RocksdbCheckpointHandleWrapper>>
get_active_checkpoint_uuid(int64_t global_step) {
auto uuid_opt = impl_->get_active_checkpoint_uuid(global_step);
if (uuid_opt.has_value()) {
return c10::make_intrusive<RocksdbCheckpointHandleWrapper>(
uuid_opt.value(), impl_);
} else {
return nullptr;
if (!uuid_opt.has_value()) {
return std::nullopt;
}
std::lock_guard<std::mutex> _lk(g_mu);
for (auto it = g_cache.begin(); it != g_cache.end();) {
if (impl_->query_checkpoint_by_uuid(it->first)) {
++it;
} else {
// handle already destroyed, remove its entry from cache
g_cache.erase(it++);
}
}
if (auto it = g_cache.find(uuid_opt.value()); it != g_cache.end()) {
if (auto sp = it->second.lock()) {
return sp;
} else {
// reference count is 0, handle is scheduled for destruction.
return std::nullopt;
}
}

auto obj = c10::make_intrusive<RocksdbCheckpointHandleWrapper>(
uuid_opt.value(), impl_);
g_cache.emplace(
uuid_opt.value(),
c10::weak_intrusive_ptr<RocksdbCheckpointHandleWrapper>(obj));

return obj;
}

void set_backend_return_whole_row(bool backend_return_whole_row) {
Expand All @@ -238,6 +260,12 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {

// shared pointer since we use shared_from_this() in callbacks.
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
// cache of RocksdbCheckpointHandleWrapper
std::unordered_map<
std::string,
c10::weak_intrusive_ptr<RocksdbCheckpointHandleWrapper>>
g_cache;
std::mutex g_mu;
};

} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
return checkpoints_.at(ckpt_uuid)->get_shard_checkpoints();
}

bool query_checkpoint_by_uuid(const std::string& ckpt_uuid) const {
return checkpoints_.find(ckpt_uuid) != checkpoints_.end();
}

std::string get_tbe_uuid() const {
return tbe_uuid_;
}
Expand Down
31 changes: 31 additions & 0 deletions fbgemm_gpu/test/tbe/ssd/kv_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,37 @@ def test_rocksdb_get_discrete_ids(

assert torch.equal(ids_in_range_ordered, id_tensor_ordered)

@given(**default_st)
@settings(**default_settings)
def test_ssd_rocksdb_checkpoint_handle(
self,
T: int,
D: int,
log_E: int,
mixed: bool,
weights_precision: SparseType,
) -> None:
emb, _, _ = self.generate_fbgemm_kv_tbe(
T, D, log_E, weights_precision, mixed, False, 8
)

# expect no checkpoint handle
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
assert checkpoint_handle is None, f"{checkpoint_handle=}"
# create a checkpoint
emb.create_rocksdb_hard_link_snapshot()
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
assert checkpoint_handle is not None, f"{checkpoint_handle=}"
# delete checkpoint_handle, handle should still exist because emb holds a reference
del checkpoint_handle
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
assert checkpoint_handle is not None, f"{checkpoint_handle=}"
# delete checkpoint_handle, and release emb's reference, handle should be destroyed
del checkpoint_handle
emb.clear_rocksdb_hard_link_snapshot()
checkpoint_handle = emb._ssd_db.get_active_checkpoint_uuid(0)
assert checkpoint_handle is None, f"{checkpoint_handle=}"

@given(
E=st.integers(min_value=1000, max_value=10000),
num_buckets=st.integers(min_value=10, max_value=15),
Expand Down
Loading