@@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
105105 int64_t num_shards = 8 ,
106106 int64_t num_threads = 32 ,
107107 int64_t row_storage_bitwidth = 32 ,
108+ bool backend_return_whole_row = false ,
108109 bool enable_async_update = false ,
109110 std::optional<at::Tensor> table_dims = std::nullopt ,
110111 std::optional<at::Tensor> hash_size_cumsum = std::nullopt )
@@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
126127 block_alignment_,
127128 /* blocks_per_chunk=*/ 8192 )),
128129 elem_size_(row_storage_bitwidth / 8 ),
130+ backend_return_whole_row_(backend_return_whole_row),
129131 feature_evict_config_(feature_evict_config) {
130132 executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t >(
131133 num_threads, facebook::Proc::getCpuInfo ().numCpuCores ));
@@ -608,11 +610,15 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
608610 void set_range_to_storage (
609611 const at::Tensor& weights,
610612 const int64_t start,
611- const int64_t length) {
612- const auto seq_indices =
613- at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
614- const auto count = at::tensor ({length}, at::ScalarType::Long);
615- folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
613+ const int64_t length) override {
614+ if (backend_return_whole_row_) {
615+ set_kv_with_metaheader_to_storage (weights);
616+ } else {
617+ const auto seq_indices = at::arange (
618+ start, start + length, at::TensorOptions ().dtype (at::kLong ));
619+ const auto count = at::tensor ({length}, at::ScalarType::Long);
620+ folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
621+ }
616622 }
617623
618624 void get_range_from_snapshot (
@@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
625631 CHECK (snapshot_handle == nullptr );
626632 const auto seq_indices =
627633 at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
628- const auto count = at::tensor ({length}, at::ScalarType::Long);
629- get_kv_db_async_impl (
630- seq_indices, weights, count, width_offset, width_length)
631- .wait ();
634+
635+ if (backend_return_whole_row_) {
636+ get_kv_with_metaheader_from_storage (seq_indices, weights);
637+ } else {
638+ const auto count = at::tensor ({length}, at::ScalarType::Long);
639+ get_kv_db_async_impl (
640+ seq_indices, weights, count, width_offset, width_length)
641+ .wait ();
642+ }
643+
632644 // this is called by checkpoint mostly, and checkpoint should wait until
633645 // eviction finishes so that we could reacha consistent state before/after
634646 // state_dict() calls
@@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
642654 int64_t width_offset = 0 ,
643655 std::optional<int64_t > width_length = std::nullopt ) override {
644656 CHECK (snapshot_handle == nullptr );
657+
658+ if (backend_return_whole_row_) {
659+ get_kv_with_metaheader_from_storage (
660+ ids, weights, width_offset, width_length);
661+ } else {
662+ const auto count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
663+ get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
664+ .wait ();
665+ }
666+ }
667+
668+ // used for ckpt, get kv with metaheader from storage
669+ void get_kv_with_metaheader_from_storage (
670+ const at::Tensor& ids,
671+ const at::Tensor& weights_with_metaheader,
672+ int64_t width_offset = 0 ,
673+ std::optional<int64_t > width_length = std::nullopt ) {
645674 const auto count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
646- get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
675+ get_kv_db_with_metaheader_async_impl (
676+ ids, weights_with_metaheader, count, width_offset, width_length)
677+ .wait ();
678+ }
679+
680+ void set_kv_with_metaheader_to_storage (
681+ const at::Tensor& weights_with_metaheader) {
682+ std::vector<int64_t > keys (weights_with_metaheader.size (0 ), 0 );
683+ for (int64_t i = 0 ; i < weights_with_metaheader.size (0 ); ++i) {
684+ keys[i] = FixedBlockPool::get_key (weights_with_metaheader[i].data_ptr ());
685+ }
686+ auto indices =
687+ torch::from_blob (keys.data (), {int64_t (keys.size ())}, torch::kInt64 );
688+ const auto count =
689+ at::tensor ({weights_with_metaheader.size (0 )}, at::ScalarType::Long);
690+ set_kv_db_with_metaheader_async_impl (
691+ indices, weights_with_metaheader, count)
647692 .wait ();
648693 // this is called by checkpoint mostly, and checkpoint should wait until
649694 // eviction finishes so that we could reacha consistent state before/after
@@ -826,6 +871,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
826871
827872 void flush_or_compact (const int64_t timestep) override {}
828873
874+ bool get_backend_return_whole_row () override {
875+ return backend_return_whole_row_;
876+ }
877+
878+ int64_t get_metaheader_width_in_front () override {
879+ return backend_return_whole_row_
880+ ? FixedBlockPool::get_metaheader_dim<weight_type>()
881+ : 0 ;
882+ }
883+
829884 void resume_ongoing_eviction () override {
830885 if (feature_evict_) {
831886 feature_evict_->resume ();
@@ -930,6 +985,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
930985 return ret;
931986 }
932987
988+ // / Get embeddings and metaheader from kvstore.
989+ // /
990+ // / @param indices The 1D embedding index tensor, should skip on negative
991+ // / value
992+ // / @param weights_with_metaheader The 2D tensor that each row(embeddings) is
993+ // / paired up with relative element in <indices>. This tensor will be
994+ // / filled up with the returned embeddings from KVstore.
995+ // / @param count A single element tensor that contains the number of indices
996+ // / to be processed
997+ // /
998+ // / @return None
999+ folly::SemiFuture<std::vector<folly::Unit>>
1000+ get_kv_db_with_metaheader_async_impl (
1001+ const at::Tensor& indices,
1002+ const at::Tensor& weights_with_metaheader,
1003+ const at::Tensor& count,
1004+ int64_t width_offset = 0 ,
1005+ std::optional<int64_t > width_length = std::nullopt ) {
1006+ std::vector<folly::Future<folly::Unit>> futures;
1007+ auto row_width = weights_with_metaheader.size (1 );
1008+ auto copy_width = width_length.value_or (row_width);
1009+ CHECK_LE (row_width, block_size_);
1010+ CHECK_EQ (copy_width, row_width);
1011+ auto shardid_to_indexes = shard_input (indices, count);
1012+
1013+ for (auto iter = shardid_to_indexes.begin ();
1014+ iter != shardid_to_indexes.end ();
1015+ iter++) {
1016+ const auto shard_id = iter->first ;
1017+ const auto indexes = iter->second ;
1018+ auto f =
1019+ folly::via (executor_.get ())
1020+ .thenValue ([this ,
1021+ shard_id,
1022+ indexes,
1023+ &indices,
1024+ &weights_with_metaheader,
1025+ width_offset,
1026+ row_width](folly::Unit) {
1027+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1028+ indices.scalar_type (),
1029+ " dram_kvstore_get_with_metaheader" ,
1030+ [this ,
1031+ shard_id,
1032+ indexes,
1033+ &indices,
1034+ &weights_with_metaheader,
1035+ width_offset,
1036+ row_width] {
1037+ using index_t = scalar_t ;
1038+ CHECK (indices.is_contiguous ());
1039+ CHECK (weights_with_metaheader.is_contiguous ());
1040+ CHECK_EQ (
1041+ indices.size (0 ), weights_with_metaheader.size (0 ));
1042+ auto wlmap = kv_store_.by (shard_id).wlock ();
1043+ auto indices_data_ptr = indices.data_ptr <index_t >();
1044+ auto weights_data_ptr =
1045+ weights_with_metaheader.data_ptr <weight_type>();
1046+ {
1047+ for (auto index_iter = indexes.begin ();
1048+ index_iter != indexes.end ();
1049+ index_iter++) {
1050+ const auto weights_row_index = *index_iter;
1051+ auto weight_idx =
1052+ int64_t (indices_data_ptr[weights_row_index]);
1053+ const auto cached_iter = wlmap->find (weight_idx);
1054+ // Defensive programming
1055+ // it shouldn't occur under normal circumstances
1056+ if (cached_iter == wlmap->end ()) {
1057+ std::memset (
1058+ &(weights_data_ptr
1059+ [weights_row_index * row_width]),
1060+ 0 ,
1061+ row_width);
1062+ continue ;
1063+ }
1064+
1065+ // For weight KVT, offset=0 and it will read the whole
1066+ // row. For optimizer, offset=dim(metaheader) +
1067+ // emb_dim so it will only read the optimizer part
1068+ const auto * ptr_offset_from_front =
1069+ FixedBlockPool::ptr_offset_from_front<
1070+ weight_type>(
1071+ cached_iter->second , width_offset);
1072+ std::copy (
1073+ ptr_offset_from_front,
1074+ ptr_offset_from_front + row_width,
1075+ &(weights_data_ptr
1076+ [weights_row_index * row_width]));
1077+ }
1078+ }
1079+ });
1080+ });
1081+ futures.push_back (std::move (f));
1082+ }
1083+ return folly::collect (futures);
1084+ }
1085+
1086+ // / insert embeddings and metaheader into kvstore.
1087+ // / current underlying memory management is done through F14FastMap
1088+ // / key value pair will be sharded into multiple shards to increase
1089+ // / parallelism.
1090+ // /
1091+ // / @param indices The 1D embedding index tensor, should skip on negative
1092+ // / value
1093+ // / @param weights_with_metaheader The 2D tensor that each row(embeddings with
1094+ // / metaheader) is paired up with relative element in <indices>
1095+ // / @param count A single element tensor that contains the number of indices
1096+ // / to be processed
1097+ // /
1098+ // / @return None
1099+ folly::SemiFuture<std::vector<folly::Unit>>
1100+ set_kv_db_with_metaheader_async_impl (
1101+ const at::Tensor& indices,
1102+ const at::Tensor& weights_with_metaheader,
1103+ const at::Tensor& count) {
1104+ std::vector<folly::Future<folly::Unit>> futures;
1105+ auto shardid_to_indexes = shard_input (indices, count);
1106+ for (auto iter = shardid_to_indexes.begin ();
1107+ iter != shardid_to_indexes.end ();
1108+ iter++) {
1109+ const auto shard_id = iter->first ;
1110+ const auto indexes = iter->second ;
1111+ auto f =
1112+ folly::via (executor_.get ())
1113+ .thenValue (
1114+ [this , shard_id, indexes, &indices, &weights_with_metaheader](
1115+ folly::Unit) {
1116+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1117+ indices.scalar_type (),
1118+ " dram_kv_set_with_metaheader" ,
1119+ [this ,
1120+ shard_id,
1121+ indexes,
1122+ &indices,
1123+ &weights_with_metaheader] {
1124+ using index_t = scalar_t ;
1125+ CHECK (indices.is_contiguous ());
1126+ CHECK (weights_with_metaheader.is_contiguous ());
1127+ CHECK_EQ (
1128+ indices.size (0 ), weights_with_metaheader.size (0 ));
1129+ {
1130+ auto wlmap = kv_store_.by (shard_id).wlock ();
1131+ auto * pool = kv_store_.pool_by (shard_id);
1132+ int64_t stride = weights_with_metaheader.size (1 );
1133+ auto indices_data_ptr = indices.data_ptr <index_t >();
1134+ auto weights_data_ptr =
1135+ weights_with_metaheader.data_ptr <weight_type>();
1136+ for (auto index_iter = indexes.begin ();
1137+ index_iter != indexes.end ();
1138+ index_iter++) {
1139+ const auto & id_index = *index_iter;
1140+ auto id = int64_t (indices_data_ptr[id_index]);
1141+ // Defensive programming
1142+ // it shouldn't occur under normal circumstances
1143+ auto used = FixedBlockPool::get_used (
1144+ weights_data_ptr + id_index * stride);
1145+ if (!used) {
1146+ continue ;
1147+ }
1148+ // use mempool
1149+ weight_type* block = nullptr ;
1150+ // First check if the key already exists
1151+ auto it = wlmap->find (id);
1152+ if (it != wlmap->end ()) {
1153+ block = it->second ;
1154+ } else {
1155+ // Key doesn't exist, allocate new block and
1156+ // insert.
1157+ block =
1158+ pool->template allocate_t <weight_type>();
1159+ wlmap->insert ({id, block});
1160+ }
1161+ std::copy (
1162+ weights_data_ptr + id_index * stride,
1163+ weights_data_ptr + (id_index + 1 ) * stride,
1164+ block);
1165+ }
1166+ }
1167+ });
1168+ });
1169+ futures.push_back (std::move (f));
1170+ }
1171+ return folly::collect (futures);
1172+ }
1173+
9331174 std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
9341175 // background thread
9351176 folly::FunctionScheduler scheduler_;
@@ -942,6 +1183,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
9421183 std::atomic_bool is_eviction_ongoing_ = false ;
9431184 std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
9441185 int64_t elem_size_;
1186+ bool backend_return_whole_row_;
9451187 std::vector<int64_t > sub_table_dims_;
9461188 std::vector<int64_t > sub_table_hash_cumsum_;
9471189 std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;
0 commit comments