Skip to content

Commit 15bb0b0

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Feature score eviction backend and frontend support (#4681)
Summary: X-link: pytorch/torchrec#3287 Pull Request resolved: #4681 X-link: facebookresearch/FBGEMM#1707 ## Context We need a new eviction policy for large embedding which has high id growth rate. The feature score eviction is based on engagement rate of id instead of only time or counter. This will help model to keep all relatively important ids during eviction. ## Detail * New Eviction Strategy: BY_FEATURE_SCORE Added a new eviction trigger strategy BY_FEATURE_SCORE in the eviction config and logic. This strategy uses feature scores derived from engagement rates to decide which IDs to evict. * FeatureScoreBasedEvict Class Implements the feature score based eviction logic. Maintains buckets of feature scores per shard and table to compute eviction thresholds. * Supports a dry-run mode to calculate thresholds before actual eviction. Eviction decisions are based on thresholds computed from feature score distributions. Supports decay of feature score statistics over time. * Async Metadata Update API Added set_kv_zch_eviction_metadata_async method to update feature score metadata asynchronously in the KV store. This method shards the input indices and engagement rates and updates the feature score statistics in parallel. * Dry Run Eviction Mode Introduced a dry run mode to simulate eviction rounds to compute thresholds without actually evicting. Dry run results are used to finalize thresholds for real eviction rounds. Reviewed By: emlin Differential Revision: D78138679 fbshipit-source-id: 6196c3676abf94b690f1ac776ca8f5c739cae1ea
1 parent 635ffe7 commit 15bb0b0

16 files changed

+1319
-132
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class EvictionPolicy(NamedTuple):
6565
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
6666
)
6767
eviction_strategy: int = (
68-
0 # 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
68+
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
6969
)
7070
eviction_step_intervals: Optional[int] = (
7171
None # trigger_step_interval if trigger mode is iteration
@@ -74,17 +74,32 @@ class EvictionPolicy(NamedTuple):
7474
None # eviction trigger condition if trigger mode is mem_util
7575
)
7676
counter_thresholds: Optional[List[int]] = (
77-
None # count_thresholds for each table if eviction strategy is feature score
77+
None # count_thresholds for each table if eviction strategy is counter
7878
)
7979
ttls_in_mins: Optional[List[int]] = (
8080
None # ttls_in_mins for each table if eviction strategy is timestamp
8181
)
8282
counter_decay_rates: Optional[List[float]] = (
83-
None # count_decay_rates for each table if eviction strategy is feature score
83+
None # count_decay_rates for each table if eviction strategy is counter
84+
)
85+
feature_score_counter_decay_rates: Optional[List[float]] = (
86+
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
87+
)
88+
max_training_id_num_per_table: Optional[List[int]] = (
89+
None # max_training_id_num_per_table for each table
90+
)
91+
target_eviction_percent_per_table: Optional[List[float]] = (
92+
None # target_eviction_percent_per_table for each table
8493
)
8594
l2_weight_thresholds: Optional[List[float]] = (
8695
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
8796
)
97+
threshold_calculation_bucket_stride: Optional[float] = (
98+
0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score
99+
)
100+
threshold_calculation_bucket_num: Optional[int] = (
101+
1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score
102+
)
88103
interval_for_insufficient_eviction_s: int = (
89104
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
90105
# insufficient means we didn't evict enough rows, so we want to wait longer time to
@@ -95,6 +110,9 @@ class EvictionPolicy(NamedTuple):
95110
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
96111
60
97112
)
113+
interval_for_feature_statistics_decay_s: int = (
114+
24 * 3600 # 1 day, interval for feature statistics decay
115+
)
98116
meta_header_lens: Optional[List[int]] = None # metaheader length for each table
99117

100118
def validate(self) -> None:
@@ -105,8 +123,8 @@ def validate(self) -> None:
105123
if self.eviction_trigger_mode == 0:
106124
return
107125

