Skip to content

Commit 2c46add

Browse files
authored
Merge pull request #20 from aws-neuron/release_2.24.0
Neuron Release 2.24.0
2 parents 9b90cd0 + 98c0fea commit 2c46add

File tree

99 files changed

+15044
-1914
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+15044
-1914
lines changed

build.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@ set -e
33

44
: ${BUILD_PATH:=build}
55

6-
python3.10 -m pip install ruff
6+
python -m pip install ruff
77
# remove --exit-zero once all errors are fixed/explicitly ignore
8-
python3.10 -m ruff check --line-length=120 --ignore=F401,E203
8+
python -m ruff check --line-length=120 --ignore=F401,E203
99
# exit when asked to run `ruff` only
1010
if [[ "$1" == "ruff" ]]
1111
then
1212
exit 0
1313
fi
1414

1515
# Run static code analysis
16-
python3.10 -m pip install mypy
17-
python3.10 -m mypy --no-incremental || true
16+
python -m pip install mypy
17+
python -m mypy --no-incremental || true
1818
# exit when asked to run `mypy` only
1919
if [[ "$1" == "mypy" ]]
2020
then
2121
exit 0
2222
fi
2323

24-
python3.10 setup.py bdist_wheel --dist-dir ${BUILD_PATH}/pip/public/neuronx-distributed-inference
24+
python setup.py bdist_wheel --dist-dir ${BUILD_PATH}/pip/public/neuronx-distributed-inference

examples/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
transformers==4.48.*
1+
transformers==4.51.*
22
huggingface-hub
33
diffusers==0.32.0
44
sentencepiece

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,15 @@ def get_version(version_str):
3939
package_data={"": []},
4040
install_requires=[
4141
"neuronx_distributed",
42-
"transformers==4.48.*",
42+
"transformers==4.51.*",
4343
"huggingface-hub",
4444
"sentencepiece",
4545
"torchvision",
4646
"pillow",
4747
"blobfile",
4848
],
4949
extras_require={
50-
"test": ["pytest", "pytest-forked", "pytest-cov", "pytest-xdist", "accelerate", "diffusers==0.32.0"],
51-
"flux": ["diffusers==0.32.0"],
50+
"test": ["pytest", "pytest-forked", "pytest-cov", "pytest-xdist", "accelerate"],
5251
},
5352
python_requires=">=3.7",
5453
package_dir={"": "src"},
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright Amazon Web Services and its Affiliates. All Rights Reserved.
22
# ==============================================================================
3-
__version__ = "0.3.0"
3+
__version__ = "0.4.0"

src/neuronx_distributed_inference/inference_demo.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from neuronx_distributed_inference.models.config import (
2121
FusedSpecNeuronConfig,
2222
OnDeviceSamplingConfig,
23+
ChunkedPrefillConfig,
2324
to_torch_dtype,
2425
)
2526
from neuronx_distributed_inference.models.dbrx.modeling_dbrx import NeuronDbrxForCausalLM
@@ -38,6 +39,7 @@
3839
from neuronx_distributed_inference.utils.exceptions import LogitMatchingValidationError
3940
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
4041
from neuronx_distributed_inference.utils.random import set_random_seed
42+
from neuronx_distributed_inference.utils.constants import BENCHMARK_REPORT_PATH
4143

4244
set_random_seed(0)
4345

@@ -120,6 +122,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
120122
run_parser.add_argument("--max-new-tokens", type=int)
121123
run_parser.add_argument("--max-length", type=int)
122124
run_parser.add_argument("--rpl-reduce-dtype", type=to_torch_dtype)
125+
run_parser.add_argument("--attention-dtype", type=to_torch_dtype)
123126
run_parser.add_argument("--output-logits", action="store_true")
124127
run_parser.add_argument("--vocab-parallel", action="store_true")
125128
run_parser.add_argument("--layer-boundary-markers", action="store_true", default=False)
@@ -148,6 +151,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
148151
run_parser.add_argument("--enable-bucketing", action="store_true")
149152
run_parser.add_argument("--bucket-n-active-tokens", action="store_true")
150153
run_parser.add_argument("--context-encoding-buckets", nargs="+", type=int)
154+
run_parser.add_argument("--prefix-buckets", nargs="+", type=int)
151155
run_parser.add_argument("--token-generation-buckets", nargs="+", type=int)
152156

153157
# Quantization
@@ -166,6 +170,13 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
166170

167171
# MoE
168172
run_parser.add_argument("--capacity-factor", type=float)
173+
run_parser.add_argument("--early-expert-affinity-modulation", action="store_true")
174+
run_parser.add_argument("--disable-normalize-top-k-affinities", action="store_true")
175+
run_parser.add_argument("--fused-shared-experts", action="store_true")
176+
177+
# Router Config
178+
run_parser.add_argument("--router-act-fn", type=str)
179+
run_parser.add_argument("--router-dtype", type=str)
169180

170181
# Speculative decoding
171182
run_parser.add_argument("--draft-model-path", type=str)
@@ -189,6 +200,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
189200

