Skip to content

RuntimeError: 'weight' must be 2-D #65

@hengzhi40

Description

@hengzhi40

[rank1]: Traceback (most recent call last):
[rank1]: File "/data/X-R1-main/src/x_r1/grpo.py", line 281, in
[rank1]: main(script_args, training_args, model_args )
[rank1]: File "/data/X-R1-main/src/x_r1/grpo.py", line 245, in main
[rank1]: train_result = trainer.train(resume_from_checkpoint=checkpoint)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/transformers/trainer.py", line 2241, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/transformers/trainer.py", line 2548, in _inner_training_loop
[rank1]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/transformers/trainer.py", line 3692, in training_step
[rank1]: inputs = self._prepare_inputs(inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data/X-R1-main/src/x_r1/x_grpo_trainer.py", line 494, in _prepare_inputs
[rank1]: inputs = self._generate_and_score_completions(inputs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data/X-R1-main/src/x_r1/x_grpo_trainer.py", line 605, in _generate_and_score_completions
[rank1]: old_per_token_logps = self._get_per_token_logps(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/data/X-R1-main/src/x_r1/x_grpo_trainer.py", line 446, in _get_per_token_logps
[rank1]: logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/peft/peft_model.py", line 1719, in forward
[rank1]: return self.base_model(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
[rank1]: return self.model.forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 856, in forward
[rank1]: outputs = self.model(
[rank1]: ^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 535, in forward
[rank1]: inputs_embeds = self.embed_tokens(input_ids)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/peft/tuners/lora/layer.py", line 886, in forward
[rank1]: after_A = self._embed(x, embedding_A)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/peft/tuners/lora/layer.py", line 852, in _embed
[rank1]: return F.embedding(
[rank1]: ^^^^^^^^^^^^
[rank1]: File "/app/anaconda3/envs/grpo_env/lib/python3.11/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank1]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: 'weight' must be 2-D

Package Version


absl-py 2.1.0
accelerate 1.0.1
aiofiles 23.2.1
aiohappyeyeballs 2.4.8
aiohttp 3.11.13
aiosignal 1.3.2
annotated-types 0.7.0
antlr4-python3-runtime 4.13.2
anyio 4.8.0
asttokens 3.0.0
attrs 25.1.0
certifi 2025.1.31
charset-normalizer 3.4.1
click 8.1.8
cloudpickle 3.1.1
comm 0.2.2
compressed-tensors 0.8.0
contourpy 1.3.1
cycler 0.12.1
datasets 3.1.0
debugpy 1.8.13
decorator 5.2.1
deepspeed 0.14.4
dill 0.3.8
diskcache 5.6.3
distro 1.9.0
docker-pycreds 0.4.0
einops 0.8.1
executing 2.2.0
fastapi 0.115.11
ffmpy 0.5.0
filelock 3.17.0
flash_attn 2.7.4.post1
fonttools 4.56.0
frozenlist 1.5.0
fsspec 2024.9.0
gguf 0.10.0
gitdb 4.0.12
GitPython 3.1.44
gradio 4.44.0
gradio_client 1.3.0
grpcio 1.70.0
h11 0.14.0
hjson 3.1.0
httpcore 1.0.7
httptools 0.6.4
httpx 0.28.1
huggingface-hub 0.28.1
idna 3.10
importlib_metadata 8.6.1
importlib_resources 6.5.2
interegular 0.3.3
ipykernel 6.28.0
ipython 8.20.0
jedi 0.19.2
Jinja2 3.1.5
jiter 0.8.2
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
jupyter_client 8.6.3
jupyter_core 5.7.2
kiwisolver 1.4.8
lark 1.2.2
latex2sympy2_extended 1.10.1
liger_kernel 0.5.3
llvmlite 0.44.0
lm-format-enforcer 0.10.11
loguru 0.7.3
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 2.1.5
math-verify 0.7.0
matplotlib 3.10.1
matplotlib-inline 0.1.7
mdurl 0.1.2
mistral_common 1.5.3
mpmath 1.3.0
msgpack 1.1.0
msgspec 0.19.0
multidict 6.1.0
multiprocess 0.70.16
nest-asyncio 1.6.0
networkx 3.4.2
ninja 1.11.1.3
numba 0.61.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.570.86
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.8.61
nvidia-nvtx-cu12 12.1.105
openai 1.56.1
opencv-python-headless 4.11.0.86
orjson 3.10.15
outlines 0.0.46
packaging 24.2
pandas 2.2.3
parso 0.8.4
partial-json-parser 0.2.1.1.post5
peft 0.14.0
pexpect 4.9.0
pillow 10.4.0
pip 25.0
platformdirs 4.3.6
prometheus_client 0.21.1
prometheus-fastapi-instrumentator 7.0.2
prompt_toolkit 3.0.50
propcache 0.3.0
protobuf 5.29.3
psutil 7.0.0
ptyprocess 0.7.0
pure_eval 0.2.3
py-cpuinfo 9.0.0
pyairports 2.1.1
pyarrow 19.0.1
pycountry 24.6.1
pydantic 2.10.6
pydantic_core 2.27.2
pydub 0.25.1
Pygments 2.19.1
pyparsing 3.2.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.20
pytz 2025.1
PyYAML 6.0.2
pyzmq 26.2.1
ray 2.43.0
referencing 0.36.2
regex 2024.11.6
requests 2.32.3
rich 13.9.4
rpds-py 0.23.1
ruff 0.9.9
safetensors 0.5.3
semantic-version 2.10.0
sentencepiece 0.2.0
sentry-sdk 2.22.0
setproctitle 1.3.5
setuptools 75.8.0
shellingham 1.5.4
six 1.17.0
smmap 5.0.2
sniffio 1.3.1
stack-data 0.6.3
starlette 0.46.0
sympy 1.13.1
tensorboard 2.19.0
tensorboard-data-server 0.7.2
tensorboardX 2.6.2.2
tiktoken 0.9.0
tokenizers 0.21.0
tomlkit 0.12.0
torch 2.5.1+cu121
torchvision 0.20.1
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
transformers 4.49.0
triton 3.1.0
trl 0.15.0
typer 0.15.2
typing_extensions 4.12.2
tzdata 2025.1
urllib3 2.3.0
uvicorn 0.34.0
uvloop 0.21.0
vllm 0.6.4.post1
wandb 0.17.3
watchfiles 1.0.4
wcwidth 0.2.13
websockets 12.0
Werkzeug 3.1.3
wheel 0.45.1
xformers 0.0.28.post3
xxhash 3.5.0
yarl 1.18.3
zipp 3.21.0
这个是什么错误

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions