From f372505e294c45223efae2a47ae176d23cea64be Mon Sep 17 00:00:00 2001 From: Eddy Li Date: Thu, 21 Aug 2025 10:55:03 -0700 Subject: [PATCH] Remove dry run for feature score eviction Summary: Previously we are using dry run to scan all weight blocks in backend and assign each feature score into eviction bucket and calculate the eviction threshold. This diff is removing dry run process and put the process assign eviction bucket into update and eviction block. This can save half time about total eviction duration. Reviewed By: emlin Differential Revision: D80425794 --- .../dram_kv_embedding_cache.h | 142 +++-- .../dram_kv_embedding_cache_wrapper.h | 5 +- .../dram_kv_embedding_cache/feature_evict.h | 568 ++++++------------ .../feature_evict_test.cpp | 26 +- fbgemm_gpu/test/tbe/ssd/kv_backend_test.py | 45 +- 5 files changed, 316 insertions(+), 470 deletions(-) diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 80cefd3f74..b66793fd0d 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -384,6 +384,19 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { local_write_allocate_total_duration += facebook::WallClockUtil::NowInUsecFast() - before_alloc_ts; + if (feature_evict_config_.has_value() && + feature_evict_config_.value()->trigger_mode_ != + EvictTriggerMode::DISABLED && + feature_evict_) { + auto* feature_score_evict = dynamic_cast< + FeatureScoreBasedEvict*>( + feature_evict_.get()); + if (feature_score_evict) { + feature_score_evict + ->update_feature_score_statistics( + block, 0, shard_id, true); + } + } } if (feature_evict_config_.has_value() && feature_evict_config_.value()->trigger_mode_ != @@ -705,16 +718,21 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { auto it = wlmap->find(id); if (it != wlmap->end()) { block = it->second; + feature_score_evict + ->update_feature_score_statistics( + block, engege_rate, shard_id, false); } else { // Key doesn't exist, allocate new block and // insert. block = pool->template allocate_t(); FixedBlockPool::set_key(block, id); + FixedBlockPool::set_feature_score_rate( + block, engege_rate); wlmap->insert({id, block}); + feature_score_evict + ->update_feature_score_statistics( + block, 0, shard_id, true); } - - feature_score_evict->update_feature_score_statistics( - block, engege_rate); updated_id_count++; } } @@ -1489,61 +1507,75 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { const auto indexes = iter->second; auto f = folly::via(executor_.get()) - .thenValue( - [this, shard_id, indexes, &indices, &weights_with_metaheader]( - folly::Unit) { - FBGEMM_DISPATCH_INTEGRAL_TYPES( - indices.scalar_type(), - "dram_kv_set_with_metaheader", - [this, - shard_id, - indexes, - &indices, - &weights_with_metaheader] { - using index_t = scalar_t; - CHECK(indices.is_contiguous()); - CHECK(weights_with_metaheader.is_contiguous()); - CHECK_EQ( - indices.size(0), weights_with_metaheader.size(0)); - { - auto wlmap = kv_store_.by(shard_id).wlock(); - auto* pool = kv_store_.pool_by(shard_id); - int64_t stride = weights_with_metaheader.size(1); - auto indices_data_ptr = indices.data_ptr(); - auto weights_data_ptr = - weights_with_metaheader.data_ptr(); - for (auto index_iter = indexes.begin(); - index_iter != indexes.end(); - index_iter++) { - const auto& id_index = *index_iter; - auto id = int64_t(indices_data_ptr[id_index]); - // Defensive programming - // used is false shouldn't occur under normal - // circumstances - FixedBlockPool::set_used( - weights_data_ptr + id_index * stride, true); - - // use mempool - weight_type* block = nullptr; - // First check if the key already exists - auto it = wlmap->find(id); - if (it != wlmap->end()) { - block = it->second; - } else { - // Key doesn't exist, allocate new block and - // insert. - block = - pool->template allocate_t(); - wlmap->insert({id, block}); + .thenValue([this, + shard_id, + indexes, + &indices, + &weights_with_metaheader](folly::Unit) { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), + "dram_kv_set_with_metaheader", + [this, + shard_id, + indexes, + &indices, + &weights_with_metaheader] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights_with_metaheader.is_contiguous()); + CHECK_EQ( + indices.size(0), weights_with_metaheader.size(0)); + { + auto wlmap = kv_store_.by(shard_id).wlock(); + auto* pool = kv_store_.pool_by(shard_id); + int64_t stride = weights_with_metaheader.size(1); + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = + weights_with_metaheader.data_ptr(); + for (auto index_iter = indexes.begin(); + index_iter != indexes.end(); + index_iter++) { + const auto& id_index = *index_iter; + auto id = int64_t(indices_data_ptr[id_index]); + // Defensive programming + // used is false shouldn't occur under normal + // circumstances + FixedBlockPool::set_used( + weights_data_ptr + id_index * stride, true); + + // use mempool + weight_type* block = nullptr; + // First check if the key already exists + auto it = wlmap->find(id); + if (it != wlmap->end()) { + block = it->second; + } else { + // Key doesn't exist, allocate new block and + // insert. + block = pool->template allocate_t(); + wlmap->insert({id, block}); + if (feature_evict_config_.has_value() && + feature_evict_config_.value()->trigger_mode_ != + EvictTriggerMode::DISABLED && + feature_evict_) { + auto* feature_score_evict = dynamic_cast< + FeatureScoreBasedEvict*>( + feature_evict_.get()); + if (feature_score_evict) { + feature_score_evict + ->update_feature_score_statistics( + block, 0, shard_id, true); } - std::copy( - weights_data_ptr + id_index * stride, - weights_data_ptr + (id_index + 1) * stride, - block); } } - }); - }); + std::copy( + weights_data_ptr + id_index * stride, + weights_data_ptr + (id_index + 1) * stride, + block); + } + } + }); + }); futures.push_back(std::move(f)); } return folly::collect(futures); diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h index 94c5939732..6744e1bd76 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h @@ -148,8 +148,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { at::Tensor processed_counts, at::Tensor eviction_threshold_with_dry_run, at::Tensor full_duration_ms, - at::Tensor exec_duration_ms, - at::Tensor dry_run_exec_duration_ms) { + at::Tensor exec_duration_ms) { auto metrics = impl_->get_feature_evict_metric(); if (metrics.has_value()) { evicted_counts.copy_( @@ -164,8 +163,6 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { metrics.value().full_duration_ms); // full duration (Long) exec_duration_ms.copy_( metrics.value().exec_duration_ms); // exec duration (Long) - dry_run_exec_duration_ms.copy_( - metrics.value().dry_run_exec_duration_ms); // dry run exec duration } } diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h index feec68fd02..f2f75a46b9 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h @@ -268,7 +268,6 @@ struct FeatureEvictMetrics { eviction_threshold_with_dry_run.resize(table_num, 0.0); exec_duration_ms = 0; full_duration_ms = 0; - dry_run_exec_duration_ms = 0; } void reset() { @@ -280,7 +279,6 @@ struct FeatureEvictMetrics { 0.0); exec_duration_ms = 0; full_duration_ms = 0; - dry_run_exec_duration_ms = 0; start_time_ms = std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()) @@ -296,7 +294,6 @@ struct FeatureEvictMetrics { // The exec_duration of all shards will be accumulated during the // statistics So finally, the number of shards needs to be divided exec_duration_ms /= num_shards; - dry_run_exec_duration_ms /= num_shards; } std::vector evicted_counts; @@ -304,7 +301,6 @@ struct FeatureEvictMetrics { std::vector eviction_threshold_with_dry_run; int64_t exec_duration_ms; int64_t full_duration_ms; - int64_t dry_run_exec_duration_ms; int64_t start_time_ms; }; @@ -315,7 +311,6 @@ struct FeatureEvictMetricTensors { processed_counts(at::zeros({table_num}, at::kLong)), eviction_threshold_with_dry_run(at::zeros({table_num}, at::kFloat)), exec_duration_ms(at::scalar_tensor(0, at::kLong)), - dry_run_exec_duration_ms(at::scalar_tensor(0, at::kLong)), full_duration_ms(at::scalar_tensor(0, at::kLong)) {} // Constructor to initialize from existing tensors @@ -324,14 +319,12 @@ struct FeatureEvictMetricTensors { at::Tensor processed, at::Tensor eviction_threshold_with_dry_run, at::Tensor exec_duration, - at::Tensor dry_run_exec_duration_ms, at::Tensor full_duration) : evicted_counts(std::move(evicted)), processed_counts(std::move(processed)), eviction_threshold_with_dry_run( std::move(eviction_threshold_with_dry_run)), exec_duration_ms(std::move(exec_duration)), - dry_run_exec_duration_ms(std::move(dry_run_exec_duration_ms)), full_duration_ms(std::move(full_duration)) {} [[nodiscard]] FeatureEvictMetricTensors clone() const { @@ -340,7 +333,6 @@ struct FeatureEvictMetricTensors { processed_counts.clone(), eviction_threshold_with_dry_run.clone(), exec_duration_ms.clone(), - dry_run_exec_duration_ms.clone(), full_duration_ms.clone()}; } @@ -352,22 +344,16 @@ struct FeatureEvictMetricTensors { at::Tensor eviction_threshold_with_dry_run; // feature evict exec duration at::Tensor exec_duration_ms; - // feature evict dry run exec duration - at::Tensor dry_run_exec_duration_ms; // feature evict full duration(from trigger to finish) at::Tensor full_duration_ms; }; -enum class EvictState { Idle, Dry_Run_Ongoing, Dry_Run_Done, Evict_Ongoing }; +enum class EvictState { Idle, Evict_Ongoing }; inline std::string to_string(EvictState state) { switch (state) { case EvictState::Idle: return "Idle"; - case EvictState::Dry_Run_Ongoing: - return "Dry_Run_Ongoing"; - case EvictState::Dry_Run_Done: - return "Dry_Run_Done"; case EvictState::Evict_Ongoing: return "Evict_Ongoing"; default: @@ -385,7 +371,6 @@ class FeatureEvict { int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, bool is_training = true, - bool enable_dry_run = false, TestMode test_mode = TestMode::DISABLED) : kv_store_(kv_store), evict_state_(EvictState::Idle), @@ -401,7 +386,6 @@ class FeatureEvict { interval_for_feature_statistics_decay_s_( interval_for_feature_statistics_decay_s), is_training_(is_training), - enable_dry_run_(enable_dry_run), test_mode_(test_mode) { executor_ = std::make_unique(num_shards_); @@ -421,8 +405,6 @@ class FeatureEvict { evict_cv_.notify_all(); // wait until futures all finished - folly::collectAll(dry_run_futures_).wait(); - dry_run_futures_.clear(); folly::collectAll(futures_).wait(); futures_.clear(); }; @@ -450,57 +432,24 @@ class FeatureEvict { if (!reach_interval_to_trigger_new_round()) { return; } - if (enable_dry_run_) { - LOG(INFO) << "trigger new round of eviction with dry run enabled: " - << to_string(evict_state_); - auto evict_state = evict_state_.load(); - // if dry run or eviction task is ongoing, return directly - if (evict_state == EvictState::Dry_Run_Ongoing || - evict_state == EvictState::Evict_Ongoing) { - return; - } - // if no dry run running or finished and no evict running, start new - // round eviction, run dry run first - if (evict_state == EvictState::Idle) { - sanity_check_before_new_round(); - evict_state_.store(EvictState::Dry_Run_Ongoing); - prepare_evict(); - - // Decide should decay or not. - if (reach_interval_to_decay_feature_statistics()) { - should_decay_.store(true); - } else { - should_decay_.store(false); - } - - LOG(INFO) - << "Trigger dry run eviction to get the feature evict threshold with decay " - << should_decay_.load(); - - for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { - submit_shard_task(shard_id, true); // Dry run is true - } - } else if (evict_state == EvictState::Dry_Run_Done) { - // if dry run is done, run eviction - evict_state_.store(EvictState::Evict_Ongoing); - LOG(INFO) << "Trigger new round eviction after dry run"; - for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { - submit_shard_task(shard_id, false); - } - return; - } + if (evict_state_.load() == EvictState::Evict_Ongoing) { + return; + } + evict_state_.store(EvictState::Evict_Ongoing); + sanity_check_before_new_round(); + prepare_evict(); + // Decide should decay or not. + if (reach_interval_to_decay_feature_statistics()) { + should_decay_.store(true); } else { - if (evict_state_.load() == EvictState::Evict_Ongoing) { - return; - } - evict_state_.store(EvictState::Evict_Ongoing); - LOG(INFO) << "trigger new round of eviction"; - sanity_check_before_new_round(); - prepare_evict(); - for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { - submit_shard_task(shard_id, false); - } + should_decay_.store(false); + } + + LOG(INFO) << "trigger new round of eviction with decay=" + << should_decay_.load(); + for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { + submit_shard_task(shard_id); } } @@ -545,9 +494,6 @@ class FeatureEvict { void wait_until_eviction_done() { resume(); - folly::collectAll(dry_run_futures_).wait(); - dry_run_futures_.clear(); - folly::collectAll(futures_).wait(); futures_.clear(); } @@ -555,8 +501,6 @@ class FeatureEvict { virtual void update_feature_statistics(weight_type* block) = 0; void wait_completion() { - folly::collectAll(dry_run_futures_).wait(); - dry_run_futures_.clear(); folly::collectAll(futures_).wait(); futures_.clear(); } @@ -575,19 +519,12 @@ class FeatureEvict { << "found " << futures_.size() << " futures before triggering new round of evict, " << "this should be either 0 or num shards:" << num_shards_; - CHECK( - dry_run_futures_.size() == 0 || dry_run_futures_.size() == num_shards_) - << "found " << dry_run_futures_.size() - << " futures before triggering new round of evict, " - << "this should be either 0 or num shards:" << num_shards_; } void init_shard_status() { block_cursors_.resize(num_shards_); - dry_run_block_cursors_.resize(num_shards_); block_nums_snapshot_.resize(num_shards_); for (int i = 0; i < num_shards_; ++i) { block_cursors_[i] = 0; - dry_run_block_cursors_[i] = 0; block_nums_snapshot_[i] = 0; } } @@ -600,33 +537,21 @@ class FeatureEvict { block_nums_snapshot_[shard_id] = mempool->get_chunks().size() * mempool->get_blocks_per_chunk(); block_cursors_[shard_id] = 0; - dry_run_block_cursors_[shard_id] = 0; } metrics_.reset(); - dry_run_futures_.clear(); futures_.clear(); finished_evictions_.store(0); - finished_dry_run_.store(0); // make sure we don't start right away, wait until resume() is called evict_interrupt_.store(true); } // submitting eviction job to the executor - void submit_shard_task(int shard_id, bool dry_run) { - if (dry_run) { - dry_run_futures_.emplace_back( - folly::via(executor_.get()) - .thenValue([this, shard_id, dry_run](auto&&) { - process_shard(shard_id, true); - update_evict_finish_flags(shard_id, dry_run); - })); - } else { - futures_.emplace_back(folly::via(executor_.get()) - .thenValue([this, shard_id, dry_run](auto&&) { - process_shard(shard_id, dry_run); - update_evict_finish_flags(shard_id, dry_run); - })); - } + void submit_shard_task(int shard_id) { + futures_.emplace_back( + folly::via(executor_.get()).thenValue([this, shard_id](auto&&) { + process_shard(shard_id); + update_evict_finish_flags(shard_id); + })); } bool reach_interval_to_trigger_new_round() { @@ -662,23 +587,12 @@ class FeatureEvict { block_cursors_[shard_id] >= block_nums_snapshot_[shard_id]; } - // conditions where we need to break the dry run evict loop - bool should_exit_dry_run_loop(int shard_id) { - return evict_interrupt_.load() || - dry_run_block_cursors_[shard_id] >= block_nums_snapshot_[shard_id]; - } - // check whether there is any evict neither paused nor finished bool has_running_evict() { return (num_waiting_evicts_.load() + finished_evictions_.load()) != num_shards_; } - // check whether there is any dry run neither paused nor finished - bool has_running_dry_run() { - return finished_dry_run_.load() != sub_table_hash_cumsum_.size(); - } - // the inner loop of each evict that can be paused void start_training_eviction_loop( int shard_id, @@ -696,7 +610,7 @@ class FeatureEvict { int64_t key = FixedBlockPool::get_key(block); int sub_table_id = get_sub_table_id(key); processed_counts[sub_table_id]++; - if (evict_block(block, sub_table_id)) { + if (evict_block(block, sub_table_id, shard_id)) { auto it = wlock->find(key); if (it != wlock->end() && block == it->second) { auto time_elapsed = FixedBlockPool::current_timestamp() - @@ -734,7 +648,7 @@ class FeatureEvict { int64_t key = FixedBlockPool::get_key(block); int sub_table_id = get_sub_table_id(key); processed_counts[sub_table_id]++; - if (evict_block(block, sub_table_id)) { + if (evict_block(block, sub_table_id, shard_id)) { pool->template deallocate_t(block); evicted_counts[sub_table_id]++; evicting_keys.push_back(key); @@ -743,8 +657,8 @@ class FeatureEvict { mem_pool_lock.unlock(); // lock dram kv shard hash map to remove evicted blocks in the map - // dedicate map update in a wlock to reduce the blocking time for inference - // read + // dedicate map update in a wlock to reduce the blocking time for + // inference read auto shard_map_wlock = kv_store_.by(shard_id).wlock(); for (auto& key : evicting_keys) { shard_map_wlock->erase(key); @@ -761,10 +675,6 @@ class FeatureEvict { return true; } - if (!should_exit_dry_run_loop(shard_id)) { - return true; - } - num_waiting_evicts_++; evict_cv_.wait(lock, [this] { return !evict_interrupt_.load(); }); num_waiting_evicts_--; @@ -775,105 +685,78 @@ class FeatureEvict { return true; } - // the outer loop of each evict round that only exits when evction round is - // done - void process_shard(int shard_id, bool dry_run) { + // the outer loop of each evict round that only exits when evction round + // is done + void process_shard(int shard_id) { std::chrono::milliseconds duration{}; - if (dry_run) { + pre_calculate_thresholds(shard_id); + std::vector evicted_counts(sub_table_hash_cumsum_.size(), 0); + std::vector processed_counts(sub_table_hash_cumsum_.size(), 0); + // each active eviction round + while (block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { + if (!wait_until_resume(shard_id)) { + return; + } auto start_time = std::chrono::high_resolution_clock::now(); - dry_run_calculate_thresholds(shard_id); + if (is_training_) { + start_training_eviction_loop( + shard_id, evicted_counts, processed_counts); + } else { + start_inference_eviction_loop( + shard_id, evicted_counts, processed_counts); + } duration += std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start_time); - { - std::unique_lock lock(metric_mtx_); - metrics_.dry_run_exec_duration_ms += duration.count(); - } - } else { - std::vector evicted_counts(sub_table_hash_cumsum_.size(), 0); - std::vector processed_counts(sub_table_hash_cumsum_.size(), 0); - // each active eviction round - while (block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { - if (!wait_until_resume(shard_id)) { - return; - } - auto start_time = std::chrono::high_resolution_clock::now(); - if (is_training_) { - start_training_eviction_loop( - shard_id, evicted_counts, processed_counts); - } else { - start_inference_eviction_loop( - shard_id, evicted_counts, processed_counts); - } - duration += std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_time); - } - if (test_mode_ == TestMode::PAUSE_ON_LAST_ITERATION && - block_cursors_[shard_id] == block_nums_snapshot_[shard_id]) { - last_iter_shards_[shard_id]->store(true); - should_call_.store(true); - // hold on on the last iteration for a while waiting for the UT to call - // pause before updating the shards_finished - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - { - std::unique_lock lock(metric_mtx_); - metrics_.exec_duration_ms += duration.count(); - for (size_t i = 0; i < evicted_counts.size(); ++i) { - metrics_.evicted_counts[i] += evicted_counts[i]; - metrics_.processed_counts[i] += processed_counts[i]; - } + } + if (test_mode_ == TestMode::PAUSE_ON_LAST_ITERATION && + block_cursors_[shard_id] == block_nums_snapshot_[shard_id]) { + last_iter_shards_[shard_id]->store(true); + should_call_.store(true); + // hold on on the last iteration for a while waiting for the UT to + // call pause before updating the shards_finished + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + { + std::unique_lock lock(metric_mtx_); + metrics_.exec_duration_ms += duration.count(); + for (size_t i = 0; i < evicted_counts.size(); ++i) { + metrics_.evicted_counts[i] += evicted_counts[i]; + metrics_.processed_counts[i] += processed_counts[i]; } } } - virtual bool evict_block(weight_type* block, int sub_table_id) = 0; + virtual bool + evict_block(weight_type* block, int sub_table_id, int shard_id) = 0; - virtual void dry_run_calculate_thresholds(int shard_id) = 0; + virtual void pre_calculate_thresholds(int shard_id) = 0; // Check and reset the eviction state . - void update_evict_finish_flags(int shard_id, bool dry_run) { - if (dry_run) { - bool trigger_evict_after_dry_run = false; - { - std::unique_lock lock(mutex_); - if (!has_running_dry_run()) { - bool all_finished = - finished_dry_run_.load() == sub_table_hash_cumsum_.size(); - if (all_finished && - evict_state_.load() == EvictState::Dry_Run_Ongoing && - shard_id == 0) { - evict_state_.store(EvictState::Dry_Run_Done); - // dry run is finished, tigger real eviction process; - // Only change evict_state and tigger evict on shard 0 - trigger_evict_after_dry_run = true; - } + void update_evict_finish_flags(int shard_id) { + std::unique_lock lock(mutex_); + finished_evictions_++; + if (!has_running_evict()) { + bool all_finished = finished_evictions_.load() == num_shards_; + if (all_finished && evict_state_.load() == EvictState::Evict_Ongoing) { + record_metrics_to_report_tensor(); + int64_t num_evicts = 0; + for (long evicted_count : metrics_.evicted_counts) { + num_evicts += evicted_count; } - } - - // Trigger evict outside the lock - if (trigger_evict_after_dry_run) { - trigger_evict(); - } - } else { - std::unique_lock lock(mutex_); - finished_evictions_++; - if (!has_running_evict()) { - bool all_finished = finished_evictions_.load() == num_shards_; - if (all_finished && evict_state_.load() == EvictState::Evict_Ongoing) { - record_metrics_to_report_tensor(); - int64_t num_evicts = 0; - for (long evicted_count : metrics_.evicted_counts) { - num_evicts += evicted_count; - } - is_last_eviction_sufficient_.store(num_evicts > 100); - last_eviction_ts_ = + is_last_eviction_sufficient_.store(num_evicts > 100); + last_eviction_ts_ = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + // update evict_state_ in the last place, making sure the future + // finishes around the same time as evict_state_ reset + evict_state_.store(EvictState::Idle); + if (should_decay_) { + last_decay_ts_ = std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()) .count(); - // update evict_state_ in the last place, making sure the future - // finishes around the same time as evict_state_ reset - evict_state_.store(EvictState::Idle); } } } @@ -918,8 +801,6 @@ class FeatureEvict { at::scalar_tensor(metrics_.full_duration_ms, at::kLong); metric_tensors_.exec_duration_ms = at::scalar_tensor(metrics_.exec_duration_ms, at::kLong); - metric_tensors_.dry_run_exec_duration_ms = - at::scalar_tensor(metrics_.dry_run_exec_duration_ms, at::kLong); std::vector evict_rates(metrics_.evicted_counts.size()); for (size_t i = 0; i < metrics_.evicted_counts.size(); ++i) { evict_rates[i] = metrics_.processed_counts[i] > 0 @@ -931,15 +812,13 @@ class FeatureEvict { " - full Time taken: {}ms\n" " - exec Time taken: {}ms\n" " - exec / full: {:.2f}%\n" - " - dryrun Time taken: {}ms\n" " - Total blocks processed: [{}]\n" " - Blocks evicted: [{}]\n" " - Eviction rate: [{}]%\n" - " - Eviction threshold dry run: [{}]\n", + " - Eviction threshold: [{}]\n", metrics_.full_duration_ms, metrics_.exec_duration_ms, metrics_.exec_duration_ms * 100.0f / metrics_.full_duration_ms, - metrics_.dry_run_exec_duration_ms, fmt::join(metrics_.processed_counts, ", "), fmt::join(metrics_.evicted_counts, ", "), fmt::join(evict_rates, ", "), @@ -952,8 +831,6 @@ class FeatureEvict { SynchronizedShardedMap& kv_store_; // Index of processed blocks. std::vector block_cursors_; - // Index of processed dry run blocks. - std::vector dry_run_block_cursors_; // Snapshot of total blocks at eviction trigger. std::vector block_nums_snapshot_; // Indicates whether an eviction task is ongoing. @@ -967,11 +844,9 @@ class FeatureEvict { // number waiting/finished evicts, used for blocking pause std::atomic num_waiting_evicts_{0}; std::atomic finished_evictions_{0}; - std::atomic finished_dry_run_{0}; // Records of shard tasks. std::vector> futures_; - std::vector> dry_run_futures_; // Interface lock to ensure thread safety for public methods. std::mutex mutex_; // Number of concurrent tasks. @@ -1007,7 +882,6 @@ class FeatureEvict { const double pct_evicts_enough_threshold{0.01}; // 0.01% const bool is_training_; - const bool enable_dry_run_; // UT specific mode TestMode test_mode_; @@ -1031,7 +905,6 @@ class CounterBasedEvict : public FeatureEvict { int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, bool is_training, - bool enable_dry_run_, TestMode test_mode = TestMode::DISABLED) : FeatureEvict( kv_store, @@ -1040,7 +913,6 @@ class CounterBasedEvict : public FeatureEvict { interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, is_training, - enable_dry_run_, test_mode), decay_rates_(decay_rates), thresholds_(thresholds) { @@ -1054,7 +926,8 @@ class CounterBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { double decay_rate = decay_rates_[sub_table_id]; int64_t threshold = thresholds_[sub_table_id]; // Apply decay and check the threshold. @@ -1064,7 +937,7 @@ class CounterBasedEvict : public FeatureEvict { return current_count < threshold; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& decay_rates_; // Decay rate for the block count. @@ -1086,7 +959,6 @@ class FeatureScoreBasedEvict : public FeatureEvict { int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, bool is_training, - bool enable_dry_run_, TestMode test_mode = TestMode::DISABLED) : FeatureEvict( kv_store, @@ -1095,7 +967,6 @@ class FeatureScoreBasedEvict : public FeatureEvict { interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, is_training, - enable_dry_run_, test_mode), decay_rates_(decay_rates), max_training_id_num_per_table_(max_training_id_num_per_table), @@ -1125,9 +996,29 @@ class FeatureScoreBasedEvict : public FeatureEvict { FixedBlockPool::update_timestamp(block); } - void update_feature_score_statistics(weight_type* block, double ratio) { - double old_ratio = FixedBlockPool::get_feature_score_rate(block); - FixedBlockPool::set_feature_score_rate(block, ratio + old_ratio); + void update_feature_score_statistics( + weight_type* block, + double ratio, + int shard_id, + bool add_new_block = false) { + int64_t key = FixedBlockPool::get_key(block); + int sub_table_id = this->get_sub_table_id(key); + if (add_new_block) { + double ratio = FixedBlockPool::get_feature_score_rate(block); + int64_t idx = get_bucket_id_from_ratio(ratio); + local_buckets_per_shard_per_table_[sub_table_id][shard_id][idx]++; + local_blocks_num_per_shard_per_table_[sub_table_id][shard_id]++; + } else { + double old_ratio = FixedBlockPool::get_feature_score_rate(block); + double new_ratio = old_ratio + ratio; + FixedBlockPool::set_feature_score_rate(block, new_ratio); + + int64_t new_idx = get_bucket_id_from_ratio(new_ratio); + int64_t old_idx = get_bucket_id_from_ratio(old_ratio); + + local_buckets_per_shard_per_table_[sub_table_id][shard_id][new_idx]++; + local_buckets_per_shard_per_table_[sub_table_id][shard_id][old_idx]--; + } } std::vector get_thresholds() { @@ -1139,118 +1030,75 @@ class FeatureScoreBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { double threshold = thresholds_[sub_table_id]; + + if (this->should_decay_) { + double ratio = FixedBlockPool::get_feature_score_rate(block); + int64_t old_idx = get_bucket_id_from_ratio(ratio); + double decay_rate = decay_rates_[sub_table_id]; + double new_ratio = ratio * decay_rate; + int64_t new_idx = get_bucket_id_from_ratio(new_ratio); + FixedBlockPool::set_feature_score_rate(block, new_ratio); + + local_buckets_per_shard_per_table_[sub_table_id][shard_id][old_idx]--; + local_buckets_per_shard_per_table_[sub_table_id][shard_id][new_idx]++; + } double overall_ratio = FixedBlockPool::get_feature_score_rate(block); const double EPSILON = 1e-9; + bool should_evict = false; switch (evict_modes_[sub_table_id]) { case EvictMode::NONE: - return false; + should_evict = false; + break; case EvictMode::ONLY_ZERO: - return std::abs(overall_ratio) < EPSILON; + should_evict = std::abs(overall_ratio) < EPSILON; + break; case EvictMode::THRESHOLD: - return overall_ratio < threshold; + should_evict = overall_ratio < threshold; + break; default: LOG(ERROR) << "Invalid evict mode"; - return false; - } - } - - void dry_run_calculate_thresholds(int shard_id) override { - auto collect_buckets_start_time = std::chrono::high_resolution_clock::now(); - while (this->dry_run_block_cursors_[shard_id] < - this->block_nums_snapshot_[shard_id]) { - if (!this->wait_until_resume(shard_id)) { - return; - } - // scan every shard to get bucket information - collect_buckets(shard_id); + should_evict = false; + break; } - auto lock_start_time = std::chrono::high_resolution_clock::now(); - finalize_barrier_.arrive_and_wait(); - auto lock_end_time = std::chrono::high_resolution_clock::now(); - auto lock_duration = std::chrono::duration_cast( - lock_end_time - lock_start_time) - .count(); - LOG(INFO) << "[Dry run debug]collect_buckets lock took " << lock_duration - << " ms"; - - auto collect_buckets_end_time = std::chrono::high_resolution_clock::now(); - - auto collect_buckets_duration = - std::chrono::duration_cast( - collect_buckets_end_time - collect_buckets_start_time) - .count(); + if (should_evict) { + int64_t idx = get_bucket_id_from_ratio(overall_ratio); - LOG(INFO) << "[Dry run debug]collect_buckets for loop took " - << collect_buckets_duration << " ms"; + local_buckets_per_shard_per_table_[sub_table_id][shard_id][idx]--; + local_blocks_num_per_shard_per_table_[sub_table_id][shard_id]--; + } + return should_evict; + } + void pre_calculate_thresholds(int shard_id) override { if (shard_id == 0) { - auto collect_buckets_wait_start_time = - std::chrono::high_resolution_clock::now(); compute_thresholds_from_buckets(); - finalize_dry_run(); - auto collect_buckets_wait_end_time = - std::chrono::high_resolution_clock::now(); - auto collect_buckets_wait_duration = - std::chrono::duration_cast( - collect_buckets_wait_end_time - collect_buckets_wait_start_time) - .count(); - LOG(INFO) - << "[Dry run debug]collect_buckets wait all shards and compute_thresholds_from_buckets and finalize_dry_run took " - << collect_buckets_wait_duration << " ms"; } + finalize_barrier_.arrive_and_wait(); } private: - void collect_buckets(int shard_id) { - auto* pool = this->kv_store_.pool_by(shard_id); - auto wlock = this->kv_store_.by(shard_id).wlock(); - while (!this->should_exit_dry_run_loop(shard_id)) { - auto* block = pool->template get_block( - this->dry_run_block_cursors_[shard_id]++); - if (block == nullptr || !FixedBlockPool::get_used(block)) - continue; - int64_t key = FixedBlockPool::get_key(block); - int sub_table_id = this->get_sub_table_id(key); - if (this->should_decay_) { - auto it = wlock->find(key); - if (it != wlock->end() && block == it->second) { - double ratio = FixedBlockPool::get_feature_score_rate(block); - double decay_rate = decay_rates_[sub_table_id]; - FixedBlockPool::set_feature_score_rate(block, ratio * decay_rate); - } - } - double ratio = FixedBlockPool::get_feature_score_rate(block); - int64_t idx = 0; - const double EPSILON = 1e-9; - if (ratio < 0) { - continue; - } else if (std::abs(ratio) < EPSILON) { - idx = 0; - } else if (ratio >= num_buckets_ * threshold_calculation_bucket_stride_) { + int64_t get_bucket_id_from_ratio(double ratio) { + int64_t idx = 0; + const double EPSILON = 1e-9; + if (std::abs(ratio) < EPSILON) { + idx = 0; + } else if (ratio >= num_buckets_ * threshold_calculation_bucket_stride_) { + idx = num_buckets_ - 1; + } else { + idx = static_cast(ratio / threshold_calculation_bucket_stride_) + + 1; + if (idx >= num_buckets_ || idx < 0) { idx = num_buckets_ - 1; - } else { - idx = - static_cast(ratio / threshold_calculation_bucket_stride_) + - 1; - } - - // Adding check to avoid out of bound access - if (idx < 0 || idx >= num_buckets_) { - LOG(ERROR) << "[Dry Run Debug]Invalid idx: " << idx - << " for key: " << key << " ratio: " << ratio; - continue; } - - local_buckets_per_shard_per_table_[sub_table_id][shard_id][idx]++; - local_blocks_num_per_shard_per_table_[sub_table_id][shard_id]++; } + return idx; } void compute_thresholds_from_buckets() { - auto start_time = std::chrono::high_resolution_clock::now(); for (size_t table_id = 0; table_id < num_tables_; ++table_id) { int64_t total = 0; @@ -1279,9 +1127,9 @@ class FeatureScoreBasedEvict : public FeatureEvict { thresholds_[table_id] = 0.0; evict_modes_[table_id] = EvictMode::NONE; } else if (bucket0_count >= evict_count) { - // Case 2: If bucket 0 alone contains sufficient blocks to satisfy the - // eviction demand, restrict eviction only to bucket 0 (blocks with - // score == 0). + // Case 2: If bucket 0 alone contains sufficient blocks to satisfy + // the eviction demand, restrict eviction only to bucket 0 (blocks + // with score == 0). thresholds_[table_id] = 0.0; evict_modes_[table_id] = EvictMode::ONLY_ZERO; } else { @@ -1311,43 +1159,15 @@ class FeatureScoreBasedEvict : public FeatureEvict { LOG(INFO) << "[Dry Run Result]table " << table_id << " threshold: " << thresholds_[table_id] << " threshold bucket: " << threshold_bucket - << " acc count: " << acc_count - << " evict count: " << evict_count << " total: " << total; + << " actual evict count: " << acc_count + << " target evict count: " << evict_count + << " total count: " << total; - { - std::unique_lock lock(this->mutex_); - this->finished_dry_run_++; + for (int table_id = 0; table_id < num_tables_; ++table_id) { + this->metrics_.eviction_threshold_with_dry_run[table_id] = + thresholds_[table_id]; } } - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast( - end_time - start_time) - .count(); - LOG(INFO) << "[Dry run debug]compute_thresholds_from_buckets for loop took " - << duration << " ms"; - } - - void finalize_dry_run() { - local_buckets_per_shard_per_table_ = - std::vector>>( - num_tables_, - std::vector>( - this->num_shards_, std::vector(num_buckets_, 0))); - - local_blocks_num_per_shard_per_table_ = std::vector>( - num_tables_, std::vector(this->num_shards_, 0)); - - if (this->should_decay_) { - this->last_decay_ts_ = - std::chrono::duration_cast( - std::chrono::high_resolution_clock::now().time_since_epoch()) - .count(); - } - - for (int table_id = 0; table_id < num_tables_; ++table_id) { - this->metrics_.eviction_threshold_with_dry_run[table_id] = - thresholds_[table_id]; - } } private: @@ -1367,8 +1187,8 @@ class FeatureScoreBasedEvict : public FeatureEvict { const std::vector& max_training_id_num_per_table_; // training max id for each table. const std::vector& - target_eviction_percent_per_table_; // target eviction percent for each - // table + target_eviction_percent_per_table_; // target eviction percent for + // each table std::vector>> local_buckets_per_shard_per_table_; std::vector> local_blocks_num_per_shard_per_table_; @@ -1389,16 +1209,14 @@ class TimeBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_), + is_training), ttls_in_mins_(ttls_in_mins) {} void update_feature_statistics(weight_type* block) override { @@ -1406,7 +1224,8 @@ class TimeBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { int64_t ttl = ttls_in_mins_[sub_table_id]; if (ttl == 0) { // ttl = 0 means no eviction @@ -1416,7 +1235,7 @@ class TimeBasedEvict : public FeatureEvict { return current_time - FixedBlockPool::get_timestamp(block) > ttl * 60; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& ttls_in_mins_; // Time-to-live for eviction. @@ -1431,16 +1250,14 @@ class TimeThresholdBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_) {} + is_training) {} void update_feature_statistics(weight_type* block) override { FixedBlockPool::update_timestamp(block); @@ -1451,11 +1268,12 @@ class TimeThresholdBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { return FixedBlockPool::get_timestamp(block) < eviction_timestamp_threshold_; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: uint32_t eviction_timestamp_threshold_ = 0; @@ -1473,16 +1291,14 @@ class TimeCounterBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_), + is_training), ttls_in_mins_(ttls_in_mins), decay_rates_(decay_rates), thresholds_(thresholds) {} @@ -1493,7 +1309,8 @@ class TimeCounterBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { int64_t ttl = ttls_in_mins_[sub_table_id]; if (ttl == 0) { // ttl = 0 means no eviction @@ -1515,7 +1332,7 @@ class TimeCounterBasedEvict : public FeatureEvict { (current_count < threshold); } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& ttls_in_mins_; // Time-to-live for eviction. @@ -1534,16 +1351,14 @@ class L2WeightBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_), + is_training), thresholds_(thresholds), sub_table_dims_(sub_table_dims) {} @@ -1551,7 +1366,8 @@ class L2WeightBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { size_t dimension = sub_table_dims_[sub_table_id]; double threshold = thresholds_[sub_table_id]; if (threshold == 0.0) { @@ -1562,7 +1378,7 @@ class L2WeightBasedEvict : public FeatureEvict { return l2weight < threshold; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& thresholds_; // L2 weight threshold for eviction. @@ -1585,8 +1401,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } case EvictTriggerStrategy::BY_COUNTER: { @@ -1605,7 +1420,6 @@ std::unique_ptr> create_feature_evict( config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, is_training, - false, test_mode); } @@ -1629,7 +1443,6 @@ std::unique_ptr> create_feature_evict( config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, is_training, - true, test_mode); } @@ -1649,8 +1462,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } case EvictTriggerStrategy::BY_L2WEIGHT: { @@ -1673,8 +1485,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } case EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD: { @@ -1684,8 +1495,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } default: diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp index df868a9e5f..d1a90a1299 100644 --- a/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp @@ -122,6 +122,7 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { auto* pool = kv_store_->pool_by(shard_id); auto* block = pool->allocate_t(); FixedBlockPool::set_key(block, i); + FixedBlockPool::set_used(block, true); FixedBlockPool::set_feature_score_rate(block, i < 400 ? 0.5 : 0.8); wlock->insert({i, block}); } @@ -132,6 +133,7 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { auto* pool = kv_store_->pool_by(shard_id); auto* block = pool->allocate_t(); FixedBlockPool::set_key(block, i); + FixedBlockPool::set_used(block, true); FixedBlockPool::set_feature_score_rate(block, i < 1500 ? 0.6 : 0.9); wlock->insert({i, block}); } @@ -154,7 +156,7 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { 10, // threshold_calculation_bucket_num 0, // interval_for_insufficient_eviction_s 0, // interval_for_sufficient_eviction_s - 0); // interval_for_feature_statistics_decay_s + 100000); // interval_for_feature_statistics_decay_s auto feature_evict = create_feature_evict( feature_evict_config, @@ -166,11 +168,30 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { auto* feature_score_evict = dynamic_cast*>(feature_evict.get()); + std::vector block_cursors_; + std::vector block_nums_snapshot_; + block_cursors_.resize(NUM_SHARDS); + block_nums_snapshot_.resize(NUM_SHARDS); + for (int i = 0; i < NUM_SHARDS; ++i) { + block_cursors_[i] = 0; + block_nums_snapshot_[i] = 0; + } + // Initial validation size_t total_blocks = 0; for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { auto rlock = kv_store_->by(shard_id).rlock(); - total_blocks += rlock->size(); + auto* pool = kv_store_->pool_by(shard_id); + block_nums_snapshot_[shard_id] = + pool->get_chunks().size() * pool->get_blocks_per_chunk(); + while (block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { + auto* block = pool->template get_block(block_cursors_[shard_id]++); + if (block != nullptr && FixedBlockPool::get_used(block)) { + total_blocks++; + feature_score_evict->update_feature_score_statistics( + block, 0, shard_id, true); + } + } } ASSERT_EQ(total_blocks, 2000); // Perform eviction @@ -511,7 +532,6 @@ TEST(FeatureEvictTest, PerformanceTest) { 0, 0, true, // is training - false, // dry run TestMode::NORMAL); auto start_time = std::chrono::high_resolution_clock::now(); diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index f34c1a2f06..4474401813 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -850,7 +850,6 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float) full_duration_ms = torch.ones(1, dtype=torch.int64) * -1 exec_duration_ms = torch.empty(1, dtype=torch.int64) - dry_run_exec_duration_ms = torch.empty(1, dtype=torch.int64) shard_load = E / 4 # init @@ -861,7 +860,6 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run, full_duration_ms, exec_duration_ms, - dry_run_exec_duration_ms, ) for _ in range(10): dram_kv_backend.get(indices.clone(), weights_out, count) # pyre-ignore @@ -875,14 +873,12 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run, full_duration_ms, exec_duration_ms, - dry_run_exec_duration_ms, ) if all(processed_counts == shard_load): self.assertTrue(all(evicted_counts == 0)) self.assertTrue(all(processed_counts == shard_load)) self.assertTrue(full_duration_ms.item() > 0) self.assertTrue(exec_duration_ms.item() >= 0) - self.assertTrue(dry_run_exec_duration_ms.item() == 0) # after another 10 rounds, the original ids should all be evicted for _ in range(10): @@ -908,7 +904,6 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run, full_duration_ms, exec_duration_ms, - dry_run_exec_duration_ms, ) if evicted_counts.sum() > 1: # ID E+1 might be evicted break @@ -916,7 +911,6 @@ def test_dram_kv_eviction(self) -> None: self.assertTrue(all(processed_counts >= shard_load)) self.assertTrue(all(full_duration_ms > 0)) self.assertTrue(all(exec_duration_ms >= 0)) - self.assertTrue(all(dry_run_exec_duration_ms == 0)) def test_dram_kv_feature_score_eviction(self) -> None: max_D = 132 # 128 + 4 @@ -930,7 +924,7 @@ def test_dram_kv_feature_score_eviction(self) -> None: eviction_policy: EvictionPolicy = EvictionPolicy( eviction_trigger_mode=1, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual eviction_strategy=5, # evict_trigger_strategy: 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm, 5: feature score - eviction_step_intervals=2, # trigger_step_interval if trigger mode is iteration + eviction_step_intervals=1, # trigger_step_interval if trigger mode is iteration feature_score_counter_decay_rates=[ 0.9, 0.9, @@ -953,7 +947,7 @@ def test_dram_kv_feature_score_eviction(self) -> None: threshold_calculation_bucket_num=1000000, interval_for_insufficient_eviction_s=0, interval_for_sufficient_eviction_s=0, - interval_for_feature_statistics_decay_s=0, + interval_for_feature_statistics_decay_s=10000, ) dram_kv_backend = self.generate_fbgemm_kv_backend( max_D=max_D, @@ -975,7 +969,6 @@ def test_dram_kv_feature_score_eviction(self) -> None: processed_counts = torch.zeros(T, dtype=torch.int64) full_duration_ms = torch.ones(1, dtype=torch.int64) * -1 exec_duration_ms = torch.empty(1, dtype=torch.int64) - dry_run_exec_duration_ms = torch.empty(1, dtype=torch.int64) eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float) shard_load = E / 4 @@ -986,30 +979,24 @@ def test_dram_kv_feature_score_eviction(self) -> None: dram_kv_backend.set_feature_score_metadata_cuda( # pyre-ignore indices, count, metadata_2d ) - time.sleep(5) # wait async set_feature_score_metadata_cuda done - for i in range(2): - print(f"round {i}") - dram_kv_backend.get(indices.clone(), weights_out, count) # pyre-ignore - dram_kv_backend.set(indices, weights, count) - dram_kv_backend.set_feature_score_metadata_cuda(indices, count, metadata_2d) - time.sleep(0.01) # 20ms, stimulate training forward time - dram_kv_backend.set_cuda(indices, weights, count, 1, True) # pyre-ignore - print("after set_cuda") - time.sleep(0.01) # 20ms, stimulate training backward time - dram_kv_backend.wait_until_eviction_done() # pyre-ignore - dram_kv_backend.get_feature_evict_metric( # pyre-ignore - evicted_counts, - processed_counts, - eviction_threshold_with_dry_run, - full_duration_ms, - exec_duration_ms, - dry_run_exec_duration_ms, - ) + dram_kv_backend.set_cuda(indices, weights, count, 1, True) # pyre-ignore + time.sleep(5) + # trigger evict + dram_kv_backend.get(indices.clone(), weights_out, count) # pyre-ignore + + dram_kv_backend.wait_until_eviction_done() # pyre-ignore + dram_kv_backend.get_feature_evict_metric( # pyre-ignore + evicted_counts, + processed_counts, + eviction_threshold_with_dry_run, + full_duration_ms, + exec_duration_ms, + ) + self.assertTrue(all(evicted_counts == 700)) self.assertTrue(all(processed_counts == shard_load)) self.assertTrue(full_duration_ms.item() > 0) self.assertTrue(exec_duration_ms.item() >= 0) - self.assertTrue(dry_run_exec_duration_ms.item() > 0) self.assertTrue(all(eviction_threshold_with_dry_run > 0)) @given(