108-
assert self.eviction_strategy in [0, 1, 2, 3], (
109-
"eviction_strategy must be 0, 1, 2, or 3, "
126+
assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
127+
"eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
110128
f"actual {self.eviction_strategy}"
111129
)
112130
if self.eviction_trigger_mode == 1:
@@ -161,6 +179,35 @@ def validate(self) -> None:
161179
"counter_thresholds and ttls_in_mins must have the same length, "
162180
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
163181
)
182+
elif self.eviction_strategy == 5:
183+
assert self.feature_score_counter_decay_rates is not None, (
184+
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
185+
f"actual {self.feature_score_counter_decay_rates}"
186+
)
187+
assert self.max_training_id_num_per_table is not None, (
188+
"max_training_id_num_per_table must be set if eviction_strategy is 5,"
189+
f"actual {self.max_training_id_num_per_table}"
190+
)
191+
assert self.target_eviction_percent_per_table is not None, (
192+
"target_eviction_percent_per_table must be set if eviction_strategy is 5,"
193+
f"actual {self.target_eviction_percent_per_table}"
194+
)
195+
assert self.threshold_calculation_bucket_stride is not None, (
196+
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
197+
f"actual {self.threshold_calculation_bucket_stride}"
198+
)
199+
assert self.threshold_calculation_bucket_num is not None, (
200+
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
201+
f"actual {self.threshold_calculation_bucket_num}"
202+
)
203+
assert (
204+
len(self.target_eviction_percent_per_table)
205+
== len(self.feature_score_counter_decay_rates)
206+
== len(self.max_training_id_num_per_table)
207+
), (
208+
"feature_score_thresholds, max_training_id_num_per_table and target_eviction_percent_per_table must have the same length, "
209+
f"actual {self.target_eviction_percent_per_table} vs {self.feature_score_counter_decay_rates} vs {self.max_training_id_num_per_table}"
210+
)
164211

165212

