@@ -384,6 +384,18 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
384
384
local_write_allocate_total_duration +=
385
385
facebook::WallClockUtil::NowInUsecFast () -
386
386
before_alloc_ts;
387
+ if (feature_evict_config_.has_value () &&
388
+ feature_evict_config_.value ()->trigger_mode_ !=
389
+ EvictTriggerMode::DISABLED &&
390
+ feature_evict_) {
391
+ auto * feature_score_evict = dynamic_cast <
392
+ FeatureScoreBasedEvict<weight_type>*>(
393
+ feature_evict_.get ());
394
+ if (feature_score_evict) {
395
+ feature_score_evict->init_eviction_bucket (
396
+ block, shard_id);
397
+ }
398
+ }
387
399
}
388
400
if (feature_evict_config_.has_value () &&
389
401
feature_evict_config_.value ()->trigger_mode_ !=
@@ -711,10 +723,12 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
711
723
block = pool->template allocate_t <weight_type>();
712
724
FixedBlockPool::set_key (block, id);
713
725
wlmap->insert ({id, block});
726
+ feature_score_evict->init_eviction_bucket (
727
+ block, shard_id);
714
728
}
715
729
716
730
feature_score_evict->update_feature_score_statistics (
717
- block, engege_rate);
731
+ block, engege_rate, shard_id );
718
732
updated_id_count++;
719
733
}
720
734
}
@@ -1489,61 +1503,75 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
1489
1503
const auto indexes = iter->second ;
1490
1504
auto f =
1491
1505
folly::via (executor_.get ())
1492
- .thenValue (
1493
- [this , shard_id, indexes, &indices, &weights_with_metaheader](
1494
- folly::Unit) {
1495
- FBGEMM_DISPATCH_INTEGRAL_TYPES (
1496
- indices.scalar_type (),
1497
- " dram_kv_set_with_metaheader" ,
1498
- [this ,
1499
- shard_id,
1500
- indexes,
1501
- &indices,
1502
- &weights_with_metaheader] {
1503
- using index_t = scalar_t ;
1504
- CHECK (indices.is_contiguous ());
1505
- CHECK (weights_with_metaheader.is_contiguous ());
1506
- CHECK_EQ (
1507
- indices.size (0 ), weights_with_metaheader.size (0 ));
1508
- {
1509
- auto wlmap = kv_store_.by (shard_id).wlock ();
1510
- auto * pool = kv_store_.pool_by (shard_id);
1511
- int64_t stride = weights_with_metaheader.size (1 );
1512
- auto indices_data_ptr = indices.data_ptr <index_t >();
1513
- auto weights_data_ptr =
1514
- weights_with_metaheader.data_ptr <weight_type>();
1515
- for (auto index_iter = indexes.begin ();
1516
- index_iter != indexes.end ();
1517
- index_iter++) {
1518
- const auto & id_index = *index_iter;
1519
- auto id = int64_t (indices_data_ptr[id_index]);
1520
- // Defensive programming
1521
- // used is false shouldn't occur under normal
1522
- // circumstances
1523
- FixedBlockPool::set_used (
1524
- weights_data_ptr + id_index * stride, true );
1525
-
1526
- // use mempool
1527
- weight_type* block = nullptr ;
1528
- // First check if the key already exists
1529
- auto it = wlmap->find (id);
1530
- if (it != wlmap->end ()) {
1531
- block = it->second ;
1532
- } else {
1533
- // Key doesn't exist, allocate new block and
1534
- // insert.
1535
- block =
1536
- pool->template allocate_t <weight_type>();
1537
- wlmap->insert ({id, block});
1538
- }
1539
- std::copy (
1540
- weights_data_ptr + id_index * stride,
1541
- weights_data_ptr + (id_index + 1 ) * stride,
1542
- block);
1506
+ .thenValue ([this ,
1507
+ shard_id,
1508
+ indexes,
1509
+ &indices,
1510
+ &weights_with_metaheader](folly::Unit) {
1511
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1512
+ indices.scalar_type (),
1513
+ " dram_kv_set_with_metaheader" ,
1514
+ [this ,
1515
+ shard_id,
1516
+ indexes,
1517
+ &indices,
1518
+ &weights_with_metaheader] {
1519
+ using index_t = scalar_t ;
1520
+ CHECK (indices.is_contiguous ());
1521
+ CHECK (weights_with_metaheader.is_contiguous ());
1522
+ CHECK_EQ (
1523
+ indices.size (0 ), weights_with_metaheader.size (0 ));
1524
+ {
1525
+ auto wlmap = kv_store_.by (shard_id).wlock ();
1526
+ auto * pool = kv_store_.pool_by (shard_id);
1527
+ int64_t stride = weights_with_metaheader.size (1 );
1528
+ auto indices_data_ptr = indices.data_ptr <index_t >();
1529
+ auto weights_data_ptr =
1530
+ weights_with_metaheader.data_ptr <weight_type>();
1531
+ for (auto index_iter = indexes.begin ();
1532
+ index_iter != indexes.end ();
1533
+ index_iter++) {
1534
+ const auto & id_index = *index_iter;
1535
+ auto id = int64_t (indices_data_ptr[id_index]);
1536
+ // Defensive programming
1537
+ // used is false shouldn't occur under normal
1538
+ // circumstances
1539
+ FixedBlockPool::set_used (
1540
+ weights_data_ptr + id_index * stride, true );
1541
+
1542
+ // use mempool
1543
+ weight_type* block = nullptr ;
1544
+ // First check if the key already exists
1545
+ auto it = wlmap->find (id);
1546
+ if (it != wlmap->end ()) {
1547
+ block = it->second ;
1548
+ } else {
1549
+ // Key doesn't exist, allocate new block and
1550
+ // insert.
1551
+ block = pool->template allocate_t <weight_type>();
1552
+ wlmap->insert ({id, block});
1553
+ }
1554
+ std::copy (
1555
+ weights_data_ptr + id_index * stride,
1556
+ weights_data_ptr + (id_index + 1 ) * stride,
1557
+ block);
1558
+
1559
+ if (feature_evict_config_.has_value () &&
1560
+ feature_evict_config_.value ()->trigger_mode_ !=
1561
+ EvictTriggerMode::DISABLED &&
1562
+ feature_evict_) {
1563
+ auto * feature_score_evict = dynamic_cast <
1564
+ FeatureScoreBasedEvict<weight_type>*>(
1565
+ feature_evict_.get ());
1566
+ if (feature_score_evict) {
1567
+ feature_score_evict->init_eviction_bucket (
1568
+ block, shard_id);
1543
1569
}
1544
1570
}
1545
- });
1546
- });
1571
+ }
1572
+ }
1573
+ });
1574
+ });
1547
1575
futures.push_back (std::move (f));
1548
1576
}
1549
1577
return folly::collect (futures);
0 commit comments