|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved. |
| 2 | + |
| 3 | +from typing import Any, Iterable, List, Optional, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import nn |
| 7 | +from transformers.configuration_utils import PretrainedConfig |
| 8 | + |
| 9 | +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager |
| 10 | +from lmdeploy.pytorch.nn.linear import build_rowwise_linear |
| 11 | +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight |
| 12 | + |
| 13 | +from .internlm2 import InternLM2Model |
| 14 | +from .utils.cudagraph import CudaGraphMixin |
| 15 | + |
| 16 | + |
| 17 | +class InternLM2ForRewardModel(nn.Module, CudaGraphMixin): |
| 18 | + """rewrote model of InternLM2ForRewardModel.""" |
| 19 | + |
| 20 | + packed_modules_mapping = { |
| 21 | + 'gate_up_proj': [ |
| 22 | + 'w1', |
| 23 | + 'w3', |
| 24 | + ], |
| 25 | + } |
| 26 | + |
| 27 | + def __init__(self, |
| 28 | + config: PretrainedConfig, |
| 29 | + ctx_mgr: StepContextManager, |
| 30 | + dtype: torch.dtype = None, |
| 31 | + device: torch.device = None): |
| 32 | + super().__init__() |
| 33 | + self.config = config |
| 34 | + self.ctx_mgr = ctx_mgr |
| 35 | + # build Model |
| 36 | + self.model = InternLM2Model(config, dtype=dtype, device=device) |
| 37 | + # build v_head |
| 38 | + self.v_head = build_rowwise_linear(config.hidden_size, 1, bias=False, dtype=dtype, device=device) |
| 39 | + |
| 40 | + def forward( |
| 41 | + self, |
| 42 | + input_ids: torch.Tensor, |
| 43 | + position_ids: torch.Tensor, |
| 44 | + past_key_values: List[List[torch.Tensor]], |
| 45 | + attn_metadata: Any = None, |
| 46 | + inputs_embeds: torch.Tensor = None, |
| 47 | + **kwargs, |
| 48 | + ): |
| 49 | + """model forward, return logits.""" |
| 50 | + hidden_states = self.model( |
| 51 | + input_ids=input_ids, |
| 52 | + position_ids=position_ids, |
| 53 | + past_key_values=past_key_values, |
| 54 | + attn_metadata=attn_metadata, |
| 55 | + inputs_embeds=inputs_embeds, |
| 56 | + ) |
| 57 | + return hidden_states |
| 58 | + |
| 59 | + def get_logits(self, hidden_states: torch.Tensor): |
| 60 | + """compute logits of the model output.""" |
| 61 | + return self.v_head(hidden_states) |
| 62 | + |
| 63 | + def get_input_embeddings(self): |
| 64 | + """get input embeddings.""" |
| 65 | + return self.model.get_input_embeddings() |
| 66 | + |
| 67 | + def prepare_inputs_for_generation( |
| 68 | + self, |
| 69 | + past_key_values: List[List[torch.Tensor]], |
| 70 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 71 | + context: StepContext = None, |
| 72 | + ): |
| 73 | + """prepare input.""" |
| 74 | + # get input_ids, position_ids and attention metadatas |
| 75 | + input_ids = context.input_ids |
| 76 | + position_ids = context.position_ids |
| 77 | + attn_metadata = context.attn_metadata |
| 78 | + |
| 79 | + vision_embeddings = context.input_embeddings |
| 80 | + if vision_embeddings is not None and len(vision_embeddings) > 0: |
| 81 | + raise ValueError('InternLM2RewardModel does not support vision embedding') |
| 82 | + |
| 83 | + # inputs of forward |
| 84 | + return dict( |
| 85 | + input_ids=input_ids, |
| 86 | + position_ids=position_ids, |
| 87 | + past_key_values=past_key_values, |
| 88 | + attn_metadata=attn_metadata, |
| 89 | + inputs_embeds=inputs_embeds, |
| 90 | + ) |
| 91 | + |
| 92 | + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int): |
| 93 | + """load lora weights.""" |
| 94 | + |
| 95 | + from lmdeploy.pytorch.adapter.adapter import load_lora_weights |
| 96 | + |
| 97 | + num_heads = self.config.num_attention_heads |
| 98 | + num_key_value_heads = self.config.num_key_value_heads |
| 99 | + hidden_size = self.config.hidden_size |
| 100 | + head_dim = hidden_size // num_heads |
| 101 | + group_size = num_heads // num_key_value_heads |
| 102 | + |
| 103 | + def _rearange_wqkv(weights): |
| 104 | + for name, loaded_weight in weights: |
| 105 | + if 'wqkv.lora_B' in name: |
| 106 | + loaded_weight = loaded_weight.unflatten(0, (-1, 2 + group_size, head_dim)) |
| 107 | + q = loaded_weight[:, :-2].flatten(0, 2) |
| 108 | + k = loaded_weight[:, -2].flatten(0, 1) |
| 109 | + v = loaded_weight[:, -1].flatten(0, 1) |
| 110 | + loaded_weight = torch.cat([q, k, v], dim=0) |
| 111 | + yield name, loaded_weight |
| 112 | + |
| 113 | + weights_iter = _rearange_wqkv(weights) |
| 114 | + load_lora_weights(self, weights_iter, adapter_id) |
| 115 | + |
| 116 | + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
| 117 | + """load weights.""" |
| 118 | + # modify from vllm |
| 119 | + stacked_params_mapping = [ |
| 120 | + # (param_name, shard_name, shard_id) |
| 121 | + ('.gate_up_proj', '.w1', 0), |
| 122 | + ('.gate_up_proj', '.w3', 1), |
| 123 | + ] |
| 124 | + |
| 125 | + params_dict = dict(self.named_parameters()) |
| 126 | + for name, loaded_weight in weights: |
| 127 | + if 'rotary_emb.inv_freq' in name: |
| 128 | + continue |
| 129 | + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): |
| 130 | + continue |
| 131 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 132 | + if weight_name not in name: |
| 133 | + continue |
| 134 | + name = name.replace(weight_name, param_name) |
| 135 | + param = params_dict[name] |
| 136 | + load_weight(param, loaded_weight, shard_id=shard_id) |
| 137 | + break |
| 138 | + else: |
| 139 | + if '.wqkv' in name: |
| 140 | + param = params_dict[name] |
| 141 | + q, k, v = param.weight_spliter(loaded_weight, layout='hgd') |
| 142 | + load_weight(param, q, shard_id='q') |
| 143 | + load_weight(param, k, shard_id='k') |
| 144 | + load_weight(param, v, shard_id='v') |
| 145 | + else: |
| 146 | + param = params_dict[name] |
| 147 | + load_weight(param, loaded_weight) |
0 commit comments