166213
class KVZCHParams(NamedTuple):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -677,18 +677,25 @@ def __init__(
677677
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
678678
else self.l2_cache_size
679679
)
680+
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
680681
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
681682
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
682-
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
683+
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
683684
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
684685
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
685686
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
686-
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is feature score
687-
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is feature score
687+
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
688+
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
689+
self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
690+
self.kv_zch_params.eviction_policy.max_training_id_num_per_table, # max_training_id_num for each table
691+
self.kv_zch_params.eviction_policy.target_eviction_percent_per_table, # target_eviction_percent for each table
688692
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
689693
table_dims.tolist() if table_dims is not None else None,
694+
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
695+
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
690696
self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
691697
self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
698+
self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s,
692699
)
693700
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
694701
self.cache_row_dim,
@@ -1020,6 +1027,9 @@ def __init__(
10201027
self.stats_reporter.register_stats(
10211028
"eviction.feature_table.exec_duration_ms"
10221029
)
1030+
self.stats_reporter.register_stats(
1031+
"eviction.feature_table.dry_run_exec_duration_ms"
1032+
)
10231033
self.stats_reporter.register_stats(
10241034
"eviction.feature_table.exec_div_full_duration_rate"
10251035
)
@@ -1607,6 +1617,7 @@ def prefetch(
16071617
self,
16081618
indices: Tensor,
16091619
offsets: Tensor,
1620+
weights: Optional[Tensor] = None, # todo: need to update caller
16101621
forward_stream: Optional[torch.cuda.Stream] = None,
16111622
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
16121623
) -> None:
@@ -1632,6 +1643,7 @@ def prefetch(
16321643
self._prefetch(
16331644
indices,
16341645
offsets,
1646+
weights,
16351647
vbe_metadata,
16361648
forward_stream,
16371649
)
@@ -1640,6 +1652,7 @@ def _prefetch( # noqa C901
16401652
self,
16411653
indices: Tensor,
16421654
offsets: Tensor,
1655+
weights: Optional[Tensor] = None,
16431656
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
16441657
forward_stream: Optional[torch.cuda.Stream] = None,
16451658
) -> None:
@@ -1667,6 +1680,12 @@ def _prefetch( # noqa C901
16671680

16681681
self.timestep += 1
16691682
self.timesteps_prefetched.append(self.timestep)
1683+
if self.backend_type == BackendType.DRAM and weights is not None:
1684+
# DRAM backend supports feature score eviction, if there is weights available
1685+
# in the prefetch call, we will set metadata for feature score eviction asynchronously
1686+
cloned_linear_cache_indices = linear_cache_indices.clone()
1687+
else:
1688+
cloned_linear_cache_indices = None
16701689

16711690
# Lookup and virtually insert indices into L1. After this operator,
16721691
# we know:
@@ -2024,6 +2043,16 @@ def _prefetch( # noqa C901
20242043
is_bwd=False,
20252044
)
20262045

2046+
if self.backend_type == BackendType.DRAM and weights is not None:
2047+
# Write feature score metadata to DRAM
2048+
self.record_function_via_dummy_profile(
2049+
"## ssd_write_feature_score_metadata ##",
2050+
self.ssd_db.set_feature_score_metadata_cuda,
2051+
cloned_linear_cache_indices.cpu(),
2052+
torch.tensor([weights.shape[0]], device="cpu", dtype=torch.long),
2053+
weights.cpu().view(torch.float32).view(-1, 2),
2054+
)
2055+
20272056
# Generate row addresses (pointing to either L1 or the current
20282057
# iteration's scratch pad)
20292058
with record_function("## ssd_generate_row_addrs ##"):
@@ -2166,6 +2195,7 @@ def forward(
21662195
self,
21672196
indices: Tensor,
21682197
offsets: Tensor,
2198+
weights: Optional[Tensor] = None,
21692199
per_sample_weights: Optional[Tensor] = None,
21702200
feature_requires_grad: Optional[Tensor] = None,
21712201
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
@@ -2187,7 +2217,7 @@ def forward(
21872217
context=self.step,
21882218
stream=self.ssd_eviction_stream,
21892219
):
2190-
self._prefetch(indices, offsets, vbe_metadata)
2220+
self._prefetch(indices, offsets, weights, vbe_metadata)
21912221

21922222
assert len(self.ssd_prefetch_data) > 0
21932223

@@ -3792,8 +3822,13 @@ def _report_eviction_stats(self) -> None:
37923822
processed_counts = torch.zeros(T, dtype=torch.int64)
37933823
full_duration_ms = torch.tensor(0, dtype=torch.int64)
37943824
exec_duration_ms = torch.tensor(0, dtype=torch.int64)
3825+
dry_run_exec_duration_ms = torch.tensor(0, dtype=torch.int64)
37953826
self.ssd_db.get_feature_evict_metric(
3796-
evicted_counts, processed_counts, full_duration_ms, exec_duration_ms
3827+
evicted_counts,
3828+
processed_counts,
3829+
full_duration_ms,
3830+
exec_duration_ms,
3831+
dry_run_exec_duration_ms,
37973832
)
37983833

37993834
stats_reporter.report_data_amount(
@@ -3845,6 +3880,12 @@ def _report_eviction_stats(self) -> None:
38453880
duration_ms=exec_duration_ms.item(),
38463881
time_unit="ms",
38473882
)
3883+
stats_reporter.report_duration(
3884+
iteration_step=self.step,
3885+
event_name="eviction.feature_table.dry_run_exec_duration_ms",
3886+
duration_ms=dry_run_exec_duration_ms.item(),
3887+
time_unit="ms",
3888+
)
38483889
if full_duration_ms.item() != 0:
38493890
stats_reporter.report_data_amount(
38503891
iteration_step=self.step,

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,111 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
633633
});
634634
}
635635

636+
/// Update feature scores metadata into kvstore.
637+
folly::SemiFuture<std::vector<folly::Unit>>
638+
set_kv_zch_eviction_metadata_async(
639+
at::Tensor indices,
640+
at::Tensor count,
641+
at::Tensor engege_rates) override {
642+
if (!feature_evict_ || !feature_evict_config_.has_value() ||
643+
feature_evict_config_.value()->trigger_mode_ ==
644+
EvictTriggerMode::DISABLED) {
645+
// featre eviction is disabled
646+
return folly::makeSemiFuture(std::vector<folly::Unit>());
647+
}
648+
649+
CHECK_EQ(engege_rates.scalar_type(), at::ScalarType::Float);
650+
auto* feature_score_evict =
651+
dynamic_cast<FeatureScoreBasedEvict<weight_type>*>(
652+
feature_evict_.get());
653+
654+
if (feature_score_evict == nullptr) {
655+
// Not a feature score based eviction
656+
return folly::makeSemiFuture(std::vector<folly::Unit>());
657+
}
658+
pause_ongoing_eviction();
659+
std::vector<folly::Future<int64_t>> futures;
660+
auto shardid_to_indexes = shard_input(indices, count);
661+
for (auto iter = shardid_to_indexes.begin();
662+
iter != shardid_to_indexes.end();
663+
iter++) {
664+
const auto shard_id = iter->first;
665+
const auto indexes = iter->second;
666+
auto f =
667+
folly::via(executor_.get())
668+
.thenValue([this,
669+
shard_id,
670+
indexes,
671+
indices,
672+
engege_rates,
673+
feature_score_evict](folly::Unit) {
674+
int64_t updated_id_count = 0;
675+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
676+
indices.scalar_type(),
677+
"dram_set_kv_feature_score_metadata",
678+
[this,
679+
shard_id,
680+
indexes,
681+
indices,
682+
engege_rates,
683+
feature_score_evict,
684+
&updated_id_count] {
685+
using index_t = scalar_t;
686+
CHECK(indices.is_contiguous());
687+
CHECK(engege_rates.is_contiguous());
688+
CHECK_EQ(indices.size(0), engege_rates.size(0));
689+
auto indices_data_ptr = indices.data_ptr<index_t>();
690+
auto engage_rate_ptr = engege_rates.data_ptr<float>();
691+
int64_t stride = 2;
692+
{
693+
auto wlmap = kv_store_.by(shard_id).wlock();
694+
auto* pool = kv_store_.pool_by(shard_id);
695+
696+
for (auto index_iter = indexes.begin();
697+
index_iter != indexes.end();
698+
index_iter++) {
699+
const auto& id_index = *index_iter;
700+
auto id = int64_t(indices_data_ptr[id_index]);
701+
float engege_rate =
702+
float(engage_rate_ptr[id_index * stride + 0]);
703+
// use mempool
704+
weight_type* block = nullptr;
705+
auto it = wlmap->find(id);
706+
if (it != wlmap->end()) {
707+
block = it->second;
708+
} else {
709+
// Key doesn't exist, allocate new block and
710+
// insert.
711+
block = pool->template allocate_t<weight_type>();
712+
FixedBlockPool::set_key(block, id);
713+
wlmap->insert({id, block});
714+
}
715+
716+
feature_score_evict->update_feature_score_statistics(
717+
block, engege_rate);
718+
updated_id_count++;
719+
}
720+
}
721+
});
722+
return updated_id_count;
723+
});
724+
futures.push_back(std::move(f));
725+
}
726+
return folly::collect(std::move(futures))
727+
.via(executor_.get())
728+
.thenValue([this](const std::vector<int64_t>& results) {
729+
resume_ongoing_eviction();
730+
int total_updated_ids = 0;
731+
for (const auto& result : results) {
732+
total_updated_ids += result;
733+
}
734+
LOG(INFO)
735+
<< "[DRAM KV][Feature Score Eviction]Total updated IDs across all shards: "
736+
<< total_updated_ids;
737+
return std::vector<folly::Unit>(results.size());
738+
});
739+
}
740+
636741
/// Get embeddings from kvstore.
637742
///
638743
/// @param indices The 1D embedding index tensor, should skip on negative

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
7676
at::Tensor count,
7777
int64_t timestep,
7878
bool is_bwd) {
79-
return impl_->set_cuda(indices, weights, count, timestep);
79+
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
8080
}
8181

