3333 get_spec_resource_manager )
3434from ..virtual_memory import ExecutorMemoryType , RestoreMode
3535from ..virtual_memory import scope as virtual_memory_scope
36- from ._util import (KvCacheCreator , _adjust_torch_mem_fraction ,
37- create_py_executor_instance , instantiate_sampler , is_mla ,
38- validate_feature_combination )
36+ from ._util import (KvCacheCreator , create_py_executor_instance ,
37+ instantiate_sampler , is_mla , validate_feature_combination )
3938from .config_utils import is_mla
4039from .guided_decoder import CapturableGuidedDecoder , GuidedDecoder
4140from .kv_cache_connector import KvCacheConnectorManager
@@ -222,7 +221,6 @@ def create_py_executor(
222221 tokenizer : Optional [TokenizerBase ] = None ,
223222 profiling_stage_data : Optional [dict ] = None ,
224223) -> PyExecutor :
225- torch .cuda .set_per_process_memory_fraction (1.0 )
226224 garbage_collection_gen0_threshold = llm_args .garbage_collection_gen0_threshold
227225 lora_config = llm_args .lora_config
228226 kv_connector_config = llm_args .kv_connector_config
@@ -434,7 +432,6 @@ def drafting_loop_wrapper(model):
434432
435433 # PyTorchModelEngine modifies these fields, update them
436434 model_engine_max_seq_len = model_engine .max_seq_len
437- net_max_seq_len = model_engine_max_seq_len
438435 if not llm_args .disable_overlap_scheduler :
439436 model_engine_max_seq_len = model_engine .max_seq_len + 1
440437 if spec_config is not None :
@@ -604,7 +601,6 @@ def drafting_loop_wrapper(model):
604601 kv_connector_manager = None
605602
606603 resources = {}
607- estimating_kv_cache = False
608604 kv_cache_creator = None
609605
610606 # Create the execution stream for model forward operations
@@ -619,7 +615,6 @@ def drafting_loop_wrapper(model):
619615 model_engine = model_engine ,
620616 draft_model_engine = draft_model_engine ,
621617 mapping = mapping ,
622- net_max_seq_len = net_max_seq_len ,
623618 kv_connector_manager = kv_connector_manager ,
624619 max_num_tokens = max_num_tokens ,
625620 max_beam_width = max_beam_width ,
@@ -633,11 +628,8 @@ def drafting_loop_wrapper(model):
633628 sparse_attention_config = sparse_attention_config ,
634629 execution_stream = execution_stream ,
635630 )
636- estimating_kv_cache = kv_cache_creator .try_prepare_estimation ()
637- with allocation_scope (
638- ExecutorMemoryType .INIT_KV_CACHE if estimating_kv_cache else
639- ExecutorMemoryType .KV_CACHE , RestoreMode .NONE ):
640- kv_cache_creator .build_managers (resources , estimating_kv_cache )
631+ with allocation_scope (ExecutorMemoryType .KV_CACHE , RestoreMode .NONE ):
632+ kv_cache_creator .build_managers (resources )
641633 # Originally, max_seq_len might be mutated inside build_managers as field of executor config.
642634 # Since now, we are changing kv_cache_creator._max_seq_len instead. Restore max_seq_len here.
643635 max_seq_len = kv_cache_creator ._max_seq_len
@@ -663,100 +655,40 @@ def drafting_loop_wrapper(model):
663655 spec_resource_manager = spec_resource_manager ,
664656 guided_decoder = guided_decoder )
665657
666- with allocation_scope (
667- ExecutorMemoryType .INIT_EXTRA_RESOURCES if estimating_kv_cache else
668- ExecutorMemoryType .EXTRA_RESOURCES , RestoreMode .PINNED ):
658+ with allocation_scope (ExecutorMemoryType .EXTRA_RESOURCES ,
659+ RestoreMode .PINNED ):
660+
661+ # run gc.collect() to free memory of the previous py_executor, avoid cudaFree overlap with cuda graph capture
662+ gc .collect ()
669663 py_executor = create_py_executor_instance (
670664 dist = dist ,
671665 resources = resources ,
672666 mapping = mapping ,
673667 llm_args = llm_args ,
674668 ctx_chunk_config = ctx_chunk_config ,
675669 model_engine = model_engine ,
676- start_worker = False ,
677670 sampler = sampler ,
678671 drafter = drafter ,
679672 guided_decoder = guided_decoder ,
680673 lora_config = lora_config ,
681674 garbage_collection_gen0_threshold = garbage_collection_gen0_threshold ,
682- kv_connector_manager = kv_connector_manager
683- if not estimating_kv_cache else None ,
675+ kv_connector_manager = kv_connector_manager ,
684676 max_seq_len = max_seq_len ,
685677 max_batch_size = max_batch_size ,
686678 max_beam_width = max_beam_width ,
687679 max_num_tokens = max_num_tokens ,
688680 peft_cache_config = peft_cache_config ,
689681 scheduler_config = scheduler_config ,
690682 cache_transceiver_config = cache_transceiver_config ,
691- virtual_memory_pools = vm_pools if not estimating_kv_cache else None ,
683+ virtual_memory_pools = vm_pools ,
692684 execution_stream = execution_stream ,
693685 )
686+
694687 # Originally, peft_cache_config might be mutated inside
695688 # create_py_executor_instance. Restore it here.
696689 peft_cache_config = py_executor .peft_cache_config
697690
698- if estimating_kv_cache :
699- assert kv_cache_creator is not None
700- with allocation_scope (ExecutorMemoryType .MODEL_EXTRA ,
701- RestoreMode .PINNED ):
702- kv_cache_creator .configure_kv_cache_capacity (py_executor )
703- kv_cache_creator .teardown_managers (resources )
704- del py_executor # free before constructing new
705-
706- with allocation_scope (ExecutorMemoryType .KV_CACHE , RestoreMode .NONE ):
707- # Before estimating KV cache size, a minimal KV cache has been allocated using
708- # create_kv_cache_manager above, which caps kv_cache_creator.max_seq_len. Restoring
709- # the original value before creating the final KV cache.
710- kv_cache_creator ._max_seq_len = model_engine_max_seq_len
711- kv_cache_creator .build_managers (resources , False )
712- # Originally, max_seq_len might be mutated inside build_managers as field of executor config.
713- # Since now, we are changing kv_cache_creator._max_seq_len instead. Restore max_seq_len here.
714- max_seq_len = kv_cache_creator ._max_seq_len
715- update_sampler_max_seq_len (max_seq_len , sampler )
716-
717- for eng in [model_engine , draft_model_engine ]:
718- if eng is None :
719- continue
720- if eng .attn_metadata is not None :
721- if llm_args .cuda_graph_config is not None :
722- eng ._release_cuda_graphs ()
723- eng .attn_metadata = None
724-
725- with allocation_scope (ExecutorMemoryType .EXTRA_RESOURCES ,
726- RestoreMode .PINNED ):
727-
728- # run gc.collect() to free memory of the previous py_executor, avoid cudaFree overlap with cuda graph capture
729- gc .collect ()
730- py_executor = create_py_executor_instance (
731- dist = dist ,
732- resources = resources ,
733- mapping = mapping ,
734- llm_args = llm_args ,
735- ctx_chunk_config = ctx_chunk_config ,
736- model_engine = model_engine ,
737- start_worker = False ,
738- sampler = sampler ,
739- drafter = drafter ,
740- guided_decoder = guided_decoder ,
741- lora_config = lora_config ,
742- garbage_collection_gen0_threshold =
743- garbage_collection_gen0_threshold ,
744- kv_connector_manager = kv_connector_manager ,
745- max_seq_len = max_seq_len ,
746- max_batch_size = max_batch_size ,
747- max_beam_width = max_beam_width ,
748- max_num_tokens = max_num_tokens ,
749- peft_cache_config = peft_cache_config ,
750- scheduler_config = scheduler_config ,
751- cache_transceiver_config = cache_transceiver_config ,
752- virtual_memory_pools = vm_pools ,
753- execution_stream = execution_stream ,
754- )
755-
756- _adjust_torch_mem_fraction ()
757-
758691 if mapping .rank == 0 :
759692 logger .info (f"LLM Args:\n { llm_args } " )
760693
761- py_executor .start_worker ()
762694 return py_executor
0 commit comments