|
18 | 18 |
|
19 | 19 | from .modules import DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE, PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder
|
20 | 20 |
|
21 |
| -# logging library is not automatically supported by Torchscript |
22 |
| -import warnings |
23 |
| - |
24 | 21 |
|
25 |
| -@dataclass(frozen=True) |
| 22 | +@dataclass |
26 | 23 | class T5Conf:
|
27 | 24 | encoder_only: bool = False
|
28 | 25 | linear_head: bool = False
|
@@ -215,7 +212,6 @@ def prepare_inputs_for_generation(
|
215 | 212 | "return_past_key_values": return_past_key_values,
|
216 | 213 | }
|
217 | 214 |
|
218 |
| - @torch.jit.export |
219 | 215 | def get_encoder(self) -> T5Encoder:
|
220 | 216 | return self.encoder
|
221 | 217 |
|
@@ -292,8 +288,6 @@ def forward(
|
292 | 288 |
|
293 | 289 | # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
|
294 | 290 | if decoder_tokens is None:
|
295 |
| - batch_size = encoder_output.size()[0] |
296 |
| - encoder_output_device = encoder_output.device |
297 | 291 | decoder_tokens = (
|
298 | 292 | torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx
|
299 | 293 | )
|
@@ -323,7 +317,7 @@ def forward(
|
323 | 317 | # Rescale output before projecting on vocab. This happens when the encoder and decoder share the
|
324 | 318 | # same word embeddings, which is always the case in our t5 implementation.
|
325 | 319 | # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
|
326 |
| - decoder_output = decoder_output * (self.embedding_dim**-0.5) |
| 320 | + decoder_output = decoder_output * (self.embedding_dim ** -0.5) |
327 | 321 | decoder_output = self.lm_head(decoder_output)
|
328 | 322 | decoder_outputs["decoder_output"] = decoder_output
|
329 | 323 |
|
|
0 commit comments