diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 80cefd3f74..b66793fd0d 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -384,6 +384,19 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { local_write_allocate_total_duration += facebook::WallClockUtil::NowInUsecFast() - before_alloc_ts; + if (feature_evict_config_.has_value() && + feature_evict_config_.value()->trigger_mode_ != + EvictTriggerMode::DISABLED && + feature_evict_) { + auto* feature_score_evict = dynamic_cast< + FeatureScoreBasedEvict*>( + feature_evict_.get()); + if (feature_score_evict) { + feature_score_evict + ->update_feature_score_statistics( + block, 0, shard_id, true); + } + } } if (feature_evict_config_.has_value() && feature_evict_config_.value()->trigger_mode_ != @@ -705,16 +718,21 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { auto it = wlmap->find(id); if (it != wlmap->end()) { block = it->second; + feature_score_evict + ->update_feature_score_statistics( + block, engege_rate, shard_id, false); } else { // Key doesn't exist, allocate new block and // insert. block = pool->template allocate_t(); FixedBlockPool::set_key(block, id); + FixedBlockPool::set_feature_score_rate( + block, engege_rate); wlmap->insert({id, block}); + feature_score_evict + ->update_feature_score_statistics( + block, 0, shard_id, true); } - - feature_score_evict->update_feature_score_statistics( - block, engege_rate); updated_id_count++; } } @@ -1489,61 +1507,75 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { const auto indexes = iter->second; auto f = folly::via(executor_.get()) - .thenValue( - [this, shard_id, indexes, &indices, &weights_with_metaheader]( - folly::Unit) { - FBGEMM_DISPATCH_INTEGRAL_TYPES( - indices.scalar_type(), - "dram_kv_set_with_metaheader", - [this, - shard_id, - indexes, - &indices, - &weights_with_metaheader] { - using index_t = scalar_t; - CHECK(indices.is_contiguous()); - CHECK(weights_with_metaheader.is_contiguous()); - CHECK_EQ( - indices.size(0), weights_with_metaheader.size(0)); - { - auto wlmap = kv_store_.by(shard_id).wlock(); - auto* pool = kv_store_.pool_by(shard_id); - int64_t stride = weights_with_metaheader.size(1); - auto indices_data_ptr = indices.data_ptr(); - auto weights_data_ptr = - weights_with_metaheader.data_ptr(); - for (auto index_iter = indexes.begin(); - index_iter != indexes.end(); - index_iter++) { - const auto& id_index = *index_iter; - auto id = int64_t(indices_data_ptr[id_index]); - // Defensive programming - // used is false shouldn't occur under normal - // circumstances - FixedBlockPool::set_used( - weights_data_ptr + id_index * stride, true); - - // use mempool - weight_type* block = nullptr; - // First check if the key already exists - auto it = wlmap->find(id); - if (it != wlmap->end()) { - block = it->second; - } else { - // Key doesn't exist, allocate new block and - // insert. - block = - pool->template allocate_t(); - wlmap->insert({id, block}); + .thenValue([this, + shard_id, + indexes, + &indices, + &weights_with_metaheader](folly::Unit) { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), + "dram_kv_set_with_metaheader", + [this, + shard_id, + indexes, + &indices, + &weights_with_metaheader] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights_with_metaheader.is_contiguous()); + CHECK_EQ( + indices.size(0), weights_with_metaheader.size(0)); + { + auto wlmap = kv_store_.by(shard_id).wlock(); + auto* pool = kv_store_.pool_by(shard_id); + int64_t stride = weights_with_metaheader.size(1); + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = + weights_with_metaheader.data_ptr(); + for (auto index_iter = indexes.begin(); + index_iter != indexes.end(); + index_iter++) { + const auto& id_index = *index_iter; + auto id = int64_t(indices_data_ptr[id_index]); + // Defensive programming + // used is false shouldn't occur under normal + // circumstances + FixedBlockPool::set_used( + weights_data_ptr + id_index * stride, true); + + // use mempool + weight_type* block = nullptr; + // First check if the key already exists + auto it = wlmap->find(id); + if (it != wlmap->end()) { + block = it->second; + } else { + // Key doesn't exist, allocate new block and + // insert. + block = pool->template allocate_t(); + wlmap->insert({id, block}); + if (feature_evict_config_.has_value() && + feature_evict_config_.value()->trigger_mode_ != + EvictTriggerMode::DISABLED && + feature_evict_) { + auto* feature_score_evict = dynamic_cast< + FeatureScoreBasedEvict*>( + feature_evict_.get()); + if (feature_score_evict) { + feature_score_evict + ->update_feature_score_statistics( + block, 0, shard_id, true); } - std::copy( - weights_data_ptr + id_index * stride, - weights_data_ptr + (id_index + 1) * stride, - block); } } - }); - }); + std::copy( + weights_data_ptr + id_index * stride, + weights_data_ptr + (id_index + 1) * stride, + block); + } + } + }); + }); futures.push_back(std::move(f)); } return folly::collect(futures); diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h index 94c5939732..6744e1bd76 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h @@ -148,8 +148,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { at::Tensor processed_counts, at::Tensor eviction_threshold_with_dry_run, at::Tensor full_duration_ms, - at::Tensor exec_duration_ms, - at::Tensor dry_run_exec_duration_ms) { + at::Tensor exec_duration_ms) { auto metrics = impl_->get_feature_evict_metric(); if (metrics.has_value()) { evicted_counts.copy_( @@ -164,8 +163,6 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { metrics.value().full_duration_ms); // full duration (Long) exec_duration_ms.copy_( metrics.value().exec_duration_ms); // exec duration (Long) - dry_run_exec_duration_ms.copy_( - metrics.value().dry_run_exec_duration_ms); // dry run exec duration } } diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h index feec68fd02..f2f75a46b9 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h @@ -268,7 +268,6 @@ struct FeatureEvictMetrics { eviction_threshold_with_dry_run.resize(table_num, 0.0); exec_duration_ms = 0; full_duration_ms = 0; - dry_run_exec_duration_ms = 0; } void reset() { @@ -280,7 +279,6 @@ struct FeatureEvictMetrics { 0.0); exec_duration_ms = 0; full_duration_ms = 0; - dry_run_exec_duration_ms = 0; start_time_ms = std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()) @@ -296,7 +294,6 @@ struct FeatureEvictMetrics { // The exec_duration of all shards will be accumulated during the // statistics So finally, the number of shards needs to be divided exec_duration_ms /= num_shards; - dry_run_exec_duration_ms /= num_shards; } std::vector evicted_counts; @@ -304,7 +301,6 @@ struct FeatureEvictMetrics { std::vector eviction_threshold_with_dry_run; int64_t exec_duration_ms; int64_t full_duration_ms; - int64_t dry_run_exec_duration_ms; int64_t start_time_ms; }; @@ -315,7 +311,6 @@ struct FeatureEvictMetricTensors { processed_counts(at::zeros({table_num}, at::kLong)), eviction_threshold_with_dry_run(at::zeros({table_num}, at::kFloat)), exec_duration_ms(at::scalar_tensor(0, at::kLong)), - dry_run_exec_duration_ms(at::scalar_tensor(0, at::kLong)), full_duration_ms(at::scalar_tensor(0, at::kLong)) {} // Constructor to initialize from existing tensors @@ -324,14 +319,12 @@ struct FeatureEvictMetricTensors { at::Tensor processed, at::Tensor eviction_threshold_with_dry_run, at::Tensor exec_duration, - at::Tensor dry_run_exec_duration_ms, at::Tensor full_duration) : evicted_counts(std::move(evicted)), processed_counts(std::move(processed)), eviction_threshold_with_dry_run( std::move(eviction_threshold_with_dry_run)), exec_duration_ms(std::move(exec_duration)), - dry_run_exec_duration_ms(std::move(dry_run_exec_duration_ms)), full_duration_ms(std::move(full_duration)) {} [[nodiscard]] FeatureEvictMetricTensors clone() const { @@ -340,7 +333,6 @@ struct FeatureEvictMetricTensors { processed_counts.clone(), eviction_threshold_with_dry_run.clone(), exec_duration_ms.clone(), - dry_run_exec_duration_ms.clone(), full_duration_ms.clone()}; } @@ -352,22 +344,16 @@ struct FeatureEvictMetricTensors { at::Tensor eviction_threshold_with_dry_run; // feature evict exec duration at::Tensor exec_duration_ms; - // feature evict dry run exec duration - at::Tensor dry_run_exec_duration_ms; // feature evict full duration(from trigger to finish) at::Tensor full_duration_ms; }; -enum class EvictState { Idle, Dry_Run_Ongoing, Dry_Run_Done, Evict_Ongoing }; +enum class EvictState { Idle, Evict_Ongoing }; inline std::string to_string(EvictState state) { switch (state) { case EvictState::Idle: return "Idle"; - case EvictState::Dry_Run_Ongoing: - return "Dry_Run_Ongoing"; - case EvictState::Dry_Run_Done: - return "Dry_Run_Done"; case EvictState::Evict_Ongoing: return "Evict_Ongoing"; default: @@ -385,7 +371,6 @@ class FeatureEvict { int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, bool is_training = true, - bool enable_dry_run = false, TestMode test_mode = TestMode::DISABLED) : kv_store_(kv_store), evict_state_(EvictState::Idle), @@ -401,7 +386,6 @@ class FeatureEvict { interval_for_feature_statistics_decay_s_( interval_for_feature_statistics_decay_s), is_training_(is_training), - enable_dry_run_(enable_dry_run), test_mode_(test_mode) { executor_ = std::make_unique(num_shards_); @@ -421,8 +405,6 @@ class FeatureEvict { evict_cv_.notify_all(); // wait until futures all finished - folly::collectAll(dry_run_futures_).wait(); - dry_run_futures_.clear(); folly::collectAll(futures_).wait(); futures_.clear(); }; @@ -450,57 +432,24 @@ class FeatureEvict { if (!reach_interval_to_trigger_new_round()) { return; } - if (enable_dry_run_) { - LOG(INFO) << "trigger new round of eviction with dry run enabled: " - << to_string(evict_state_); - auto evict_state = evict_state_.load(); - // if dry run or eviction task is ongoing, return directly - if (evict_state == EvictState::Dry_Run_Ongoing || - evict_state == EvictState::Evict_Ongoing) { - return; - } - // if no dry run running or finished and no evict running, start new - // round eviction, run dry run first - if (evict_state == EvictState::Idle) { - sanity_check_before_new_round(); - evict_state_.store(EvictState::Dry_Run_Ongoing); - prepare_evict(); - - // Decide should decay or not. - if (reach_interval_to_decay_feature_statistics()) { - should_decay_.store(true); - } else { - should_decay_.store(false); - } - - LOG(INFO) - << "Trigger dry run eviction to get the feature evict threshold with decay " - << should_decay_.load(); - - for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { - submit_shard_task(shard_id, true); // Dry run is true - } - } else if (evict_state == EvictState::Dry_Run_Done) { - // if dry run is done, run eviction - evict_state_.store(EvictState::Evict_Ongoing); - LOG(INFO) << "Trigger new round eviction after dry run"; - for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { - submit_shard_task(shard_id, false); - } - return; - } + if (evict_state_.load() == EvictState::Evict_Ongoing) { + return; + } + evict_state_.store(EvictState::Evict_Ongoing); + sanity_check_before_new_round(); + prepare_evict(); + // Decide should decay or not. + if (reach_interval_to_decay_feature_statistics()) { + should_decay_.store(true); } else { - if (evict_state_.load() == EvictState::Evict_Ongoing) { - return; - } - evict_state_.store(EvictState::Evict_Ongoing); - LOG(INFO) << "trigger new round of eviction"; - sanity_check_before_new_round(); - prepare_evict(); - for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { - submit_shard_task(shard_id, false); - } + should_decay_.store(false); + } + + LOG(INFO) << "trigger new round of eviction with decay=" + << should_decay_.load(); + for (int shard_id = 0; shard_id < num_shards_; ++shard_id) { + submit_shard_task(shard_id); } } @@ -545,9 +494,6 @@ class FeatureEvict { void wait_until_eviction_done() { resume(); - folly::collectAll(dry_run_futures_).wait(); - dry_run_futures_.clear(); - folly::collectAll(futures_).wait(); futures_.clear(); } @@ -555,8 +501,6 @@ class FeatureEvict { virtual void update_feature_statistics(weight_type* block) = 0; void wait_completion() { - folly::collectAll(dry_run_futures_).wait(); - dry_run_futures_.clear(); folly::collectAll(futures_).wait(); futures_.clear(); } @@ -575,19 +519,12 @@ class FeatureEvict { << "found " << futures_.size() << " futures before triggering new round of evict, " << "this should be either 0 or num shards:" << num_shards_; - CHECK( - dry_run_futures_.size() == 0 || dry_run_futures_.size() == num_shards_) - << "found " << dry_run_futures_.size() - << " futures before triggering new round of evict, " - << "this should be either 0 or num shards:" << num_shards_; } void init_shard_status() { block_cursors_.resize(num_shards_); - dry_run_block_cursors_.resize(num_shards_); block_nums_snapshot_.resize(num_shards_); for (int i = 0; i < num_shards_; ++i) { block_cursors_[i] = 0; - dry_run_block_cursors_[i] = 0; block_nums_snapshot_[i] = 0; } } @@ -600,33 +537,21 @@ class FeatureEvict { block_nums_snapshot_[shard_id] = mempool->get_chunks().size() * mempool->get_blocks_per_chunk(); block_cursors_[shard_id] = 0; - dry_run_block_cursors_[shard_id] = 0; } metrics_.reset(); - dry_run_futures_.clear(); futures_.clear(); finished_evictions_.store(0); - finished_dry_run_.store(0); // make sure we don't start right away, wait until resume() is called evict_interrupt_.store(true); } // submitting eviction job to the executor - void submit_shard_task(int shard_id, bool dry_run) { - if (dry_run) { - dry_run_futures_.emplace_back( - folly::via(executor_.get()) - .thenValue([this, shard_id, dry_run](auto&&) { - process_shard(shard_id, true); - update_evict_finish_flags(shard_id, dry_run); - })); - } else { - futures_.emplace_back(folly::via(executor_.get()) - .thenValue([this, shard_id, dry_run](auto&&) { - process_shard(shard_id, dry_run); - update_evict_finish_flags(shard_id, dry_run); - })); - } + void submit_shard_task(int shard_id) { + futures_.emplace_back( + folly::via(executor_.get()).thenValue([this, shard_id](auto&&) { + process_shard(shard_id); + update_evict_finish_flags(shard_id); + })); } bool reach_interval_to_trigger_new_round() { @@ -662,23 +587,12 @@ class FeatureEvict { block_cursors_[shard_id] >= block_nums_snapshot_[shard_id]; } - // conditions where we need to break the dry run evict loop - bool should_exit_dry_run_loop(int shard_id) { - return evict_interrupt_.load() || - dry_run_block_cursors_[shard_id] >= block_nums_snapshot_[shard_id]; - } - // check whether there is any evict neither paused nor finished bool has_running_evict() { return (num_waiting_evicts_.load() + finished_evictions_.load()) != num_shards_; } - // check whether there is any dry run neither paused nor finished - bool has_running_dry_run() { - return finished_dry_run_.load() != sub_table_hash_cumsum_.size(); - } - // the inner loop of each evict that can be paused void start_training_eviction_loop( int shard_id, @@ -696,7 +610,7 @@ class FeatureEvict { int64_t key = FixedBlockPool::get_key(block); int sub_table_id = get_sub_table_id(key); processed_counts[sub_table_id]++; - if (evict_block(block, sub_table_id)) { + if (evict_block(block, sub_table_id, shard_id)) { auto it = wlock->find(key); if (it != wlock->end() && block == it->second) { auto time_elapsed = FixedBlockPool::current_timestamp() - @@ -734,7 +648,7 @@ class FeatureEvict { int64_t key = FixedBlockPool::get_key(block); int sub_table_id = get_sub_table_id(key); processed_counts[sub_table_id]++; - if (evict_block(block, sub_table_id)) { + if (evict_block(block, sub_table_id, shard_id)) { pool->template deallocate_t(block); evicted_counts[sub_table_id]++; evicting_keys.push_back(key); @@ -743,8 +657,8 @@ class FeatureEvict { mem_pool_lock.unlock(); // lock dram kv shard hash map to remove evicted blocks in the map - // dedicate map update in a wlock to reduce the blocking time for inference - // read + // dedicate map update in a wlock to reduce the blocking time for + // inference read auto shard_map_wlock = kv_store_.by(shard_id).wlock(); for (auto& key : evicting_keys) { shard_map_wlock->erase(key); @@ -761,10 +675,6 @@ class FeatureEvict { return true; } - if (!should_exit_dry_run_loop(shard_id)) { - return true; - } - num_waiting_evicts_++; evict_cv_.wait(lock, [this] { return !evict_interrupt_.load(); }); num_waiting_evicts_--; @@ -775,105 +685,78 @@ class FeatureEvict { return true; } - // the outer loop of each evict round that only exits when evction round is - // done - void process_shard(int shard_id, bool dry_run) { + // the outer loop of each evict round that only exits when evction round + // is done + void process_shard(int shard_id) { std::chrono::milliseconds duration{}; - if (dry_run) { + pre_calculate_thresholds(shard_id); + std::vector evicted_counts(sub_table_hash_cumsum_.size(), 0); + std::vector processed_counts(sub_table_hash_cumsum_.size(), 0); + // each active eviction round + while (block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { + if (!wait_until_resume(shard_id)) { + return; + } auto start_time = std::chrono::high_resolution_clock::now(); - dry_run_calculate_thresholds(shard_id); + if (is_training_) { + start_training_eviction_loop( + shard_id, evicted_counts, processed_counts); + } else { + start_inference_eviction_loop( + shard_id, evicted_counts, processed_counts); + } duration += std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start_time); - { - std::unique_lock lock(metric_mtx_); - metrics_.dry_run_exec_duration_ms += duration.count(); - } - } else { - std::vector evicted_counts(sub_table_hash_cumsum_.size(), 0); - std::vector processed_counts(sub_table_hash_cumsum_.size(), 0); - // each active eviction round - while (block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { - if (!wait_until_resume(shard_id)) { - return; - } - auto start_time = std::chrono::high_resolution_clock::now(); - if (is_training_) { - start_training_eviction_loop( - shard_id, evicted_counts, processed_counts); - } else { - start_inference_eviction_loop( - shard_id, evicted_counts, processed_counts); - } - duration += std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_time); - } - if (test_mode_ == TestMode::PAUSE_ON_LAST_ITERATION && - block_cursors_[shard_id] == block_nums_snapshot_[shard_id]) { - last_iter_shards_[shard_id]->store(true); - should_call_.store(true); - // hold on on the last iteration for a while waiting for the UT to call - // pause before updating the shards_finished - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - { - std::unique_lock lock(metric_mtx_); - metrics_.exec_duration_ms += duration.count(); - for (size_t i = 0; i < evicted_counts.size(); ++i) { - metrics_.evicted_counts[i] += evicted_counts[i]; - metrics_.processed_counts[i] += processed_counts[i]; - } + } + if (test_mode_ == TestMode::PAUSE_ON_LAST_ITERATION && + block_cursors_[shard_id] == block_nums_snapshot_[shard_id]) { + last_iter_shards_[shard_id]->store(true); + should_call_.store(true); + // hold on on the last iteration for a while waiting for the UT to + // call pause before updating the shards_finished + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + { + std::unique_lock lock(metric_mtx_); + metrics_.exec_duration_ms += duration.count(); + for (size_t i = 0; i < evicted_counts.size(); ++i) { + metrics_.evicted_counts[i] += evicted_counts[i]; + metrics_.processed_counts[i] += processed_counts[i]; } } } - virtual bool evict_block(weight_type* block, int sub_table_id) = 0; + virtual bool + evict_block(weight_type* block, int sub_table_id, int shard_id) = 0; - virtual void dry_run_calculate_thresholds(int shard_id) = 0; + virtual void pre_calculate_thresholds(int shard_id) = 0; // Check and reset the eviction state . - void update_evict_finish_flags(int shard_id, bool dry_run) { - if (dry_run) { - bool trigger_evict_after_dry_run = false; - { - std::unique_lock lock(mutex_); - if (!has_running_dry_run()) { - bool all_finished = - finished_dry_run_.load() == sub_table_hash_cumsum_.size(); - if (all_finished && - evict_state_.load() == EvictState::Dry_Run_Ongoing && - shard_id == 0) { - evict_state_.store(EvictState::Dry_Run_Done); - // dry run is finished, tigger real eviction process; - // Only change evict_state and tigger evict on shard 0 - trigger_evict_after_dry_run = true; - } + void update_evict_finish_flags(int shard_id) { + std::unique_lock lock(mutex_); + finished_evictions_++; + if (!has_running_evict()) { + bool all_finished = finished_evictions_.load() == num_shards_; + if (all_finished && evict_state_.load() == EvictState::Evict_Ongoing) { + record_metrics_to_report_tensor(); + int64_t num_evicts = 0; + for (long evicted_count : metrics_.evicted_counts) { + num_evicts += evicted_count; } - } - - // Trigger evict outside the lock - if (trigger_evict_after_dry_run) { - trigger_evict(); - } - } else { - std::unique_lock lock(mutex_); - finished_evictions_++; - if (!has_running_evict()) { - bool all_finished = finished_evictions_.load() == num_shards_; - if (all_finished && evict_state_.load() == EvictState::Evict_Ongoing) { - record_metrics_to_report_tensor(); - int64_t num_evicts = 0; - for (long evicted_count : metrics_.evicted_counts) { - num_evicts += evicted_count; - } - is_last_eviction_sufficient_.store(num_evicts > 100); - last_eviction_ts_ = + is_last_eviction_sufficient_.store(num_evicts > 100); + last_eviction_ts_ = + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count(); + // update evict_state_ in the last place, making sure the future + // finishes around the same time as evict_state_ reset + evict_state_.store(EvictState::Idle); + if (should_decay_) { + last_decay_ts_ = std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()) .count(); - // update evict_state_ in the last place, making sure the future - // finishes around the same time as evict_state_ reset - evict_state_.store(EvictState::Idle); } } } @@ -918,8 +801,6 @@ class FeatureEvict { at::scalar_tensor(metrics_.full_duration_ms, at::kLong); metric_tensors_.exec_duration_ms = at::scalar_tensor(metrics_.exec_duration_ms, at::kLong); - metric_tensors_.dry_run_exec_duration_ms = - at::scalar_tensor(metrics_.dry_run_exec_duration_ms, at::kLong); std::vector evict_rates(metrics_.evicted_counts.size()); for (size_t i = 0; i < metrics_.evicted_counts.size(); ++i) { evict_rates[i] = metrics_.processed_counts[i] > 0 @@ -931,15 +812,13 @@ class FeatureEvict { " - full Time taken: {}ms\n" " - exec Time taken: {}ms\n" " - exec / full: {:.2f}%\n" - " - dryrun Time taken: {}ms\n" " - Total blocks processed: [{}]\n" " - Blocks evicted: [{}]\n" " - Eviction rate: [{}]%\n" - " - Eviction threshold dry run: [{}]\n", + " - Eviction threshold: [{}]\n", metrics_.full_duration_ms, metrics_.exec_duration_ms, metrics_.exec_duration_ms * 100.0f / metrics_.full_duration_ms, - metrics_.dry_run_exec_duration_ms, fmt::join(metrics_.processed_counts, ", "), fmt::join(metrics_.evicted_counts, ", "), fmt::join(evict_rates, ", "), @@ -952,8 +831,6 @@ class FeatureEvict { SynchronizedShardedMap& kv_store_; // Index of processed blocks. std::vector block_cursors_; - // Index of processed dry run blocks. - std::vector dry_run_block_cursors_; // Snapshot of total blocks at eviction trigger. std::vector block_nums_snapshot_; // Indicates whether an eviction task is ongoing. @@ -967,11 +844,9 @@ class FeatureEvict { // number waiting/finished evicts, used for blocking pause std::atomic num_waiting_evicts_{0}; std::atomic finished_evictions_{0}; - std::atomic finished_dry_run_{0}; // Records of shard tasks. std::vector> futures_; - std::vector> dry_run_futures_; // Interface lock to ensure thread safety for public methods. std::mutex mutex_; // Number of concurrent tasks. @@ -1007,7 +882,6 @@ class FeatureEvict { const double pct_evicts_enough_threshold{0.01}; // 0.01% const bool is_training_; - const bool enable_dry_run_; // UT specific mode TestMode test_mode_; @@ -1031,7 +905,6 @@ class CounterBasedEvict : public FeatureEvict { int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, bool is_training, - bool enable_dry_run_, TestMode test_mode = TestMode::DISABLED) : FeatureEvict( kv_store, @@ -1040,7 +913,6 @@ class CounterBasedEvict : public FeatureEvict { interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, is_training, - enable_dry_run_, test_mode), decay_rates_(decay_rates), thresholds_(thresholds) { @@ -1054,7 +926,8 @@ class CounterBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { double decay_rate = decay_rates_[sub_table_id]; int64_t threshold = thresholds_[sub_table_id]; // Apply decay and check the threshold. @@ -1064,7 +937,7 @@ class CounterBasedEvict : public FeatureEvict { return current_count < threshold; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& decay_rates_; // Decay rate for the block count. @@ -1086,7 +959,6 @@ class FeatureScoreBasedEvict : public FeatureEvict { int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, bool is_training, - bool enable_dry_run_, TestMode test_mode = TestMode::DISABLED) : FeatureEvict( kv_store, @@ -1095,7 +967,6 @@ class FeatureScoreBasedEvict : public FeatureEvict { interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, is_training, - enable_dry_run_, test_mode), decay_rates_(decay_rates), max_training_id_num_per_table_(max_training_id_num_per_table), @@ -1125,9 +996,29 @@ class FeatureScoreBasedEvict : public FeatureEvict { FixedBlockPool::update_timestamp(block); } - void update_feature_score_statistics(weight_type* block, double ratio) { - double old_ratio = FixedBlockPool::get_feature_score_rate(block); - FixedBlockPool::set_feature_score_rate(block, ratio + old_ratio); + void update_feature_score_statistics( + weight_type* block, + double ratio, + int shard_id, + bool add_new_block = false) { + int64_t key = FixedBlockPool::get_key(block); + int sub_table_id = this->get_sub_table_id(key); + if (add_new_block) { + double ratio = FixedBlockPool::get_feature_score_rate(block); + int64_t idx = get_bucket_id_from_ratio(ratio); + local_buckets_per_shard_per_table_[sub_table_id][shard_id][idx]++; + local_blocks_num_per_shard_per_table_[sub_table_id][shard_id]++; + } else { + double old_ratio = FixedBlockPool::get_feature_score_rate(block); + double new_ratio = old_ratio + ratio; + FixedBlockPool::set_feature_score_rate(block, new_ratio); + + int64_t new_idx = get_bucket_id_from_ratio(new_ratio); + int64_t old_idx = get_bucket_id_from_ratio(old_ratio); + + local_buckets_per_shard_per_table_[sub_table_id][shard_id][new_idx]++; + local_buckets_per_shard_per_table_[sub_table_id][shard_id][old_idx]--; + } } std::vector get_thresholds() { @@ -1139,118 +1030,75 @@ class FeatureScoreBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { double threshold = thresholds_[sub_table_id]; + + if (this->should_decay_) { + double ratio = FixedBlockPool::get_feature_score_rate(block); + int64_t old_idx = get_bucket_id_from_ratio(ratio); + double decay_rate = decay_rates_[sub_table_id]; + double new_ratio = ratio * decay_rate; + int64_t new_idx = get_bucket_id_from_ratio(new_ratio); + FixedBlockPool::set_feature_score_rate(block, new_ratio); + + local_buckets_per_shard_per_table_[sub_table_id][shard_id][old_idx]--; + local_buckets_per_shard_per_table_[sub_table_id][shard_id][new_idx]++; + } double overall_ratio = FixedBlockPool::get_feature_score_rate(block); const double EPSILON = 1e-9; + bool should_evict = false; switch (evict_modes_[sub_table_id]) { case EvictMode::NONE: - return false; + should_evict = false; + break; case EvictMode::ONLY_ZERO: - return std::abs(overall_ratio) < EPSILON; + should_evict = std::abs(overall_ratio) < EPSILON; + break; case EvictMode::THRESHOLD: - return overall_ratio < threshold; + should_evict = overall_ratio < threshold; + break; default: LOG(ERROR) << "Invalid evict mode"; - return false; - } - } - - void dry_run_calculate_thresholds(int shard_id) override { - auto collect_buckets_start_time = std::chrono::high_resolution_clock::now(); - while (this->dry_run_block_cursors_[shard_id] < - this->block_nums_snapshot_[shard_id]) { - if (!this->wait_until_resume(shard_id)) { - return; - } - // scan every shard to get bucket information - collect_buckets(shard_id); + should_evict = false; + break; } - auto lock_start_time = std::chrono::high_resolution_clock::now(); - finalize_barrier_.arrive_and_wait(); - auto lock_end_time = std::chrono::high_resolution_clock::now(); - auto lock_duration = std::chrono::duration_cast( - lock_end_time - lock_start_time) - .count(); - LOG(INFO) << "[Dry run debug]collect_buckets lock took " << lock_duration - << " ms"; - - auto collect_buckets_end_time = std::chrono::high_resolution_clock::now(); - - auto collect_buckets_duration = - std::chrono::duration_cast( - collect_buckets_end_time - collect_buckets_start_time) - .count(); + if (should_evict) { + int64_t idx = get_bucket_id_from_ratio(overall_ratio); - LOG(INFO) << "[Dry run debug]collect_buckets for loop took " - << collect_buckets_duration << " ms"; + local_buckets_per_shard_per_table_[sub_table_id][shard_id][idx]--; + local_blocks_num_per_shard_per_table_[sub_table_id][shard_id]--; + } + return should_evict; + } + void pre_calculate_thresholds(int shard_id) override { if (shard_id == 0) { - auto collect_buckets_wait_start_time = - std::chrono::high_resolution_clock::now(); compute_thresholds_from_buckets(); - finalize_dry_run(); - auto collect_buckets_wait_end_time = - std::chrono::high_resolution_clock::now(); - auto collect_buckets_wait_duration = - std::chrono::duration_cast( - collect_buckets_wait_end_time - collect_buckets_wait_start_time) - .count(); - LOG(INFO) - << "[Dry run debug]collect_buckets wait all shards and compute_thresholds_from_buckets and finalize_dry_run took " - << collect_buckets_wait_duration << " ms"; } + finalize_barrier_.arrive_and_wait(); } private: - void collect_buckets(int shard_id) { - auto* pool = this->kv_store_.pool_by(shard_id); - auto wlock = this->kv_store_.by(shard_id).wlock(); - while (!this->should_exit_dry_run_loop(shard_id)) { - auto* block = pool->template get_block( - this->dry_run_block_cursors_[shard_id]++); - if (block == nullptr || !FixedBlockPool::get_used(block)) - continue; - int64_t key = FixedBlockPool::get_key(block); - int sub_table_id = this->get_sub_table_id(key); - if (this->should_decay_) { - auto it = wlock->find(key); - if (it != wlock->end() && block == it->second) { - double ratio = FixedBlockPool::get_feature_score_rate(block); - double decay_rate = decay_rates_[sub_table_id]; - FixedBlockPool::set_feature_score_rate(block, ratio * decay_rate); - } - } - double ratio = FixedBlockPool::get_feature_score_rate(block); - int64_t idx = 0; - const double EPSILON = 1e-9; - if (ratio < 0) { - continue; - } else if (std::abs(ratio) < EPSILON) { - idx = 0; - } else if (ratio >= num_buckets_ * threshold_calculation_bucket_stride_) { + int64_t get_bucket_id_from_ratio(double ratio) { + int64_t idx = 0; + const double EPSILON = 1e-9; + if (std::abs(ratio) < EPSILON) { + idx = 0; + } else if (ratio >= num_buckets_ * threshold_calculation_bucket_stride_) { + idx = num_buckets_ - 1; + } else { + idx = static_cast(ratio / threshold_calculation_bucket_stride_) + + 1; + if (idx >= num_buckets_ || idx < 0) { idx = num_buckets_ - 1; - } else { - idx = - static_cast(ratio / threshold_calculation_bucket_stride_) + - 1; - } - - // Adding check to avoid out of bound access - if (idx < 0 || idx >= num_buckets_) { - LOG(ERROR) << "[Dry Run Debug]Invalid idx: " << idx - << " for key: " << key << " ratio: " << ratio; - continue; } - - local_buckets_per_shard_per_table_[sub_table_id][shard_id][idx]++; - local_blocks_num_per_shard_per_table_[sub_table_id][shard_id]++; } + return idx; } void compute_thresholds_from_buckets() { - auto start_time = std::chrono::high_resolution_clock::now(); for (size_t table_id = 0; table_id < num_tables_; ++table_id) { int64_t total = 0; @@ -1279,9 +1127,9 @@ class FeatureScoreBasedEvict : public FeatureEvict { thresholds_[table_id] = 0.0; evict_modes_[table_id] = EvictMode::NONE; } else if (bucket0_count >= evict_count) { - // Case 2: If bucket 0 alone contains sufficient blocks to satisfy the - // eviction demand, restrict eviction only to bucket 0 (blocks with - // score == 0). + // Case 2: If bucket 0 alone contains sufficient blocks to satisfy + // the eviction demand, restrict eviction only to bucket 0 (blocks + // with score == 0). thresholds_[table_id] = 0.0; evict_modes_[table_id] = EvictMode::ONLY_ZERO; } else { @@ -1311,43 +1159,15 @@ class FeatureScoreBasedEvict : public FeatureEvict { LOG(INFO) << "[Dry Run Result]table " << table_id << " threshold: " << thresholds_[table_id] << " threshold bucket: " << threshold_bucket - << " acc count: " << acc_count - << " evict count: " << evict_count << " total: " << total; + << " actual evict count: " << acc_count + << " target evict count: " << evict_count + << " total count: " << total; - { - std::unique_lock lock(this->mutex_); - this->finished_dry_run_++; + for (int table_id = 0; table_id < num_tables_; ++table_id) { + this->metrics_.eviction_threshold_with_dry_run[table_id] = + thresholds_[table_id]; } } - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast( - end_time - start_time) - .count(); - LOG(INFO) << "[Dry run debug]compute_thresholds_from_buckets for loop took " - << duration << " ms"; - } - - void finalize_dry_run() { - local_buckets_per_shard_per_table_ = - std::vector>>( - num_tables_, - std::vector>( - this->num_shards_, std::vector(num_buckets_, 0))); - - local_blocks_num_per_shard_per_table_ = std::vector>( - num_tables_, std::vector(this->num_shards_, 0)); - - if (this->should_decay_) { - this->last_decay_ts_ = - std::chrono::duration_cast( - std::chrono::high_resolution_clock::now().time_since_epoch()) - .count(); - } - - for (int table_id = 0; table_id < num_tables_; ++table_id) { - this->metrics_.eviction_threshold_with_dry_run[table_id] = - thresholds_[table_id]; - } } private: @@ -1367,8 +1187,8 @@ class FeatureScoreBasedEvict : public FeatureEvict { const std::vector& max_training_id_num_per_table_; // training max id for each table. const std::vector& - target_eviction_percent_per_table_; // target eviction percent for each - // table + target_eviction_percent_per_table_; // target eviction percent for + // each table std::vector>> local_buckets_per_shard_per_table_; std::vector> local_blocks_num_per_shard_per_table_; @@ -1389,16 +1209,14 @@ class TimeBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_), + is_training), ttls_in_mins_(ttls_in_mins) {} void update_feature_statistics(weight_type* block) override { @@ -1406,7 +1224,8 @@ class TimeBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { int64_t ttl = ttls_in_mins_[sub_table_id]; if (ttl == 0) { // ttl = 0 means no eviction @@ -1416,7 +1235,7 @@ class TimeBasedEvict : public FeatureEvict { return current_time - FixedBlockPool::get_timestamp(block) > ttl * 60; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& ttls_in_mins_; // Time-to-live for eviction. @@ -1431,16 +1250,14 @@ class TimeThresholdBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_) {} + is_training) {} void update_feature_statistics(weight_type* block) override { FixedBlockPool::update_timestamp(block); @@ -1451,11 +1268,12 @@ class TimeThresholdBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { return FixedBlockPool::get_timestamp(block) < eviction_timestamp_threshold_; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: uint32_t eviction_timestamp_threshold_ = 0; @@ -1473,16 +1291,14 @@ class TimeCounterBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_), + is_training), ttls_in_mins_(ttls_in_mins), decay_rates_(decay_rates), thresholds_(thresholds) {} @@ -1493,7 +1309,8 @@ class TimeCounterBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { int64_t ttl = ttls_in_mins_[sub_table_id]; if (ttl == 0) { // ttl = 0 means no eviction @@ -1515,7 +1332,7 @@ class TimeCounterBasedEvict : public FeatureEvict { (current_count < threshold); } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& ttls_in_mins_; // Time-to-live for eviction. @@ -1534,16 +1351,14 @@ class L2WeightBasedEvict : public FeatureEvict { int64_t interval_for_insufficient_eviction_s, int64_t interval_for_sufficient_eviction_s, int64_t interval_for_feature_statistics_decay_s, - bool is_training, - bool enable_dry_run_) + bool is_training) : FeatureEvict( kv_store, sub_table_hash_cumsum, interval_for_insufficient_eviction_s, interval_for_sufficient_eviction_s, interval_for_feature_statistics_decay_s, - is_training, - enable_dry_run_), + is_training), thresholds_(thresholds), sub_table_dims_(sub_table_dims) {} @@ -1551,7 +1366,8 @@ class L2WeightBasedEvict : public FeatureEvict { } protected: - bool evict_block(weight_type* block, int sub_table_id) override { + bool evict_block(weight_type* block, int sub_table_id, int shard_id) + override { size_t dimension = sub_table_dims_[sub_table_id]; double threshold = thresholds_[sub_table_id]; if (threshold == 0.0) { @@ -1562,7 +1378,7 @@ class L2WeightBasedEvict : public FeatureEvict { return l2weight < threshold; } - void dry_run_calculate_thresholds(int shard_id) override {} + void pre_calculate_thresholds(int shard_id) override {} private: const std::vector& thresholds_; // L2 weight threshold for eviction. @@ -1585,8 +1401,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } case EvictTriggerStrategy::BY_COUNTER: { @@ -1605,7 +1420,6 @@ std::unique_ptr> create_feature_evict( config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, is_training, - false, test_mode); } @@ -1629,7 +1443,6 @@ std::unique_ptr> create_feature_evict( config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, is_training, - true, test_mode); } @@ -1649,8 +1462,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } case EvictTriggerStrategy::BY_L2WEIGHT: { @@ -1673,8 +1485,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } case EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD: { @@ -1684,8 +1495,7 @@ std::unique_ptr> create_feature_evict( config->interval_for_insufficient_eviction_s_, config->interval_for_sufficient_eviction_s_, config->interval_for_feature_statistics_decay_s_, - is_training, - false); + is_training); } default: diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp index df868a9e5f..d1a90a1299 100644 --- a/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/feature_evict_test.cpp @@ -122,6 +122,7 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { auto* pool = kv_store_->pool_by(shard_id); auto* block = pool->allocate_t(); FixedBlockPool::set_key(block, i); + FixedBlockPool::set_used(block, true); FixedBlockPool::set_feature_score_rate(block, i < 400 ? 0.5 : 0.8); wlock->insert({i, block}); } @@ -132,6 +133,7 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { auto* pool = kv_store_->pool_by(shard_id); auto* block = pool->allocate_t(); FixedBlockPool::set_key(block, i); + FixedBlockPool::set_used(block, true); FixedBlockPool::set_feature_score_rate(block, i < 1500 ? 0.6 : 0.9); wlock->insert({i, block}); } @@ -154,7 +156,7 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { 10, // threshold_calculation_bucket_num 0, // interval_for_insufficient_eviction_s 0, // interval_for_sufficient_eviction_s - 0); // interval_for_feature_statistics_decay_s + 100000); // interval_for_feature_statistics_decay_s auto feature_evict = create_feature_evict( feature_evict_config, @@ -166,11 +168,30 @@ TEST(FeatureEvictTest, FeatureScoreBasedEvict) { auto* feature_score_evict = dynamic_cast*>(feature_evict.get()); + std::vector block_cursors_; + std::vector block_nums_snapshot_; + block_cursors_.resize(NUM_SHARDS); + block_nums_snapshot_.resize(NUM_SHARDS); + for (int i = 0; i < NUM_SHARDS; ++i) { + block_cursors_[i] = 0; + block_nums_snapshot_[i] = 0; + } + // Initial validation size_t total_blocks = 0; for (int shard_id = 0; shard_id < NUM_SHARDS; ++shard_id) { auto rlock = kv_store_->by(shard_id).rlock(); - total_blocks += rlock->size(); + auto* pool = kv_store_->pool_by(shard_id); + block_nums_snapshot_[shard_id] = + pool->get_chunks().size() * pool->get_blocks_per_chunk(); + while (block_cursors_[shard_id] < block_nums_snapshot_[shard_id]) { + auto* block = pool->template get_block(block_cursors_[shard_id]++); + if (block != nullptr && FixedBlockPool::get_used(block)) { + total_blocks++; + feature_score_evict->update_feature_score_statistics( + block, 0, shard_id, true); + } + } } ASSERT_EQ(total_blocks, 2000); // Perform eviction @@ -511,7 +532,6 @@ TEST(FeatureEvictTest, PerformanceTest) { 0, 0, true, // is training - false, // dry run TestMode::NORMAL); auto start_time = std::chrono::high_resolution_clock::now(); diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index f34c1a2f06..4474401813 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -850,7 +850,6 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float) full_duration_ms = torch.ones(1, dtype=torch.int64) * -1 exec_duration_ms = torch.empty(1, dtype=torch.int64) - dry_run_exec_duration_ms = torch.empty(1, dtype=torch.int64) shard_load = E / 4 # init @@ -861,7 +860,6 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run, full_duration_ms, exec_duration_ms, - dry_run_exec_duration_ms, ) for _ in range(10): dram_kv_backend.get(indices.clone(), weights_out, count) # pyre-ignore @@ -875,14 +873,12 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run, full_duration_ms, exec_duration_ms, - dry_run_exec_duration_ms, ) if all(processed_counts == shard_load): self.assertTrue(all(evicted_counts == 0)) self.assertTrue(all(processed_counts == shard_load)) self.assertTrue(full_duration_ms.item() > 0) self.assertTrue(exec_duration_ms.item() >= 0) - self.assertTrue(dry_run_exec_duration_ms.item() == 0) # after another 10 rounds, the original ids should all be evicted for _ in range(10): @@ -908,7 +904,6 @@ def test_dram_kv_eviction(self) -> None: eviction_threshold_with_dry_run, full_duration_ms, exec_duration_ms, - dry_run_exec_duration_ms, ) if evicted_counts.sum() > 1: # ID E+1 might be evicted break @@ -916,7 +911,6 @@ def test_dram_kv_eviction(self) -> None: self.assertTrue(all(processed_counts >= shard_load)) self.assertTrue(all(full_duration_ms > 0)) self.assertTrue(all(exec_duration_ms >= 0)) - self.assertTrue(all(dry_run_exec_duration_ms == 0)) def test_dram_kv_feature_score_eviction(self) -> None: max_D = 132 # 128 + 4 @@ -930,7 +924,7 @@ def test_dram_kv_feature_score_eviction(self) -> None: eviction_policy: EvictionPolicy = EvictionPolicy( eviction_trigger_mode=1, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual eviction_strategy=5, # evict_trigger_strategy: 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm, 5: feature score - eviction_step_intervals=2, # trigger_step_interval if trigger mode is iteration + eviction_step_intervals=1, # trigger_step_interval if trigger mode is iteration feature_score_counter_decay_rates=[ 0.9, 0.9, @@ -953,7 +947,7 @@ def test_dram_kv_feature_score_eviction(self) -> None: threshold_calculation_bucket_num=1000000, interval_for_insufficient_eviction_s=0, interval_for_sufficient_eviction_s=0, - interval_for_feature_statistics_decay_s=0, + interval_for_feature_statistics_decay_s=10000, ) dram_kv_backend = self.generate_fbgemm_kv_backend( max_D=max_D, @@ -975,7 +969,6 @@ def test_dram_kv_feature_score_eviction(self) -> None: processed_counts = torch.zeros(T, dtype=torch.int64) full_duration_ms = torch.ones(1, dtype=torch.int64) * -1 exec_duration_ms = torch.empty(1, dtype=torch.int64) - dry_run_exec_duration_ms = torch.empty(1, dtype=torch.int64) eviction_threshold_with_dry_run = torch.zeros(T, dtype=torch.float) shard_load = E / 4 @@ -986,30 +979,24 @@ def test_dram_kv_feature_score_eviction(self) -> None: dram_kv_backend.set_feature_score_metadata_cuda( # pyre-ignore indices, count, metadata_2d ) - time.sleep(5) # wait async set_feature_score_metadata_cuda done - for i in range(2): - print(f"round {i}") - dram_kv_backend.get(indices.clone(), weights_out, count) # pyre-ignore - dram_kv_backend.set(indices, weights, count) - dram_kv_backend.set_feature_score_metadata_cuda(indices, count, metadata_2d) - time.sleep(0.01) # 20ms, stimulate training forward time - dram_kv_backend.set_cuda(indices, weights, count, 1, True) # pyre-ignore - print("after set_cuda") - time.sleep(0.01) # 20ms, stimulate training backward time - dram_kv_backend.wait_until_eviction_done() # pyre-ignore - dram_kv_backend.get_feature_evict_metric( # pyre-ignore - evicted_counts, - processed_counts, - eviction_threshold_with_dry_run, - full_duration_ms, - exec_duration_ms, - dry_run_exec_duration_ms, - ) + dram_kv_backend.set_cuda(indices, weights, count, 1, True) # pyre-ignore + time.sleep(5) + # trigger evict + dram_kv_backend.get(indices.clone(), weights_out, count) # pyre-ignore + + dram_kv_backend.wait_until_eviction_done() # pyre-ignore + dram_kv_backend.get_feature_evict_metric( # pyre-ignore + evicted_counts, + processed_counts, + eviction_threshold_with_dry_run, + full_duration_ms, + exec_duration_ms, + ) + self.assertTrue(all(evicted_counts == 700)) self.assertTrue(all(processed_counts == shard_load)) self.assertTrue(full_duration_ms.item() > 0) self.assertTrue(exec_duration_ms.item() >= 0) - self.assertTrue(dry_run_exec_duration_ms.item() > 0) self.assertTrue(all(eviction_threshold_with_dry_run > 0)) @given(