Skip to content

Conversation

NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Aug 5, 2025

Description

Migrate Mistral model to NNX, merge after #2178 gets merged.

Tests

Training tests

Command

python -m MaxText.train MaxText/configs/base.yml run_name=nc_test_mistral_$RANDOM steps=5 base_output_directory=gs://chengnuojin-maxtext-logs dataset_path=gs://chengnuojin-maxtext-dataset model_name=mistral-7b enable_checkpointing=True per_device_batch_size=1 load_parameters_path=gs://maxtext-model-checkpoints/mistral-7b/2025-01-23-19-04/scanned/0/items remat_policy='full' opt_type=sgd

Memory

# Before
Total memory size: 18.9 GB, Output size: 6.7 GB, Temp size: 12.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
# After
Total memory size: 18.9 GB, Output size: 6.7 GB, Temp size: 12.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.

Inference tests

Command

python3 -m MaxText.decode MaxText/configs/base.yml tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='I love to' attention=dot_product megablox=False sparse_matmul=False load_parameters_path=gs://chengnuojin-maxtext-logs/chengnuojin_decode_32458/checkpoints/0/items per_device_batch_size=1 run_name=chengnuojin_decode_$RANDOM max_prefill_predict_length=4 async_checkpointing=false scan_layers=false model_name=mistral-7b

Before

Memstats: After load_params:
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_0(process=0,(0,0,0,0))
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_1(process=0,(1,0,0,0))
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_2(process=0,(0,1,0,0))
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
        Using (GB) 41.95 / 400.47 (10.475192%) -->  Available:355.88
Input `I love to` -> `read. I love to read about all kinds of things. I`

After

Memstats: After load_params:
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_0(process=0,(0,0,0,0))
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_1(process=0,(1,0,0,0))
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_2(process=0,(0,1,0,0))
        Using (GB) 6.75 / 30.75 (21.951220%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
        Using (GB) 37.16 / 400.47 (9.279097%) -->  Available:360.67
Input `I love to` -> `read. I love to read about all kinds of things. I`

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 4 times, most recently from b2bfd1d to 65c58f2 Compare August 6, 2025 01:08
@NuojCheng NuojCheng changed the title NNX Migration for Mistral models [Draft] NNX Migration for Mistral models Aug 6, 2025
@NuojCheng NuojCheng changed the title [Draft] NNX Migration for Mistral models [Draft, NO MERGE] NNX Migration for Mistral models Aug 6, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 13 times, most recently from 0591ca3 to fa13350 Compare August 13, 2025 00:42
@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 10 times, most recently from 3675744 to 5973fdb Compare August 18, 2025 20:53
@NuojCheng NuojCheng changed the title [Draft, NO MERGE] NNX Migration for Mistral models NNX Migration for Mistral models Aug 18, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch from 5973fdb to b69edeb Compare August 22, 2025 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants