Skip to content

Commit e6fa7f0

Browse files
Jialinjinzhen-lin
authored andcommitted
[Core] Add basic unit test for maybe_evict_cached_block (vllm-project#21400)
Signed-off-by: Jialin Ouyang <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 6bad8a5 commit e6fa7f0

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,73 @@ def test_prefix_cache_stats_disabled():
10971097
assert manager.prefix_cache_stats is None
10981098

10991099

1100+
def test_maybe_evict_cached_block():
1101+
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
1102+
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
1103+
token_ids=(100, )),
1104+
group_id=1000)
1105+
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
1106+
token_ids=(200, )),
1107+
group_id=2000)
1108+
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
1109+
token_ids=(300, )),
1110+
group_id=3000)
1111+
block_hashes = [
1112+
block_hash0,
1113+
block_hash1,
1114+
block_hash2,
1115+
# block3 had the exact same block_hash as the first block
1116+
block_hash0,
1117+
]
1118+
assert len(pool.blocks) == len(block_hashes)
1119+
# Manually add all blocks to cached_blocks
1120+
for block, block_hash in zip(pool.blocks, block_hashes):
1121+
block.block_hash = block_hash
1122+
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
1123+
1124+
block0, block1, block2, block3 = pool.blocks
1125+
assert pool.cached_block_hash_to_block == {
1126+
block_hash0: {
1127+
block0.block_id: block0,
1128+
block3.block_id: block3
1129+
},
1130+
block_hash1: {
1131+
block1.block_id: block1
1132+
},
1133+
block_hash2: {
1134+
block2.block_id: block2
1135+
}
1136+
}
1137+
# Evict block1
1138+
pool._maybe_evict_cached_block(block1)
1139+
assert pool.cached_block_hash_to_block == {
1140+
block_hash0: {
1141+
block0.block_id: block0,
1142+
block3.block_id: block3
1143+
},
1144+
block_hash2: {
1145+
block2.block_id: block2
1146+
}
1147+
}
1148+
# Evict block0: block_hash0 entry should NOT be removed, as block3
1149+
# also use the same hash
1150+
pool._maybe_evict_cached_block(block0)
1151+
assert pool.cached_block_hash_to_block == {
1152+
block_hash0: {
1153+
block3.block_id: block3
1154+
},
1155+
block_hash2: {
1156+
block2.block_id: block2
1157+
}
1158+
}
1159+
# Evict block2
1160+
pool._maybe_evict_cached_block(block2)
1161+
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
1162+
# Evict block3
1163+
pool._maybe_evict_cached_block(block3)
1164+
assert pool.cached_block_hash_to_block == {}
1165+
1166+
11001167
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
11011168
def test_kv_cache_events(blocks_to_cache: int):
11021169
block_size = 16

0 commit comments

Comments
 (0)