8282
void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
@@ -147,7 +147,8 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
147147
at::Tensor evicted_counts,
148148
at::Tensor processed_counts,
149149
at::Tensor full_duration_ms,
150-
at::Tensor exec_duration_ms) {
150+
at::Tensor exec_duration_ms,
151+
at::Tensor dry_run_exec_duration_ms) {
151152
auto metrics = impl_->get_feature_evict_metric();
152153
if (metrics.has_value()) {
153154
evicted_counts.copy_(
@@ -158,6 +159,8 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
158159
metrics.value().full_duration_ms); // full duration (Long)
159160
exec_duration_ms.copy_(
160161
metrics.value().exec_duration_ms); // exec duration (Long)
162+
dry_run_exec_duration_ms.copy_(
163+
metrics.value().dry_run_exec_duration_ms); // dry run exec duration
161164
}
162165
}
163166

@@ -169,6 +172,13 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
169172
impl_->set_backend_return_whole_row(backend_return_whole_row);
170173
}
171174

175+
void set_feature_score_metadata_cuda(
176+
at::Tensor indices,
177+
at::Tensor count,
178+
at::Tensor engage_show_count) {
179+
impl_->set_feature_score_metadata_cuda(indices, count, engage_show_count);
180+
}
181+
172182
private:
173183
// friend class EmbeddingRocksDBWrapper;
174184
friend class ssd::KVTensorWrapper;

0 commit comments

Comments
 (0)