190201
# Parallelism
191202
run_parser.add_argument("--tp-degree", type=int)
203+
run_parser.add_argument("--cp-degree", type=int)
192204
run_parser.add_argument("--pp-degree", type=int)
193205
run_parser.add_argument("--ep-degree", type=int)
194206
run_parser.add_argument("--world-size", type=int)
@@ -224,8 +236,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
224236
run_parser.add_argument(
225237
"--enable-prefix-caching", dest="is_prefix_caching", action="store_true"
226238
)
227-
run_parser.add_argument("--cp-max-num-seqs", type=int)
228-
run_parser.add_argument("--cp-num-active-blocks", type=int)
239+
run_parser.add_argument("--max-num-seqs", type=int)
229240

230241
# Async
231242
run_parser.add_argument("--async-mode", action="store_true")
@@ -242,7 +253,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
242253
# Kernels
243254
run_parser.add_argument("--qkv-kernel-enabled", action="store_true")
244255
run_parser.add_argument("--qkv-kernel-nbsd-layout", action="store_true")
245-
run_parser.add_argument("--attn-kernel-enabled", action="store_true")
256+
run_parser.add_argument("--attn-kernel-enabled", action=argparse.BooleanOptionalAction, default=None)
246257
run_parser.add_argument("--mlp-kernel-enabled", action="store_true")
247258
run_parser.add_argument("--quantized-mlp-kernel-enabled", action="store_true")
248259
run_parser.add_argument("--fused-rmsnorm-skip-gamma", action="store_true")
@@ -270,10 +281,19 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
270281

271282
# Compiler Args
272283
run_parser.add_argument("--cc-pipeline-tiling-factor", type=int, default=2)
284+
run_parser.add_argument("--enable-spill-reload-dge", action="store_true")
273285

274286
# CPU
275287
run_parser.add_argument("--on-cpu", action="store_true")
276288

289+
# Report generation
290+
run_parser.add_argument(
291+
"--benchmark-report-path",
292+
type=str,
293+
default=BENCHMARK_REPORT_PATH,
294+
help="File path to save benchmark report."
295+
)
296+
277297
# Debugging
278298
run_parser.add_argument(
279299
"--capture-indices",
@@ -283,6 +303,10 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
283303
help=f"Specify '{argparse_utils.AUTO}' when using check accuracy mode with {CheckAccuracyMode.LOGIT_MATCHING} for inferrring capture indices when the test fails and use the indices to capture inputs. Otherwise, provide any number of integer values for capturing inputs at those indices.")
284304
run_parser.add_argument("--input-capture-save-dir", type=str, default=None)
285305

306+
run_parser.add_argument("--cast-type", choices=["config", "as-declared"], default="config",
307+
help="If set to 'config', all parameters will be casted to neuron_config.torch_dtype. "
308+
"If set to 'as-declared', casting will be done based on the dtype set for each parameter")
309+
286310
# Optional demo arguments
287311
run_parser.add_argument(
288312
"--skip-warmup",
@@ -312,6 +336,20 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
312336
help="Adds metadata into the generated HLO. This metadata maps the HLO "
313337
"operators to the corresponding lines in the PyTorch code",
314338
)
339+
run_parser.add_argument(
340+
"--apply-seq-ids-mask",
341+
action='store_true',
342+
help="Avoid KV cache update on inactive (padded) seq_ids"
343+
)
344+
run_parser.add_argument(
345+
"--input-start-offsets",
346+
nargs="+",
347+
default=None,
348+
type=int,
349+
help="Shift the input right by an offset. There can be multiple offsets, each per sequence."
350+
"If only 1 value is provided, all sequences will be shifted by this amount. "
351+
"This flag can be used to test chunked attention."
352+
)
315353

316354

317355
def validate_file_exists(path):
@@ -339,7 +377,7 @@ def get_modules_to_not_convert_json(json_path):
339377
return modules_to_not_convert, draft_model_modules_to_not_convert
340378

341379

342-
def run_inference(model_cls: Type[NeuronApplicationBase], args):
380+
def create_neuron_config(model_cls, args):
343381
# Initialize configs.
344382
print("Loading configs...")
345383

@@ -348,6 +386,11 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
348386
config_kwargs = {k: v for k, v in config_kwargs.items() if v is not None}
349387
if args.on_device_sampling:
350388
config_kwargs["on_device_sampling_config"] = OnDeviceSamplingConfig(**config_kwargs)
389+
if args.is_chunked_prefill:
390+
max_num_seqs = config_kwargs.pop("max_num_seqs", 0)
391+
config_kwargs["chunked_prefill_config"] = ChunkedPrefillConfig(
392+
max_num_seqs=max_num_seqs,
393+
)
351394

352395
if (args.quantized and args.quantization_dtype == "f8e4m3") or args.kv_cache_quant:
353396
os.environ["XLA_HANDLE_SPECIAL_SCALAR"] = "1"
@@ -371,6 +414,11 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
371414
)
372415
adapter_ids = args.adapter_ids
373416
neuron_config = model_cls.get_neuron_config_cls()(**config_kwargs)
417+
return adapter_ids, neuron_config
418+
419+
420+
def run_inference(model_cls: Type[NeuronApplicationBase], args):
421+
adapter_ids, neuron_config = create_neuron_config(model_cls, args)
374422

