Skip to content

Commit f42b56a

Browse files
duduyi2013facebook-github-bot
authored andcommitted
fix fetch eviction metadata bug and adjust UT to surface it (#4701)
Summary: Pull Request resolved: #4701 X-link: facebookresearch/FBGEMM#1726 pad id with table offset, to get the linearzied id and pass it into eviction metadata fetching logic to get the corresponding metaheader info. adjust UT to make it catch the bug locally Reviewed By: EddyLXJ Differential Revision: D80234997 fbshipit-source-id: d6a614bdef462221c9356bee407f4dbc1eadc16d
1 parent cb3f5ce commit f42b56a

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2979,7 +2979,7 @@ def split_embedding_weights(
29792979
)
29802980
)
29812981
metadata_tensor = self._ssd_db.get_kv_zch_eviction_metadata_by_snapshot(
2982-
bucket_ascending_id_tensor,
2982+
bucket_ascending_id_tensor + table_offset,
29832983
torch.as_tensor(bucket_ascending_id_tensor.size(0)),
29842984
snapshot_handle,
29852985
)

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,12 @@ def generate_kvzch_tbes(
471471
Ds = [D] * T
472472
Es = [E] * T
473473
else:
474-
Ds = [
475-
round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4)
476-
for _ in range(T)
477-
]
478-
Es = [
479-
np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)
480-
]
474+
# Ds = [
475+
# round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4)
476+
# for _ in range(T)
477+
# ]
478+
Ds = [D] * T
479+
Es = [np.random.randint(low=int(0.5 * E), high=int(E)) for _ in range(T)]
481480

482481
if pooling_mode == PoolingMode.SUM:
483482
mode = "sum"
@@ -571,9 +570,9 @@ def generate_kvzch_tbes(
571570
pad_opt = torch.zeros(emb_ref_.size(0), pad_opt_width, dtype=emb_ref_.dtype)
572571
emb_opt_ref = torch.cat((emb_ref_, pad_opt), dim=1)
573572
emb.ssd_db.set_cuda(
574-
torch.arange(t * virtual_E, t * virtual_E + E).to(torch.int64),
573+
torch.arange(t * virtual_E, t * virtual_E + Es[t]).to(torch.int64),
575574
emb_opt_ref,
576-
torch.as_tensor([E]),
575+
torch.as_tensor([Es[t]]),
577576
t,
578577
)
579578
emb_ref_cpu.append(emb_ref_)
@@ -2099,6 +2098,7 @@ def test_kv_emb_state_dict(
20992098
num_buckets=num_buckets,
21002099
enable_optimizer_offloading=enable_optimizer_offloading,
21012100
backend_type=backend_type,
2101+
mixed=True,
21022102
)
21032103

21042104
# Generate inputs

0 commit comments

Comments
 (0)