|
4 | 4 | import copy
|
5 | 5 | import gc
|
6 | 6 | import os
|
| 7 | +from contextlib import AbstractContextManager, nullcontext |
7 | 8 | from typing import TYPE_CHECKING, Any, Optional
|
8 | 9 |
|
9 | 10 | import torch
|
@@ -118,6 +119,21 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
118 | 119 | buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
119 | 120 | self._sleep_saved_buffers = {}
|
120 | 121 |
|
| 122 | + def _maybe_get_memory_pool_context(self, |
| 123 | + tag: str) -> AbstractContextManager: |
| 124 | + if self.vllm_config.model_config.enable_sleep_mode: |
| 125 | + from vllm.device_allocator.cumem import CuMemAllocator |
| 126 | + |
| 127 | + allocator = CuMemAllocator.get_instance() |
| 128 | + if tag == "weights": |
| 129 | + assert allocator.get_current_usage() == 0, ( |
| 130 | + "Sleep mode can only be " |
| 131 | + "used for one instance per process.") |
| 132 | + context = allocator.use_memory_pool(tag=tag) |
| 133 | + else: |
| 134 | + context = nullcontext() |
| 135 | + return context |
| 136 | + |
121 | 137 | def initialize_cache(self, num_gpu_blocks: int,
|
122 | 138 | num_cpu_blocks: int) -> None:
|
123 | 139 | self.cache_config.num_gpu_blocks = num_gpu_blocks
|
@@ -179,24 +195,17 @@ def init_device(self):
|
179 | 195 | # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
180 | 196 | # to hijack tensor allocation.
|
181 | 197 | def load_model(self) -> None:
|
182 |
| - if self.vllm_config.model_config.enable_sleep_mode: |
183 |
| - from vllm.device_allocator.cumem import CuMemAllocator |
184 |
| - |
185 |
| - allocator = CuMemAllocator.get_instance() |
186 |
| - assert allocator.get_current_usage() == 0, ( |
187 |
| - "Sleep mode can only be " |
188 |
| - "used for one instance per process.") |
189 |
| - context = allocator.use_memory_pool(tag="weights") |
190 |
| - else: |
191 |
| - from contextlib import nullcontext |
192 |
| - context = nullcontext() |
193 | 198 | eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
194 |
| - with context: |
| 199 | + with self._maybe_get_memory_pool_context(tag="weights"): |
195 | 200 | self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
196 | 201 |
|
197 | 202 | def update_config(self, overrides: dict[str, Any]) -> None:
|
198 | 203 | self.model_runner.update_config(overrides)
|
199 | 204 |
|
| 205 | + def reload_weights(self) -> None: |
| 206 | + with self._maybe_get_memory_pool_context(tag="weights"): |
| 207 | + self.model_runner.reload_weights() |
| 208 | + |
200 | 209 | @torch.inference_mode()
|
201 | 210 | def determine_available_memory(self) -> int:
|
202 | 211 | """Profiles the peak memory usage of the model to determine how much
|
|
0 commit comments