375423
config = model_cls.get_config_cls()(
376424
neuron_config, load_config=load_pretrained_config(args.model_path)
@@ -395,7 +443,6 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
395443
# Set eagle specific config changes
396444
if neuron_config.enable_eagle_speculation:
397445
draft_neuron_config.is_eagle_draft = True
398-
draft_neuron_config.sequence_parallel_enabled = False
399446

400447
if args.draft_model_tp_degree is not None:
401448
draft_neuron_config.tp_degree = args.draft_model_tp_degree
@@ -415,6 +462,8 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
415462
draft_model = model_cls(args.draft_model_path, draft_config)
416463

417464
model = model_cls(args.model_path, config)
465+
if args.input_start_offsets:
466+
assert len(args.input_start_offsets) == 1 or len(args.input_start_offsets) == args.batch_size, "The number of input offsets has to be either 1 or equal or batch size."
418467

419468
# Quantize model.
420469
if neuron_config.quantized:
@@ -481,7 +530,10 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
481530
generation_config_kwargs = {
482531
k: getattr(args, k) for k in generation_config_args if getattr(args, k) is not None
483532
}
484-
generation_config.update(**generation_config_kwargs)
533+
remaining_kwargs = generation_config.update(**generation_config_kwargs)
534+
# add any remaining ones (this can happen when the model generation config is missing some entries)
535+
for k, v in remaining_kwargs.items():
536+
generation_config.__dict__[k] = v
485537

486538
# With Medusa, the model is also the draft model.
487539
if neuron_config.is_medusa:
@@ -504,6 +556,7 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
504556
num_tokens_to_check=args.num_tokens_to_check,
505557
draft_model=draft_model,
506558
expected_outputs_path=args.expected_outputs_path,
559+
input_start_offsets=args.input_start_offsets,
507560
)
508561
except LogitMatchingValidationError as e:
509562
logit_error = e
@@ -530,14 +583,15 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
530583
draft_model=draft_model,
531584
adapter_ids=adapter_ids,
532585
input_capture_hook=input_capture_hook,
586+
input_start_offsets=args.input_start_offsets,
533587
)
534588

535589
if logit_error is not None:
536590
raise logit_error
537591

538592
# Benchmarking.
539593
if args.benchmark:
540-
benchmark_sampling(model, draft_model, generation_config)
594+
benchmark_sampling(model, draft_model, generation_config, benchmark_report_path=args.benchmark_report_path)
541595

542596

543597
def load_tokenizer(model_path, compiled_model_path, neuron_config):
@@ -555,9 +609,12 @@ def run_generation(
555609
draft_model=None,
556610
adapter_ids=None,
557611
input_capture_hook=None,
612+
input_start_offsets=None,
558613
):
559614
print("\nGenerating outputs...")
560615
print(f"Prompts: {prompts}")
616+
if len(prompts) == 1 and model.config.neuron_config.batch_size > 1:
617+
prompts = prompts * model.config.neuron_config.batch_size
561618

562619
_, output_tokens = get_generate_outputs(
563620
model,
@@ -569,6 +626,7 @@ def run_generation(
569626
adapter_ids=adapter_ids,
570627
max_length=model.neuron_config.max_length,
571628
input_capture_hook=input_capture_hook,
629+
input_start_offsets=input_start_offsets
572630
)
573631

574632
print("Generated outputs:")
@@ -587,13 +645,15 @@ def run_accuracy_check(
587645
num_tokens_to_check=None,
588646
draft_model=None,
589647
expected_outputs_path=None,
648+
input_start_offsets=None,
590649
):
591650
if model.neuron_config.is_medusa:
592651
# Medusa doesn't use greedy sampling, so check accuracy doesn't work.
593652
assert (
594653
check_accuracy_mode == CheckAccuracyMode.SKIP_ACCURACY_CHECK
595654
), "Accuracy checking not supported for Medusa"
596-
655+
if input_start_offsets:
656+
assert all(offset < model.config.neuron_config.max_context_length for offset in input_start_offsets), "Input offset has to be less than max context length"
597657
if check_accuracy_mode == CheckAccuracyMode.SKIP_ACCURACY_CHECK:
598658
print("\nSkipping accuracy check")
599659
return
@@ -612,6 +672,7 @@ def run_accuracy_check(
612672
draft_model=draft_model,
613673
expected_token_ids=expected_outputs,
614674
num_tokens_to_check=num_tokens_to_check,
675+
input_start_offsets=input_start_offsets,
615676
)
616677
elif check_accuracy_mode == CheckAccuracyMode.LOGIT_MATCHING:
617678
assert draft_model is None, "Logit matching not supported for speculation"
@@ -633,6 +694,7 @@ def run_accuracy_check(
633694
divergence_difference_tol=divergence_difference_tol,
634695
tol_map=tol_map,
635696
num_tokens_to_check=num_tokens_to_check,
697+
input_start_offsets=input_start_offsets,
636698
)
637699
else:
638700
raise ValueError(f"Unsupported check accuracy mode: {check_accuracy_mode}")

0 commit comments

Comments
 (0)