Skip to content

Commit ce6ff5b

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Remove dry run
Summary: X-link: facebookresearch/FBGEMM#1769 Rollback Plan: Differential Revision: D80425794
1 parent 87f413c commit ce6ff5b

File tree

5 files changed

+310
-468
lines changed

5 files changed

+310
-468
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 82 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,18 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
384384
local_write_allocate_total_duration +=
385385
facebook::WallClockUtil::NowInUsecFast() -
386386
before_alloc_ts;
387+
if (feature_evict_config_.has_value() &&
388+
feature_evict_config_.value()->trigger_mode_ !=
389+
EvictTriggerMode::DISABLED &&
390+
feature_evict_) {
391+
auto* feature_score_evict = dynamic_cast<
392+
FeatureScoreBasedEvict<weight_type>*>(
393+
feature_evict_.get());
394+
if (feature_score_evict) {
395+
feature_score_evict->init_eviction_bucket(
396+
block, shard_id);
397+
}
398+
}
387399
}
388400
if (feature_evict_config_.has_value() &&
389401
feature_evict_config_.value()->trigger_mode_ !=
@@ -711,10 +723,12 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
711723
block = pool->template allocate_t<weight_type>();
712724
FixedBlockPool::set_key(block, id);
713725
wlmap->insert({id, block});
726+
feature_score_evict->init_eviction_bucket(
727+
block, shard_id);
714728
}
715729

716730
feature_score_evict->update_feature_score_statistics(
717-
block, engege_rate);
731+
block, engege_rate, shard_id);
718732
updated_id_count++;
719733
}
720734
}
@@ -1489,61 +1503,75 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
14891503
const auto indexes = iter->second;
14901504
auto f =
14911505
folly::via(executor_.get())
1492-
.thenValue(
1493-
[this, shard_id, indexes, &indices, &weights_with_metaheader](
1494-
folly::Unit) {
1495-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1496-
indices.scalar_type(),
1497-
"dram_kv_set_with_metaheader",
1498-
[this,
1499-
shard_id,
1500-
indexes,
1501-
&indices,
1502-
&weights_with_metaheader] {
1503-
using index_t = scalar_t;
1504-
CHECK(indices.is_contiguous());
1505-
CHECK(weights_with_metaheader.is_contiguous());
1506-
CHECK_EQ(
1507-
indices.size(0), weights_with_metaheader.size(0));
1508-
{
1509-
auto wlmap = kv_store_.by(shard_id).wlock();
1510-
auto* pool = kv_store_.pool_by(shard_id);
1511-
int64_t stride = weights_with_metaheader.size(1);
1512-
auto indices_data_ptr = indices.data_ptr<index_t>();
1513-
auto weights_data_ptr =
1514-
weights_with_metaheader.data_ptr<weight_type>();
1515-
for (auto index_iter = indexes.begin();
1516-
index_iter != indexes.end();
1517-
index_iter++) {
1518-
const auto& id_index = *index_iter;
1519-
auto id = int64_t(indices_data_ptr[id_index]);
1520-
// Defensive programming
1521-
// used is false shouldn't occur under normal
1522-
// circumstances
1523-
FixedBlockPool::set_used(
1524-
weights_data_ptr + id_index * stride, true);
1525-
1526-
// use mempool
1527-
weight_type* block = nullptr;
1528-
// First check if the key already exists
1529-
auto it = wlmap->find(id);
1530-
if (it != wlmap->end()) {
1531-
block = it->second;
1532-
} else {
1533-
// Key doesn't exist, allocate new block and
1534-
// insert.
1535-
block =
1536-
pool->template allocate_t<weight_type>();
1537-
wlmap->insert({id, block});
1538-
}
1539-
std::copy(
1540-
weights_data_ptr + id_index * stride,
1541-
weights_data_ptr + (id_index + 1) * stride,
1542-
block);
1506+
.thenValue([this,
1507+
shard_id,
1508+
indexes,
1509+
&indices,
1510+
&weights_with_metaheader](folly::Unit) {
1511+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1512+
indices.scalar_type(),
1513+
"dram_kv_set_with_metaheader",
1514+
[this,
1515+
shard_id,
1516+
indexes,
1517+
&indices,
1518+
&weights_with_metaheader] {
1519+
using index_t = scalar_t;
1520+
CHECK(indices.is_contiguous());
1521+
CHECK(weights_with_metaheader.is_contiguous());
1522+
CHECK_EQ(
1523+
indices.size(0), weights_with_metaheader.size(0));
1524+
{
1525+
auto wlmap = kv_store_.by(shard_id).wlock();
1526+
auto* pool = kv_store_.pool_by(shard_id);
1527+
int64_t stride = weights_with_metaheader.size(1);
1528+
auto indices_data_ptr = indices.data_ptr<index_t>();
1529+
auto weights_data_ptr =
1530+
weights_with_metaheader.data_ptr<weight_type>();
1531+
for (auto index_iter = indexes.begin();
1532+
index_iter != indexes.end();
1533+
index_iter++) {
1534+
const auto& id_index = *index_iter;
1535+
auto id = int64_t(indices_data_ptr[id_index]);
1536+
// Defensive programming
1537+
// used is false shouldn't occur under normal
1538+
// circumstances
1539+
FixedBlockPool::set_used(
1540+
weights_data_ptr + id_index * stride, true);
1541+
1542+
// use mempool
1543+
weight_type* block = nullptr;
1544+
// First check if the key already exists
1545+
auto it = wlmap->find(id);
1546+
if (it != wlmap->end()) {
1547+
block = it->second;
1548+
} else {
1549+
// Key doesn't exist, allocate new block and
1550+
// insert.
1551+
block = pool->template allocate_t<weight_type>();
1552+
wlmap->insert({id, block});
1553+
}
1554+
std::copy(
1555+
weights_data_ptr + id_index * stride,
1556+
weights_data_ptr + (id_index + 1) * stride,
1557+
block);
1558+
1559+
if (feature_evict_config_.has_value() &&
1560+
feature_evict_config_.value()->trigger_mode_ !=
1561+
EvictTriggerMode::DISABLED &&
1562+
feature_evict_) {
1563+
auto* feature_score_evict = dynamic_cast<
1564+
FeatureScoreBasedEvict<weight_type>*>(
1565+
feature_evict_.get());
1566+
if (feature_score_evict) {
1567+
feature_score_evict->init_eviction_bucket(
1568+
block, shard_id);
15431569
}
15441570
}
1545-
});
1546-
});
1571+
}
1572+
}
1573+
});
1574+
});
15471575
futures.push_back(std::move(f));
15481576
}
15491577
return folly::collect(futures);

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
148148
at::Tensor processed_counts,
149149
at::Tensor eviction_threshold_with_dry_run,
150150
at::Tensor full_duration_ms,
151-
at::Tensor exec_duration_ms,
152-
at::Tensor dry_run_exec_duration_ms) {
151+
at::Tensor exec_duration_ms) {
153152
auto metrics = impl_->get_feature_evict_metric();
154153
if (metrics.has_value()) {
155154
evicted_counts.copy_(
@@ -164,8 +163,6 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
164163
metrics.value().full_duration_ms); // full duration (Long)
165164
exec_duration_ms.copy_(
166165
metrics.value().exec_duration_ms); // exec duration (Long)
167-
dry_run_exec_duration_ms.copy_(
168-
metrics.value().dry_run_exec_duration_ms); // dry run exec duration
169166
}
170167
}
171168

0 commit comments

Comments
 (0)