Skip to content

Commit c3ecd10

Browse files
lvhan028grimoire
andauthored
Support reward models (#3192)
* tmp * remove update badwords * update * update * update --------- Co-authored-by: grimoire <[email protected]>
1 parent 1fab2f5 commit c3ecd10

File tree

8 files changed

+321
-5
lines changed

8 files changed

+321
-5
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,6 @@ def __update_inputs(next_token_ids):
723723
return_logits=return_logits)
724724
logits = output['logits']
725725
logits = logits[0] # [bs, seq, prob] -> [seq, prob]
726-
727726
# sampling
728727
next_token_ids = await self.async_sampling_logits(logits, all_ids, guided_input_ids, sampling_inputs,
729728
inputs, num_ignore_eos > 0)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
from typing import Any, Iterable, List, Optional, Tuple
4+
5+
import torch
6+
from torch import nn
7+
from transformers.configuration_utils import PretrainedConfig
8+
9+
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
10+
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
11+
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
12+
13+
from .internlm2 import InternLM2Model
14+
from .utils.cudagraph import CudaGraphMixin
15+
16+
17+
class InternLM2ForRewardModel(nn.Module, CudaGraphMixin):
18+
"""rewrote model of InternLM2ForRewardModel."""
19+
20+
packed_modules_mapping = {
21+
'gate_up_proj': [
22+
'w1',
23+
'w3',
24+
],
25+
}
26+
27+
def __init__(self,
28+
config: PretrainedConfig,
29+
ctx_mgr: StepContextManager,
30+
dtype: torch.dtype = None,
31+
device: torch.device = None):
32+
super().__init__()
33+
self.config = config
34+
self.ctx_mgr = ctx_mgr
35+
# build Model
36+
self.model = InternLM2Model(config, dtype=dtype, device=device)
37+
# build v_head
38+
self.v_head = build_rowwise_linear(config.hidden_size, 1, bias=False, dtype=dtype, device=device)
39+
40+
def forward(
41+
self,
42+
input_ids: torch.Tensor,
43+
position_ids: torch.Tensor,
44+
past_key_values: List[List[torch.Tensor]],
45+
attn_metadata: Any = None,
46+
inputs_embeds: torch.Tensor = None,
47+
**kwargs,
48+
):
49+
"""model forward, return logits."""
50+
hidden_states = self.model(
51+
input_ids=input_ids,
52+
position_ids=position_ids,
53+
past_key_values=past_key_values,
54+
attn_metadata=attn_metadata,
55+
inputs_embeds=inputs_embeds,
56+
)
57+
return hidden_states
58+
59+
def get_logits(self, hidden_states: torch.Tensor):
60+
"""compute logits of the model output."""
61+
return self.v_head(hidden_states)
62+
63+
def get_input_embeddings(self):
64+
"""get input embeddings."""
65+
return self.model.get_input_embeddings()
66+
67+
def prepare_inputs_for_generation(
68+
self,
69+
past_key_values: List[List[torch.Tensor]],
70+
inputs_embeds: Optional[torch.Tensor] = None,
71+
context: StepContext = None,
72+
):
73+
"""prepare input."""
74+
# get input_ids, position_ids and attention metadatas
75+
input_ids = context.input_ids
76+
position_ids = context.position_ids
77+
attn_metadata = context.attn_metadata
78+
79+
vision_embeddings = context.input_embeddings
80+
if vision_embeddings is not None and len(vision_embeddings) > 0:
81+
raise ValueError('InternLM2RewardModel does not support vision embedding')
82+
83+
# inputs of forward
84+
return dict(
85+
input_ids=input_ids,
86+
position_ids=position_ids,
87+
past_key_values=past_key_values,
88+
attn_metadata=attn_metadata,
89+
inputs_embeds=inputs_embeds,
90+
)
91+
92+
def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int):
93+
"""load lora weights."""
94+
95+
from lmdeploy.pytorch.adapter.adapter import load_lora_weights
96+
97+
num_heads = self.config.num_attention_heads
98+
num_key_value_heads = self.config.num_key_value_heads
99+
hidden_size = self.config.hidden_size
100+
head_dim = hidden_size // num_heads
101+
group_size = num_heads // num_key_value_heads
102+
103+
def _rearange_wqkv(weights):
104+
for name, loaded_weight in weights:
105+
if 'wqkv.lora_B' in name:
106+
loaded_weight = loaded_weight.unflatten(0, (-1, 2 + group_size, head_dim))
107+
q = loaded_weight[:, :-2].flatten(0, 2)
108+
k = loaded_weight[:, -2].flatten(0, 1)
109+
v = loaded_weight[:, -1].flatten(0, 1)
110+
loaded_weight = torch.cat([q, k, v], dim=0)
111+
yield name, loaded_weight
112+
113+
weights_iter = _rearange_wqkv(weights)
114+
load_lora_weights(self, weights_iter, adapter_id)
115+
116+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
117+
"""load weights."""
118+
# modify from vllm
119+
stacked_params_mapping = [
120+
# (param_name, shard_name, shard_id)
121+
('.gate_up_proj', '.w1', 0),
122+
('.gate_up_proj', '.w3', 1),
123+
]
124+
125+
params_dict = dict(self.named_parameters())
126+
for name, loaded_weight in weights:
127+
if 'rotary_emb.inv_freq' in name:
128+
continue
129+
if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
130+
continue
131+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
132+
if weight_name not in name:
133+
continue
134+
name = name.replace(weight_name, param_name)
135+
param = params_dict[name]
136+
load_weight(param, loaded_weight, shard_id=shard_id)
137+
break
138+
else:
139+
if '.wqkv' in name:
140+
param = params_dict[name]
141+
q, k, v = param.weight_spliter(loaded_weight, layout='hgd')
142+
load_weight(param, q, shard_id='q')
143+
load_weight(param, k, shard_id='k')
144+
load_weight(param, v, shard_id='v')
145+
else:
146+
param = params_dict[name]
147+
load_weight(param, loaded_weight)

