Skip to content

[WIP] Adding support for FP8 training#218

Open
shahromil16 wants to merge 27 commits intomainfrom
feature/fp8
Open

[WIP] Adding support for FP8 training#218
shahromil16 wants to merge 27 commits intomainfrom
feature/fp8

Conversation

@shahromil16
Copy link
Collaborator

  • Changes made to train.py, model.py, and norms.py to check if Transformer Engine can be imported for P5 instances or H100s and use FP8 for Linear and LayerNorm layers.
  • Minor modifications to main.py for FP8 support

@shahromil16 shahromil16 self-assigned this Feb 21, 2024
@shahromil16 shahromil16 changed the title Adding support for FP8 training [WIP] Adding support for FP8 training Feb 21, 2024
open_lm/model.py Outdated
self.head_dim = args.dim // args.n_heads
self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
if using_te:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of if/else's can we have a single helper function that recursively searches the model and replaces the linears?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option is having a linear/layernorm module m that is set to either te or nn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, yes. But that would replace the Linear layer found here which breaks the training. So until a fix is not found, need to isolate that particular layer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we just not recurse for special cases?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed this earlier, but yeah this might also address my swiglu comment below (since it would recurse and replace the linear within the swiglu)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed in the latest commit.

open_lm/model.py Outdated
torch.nn.init.trunc_normal_(self.in_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to cast to float? does a float cast happen in place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need float cast. Removing that in next commit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed completely as we are recursively changing NN.Linear to TE.Linear

open_lm/model.py Outdated
eps=args.norm_eps,
)
else:
self.attention_norm = args.norm_type(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just add te.LayerNorm as one of the args.norm_type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in the latest commit based on presence of TE, TE.LayerNorm or NN.LayerNorm will be considered.

open_lm/norms.py Outdated

def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
if using_te:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achalddave do u think we should have a seperate class for TeLayerNorm? or do u prefer combining it with existing layer norm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a separate class for TeLayerNorm.

help="Using SMP Flash Attention.",
)
parser.add_argument(
"--sharding-strategy",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used? if so can we have a more specific name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also not seeing where --use-smp-flash-attention is used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used for FP8. Just placeholder flags defaulted to None for Sagemaker Model Parallel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this to avoid confusion

@@ -202,9 +245,14 @@ def __init__(self, layer_id, args: Params):
elif args.ffn_type == "gelu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also support fp8 for swiglu above? We can make a copy of the Swiglu class in this file. Here's the source for Swiglu https://github.com/facebookresearch/xformers/blob/7f8c290183344343771f4e1d945a8ce10a9500ff/xformers/ops/swiglu_op.py#L430

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rams16592 seems like the recursive replace linear patten should take care of this automatically. a function like this seems like it would be great and we can exclude certain linears that need to be higher precision for stability. this function has an include field instead of exclude, but hopefully that's easy to flip:
https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/utils.py#L65

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied this change in the latest commit. Excluding the last output Linear layer from the conversion to TE Linear as its running into errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants