diff --git a/examples/llama3-70b.sh b/examples/llama3-70b.sh new file mode 100644 index 0000000..8b6de9c --- /dev/null +++ b/examples/llama3-70b.sh @@ -0,0 +1,15 @@ +# ========= seq_len=8192 mbs=1 python-fsdp=32 ========= +LOW_CPU_MEM_USAGE=1 PJRT_ALLOCATOR_FRACTION=0.95 XLA_FLAGS="--xla_disable_hlo_passes=gpu-convert-async-collectives-to-sync,triton-autotuner --xla_gpu_memory_limit_slop_factor=100 --xla_multiheap_size_constraint_per_heap=4294967296" \ +./examples/run.sh --model ./hf_models/config/llama-3-70b --accelerator acc --mbs 1 --fsdp 32 --max_seq_length 8192 --gc --gc_cnt 80 + +# ========= seq_len=8192 mbs=1 spmd-fsdp=32 ========= +LOW_CPU_MEM_USAGE=1 PJRT_ALLOCATOR_FRACTION=0.95 XLA_FLAGS="--xla_disable_hlo_passes=gpu-convert-async-collectives-to-sync,triton-autotuner --xla_gpu_memory_limit_slop_factor=100 --xla_multiheap_size_constraint_per_heap=4294967296" \ +XLA_USE_SPMD=1 ./examples/run.sh --model ./hf_models/config/llama-3-70b --accelerator acc --mbs 1 --fsdp 32 --spmd_fsdp --max_seq_length 8192 --gc --gc_cnt 80 + +# ========= seq_len=8192 mbs=2 python-fsdp=32 ========= +LOW_CPU_MEM_USAGE=1 PJRT_ALLOCATOR_FRACTION=0.90 XLA_FLAGS="--xla_gpu_memory_limit_slop_factor=100 --xla_multiheap_size_constraint_per_heap=8589934592" \ +./examples/run.sh --model ./hf_models/config/llama-3-70b --accelerator acc --mbs 2 --fsdp 32 --max_seq_length 8192 --gc --gc_cnt 80 + +# ========= seq_len=8192 mbs=2 spmd-fsdp=32 ========= +LOW_CPU_MEM_USAGE=1 PJRT_ALLOCATOR_FRACTION=0.95 XLA_FLAGS="--xla_disable_hlo_passes=gpu-convert-async-collectives-to-sync,triton-autotuner --xla_gpu_memory_limit_slop_factor=100 --xla_multiheap_size_constraint_per_heap=4294967296" \ +XLA_USE_SPMD=1 ./examples/run.sh --model ./hf_models/config/llama-3-70b --accelerator acc --mbs 2 --fsdp 32 --spmd_fsdp --max_seq_length 8192 --gc --gc_cnt 80 # OPTIMAL \ No newline at end of file diff --git a/examples/llama3-8b.sh b/examples/llama3-8b.sh new file mode 100644 index 0000000..df6425d --- /dev/null +++ b/examples/llama3-8b.sh @@ -0,0 +1,17 @@ +# ========= seq_len=2048 mbs=1 python-fsdp=8 ========= +./examples/run.sh --model ./hf_models/config/llama-3-8b --accelerator acc --mbs 1 --fsdp 8 --max_seq_length 2048 + +# ========= seq_len=8192 mbs=1 spmd-fsdp=8 ========= +XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=./hlo/llama-3-8b-bs1-spmd-2409041135 --xla_disable_hlo_passes=gpu-convert-async-collectives-to-sync,triton-autotuner" \ +XLA_USE_SPMD=1 ./examples/run.sh --model ./hf_models/config/llama-3-8b --accelerator acc --mbs 1 --fsdp 8 --spmd_fsdp --max_seq_length 8192 + +# ========= seq_len=8192 mbs=1 spmd-fsdp=8 profile ========= +XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=./hlo/llama-3-8b-bs2-best --xla_disable_hlo_passes=gpu-convert-async-collectives-to-sync,triton-autotuner --xla_gpu_memory_limit_slop_factor=95" \ +XLA_USE_SPMD=1 ./examples/run.sh --model ./hf_models/config/llama-3-8b --accelerator acc --mbs 2 --fsdp 8 --spmd_fsdp --max_seq_length 8192 --gc --gc_cnt 1 --profile + +# ========= seq_len=8192 mbs=2 spmd-fsdp=8 ========= +XLA_FLAGS="--xla_disable_hlo_passes=gpu-convert-async-collectives-to-sync,triton-autotuner --xla_gpu_memory_limit_slop_factor=97" \ +XLA_USE_SPMD=1 ./examples/run.sh --model ./hf_models/config/llama-3-8b --accelerator acc --mbs 2 --fsdp 8 --spmd_fsdp --max_seq_length 8192 --gc --gc_cnt 1 # OPTIMAL + +# ========= seq_len=8192 mbs=2 python-fsdp=8 ========= +./examples/run.sh --model ./hf_models/config/llama-3-8b --accelerator acc --mbs 2 --fsdp 8 --max_seq_length 8192 --gc --gc_cnt 9