1010from typing import Type
1111
1212import torch
13- from neuronx_distributed .quantization .quantization_config import QuantizationType , ActivationQuantizationType
13+ from neuronx_distributed .quantization .quantization_config import (
14+ ActivationQuantizationType ,
15+ QuantizationType ,
16+ )
1417from transformers import AutoTokenizer , GenerationConfig
1518
1619from neuronx_distributed_inference .models .application_base import NeuronApplicationBase
2831 check_accuracy_logits ,
2932 get_generate_outputs ,
3033)
34+ from neuronx_distributed_inference .utils import argparse_utils
3135from neuronx_distributed_inference .utils .benchmark import benchmark_sampling
3236from neuronx_distributed_inference .utils .debug_utils import capture_model_inputs
3337from neuronx_distributed_inference .utils .distributed import get_init_rank , get_init_world_size
38+ from neuronx_distributed_inference .utils .exceptions import LogitMatchingValidationError
3439from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config
3540from neuronx_distributed_inference .utils .random import set_random_seed
3641
@@ -117,10 +122,12 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
117122 run_parser .add_argument ("--rpl-reduce-dtype" , type = to_torch_dtype )
118123 run_parser .add_argument ("--output-logits" , action = "store_true" )
119124 run_parser .add_argument ("--vocab-parallel" , action = "store_true" )
125+ run_parser .add_argument ("--layer-boundary-markers" , action = "store_true" , default = False )
120126
121127 # Attention
122128 run_parser .add_argument ("--fused-qkv" , action = "store_true" )
123129 run_parser .add_argument ("--sequence-parallel-enabled" , action = "store_true" )
130+ run_parser .add_argument ("--weight-gather-seq-len-threshold" , type = int )
124131 run_parser .add_argument ("--flash-decoding-enabled" , action = "store_true" )
125132
126133 # Continuous batching
@@ -132,6 +139,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
132139 # KV cache
133140 run_parser .add_argument ("--kv-cache-batch-size" , type = int )
134141 run_parser .add_argument ("--kv-cache-padding-size" , type = int )
142+ run_parser .add_argument ("--disable-kv-cache-tiling" , action = "store_true" )
135143
136144 # On device sampling
137145 run_parser .add_argument ("--on-device-sampling" , action = "store_true" )
@@ -193,9 +201,16 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
193201 "This is useful for ensuring processes on different nodes are in sync" ,
194202 )
195203 run_parser .add_argument (
196- "--skip-save-sharded-checkpoint" , dest = "save_sharded_checkpoint" , action = "store_false"
204+ "--save-sharded-checkpoint" ,
205+ action = "store_true" ,
206+ help = "Save sharded checkpoints to disk when compiling NxDI model. "
207+ "When loading NxDI model, sharded checkpoints will be loaded from the compiled model path." ,
208+ )
209+ run_parser .add_argument (
210+ "--skip-sharding" ,
211+ action = "store_true" ,
212+ help = "Skip sharding checkpoints when compiling NxDI model. "
197213 )
198- run_parser .add_argument ("--skip-sharding" , action = "store_true" )
199214
200215 # PA and CF
201216 run_parser .add_argument (
@@ -206,6 +221,9 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
206221 run_parser .add_argument (
207222 "--enable-chunked-prefill" , dest = "is_chunked_prefill" , action = "store_true"
208223 )
224+ run_parser .add_argument (
225+ "--enable-prefix-caching" , dest = "is_prefix_caching" , action = "store_true"
226+ )
209227 run_parser .add_argument ("--cp-max-num-seqs" , type = int )
210228 run_parser .add_argument ("--cp-num-active-blocks" , type = int )
211229
@@ -214,8 +232,8 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
214232
215233 # Lora
216234 run_parser .add_argument ("--enable-lora" , action = "store_true" )
217- run_parser .add_argument ("--max-loras" , type = int )
218- run_parser .add_argument ("--max-lora-rank" , type = int )
235+ run_parser .add_argument ("--max-loras" , type = int , default = 1 )
236+ run_parser .add_argument ("--max-lora-rank" , type = int , default = 16 )
219237 run_parser .add_argument ("--target-modules" , nargs = "+" )
220238 run_parser .add_argument ("--max-loras-on-cpu" , type = int )
221239 run_parser .add_argument ("--lora-ckpt-path" , dest = "lora_ckpt_paths" , type = str , action = "append" )
@@ -227,10 +245,21 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
227245 run_parser .add_argument ("--attn-kernel-enabled" , action = "store_true" )
228246 run_parser .add_argument ("--mlp-kernel-enabled" , action = "store_true" )
229247 run_parser .add_argument ("--quantized-mlp-kernel-enabled" , action = "store_true" )
230- run_parser .add_argument ("--activation-quantization-type" , type = str , choices = [e .value for e in ActivationQuantizationType ])
248+ run_parser .add_argument ("--fused-rmsnorm-skip-gamma" , action = "store_true" )
249+ run_parser .add_argument (
250+ "--activation-quantization-type" ,
251+ type = str ,
252+ choices = [e .value for e in ActivationQuantizationType ],
253+ )
231254 run_parser .add_argument ("--rmsnorm-quantize-kernel-enabled" , action = "store_true" )
232- run_parser .add_argument ("--quantize-clamp-bound" , type = float , default = float (' inf' ))
255+ run_parser .add_argument ("--quantize-clamp-bound" , type = float , default = float (" inf" ))
233256 run_parser .add_argument ("--mlp-kernel-fuse-residual-add" , action = "store_true" )
257+ run_parser .add_argument ("--qkv-kernel-fuse-residual-add" , action = "store_true" )
258+ run_parser .add_argument ("--attn-tkg-nki-kernel-enabled" , action = "store_true" )
259+ run_parser .add_argument ("--attn-tkg-builtin-kernel-enabled" , action = "store_true" )
260+ run_parser .add_argument ("--attn-block-tkg-nki-kernel-enabled" , action = "store_true" )
261+ run_parser .add_argument ("--attn-block-tkg-nki-kernel-cache-update" , action = "store_true" )
262+ run_parser .add_argument ("--k-cache-transposed" , action = "store_true" )
234263
235264 # Logical NeuronCore Configuration (LNC)
236265 lnc_group = run_parser .add_mutually_exclusive_group ()
@@ -246,7 +275,12 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
246275 run_parser .add_argument ("--on-cpu" , action = "store_true" )
247276
248277 # Debugging
249- run_parser .add_argument ("--capture-indices" , nargs = "+" , type = int , default = None )
278+ run_parser .add_argument (
279+ "--capture-indices" ,
280+ nargs = "+" ,
281+ action = argparse_utils .StringOrIntegers ,
282+ default = None ,
283+ help = f"Specify '{ argparse_utils .AUTO } ' when using check accuracy mode with { CheckAccuracyMode .LOGIT_MATCHING } for inferrring capture indices when the test fails and use the indices to capture inputs. Otherwise, provide any number of integer values for capturing inputs at those indices." )
250284 run_parser .add_argument ("--input-capture-save-dir" , type = str , default = None )
251285
252286 # Optional demo arguments
@@ -267,6 +301,11 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
267301 action = "store_true" ,
268302 help = "Only perform model compilation." ,
269303 )
304+ run_parser .add_argument (
305+ "--compile-dry-run" ,
306+ action = "store_true" ,
307+ help = "Perform a compilation dry run (minimal model trace)" ,
308+ )
270309 run_parser .add_argument (
271310 "--hlo-debug" ,
272311 action = "store_true" ,
@@ -385,10 +424,12 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
385424 compiling_start_time = time .monotonic ()
386425 if not args .skip_compile and not args .on_cpu :
387426 print ("\n Compiling and saving model..." )
388- model .compile (args .compiled_model_path , debug = args .hlo_debug )
427+ model .compile (args .compiled_model_path , debug = args .hlo_debug , dry_run = args . compile_dry_run )
389428 if draft_model is not None and neuron_config .enable_fused_speculation is False :
390429 print ("\n Compiling and saving draft model..." )
391- draft_model .compile (args .compiled_draft_model_path )
430+ draft_model .compile (
431+ args .compiled_draft_model_path , debug = args .hlo_debug , dry_run = args .compile_dry_run
432+ )
392433 compiling_end_time = time .monotonic ()
393434 total_compiling_time = compiling_end_time - compiling_start_time
394435 print (f"Compiling and tracing time: { total_compiling_time } seconds" )
@@ -398,7 +439,7 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
398439 if args .enable_torch_dist :
399440 torch .distributed .barrier ()
400441
401- if args .compile_only :
442+ if args .compile_only or args . compile_dry_run :
402443 return
403444
404445 # Load compiled model to Neuron.
@@ -446,25 +487,37 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
446487 if neuron_config .is_medusa :
447488 draft_model = model
448489
490+ input_capture_hook = None
491+ capture_indices = args .capture_indices
492+
449493 # Check accuracy.
450- run_accuracy_check (
451- model ,
452- tokenizer ,
453- generation_config ,
454- args .prompts [0 ],
455- args .check_accuracy_mode ,
456- args .divergence_difference_tol ,
457- args .tol_map ,
458- num_tokens_to_check = args .num_tokens_to_check ,
459- draft_model = draft_model ,
460- expected_outputs_path = args .expected_outputs_path ,
461- )
494+ logit_error = None
495+ try :
496+ run_accuracy_check (
497+ model ,
498+ tokenizer ,
499+ generation_config ,
500+ args .prompts [0 ],
501+ args .check_accuracy_mode ,
502+ args .divergence_difference_tol ,
503+ args .tol_map ,
504+ num_tokens_to_check = args .num_tokens_to_check ,
505+ draft_model = draft_model ,
506+ expected_outputs_path = args .expected_outputs_path ,
507+ )
508+ except LogitMatchingValidationError as e :
509+ logit_error = e
510+ if args .capture_indices == argparse_utils .AUTO :
511+ capture_indices = logit_error .get_divergence_index ()
512+ print (f"\n Auto capture after a failed logits test. Setting capture indices to { capture_indices } " )
462513
463- input_capture_hook = None
464- if args .capture_indices :
514+ if args .capture_indices == argparse_utils .AUTO and logit_error is None :
515+ capture_indices = None
516+
517+ if capture_indices is not None :
465518 input_capture_hook = partial (
466519 capture_model_inputs ,
467- capture_indices = args . capture_indices ,
520+ capture_indices = capture_indices ,
468521 input_capture_save_dir = args .input_capture_save_dir ,
469522 )
470523
@@ -479,6 +532,9 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
479532 input_capture_hook = input_capture_hook ,
480533 )
481534
535+ if logit_error is not None :
536+ raise logit_error
537+
482538 # Benchmarking.
483539 if args .benchmark :
484540 benchmark_sampling (model , draft_model , generation_config )
0 commit comments