20
20
from neuronx_distributed_inference .models .config import (
21
21
FusedSpecNeuronConfig ,
22
22
OnDeviceSamplingConfig ,
23
+ ChunkedPrefillConfig ,
23
24
to_torch_dtype ,
24
25
)
25
26
from neuronx_distributed_inference .models .dbrx .modeling_dbrx import NeuronDbrxForCausalLM
38
39
from neuronx_distributed_inference .utils .exceptions import LogitMatchingValidationError
39
40
from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config
40
41
from neuronx_distributed_inference .utils .random import set_random_seed
42
+ from neuronx_distributed_inference .utils .constants import BENCHMARK_REPORT_PATH
41
43
42
44
set_random_seed (0 )
43
45
@@ -120,6 +122,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
120
122
run_parser .add_argument ("--max-new-tokens" , type = int )
121
123
run_parser .add_argument ("--max-length" , type = int )
122
124
run_parser .add_argument ("--rpl-reduce-dtype" , type = to_torch_dtype )
125
+ run_parser .add_argument ("--attention-dtype" , type = to_torch_dtype )
123
126
run_parser .add_argument ("--output-logits" , action = "store_true" )
124
127
run_parser .add_argument ("--vocab-parallel" , action = "store_true" )
125
128
run_parser .add_argument ("--layer-boundary-markers" , action = "store_true" , default = False )
@@ -148,6 +151,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
148
151
run_parser .add_argument ("--enable-bucketing" , action = "store_true" )
149
152
run_parser .add_argument ("--bucket-n-active-tokens" , action = "store_true" )
150
153
run_parser .add_argument ("--context-encoding-buckets" , nargs = "+" , type = int )
154
+ run_parser .add_argument ("--prefix-buckets" , nargs = "+" , type = int )
151
155
run_parser .add_argument ("--token-generation-buckets" , nargs = "+" , type = int )
152
156
153
157
# Quantization
@@ -166,6 +170,13 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
166
170
167
171
# MoE
168
172
run_parser .add_argument ("--capacity-factor" , type = float )
173
+ run_parser .add_argument ("--early-expert-affinity-modulation" , action = "store_true" )
174
+ run_parser .add_argument ("--disable-normalize-top-k-affinities" , action = "store_true" )
175
+ run_parser .add_argument ("--fused-shared-experts" , action = "store_true" )
176
+
177
+ # Router Config
178
+ run_parser .add_argument ("--router-act-fn" , type = str )
179
+ run_parser .add_argument ("--router-dtype" , type = str )
169
180
170
181
# Speculative decoding
171
182
run_parser .add_argument ("--draft-model-path" , type = str )
@@ -189,6 +200,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
189
200
190
201
# Parallelism
191
202
run_parser .add_argument ("--tp-degree" , type = int )
203
+ run_parser .add_argument ("--cp-degree" , type = int )
192
204
run_parser .add_argument ("--pp-degree" , type = int )
193
205
run_parser .add_argument ("--ep-degree" , type = int )
194
206
run_parser .add_argument ("--world-size" , type = int )
@@ -224,8 +236,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
224
236
run_parser .add_argument (
225
237
"--enable-prefix-caching" , dest = "is_prefix_caching" , action = "store_true"
226
238
)
227
- run_parser .add_argument ("--cp-max-num-seqs" , type = int )
228
- run_parser .add_argument ("--cp-num-active-blocks" , type = int )
239
+ run_parser .add_argument ("--max-num-seqs" , type = int )
229
240
230
241
# Async
231
242
run_parser .add_argument ("--async-mode" , action = "store_true" )
@@ -242,7 +253,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
242
253
# Kernels
243
254
run_parser .add_argument ("--qkv-kernel-enabled" , action = "store_true" )
244
255
run_parser .add_argument ("--qkv-kernel-nbsd-layout" , action = "store_true" )
245
- run_parser .add_argument ("--attn-kernel-enabled" , action = "store_true" )
256
+ run_parser .add_argument ("--attn-kernel-enabled" , action = argparse . BooleanOptionalAction , default = None )
246
257
run_parser .add_argument ("--mlp-kernel-enabled" , action = "store_true" )
247
258
run_parser .add_argument ("--quantized-mlp-kernel-enabled" , action = "store_true" )
248
259
run_parser .add_argument ("--fused-rmsnorm-skip-gamma" , action = "store_true" )
@@ -270,10 +281,19 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
270
281
271
282
# Compiler Args
272
283
run_parser .add_argument ("--cc-pipeline-tiling-factor" , type = int , default = 2 )
284
+ run_parser .add_argument ("--enable-spill-reload-dge" , action = "store_true" )
273
285
274
286
# CPU
275
287
run_parser .add_argument ("--on-cpu" , action = "store_true" )
276
288
289
+ # Report generation
290
+ run_parser .add_argument (
291
+ "--benchmark-report-path" ,
292
+ type = str ,
293
+ default = BENCHMARK_REPORT_PATH ,
294
+ help = "File path to save benchmark report."
295
+ )
296
+
277
297
# Debugging
278
298
run_parser .add_argument (
279
299
"--capture-indices" ,
@@ -283,6 +303,10 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
283
303
help = f"Specify '{ argparse_utils .AUTO } ' when using check accuracy mode with { CheckAccuracyMode .LOGIT_MATCHING } for inferrring capture indices when the test fails and use the indices to capture inputs. Otherwise, provide any number of integer values for capturing inputs at those indices." )
284
304
run_parser .add_argument ("--input-capture-save-dir" , type = str , default = None )
285
305
306
+ run_parser .add_argument ("--cast-type" , choices = ["config" , "as-declared" ], default = "config" ,
307
+ help = "If set to 'config', all parameters will be casted to neuron_config.torch_dtype. "
308
+ "If set to 'as-declared', casting will be done based on the dtype set for each parameter" )
309
+
286
310
# Optional demo arguments
287
311
run_parser .add_argument (
288
312
"--skip-warmup" ,
@@ -312,6 +336,20 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
312
336
help = "Adds metadata into the generated HLO. This metadata maps the HLO "
313
337
"operators to the corresponding lines in the PyTorch code" ,
314
338
)
339
+ run_parser .add_argument (
340
+ "--apply-seq-ids-mask" ,
341
+ action = 'store_true' ,
342
+ help = "Avoid KV cache update on inactive (padded) seq_ids"
343
+ )
344
+ run_parser .add_argument (
345
+ "--input-start-offsets" ,
346
+ nargs = "+" ,
347
+ default = None ,
348
+ type = int ,
349
+ help = "Shift the input right by an offset. There can be multiple offsets, each per sequence."
350
+ "If only 1 value is provided, all sequences will be shifted by this amount. "
351
+ "This flag can be used to test chunked attention."
352
+ )
315
353
316
354
317
355
def validate_file_exists (path ):
@@ -339,7 +377,7 @@ def get_modules_to_not_convert_json(json_path):
339
377
return modules_to_not_convert , draft_model_modules_to_not_convert
340
378
341
379
342
- def run_inference (model_cls : Type [ NeuronApplicationBase ] , args ):
380
+ def create_neuron_config (model_cls , args ):
343
381
# Initialize configs.
344
382
print ("Loading configs..." )
345
383
@@ -348,6 +386,11 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
348
386
config_kwargs = {k : v for k , v in config_kwargs .items () if v is not None }
349
387
if args .on_device_sampling :
350
388
config_kwargs ["on_device_sampling_config" ] = OnDeviceSamplingConfig (** config_kwargs )
389
+ if args .is_chunked_prefill :
390
+ max_num_seqs = config_kwargs .pop ("max_num_seqs" , 0 )
391
+ config_kwargs ["chunked_prefill_config" ] = ChunkedPrefillConfig (
392
+ max_num_seqs = max_num_seqs ,
393
+ )
351
394
352
395
if (args .quantized and args .quantization_dtype == "f8e4m3" ) or args .kv_cache_quant :
353
396
os .environ ["XLA_HANDLE_SPECIAL_SCALAR" ] = "1"
@@ -371,6 +414,11 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
371
414
)
372
415
adapter_ids = args .adapter_ids
373
416
neuron_config = model_cls .get_neuron_config_cls ()(** config_kwargs )
417
+ return adapter_ids , neuron_config
418
+
419
+
420
+ def run_inference (model_cls : Type [NeuronApplicationBase ], args ):
421
+ adapter_ids , neuron_config = create_neuron_config (model_cls , args )
374
422
375
423
config = model_cls .get_config_cls ()(
376
424
neuron_config , load_config = load_pretrained_config (args .model_path )
@@ -395,7 +443,6 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
395
443
# Set eagle specific config changes
396
444
if neuron_config .enable_eagle_speculation :
397
445
draft_neuron_config .is_eagle_draft = True
398
- draft_neuron_config .sequence_parallel_enabled = False
399
446
400
447
if args .draft_model_tp_degree is not None :
401
448
draft_neuron_config .tp_degree = args .draft_model_tp_degree
@@ -415,6 +462,8 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
415
462
draft_model = model_cls (args .draft_model_path , draft_config )
416
463
417
464
model = model_cls (args .model_path , config )
465
+ if args .input_start_offsets :
466
+ assert len (args .input_start_offsets ) == 1 or len (args .input_start_offsets ) == args .batch_size , "The number of input offsets has to be either 1 or equal or batch size."
418
467
419
468
# Quantize model.
420
469
if neuron_config .quantized :
@@ -481,7 +530,10 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
481
530
generation_config_kwargs = {
482
531
k : getattr (args , k ) for k in generation_config_args if getattr (args , k ) is not None
483
532
}
484
- generation_config .update (** generation_config_kwargs )
533
+ remaining_kwargs = generation_config .update (** generation_config_kwargs )
534
+ # add any remaining ones (this can happen when the model generation config is missing some entries)
535
+ for k , v in remaining_kwargs .items ():
536
+ generation_config .__dict__ [k ] = v
485
537
486
538
# With Medusa, the model is also the draft model.
487
539
if neuron_config .is_medusa :
@@ -504,6 +556,7 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
504
556
num_tokens_to_check = args .num_tokens_to_check ,
505
557
draft_model = draft_model ,
506
558
expected_outputs_path = args .expected_outputs_path ,
559
+ input_start_offsets = args .input_start_offsets ,
507
560
)
508
561
except LogitMatchingValidationError as e :
509
562
logit_error = e
@@ -530,14 +583,15 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
530
583
draft_model = draft_model ,
531
584
adapter_ids = adapter_ids ,
532
585
input_capture_hook = input_capture_hook ,
586
+ input_start_offsets = args .input_start_offsets ,
533
587
)
534
588
535
589
if logit_error is not None :
536
590
raise logit_error
537
591
538
592
# Benchmarking.
539
593
if args .benchmark :
540
- benchmark_sampling (model , draft_model , generation_config )
594
+ benchmark_sampling (model , draft_model , generation_config , benchmark_report_path = args . benchmark_report_path )
541
595
542
596
543
597
def load_tokenizer (model_path , compiled_model_path , neuron_config ):
@@ -555,9 +609,12 @@ def run_generation(
555
609
draft_model = None ,
556
610
adapter_ids = None ,
557
611
input_capture_hook = None ,
612
+ input_start_offsets = None ,
558
613
):
559
614
print ("\n Generating outputs..." )
560
615
print (f"Prompts: { prompts } " )
616
+ if len (prompts ) == 1 and model .config .neuron_config .batch_size > 1 :
617
+ prompts = prompts * model .config .neuron_config .batch_size
561
618
562
619
_ , output_tokens = get_generate_outputs (
563
620
model ,
@@ -569,6 +626,7 @@ def run_generation(
569
626
adapter_ids = adapter_ids ,
570
627
max_length = model .neuron_config .max_length ,
571
628
input_capture_hook = input_capture_hook ,
629
+ input_start_offsets = input_start_offsets
572
630
)
573
631
574
632
print ("Generated outputs:" )
@@ -587,13 +645,15 @@ def run_accuracy_check(
587
645
num_tokens_to_check = None ,
588
646
draft_model = None ,
589
647
expected_outputs_path = None ,
648
+ input_start_offsets = None ,
590
649
):
591
650
if model .neuron_config .is_medusa :
592
651
# Medusa doesn't use greedy sampling, so check accuracy doesn't work.
593
652
assert (
594
653
check_accuracy_mode == CheckAccuracyMode .SKIP_ACCURACY_CHECK
595
654
), "Accuracy checking not supported for Medusa"
596
-
655
+ if input_start_offsets :
656
+ assert all (offset < model .config .neuron_config .max_context_length for offset in input_start_offsets ), "Input offset has to be less than max context length"
597
657
if check_accuracy_mode == CheckAccuracyMode .SKIP_ACCURACY_CHECK :
598
658
print ("\n Skipping accuracy check" )
599
659
return
@@ -612,6 +672,7 @@ def run_accuracy_check(
612
672
draft_model = draft_model ,
613
673
expected_token_ids = expected_outputs ,
614
674
num_tokens_to_check = num_tokens_to_check ,
675
+ input_start_offsets = input_start_offsets ,
615
676
)
616
677
elif check_accuracy_mode == CheckAccuracyMode .LOGIT_MATCHING :
617
678
assert draft_model is None , "Logit matching not supported for speculation"
@@ -633,6 +694,7 @@ def run_accuracy_check(
633
694
divergence_difference_tol = divergence_difference_tol ,
634
695
tol_map = tol_map ,
635
696
num_tokens_to_check = num_tokens_to_check ,
697
+ input_start_offsets = input_start_offsets ,
636
698
)
637
699
else :
638
700
raise ValueError (f"Unsupported check accuracy mode: { check_accuracy_mode } " )
0 commit comments