lmdeploy/pytorch/models/module_map.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,11 @@
168168
'InternLM3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm3.InternLM3ForCausalLM',
169169
})
170170

171+
# internlm2 reward model
172+
MODULE_MAP.update(
173+
{'InternLM2ForRewardModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2_reward.InternLM2ForRewardModel'})
174+
175+
# qwen2 reward model
176+
MODULE_MAP.update({'Qwen2ForRewardModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_reward.Qwen2ForRewardModel'})
177+
171178
CUSTOM_MODULE_MAP = dict()
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Any, Iterable, List, Optional, Tuple
3+
4+
import torch
5+
from torch import nn
6+
from transformers.configuration_utils import PretrainedConfig
7+
8+
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
9+
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
10+
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
11+
12+
from .qwen2 import Qwen2Model
13+
from .utils.cudagraph import CudaGraphMixin
14+
15+
16+
class Qwen2ForRewardModel(nn.Module, CudaGraphMixin):
17+
"""ModelForCausalLM."""
18+
19+
packed_modules_mapping = {
20+
'qkv_proj': [
21+
'q_proj',
22+
'k_proj',
23+
'v_proj',
24+
],
25+
'gate_up_proj': [
26+
'gate_proj',
27+
'up_proj',
28+
],
29+
}
30+
31+
def __init__(self,
32+
config: PretrainedConfig,
33+
ctx_mgr: StepContextManager,
34+
dtype: torch.dtype = None,
35+
device: torch.device = None):
36+
super().__init__()
37+
self.config = config
38+
self.ctx_mgr = ctx_mgr
39+
# build model
40+
self.model = Qwen2Model(config, dtype=dtype, device=device)
41+
42+
self.lm_head = build_rowwise_linear(config.hidden_size,
43+
config.vocab_size,
44+
bias=False,
45+
dtype=dtype,
46+
device=device)
47+
48+
self.num_labels = 1
49+
self.score = nn.Sequential(
50+
build_rowwise_linear(config.hidden_size, config.hidden_size, bias=True, dtype=dtype, device=device),
51+
nn.ReLU(), build_rowwise_linear(config.hidden_size, self.num_labels, bias=True, dtype=dtype, device=device))
52+
53+
def forward(
54+
self,
55+
input_ids: torch.Tensor,
56+
position_ids: torch.Tensor,
57+
past_key_values: List[List[torch.Tensor]],
58+
attn_metadata: Any = None,
59+
inputs_embeds: torch.Tensor = None,
60+
**kwargs,
61+
):
62+
"""model forward, return logits."""
63+
hidden_states = self.model(
64+
input_ids=input_ids,
65+
position_ids=position_ids,
66+
past_key_values=past_key_values,
67+
attn_metadata=attn_metadata,
68+
inputs_embeds=inputs_embeds,
69+
)
70+
return hidden_states
71+
72+
def get_logits(self, hidden_states: torch.Tensor):
73+
"""compute logits of the model output."""
74+
logits = self.score(hidden_states)
75+
return logits
76+
77+
def update_weights(self):
78+
"""update weights."""
79+
pass
80+
81+
def get_input_embeddings(self):
82+
"""get input embeddings."""
83+
return self.model.get_input_embeddings()
84+
85+
def prepare_inputs_for_generation(
86+
self,
87+
past_key_values: List[List[torch.Tensor]],
88+
inputs_embeds: Optional[torch.Tensor] = None,
89+
context: StepContext = None,
90+
):
91+
"""prepare input."""
92+
# get input_ids, position_ids and attention metadatas
93+
input_ids = context.input_ids
94+
position_ids = context.position_ids
95+
attn_metadata = context.attn_metadata
96+
97+
# inputs of forward
98+
return dict(
99+
input_ids=input_ids,
100+
position_ids=position_ids,
101+
past_key_values=past_key_values,
102+
attn_metadata=attn_metadata,
103+
# inputs_embeds=inputs_embeds,
104+
)
105+
106+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
107+
"""load weights."""
108+
# modify from vllm
109+
stacked_params_mapping = [
110+
# (param_name, shard_name, shard_id)
111+
('.qkv_proj', '.q_proj', 'q'),
112+
('.qkv_proj', '.k_proj', 'k'),
113+
('.qkv_proj', '.v_proj', 'v'),
114+
('.gate_up_proj', '.gate_proj', 0),
115+
('.gate_up_proj', '.up_proj', 1),
116+
]
117+
118+
params_dict = dict(self.named_parameters())
119+
for name, loaded_weight in weights:
120+
if 'rotary_emb.inv_freq' in name:
121+
continue
122+
if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
123+
continue
124+
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
125+
continue
126+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
127+
if weight_name not in name:
128+
continue
129+
name = name.replace(weight_name, param_name)
130+
param = params_dict[name]
131+
load_weight(param, loaded_weight, shard_id=shard_id)
132+
break
133+
else:
134+
param = params_dict[name]
135+
load_weight(param, loaded_weight)

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def _get_llama3_parameters(config: PretrainedConfig):
6464
def build_rotary_params(config: PretrainedConfig):
6565
"""get scaling_factor rotary params, and emb_type."""
6666
params = dict(emb_type=RopeType.Default)
67-
if config.rope_scaling is not None:
67+
# cannot access config.rope_scaling when the model is "Qwen/Qwen2-Math-RM-72B"
68+
rope_scaling = getattr(config, 'rope_scaling', None)
69+
if rope_scaling is not None:
6870
rope_type_str = config.rope_scaling.get('rope_type', 'default')
6971
build_funcs = dict(default=_get_default_rope_parameters,
7072
linear=_get_linear_scaling_rope_parameters,

lmdeploy/pytorch/supported_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
InternLMForCausalLM=True,
2424
# internlm2
2525
InternLM2ForCausalLM=True,
26+
InternLM2ForRewardModel=True,
2627
# internlm-xcomposer
2728
InternLMXComposerForCausalLM=False,
2829
# internlm2-xcomposer
@@ -107,7 +108,7 @@ def is_supported(model_path: str):
107108

108109
triton_model_path = os.path.join(model_path, 'triton_models')
109110
if os.path.exists(triton_model_path):
110-
logger.warning(f'{model_path} seems to be a turbomind workspace, '
111+
logger.warning(f'{model_path} seems to be a turbomind model, '
111112
'which can only be ran with turbomind engine.')
112113
else:
113114
try:

lmdeploy/serve/async_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tqdm
2121

2222
from lmdeploy import Tokenizer
23+
from lmdeploy.archs import get_model_arch
2324
from lmdeploy.logger import RequestLogger
2425
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig
2526
from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model
@@ -271,6 +272,7 @@ def __init__(self,
271272

272273
self.tokenizer = Tokenizer(model_path)
273274
self.hf_gen_cfg = get_hf_gen_cfg(model_path)
275+
self.arch, _ = get_model_arch(model_path)
274276

275277
# build backend engine
276278
if backend == 'turbomind':

lmdeploy/serve/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,28 @@
1717

1818

1919
class LogitsMixin:
20-
"""Helper class to calculate ppl."""
20+
"""Helper class to get logits, reward score and calculate ppl."""
21+
22+
def get_reward_score(self, input_ids: List) -> List[float]:
23+
"""
24+
Args:
25+
input_ids(List): a list of token_id or a list of token_id list or a tensor containing
26+
token_ids
27+
Return:
28+
reward score in a list. If the input_ids is a list of token_id, the return value
29+
is still a list with length 1.
30+
"""
31+
supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel']
32+
if self.arch not in supported_reward_models:
33+
raise ValueError(f'{self.arch} is not in reward mode list: {supported_reward_models}')
34+
assert isinstance(input_ids, List)
35+
assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids)
36+
# Make input_ids a list of token_id list
37+
input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids
38+
logits = self._run(coro=self._async_get_logits(input_ids=input_ids)).result()
39+
logits = [x.squeeze() for x in logits]
40+
scores = [x[-1].cpu().item() for x in logits]
41+
return scores
2142

2243
async def _async_get_logits(self,
2344
input_ids,
@@ -34,7 +55,9 @@ async def _proc(i):
3455
input_len = len(input_ids[i])
3556
# TODO(lvhan): Fix the ugly code later on
3657
max_new_tokens = 1 if self.backend == 'turbomind' else 0
37-
gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all')
58+
# The reason to set `top_k=1` is that pt engine crashes at top_k sampling stage
59+
# when perform inference on a reward model.
60+
gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all', top_k=1)
3861
async with self.safe_run(inst,
3962
session_id=i,
4063
input_ids=input_ids[i],

0 commit comments

Comments
 (0)