diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 5ddc59dea5..1db0d02f8e 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -239,17 +239,27 @@ def test_current_stream(self): def test_vllm_version_is(self): with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}): with mock.patch("vllm.__version__", "1.0.0"): - self.assertTrue(utils.vllm_version_is("1.0.0")) - self.assertFalse(utils.vllm_version_is("2.0.0")) + self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0")) + self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0")) with mock.patch("vllm.__version__", "2.0.0"): - self.assertTrue(utils.vllm_version_is("1.0.0")) - self.assertFalse(utils.vllm_version_is("2.0.0")) + self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0")) + self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0")) with mock.patch("vllm.__version__", "1.0.0"): - self.assertTrue(utils.vllm_version_is("1.0.0")) - self.assertFalse(utils.vllm_version_is("2.0.0")) + self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0")) + self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0")) with mock.patch("vllm.__version__", "2.0.0"): - self.assertTrue(utils.vllm_version_is("2.0.0")) - self.assertFalse(utils.vllm_version_is("1.0.0")) + self.assertTrue(utils.vllm_version_is.__wrapped__("2.0.0")) + self.assertFalse(utils.vllm_version_is.__wrapped__("1.0.0")) + # Test caching takes effect + utils.vllm_version_is.cache_clear() + utils.vllm_version_is("1.0.0") + misses = utils.vllm_version_is.cache_info().misses + hits = utils.vllm_version_is.cache_info().hits + self.assertEqual(misses, 1) + self.assertEqual(hits, 0) + utils.vllm_version_is("1.0.0") + hits = utils.vllm_version_is.cache_info().hits + self.assertEqual(hits, 1) def test_update_aclgraph_sizes(self): # max_num_batch_sizes < len(original_sizes) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 634e13cb9e..a024d93289 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -19,6 +19,7 @@ import atexit import fcntl +import functools import math import os import shutil @@ -280,6 +281,7 @@ def adapt_patch(is_global_patch: bool = False): from vllm_ascend.patch import worker # noqa: F401 +@functools.cache def vllm_version_is(target_vllm_version: str): if envs.VLLM_VERSION is not None: vllm_version = envs.VLLM_VERSION