| 
4 | 4 | import copy  | 
5 | 5 | import gc  | 
6 | 6 | import os  | 
 | 7 | +from contextlib import AbstractContextManager, nullcontext  | 
7 | 8 | from typing import TYPE_CHECKING, Any, Optional  | 
8 | 9 | 
 
  | 
9 | 10 | import torch  | 
@@ -118,6 +119,21 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:  | 
118 | 119 |                     buffer.data.copy_(self._sleep_saved_buffers[name].data)  | 
119 | 120 |             self._sleep_saved_buffers = {}  | 
120 | 121 | 
 
  | 
 | 122 | +    def _maybe_get_memory_pool_context(self,  | 
 | 123 | +                                       tag: str) -> AbstractContextManager:  | 
 | 124 | +        if self.vllm_config.model_config.enable_sleep_mode:  | 
 | 125 | +            from vllm.device_allocator.cumem import CuMemAllocator  | 
 | 126 | + | 
 | 127 | +            allocator = CuMemAllocator.get_instance()  | 
 | 128 | +            if tag == "weights":  | 
 | 129 | +                assert allocator.get_current_usage() == 0, (  | 
 | 130 | +                    "Sleep mode can only be "  | 
 | 131 | +                    "used for one instance per process.")  | 
 | 132 | +            context = allocator.use_memory_pool(tag=tag)  | 
 | 133 | +        else:  | 
 | 134 | +            context = nullcontext()  | 
 | 135 | +        return context  | 
 | 136 | + | 
121 | 137 |     def initialize_cache(self, num_gpu_blocks: int,  | 
122 | 138 |                          num_cpu_blocks: int) -> None:  | 
123 | 139 |         self.cache_config.num_gpu_blocks = num_gpu_blocks  | 
@@ -179,24 +195,17 @@ def init_device(self):  | 
179 | 195 |     # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool  | 
180 | 196 |     # to hijack tensor allocation.  | 
181 | 197 |     def load_model(self) -> None:  | 
182 |  | -        if self.vllm_config.model_config.enable_sleep_mode:  | 
183 |  | -            from vllm.device_allocator.cumem import CuMemAllocator  | 
184 |  | - | 
185 |  | -            allocator = CuMemAllocator.get_instance()  | 
186 |  | -            assert allocator.get_current_usage() == 0, (  | 
187 |  | -                "Sleep mode can only be "  | 
188 |  | -                "used for one instance per process.")  | 
189 |  | -            context = allocator.use_memory_pool(tag="weights")  | 
190 |  | -        else:  | 
191 |  | -            from contextlib import nullcontext  | 
192 |  | -            context = nullcontext()  | 
193 | 198 |         eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"  | 
194 |  | -        with context:  | 
 | 199 | +        with self._maybe_get_memory_pool_context(tag="weights"):  | 
195 | 200 |             self.model_runner.load_model(eep_scale_up=eep_scale_up)  | 
196 | 201 | 
 
  | 
197 | 202 |     def update_config(self, overrides: dict[str, Any]) -> None:  | 
198 | 203 |         self.model_runner.update_config(overrides)  | 
199 | 204 | 
 
  | 
 | 205 | +    def reload_weights(self) -> None:  | 
 | 206 | +        with self._maybe_get_memory_pool_context(tag="weights"):  | 
 | 207 | +            self.model_runner.reload_weights()  | 
 | 208 | + | 
200 | 209 |     @torch.inference_mode()  | 
201 | 210 |     def determine_available_memory(self) -> int:  | 
202 | 211 |         """Profiles the peak memory usage of the model to determine how much   | 
 | 
0 commit comments