-
Notifications
You must be signed in to change notification settings - Fork 34
Open
Description
in the code run_dpo.py
why the model is torch.float16 and the ref_model is torch.bfloat16
# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
)
model.config.use_cache = False
if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
if script_args.ref_model:
ref_name = script_args.ref_model
else:
ref_name = model_config.model_name_or_path
model_ref = AutoModelForCausalLM.from_pretrained(
ref_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)Metadata
Metadata
Assignees
Labels
No labels