Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/fairseq2/composition/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@
create_nllb_model,
register_nllb_configs,
)
from fairseq2.models.opt import (
OPT_FAMILY,
OPTConfig,
convert_opt_state_dict,
create_opt_model,
register_opt_configs,
)
from fairseq2.models.qwen import (
QWEN_FAMILY,
QwenConfig,
Expand Down Expand Up @@ -296,6 +303,21 @@ def _register_model_families(container: DependencyContainer) -> None:

register_nllb_configs(container)

# OPT
register_model_family(
container,
OPT_FAMILY,
kls=TransformerLM,
config_kls=OPTConfig,
factory=create_opt_model,
state_dict_converter=convert_opt_state_dict,
compiler=compile_transformer_lm,
fsdp_applier=apply_fsdp_to_transformer_lm,
layerwise_ac_applier=apply_ac_to_transformer_lm,
)

register_opt_configs(container)

# Qwen
register_model_family(
container,
Expand Down
15 changes: 15 additions & 0 deletions src/fairseq2/models/opt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.opt.config import OPT_FAMILY as OPT_FAMILY
from fairseq2.models.opt.config import OPTConfig as OPTConfig
from fairseq2.models.opt.config import register_opt_configs as register_opt_configs
from fairseq2.models.opt.factory import OPTFactory as OPTFactory
from fairseq2.models.opt.factory import create_opt_model as create_opt_model
from fairseq2.models.opt.hub import get_opt_model_hub as get_opt_model_hub
from fairseq2.models.opt.interop import convert_opt_state_dict as convert_opt_state_dict
62 changes: 62 additions & 0 deletions src/fairseq2/models/opt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Final

from fairseq2.runtime.config_registry import ConfigRegistrar
from fairseq2.runtime.dependency import DependencyContainer

OPT_FAMILY: Final = "opt"


@dataclass(kw_only=True)
class OPTConfig:
"""Holds the configuration of a OPT model.

The default values correspond to the base architecture as described in
:cite:t:`https://arxiv.org/abs/2205.01068`.
"""

model_dim: int = 768
"""The dimensionality of the model."""

max_seq_len: int = 2048 + 1
"""The maximum sequence length."""

vocab_size: int = 50272
"""The size of the vocabulary."""

pad_idx: int | None = 1
"""The index of the PAD symbol in the vocabulary."""

attn_window_len: int = 2048
"""The local attention window length."""

num_layers: int = 12
"""The number of decoder layers."""

num_attn_heads: int = 12
"""The number of attention heads in decoder layers."""

num_key_value_heads: int = 12
"""The number of key/value heads for Grouped Query Attention."""

ffn_inner_dim: int = 3072
"""The dimensionality of inner projection layers in feed-forward networks."""

dropout_p: float = 0.1
"""The dropout probability on outputs of Transformer layers."""


def register_opt_configs(container: DependencyContainer) -> None:
arch = ConfigRegistrar(container, OPTConfig)

@arch("125m")
def _125m() -> OPTConfig:
return OPTConfig()
175 changes: 175 additions & 0 deletions src/fairseq2/models/opt/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import torch.nn as nn

from fairseq2.models.opt.config import OPTConfig
from fairseq2.models.transformer import (
CausalAttentionBias,
FeedForwardNetwork,
LocalAttentionStateFactory,
MultiheadAttention,
StandardFeedForwardNetwork,
StandardMultiheadAttention,
TransformerEmbeddingFrontend,
TransformerFrontend,
TransformerNormOrder,
create_default_sdpa,
)
from fairseq2.models.transformer_lm import (
StandardTransformerLMDecoder,
StandardTransformerLMDecoderLayer,
TransformerLM,
TransformerLMDecoder,
TransformerLMDecoderLayer,
)
from fairseq2.nn import (
Embedding,
LayerNorm,
LearnedPositionEncoder,
Linear,
PositionEncoder,
Projection,
StandardEmbedding,
StandardLayerNorm,
)


def create_opt_model(config: OPTConfig) -> TransformerLM:
return OPTFactory(config).create_model()


class OPTFactory:
def __init__(self, config: OPTConfig) -> None:
self._config = config

def create_model(self) -> TransformerLM:
config = self._config

decoder_frontend = self.create_decoder_frontend()

decoder = self.create_decoder()

final_proj = self.create_final_projection()

return TransformerLM(
config.model_dim,
decoder_frontend,
decoder,
final_proj,
config.pad_idx,
config.max_seq_len,
)

def create_decoder_frontend(self) -> TransformerFrontend:
config = self._config

embed = self.create_embedding()

pos_encoder = self.create_position_encoder()

return TransformerEmbeddingFrontend(
config.model_dim,
embed,
pos_encoder=pos_encoder,
no_scale=True,
# dropout_p=config.dropout_p, # TODO: check if there is dropout here
)

def create_embedding(self) -> Embedding:
config = self._config

return StandardEmbedding(config.vocab_size, config.model_dim, config.pad_idx)

def create_decoder(self) -> TransformerLMDecoder:
config = self._config

layers = []

for _ in range(config.num_layers):
layer = self.create_decoder_layer()

layers.append(layer)

layer_norm = self.create_layer_norm()

return StandardTransformerLMDecoder(layers, layer_norm)

def create_position_encoder(self) -> PositionEncoder:
config = self._config

return LearnedPositionEncoder(
config.model_dim, config.max_seq_len, _legacy_pad_idx=1
)

def create_decoder_layer(self) -> TransformerLMDecoderLayer:
config = self._config

self_attn = self.create_self_attention()

self_attn_layer_norm = self.create_layer_norm()

ffn = self.create_ffn()

ffn_layer_norm = self.create_layer_norm()

return StandardTransformerLMDecoderLayer(
self_attn,
self_attn_layer_norm,
ffn,
ffn_layer_norm,
norm_order=TransformerNormOrder.PRE,
dropout_p=config.dropout_p,
)

def create_self_attention(self) -> MultiheadAttention:
config = self._config

attn_bias = CausalAttentionBias(attn_window_len=config.attn_window_len)

sdpa = create_default_sdpa(attn_bias)

state_factory = LocalAttentionStateFactory(config.attn_window_len)

return StandardMultiheadAttention(
config.model_dim,
config.num_attn_heads,
sdpa,
num_key_value_heads=config.num_key_value_heads,
bias=True,
state_factory=state_factory,
)

def create_ffn(self) -> FeedForwardNetwork:
config = self._config

return StandardFeedForwardNetwork(
config.model_dim, config.ffn_inner_dim, bias=True
)

def create_layer_norm(self) -> LayerNorm:
config = self._config

return StandardLayerNorm(config.model_dim, bias=True)

def create_final_projection(self) -> Projection:
config = self._config

return Linear(
config.model_dim,
config.vocab_size,
bias=False,
init_fn=_init_final_projection,
)


def _init_final_projection(proj: Linear) -> None:
nn.init.normal_(proj.weight, std=proj.input_dim**-0.5)

if proj.bias is not None:
nn.init.zeros_(proj.bias)
15 changes: 15 additions & 0 deletions src/fairseq2/models/opt/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models import ModelHubAccessor
from fairseq2.models.opt.config import OPT_FAMILY, OPTConfig
from fairseq2.models.transformer_lm import TransformerLM

get_opt_model_hub = ModelHubAccessor(
OPT_FAMILY, kls=TransformerLM, config_kls=OPTConfig
)
37 changes: 37 additions & 0 deletions src/fairseq2/models/opt/interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.opt.config import OPTConfig
from fairseq2.models.utils.checkpoint import convert_state_dict

_HG_KEY_MAP = {
# fmt: off
r"^model\.decoder\.embed_tokens\.": r"decoder_frontend.embed.",
r"^model\.decoder\.embed_positions\.": r"decoder_frontend.pos_encoder.",
r"^model\.decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"decoder.layers.\1.self_attn_layer_norm.",
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"decoder.layers.\1.self_attn.q_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"decoder.layers.\1.self_attn.v_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"decoder.layers.\1.self_attn.output_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.mlp\.gate_proj\.": r"decoder.layers.\1.ffn.gate_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.fc1\.": r"decoder.layers.\1.ffn.inner_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.fc2\.": r"decoder.layers.\1.ffn.output_proj.",
r"^model\.decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.",
r"^model\.decoder\.final_layer_norm\.": r"decoder.layer_norm.",
r"^lm_head\.": r"final_proj.",
# fmt: on
}


def convert_opt_state_dict(
state_dict: dict[str, object], config: OPTConfig
) -> dict[str, object]:
if "model.decoder.embed_tokens.weight" in state_dict: # Hugging Face
state_dict = convert_state_dict(state_dict, _HG_KEY_MAP)

return state_dict
7 changes: 6 additions & 1 deletion src/fairseq2/nn/position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
encoding_dim: int,
max_seq_len: int,
*,
_legacy_pad_idx: int | None = None,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
Expand All @@ -243,6 +244,10 @@ def __init__(

self.max_seq_len = max_seq_len

# This is a legacy parameter that should only be set when the encodings
# must be compatible with fairseq.
self._legacy_pad_idx = 0 if _legacy_pad_idx is None else _legacy_pad_idx

self.reset_parameters()

def reset_parameters(self) -> None:
Expand Down Expand Up @@ -271,7 +276,7 @@ def forward(
f"The lengths of all sequences in `seqs` must be less than or equal to the maximum sequence length ({self.max_seq_len}), but at least one sequence is of length {max_seq_len} instead."
)

indices = seqs_layout.position_indices + 1 # +1 for padding
indices = seqs_layout.position_indices + (1 + self._legacy_pad_idx)

if not self.training and state_bag is not None:
indices = state_bag.step_nr + indices
Expand Down
Empty file.
Loading
Loading