Conversation
open_lm/model.py
Outdated
| self.head_dim = args.dim // args.n_heads | ||
| self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) | ||
| self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) | ||
| if using_te: |
There was a problem hiding this comment.
instead of if/else's can we have a single helper function that recursively searches the model and replaces the linears?
There was a problem hiding this comment.
another option is having a linear/layernorm module m that is set to either te or nn
There was a problem hiding this comment.
Ideally, yes. But that would replace the Linear layer found here which breaks the training. So until a fix is not found, need to isolate that particular layer.
There was a problem hiding this comment.
can't we just not recurse for special cases?
There was a problem hiding this comment.
missed this earlier, but yeah this might also address my swiglu comment below (since it would recurse and replace the linear within the swiglu)
There was a problem hiding this comment.
Completed in the latest commit.
open_lm/model.py
Outdated
| torch.nn.init.trunc_normal_(self.in_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) | ||
| # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. | ||
| std = std / math.sqrt(2 * (self.layer_id + 1)) | ||
| torch.nn.init.trunc_normal_(self.out_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) |
There was a problem hiding this comment.
why do we need to cast to float? does a float cast happen in place?
There was a problem hiding this comment.
We dont need float cast. Removing that in next commit.
There was a problem hiding this comment.
Removed completely as we are recursively changing NN.Linear to TE.Linear
open_lm/model.py
Outdated
| eps=args.norm_eps, | ||
| ) | ||
| else: | ||
| self.attention_norm = args.norm_type( |
There was a problem hiding this comment.
can we just add te.LayerNorm as one of the args.norm_type?
There was a problem hiding this comment.
Added this in the latest commit based on presence of TE, TE.LayerNorm or NN.LayerNorm will be considered.
open_lm/norms.py
Outdated
|
|
||
| def forward(self, input: Tensor) -> Tensor: | ||
| return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) | ||
| if using_te: |
There was a problem hiding this comment.
@achalddave do u think we should have a seperate class for TeLayerNorm? or do u prefer combining it with existing layer norm
There was a problem hiding this comment.
We can have a separate class for TeLayerNorm.
open_lm/params.py
Outdated
| help="Using SMP Flash Attention.", | ||
| ) | ||
| parser.add_argument( | ||
| "--sharding-strategy", |
There was a problem hiding this comment.
is this used? if so can we have a more specific name?
There was a problem hiding this comment.
also not seeing where --use-smp-flash-attention is used
There was a problem hiding this comment.
This is not used for FP8. Just placeholder flags defaulted to None for Sagemaker Model Parallel.
There was a problem hiding this comment.
Removed this to avoid confusion
| @@ -202,9 +245,14 @@ def __init__(self, layer_id, args: Params): | |||
| elif args.ffn_type == "gelu": | |||
There was a problem hiding this comment.
Could we also support fp8 for swiglu above? We can make a copy of the Swiglu class in this file. Here's the source for Swiglu https://github.com/facebookresearch/xformers/blob/7f8c290183344343771f4e1d945a8ce10a9500ff/xformers/ops/swiglu_op.py#L430
There was a problem hiding this comment.
@rams16592 seems like the recursive replace linear patten should take care of this automatically. a function like this seems like it would be great and we can exclude certain linears that need to be higher precision for stability. this function has an include field instead of exclude, but hopefully that's easy to flip:
https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/utils.py#L65
There was a problem hiding this comment.
Applied this change in the latest commit. Excluding the last output Linear layer from the conversion to TE Linear as its running into errors.
…til natively supports FSDP
train.py,model.py, andnorms.pyto check if Transformer Engine can be imported for P5 instances or H100s and use FP8 for Linear and LayerNorm layers.main.pyfor FP8 support