From 1efeab84050517d200aa2fbef6b4c1ea93883a3c Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 31 Oct 2021 16:27:10 +0100 Subject: [PATCH 01/34] feat(model): add fft attention --- src/model.py | 127 +++++++++++++++++++-------------------------------- 1 file changed, 48 insertions(+), 79 deletions(-) diff --git a/src/model.py b/src/model.py index b57ff9a..8ee4e9f 100644 --- a/src/model.py +++ b/src/model.py @@ -76,7 +76,7 @@ def backward(ctx, grad_outputs: torch.Tensor): def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: bool, - jitter_epsilon: float, feature_shuffle: torch.Tensor, groups: int, experts: int) -> torch.Tensor: + jitter_epsilon: float, groups: int, experts: int) -> torch.Tensor: *expert_weights, gate = expert_weights batch, features, sequence = inp.size() tokens = batch * sequence @@ -112,8 +112,6 @@ def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: boo # permute inp = inp.gather(0, expert_permutation.expand_as(inp)) - if feature_shuffle is not None: - inp = inp.gather(1, feature_shuffle.view(1, -1).expand_as(inp)) inp = inp.view(tokens // experts, experts * groups, features // groups) if len(expert_weights) == 1: inp = expert_matmul(inp, expert_weights[0]) @@ -126,43 +124,20 @@ def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: boo def moe_check(inp: torch.Tensor, w: torch.nn.ParameterList, training: bool, - jitter_epsilon: float, feature_shuffle: torch.Tensor, groups: int, experts: int) -> torch.Tensor: + jitter_epsilon: float, groups: int, experts: int) -> torch.Tensor: if experts > 0: - return moe(inp, w, training, jitter_epsilon, feature_shuffle, groups, experts) + return moe(inp, w, training, jitter_epsilon, groups, experts) return conv(inp, w[0], groups, False) -def linear_attention(inp: torch.Tensor, divisor: torch.Tensor, - w0: torch.nn.ParameterList, - feature_shuffle0: typing.Optional[torch.Tensor], groups0: int, experts0: int, - w1: torch.Tensor, - w2: torch.nn.ParameterList, - feature_shuffle2: typing.Optional[torch.Tensor], groups2: int, experts2: int, - input_cache: torch.Tensor, cumsum_cache: torch.Tensor, bottleneck_group: int, training: bool, - caching: bool, idx: int, norm_power: int, jitter_epsilon: float - ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - kernel_size = w1.size(2) - pad = True - if not training and caching: - if idx - 1 > kernel_size and inp.size(2) == 1: - pad = False - inp = torch.cat([input_cache, inp], -1) - input_cache = inp[:, :, -kernel_size + 1:].detach() - inp = moe_check(inp, w0, training, jitter_epsilon, feature_shuffle0, groups0, experts0) - depth, scale, shift = inp.chunk(3, 1) - cum = depth.cumsum(-1) - if not training and caching: - cum = cum + cumsum_cache - scale = scale[:, :, -1:] - shift = shift[:, :, -1:] - cum = cum[:, :, -1:] - if idx - 1 > kernel_size: - cumsum_cache = cum.detach() - inp = TripleNorm.apply(cum / divisor, scale, shift, norm_power) - inp = conv(inp, w1, bottleneck_group, pad) +def linear_attention(inp: torch.Tensor, w0: torch.nn.ParameterList, groups0: int, experts0: int, w1: torch.Tensor, + w2: torch.nn.ParameterList, groups2: int, experts2: int, bottleneck_group: int, training: bool, + norm_power: int, jitter_epsilon: float) -> torch.Tensor: + inp = moe_check(inp, w0, training, jitter_epsilon, groups0, experts0) inp = TripleNorm.apply(*inp.chunk(3, 1), norm_power) - inp = moe_check(inp, w2, training, jitter_epsilon, feature_shuffle2, groups2, experts2) - return input_cache, cumsum_cache, inp + inp = conv(inp, w1, bottleneck_group, True) + inp = TripleNorm.apply(*inp.chunk(3, 1), norm_power) + return moe_check(inp, w2, training, jitter_epsilon, groups2, experts2) def conv_weight(in_features: int, out_features: int, kernel_size: int, groups: int, std: float): @@ -255,12 +230,18 @@ def __init__(self, ctx: Context): pos_embd = torch.arange(0, ctx.model.sequence_length).unsqueeze(0) + 1 self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device)) - cell = LinearAttentionCell(self, ctx, 1) + ff = FeedForward(self, ctx, 1, 1) + attn = FFTAttention(self, ctx, 1) self.stem = revlib.ReversibleSequential(*[c - for i in range(1, 1 + ctx.model.depth) - for c in [cell.momentum((1 - ctx.model.momentumnet_beta) / - ctx.model.momentumnet_beta ** i, not ctx.model.weight_sharing), - MomentumNetSide(ctx.model.momentumnet_beta ** i)]], + for i in range(1, 1 + ctx.model.depth * 2, 2) + for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / + ctx.model.momentumnet_beta ** i, + not ctx.model.weight_sharing), + MomentumNetSide(ctx.model.momentumnet_beta ** i), + attn.momentum((1 - ctx.model.momentumnet_beta) / + ctx.model.momentumnet_beta ** (i + 1), + not ctx.model.weight_sharing), + MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device) self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) torch.nn.init.zeros_(self.output.weight.data) @@ -270,7 +251,7 @@ def forward(self, inp: torch.Tensor): def reset_cache(self): for mod in self.stem.modules(): - if isinstance(mod, LinearAttentionCell): + if isinstance(mod, FeedForward): mod.reset_cache() @@ -300,9 +281,9 @@ def get_moe_param(in_features: int, out_features: int, groups: int, experts: int return [torch.nn.Parameter(conv_weight(in_features, out_features, 1, groups, std))] -class LinearAttentionCell(torch.nn.Module): - def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): - super(LinearAttentionCell, self).__init__() +class FeedForward(torch.nn.Module): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float): + super(FeedForward, self).__init__() self.divisor = lambda: base.divisor self.init_scale = init_scale self.caching = ctx.eval.cache @@ -316,50 +297,38 @@ def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): self.jitter_epsilon = ctx.model.moe_jitter_epsilon self.expert_chunks = ctx.model.expert_chunks intermediate = int(ctx.model.features * ctx.model.feed_forward_intermediate_factor) - self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features, intermediate * 3, self.groups0, - self.experts0, self.expert_chunks, ctx.model.activation_std)) + self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features * feature_factor, intermediate * 3, + self.groups0, self.experts0, self.expert_chunks, + ctx.model.activation_std)) self.w1 = conv_weight(intermediate, intermediate * 3, ctx.model.conv_kernel_size, ctx.model.bottleneck_group, ctx.model.activation_std) - self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features, self.groups2, + self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features * feature_factor, self.groups2, self.experts2, self.expert_chunks, 1)) self.idx: int = 0 - self._input_cache = torch.zeros([]) - self._cumsum_cache = torch.zeros([]) - if ctx.model.feature_shuffle: - self.register_buffer("feature_shuffle0", torch.argsort(torch.randn(ctx.model.features)).view(1, -1, 1)) - self.register_buffer("feature_shuffle2", torch.argsort(torch.randn(intermediate)).view(1, -1, 1)) - else: - self.feature_shuffle0 = None - self.feature_shuffle2 = None - def reset_cache(self): - self._cumsum_cache = torch.zeros([]) - self._input_cache = torch.zeros([]) - self.idx = 0 + def _ff(self, inp: torch.Tensor) -> torch.Tensor: + return linear_attention(inp, self.w0, self.groups0, self.experts0, self.w1, self.w2, self.groups2, + self.experts2, + self.bottleneck_group, self.training, self.norm_power, self.jitter_epsilon) def forward(self, inp: torch.Tensor) -> torch.Tensor: - if self.training: - div = self.divisor() - elif self.caching: - self.idx += inp.size(2) - div = torch.LongTensor([self.idx]).to(inp.device) - else: - self.idx = inp.size(2) - div = torch.arange(self.idx, device=inp.device).view(1, 1, -1) + 1 - self._input_cache, self._cumsum_cache, out = linear_attention(inp, div, - self.w0, self.feature_shuffle0, self.groups0, - self.experts0, - self.w1, - self.w2, self.feature_shuffle2, self.groups2, - self.experts2, self._input_cache, - self._cumsum_cache, self.bottleneck_group, - self.training, self.caching, self.idx, - self.norm_power, self.jitter_epsilon - ) - out = out * self.init_scale - return out + return self._ff(inp) * self.init_scale def momentum(self, init_scale: float, deep: bool): out = copy.deepcopy(self) if deep else copy.copy(self) out.init_scale = init_scale return out + + +class FFTAttention(FeedForward): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(FFTAttention, self).__init__(base, ctx, init_scale, 2) + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + batch, features, sequence = inp.size() + out = torch.view_as_real(torch.fft.rfft(inp, 2 * sequence)) + out = out.transpose(2, 3).reshape(batch, features * 2, sequence + 1) + out = self._ff(out) + out = out.view(batch, features, 2, sequence + 1).transpose(2, 3) + out = torch.view_as_complex(out) + return torch.fft.irfft(out, 2 * sequence)[:, :, :sequence] From 0e7ca86f67f0f88877cb7c3fdef666cd6feeb41c Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Tue, 9 Nov 2021 22:06:43 +0100 Subject: [PATCH 02/34] feat(model): add sum-based attention --- src/dataclass.py | 4 ++-- src/model.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/dataclass.py b/src/dataclass.py index 9511263..3912ce6 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -1,7 +1,6 @@ import pathlib -import typing - import torch +import typing import yaml @@ -24,6 +23,7 @@ class Model(DataClass): steps_per_checkpoint: int = 0 # 0 -> disabled print_on_init: bool = True features: int = 256 + sum_attention_level: int = 0 momentumnet_beta: float = 0.99 # The higher this is, the more numerically stable. BUT also lower impact per layer depth: int = 64 batch_size: int = 128 diff --git a/src/model.py b/src/model.py index 8ee4e9f..f6cf211 100644 --- a/src/model.py +++ b/src/model.py @@ -1,10 +1,9 @@ import copy -import typing - import numpy as np import revlib import torch import torch.utils.data +import typing from deepspeed.runtime import lr_schedules from torch.nn import functional as F @@ -332,3 +331,16 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: out = out.view(batch, features, 2, sequence + 1).transpose(2, 3) out = torch.view_as_complex(out) return torch.fft.irfft(out, 2 * sequence)[:, :, :sequence] + + +class SumAttention(FeedForward): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(SumAttention, self).__init__(base, ctx, init_scale, ctx.model.sum_attention_level) + self.sum_attention_level = ctx.model.sum_attention_level + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + out = self._ff(inp).chunk(self.sum_attention_level, 1) + batch, features, seq = out[0].size() + return sum(f(out[0] + sum(out[inner + 1][outer // batch ** inner % batch] + for inner in range(self.sum_attention_level)).unsqueeze(0)) + for outer in range(batch ** (self.sum_attention_level - 1))) From 3089fa68cbe1d14e3180f936b33b22a921b4ea6f Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Tue, 9 Nov 2021 22:19:09 +0100 Subject: [PATCH 03/34] fix(model): define f() --- src/model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/model.py b/src/model.py index f6cf211..8157f33 100644 --- a/src/model.py +++ b/src/model.py @@ -1,9 +1,10 @@ import copy +import typing + import numpy as np import revlib import torch import torch.utils.data -import typing from deepspeed.runtime import lr_schedules from torch.nn import functional as F @@ -337,10 +338,12 @@ class SumAttention(FeedForward): def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): super(SumAttention, self).__init__(base, ctx, init_scale, ctx.model.sum_attention_level) self.sum_attention_level = ctx.model.sum_attention_level + self.weight = conv_weight(ctx.model.features, ctx.model.features, 3, 1, 1) def forward(self, inp: torch.Tensor) -> torch.Tensor: out = self._ff(inp).chunk(self.sum_attention_level, 1) batch, features, seq = out[0].size() - return sum(f(out[0] + sum(out[inner + 1][outer // batch ** inner % batch] - for inner in range(self.sum_attention_level)).unsqueeze(0)) + return sum(conv(torch.relu(out[0] + sum(out[inner + 1][outer // batch ** inner % batch] + for inner in range(self.sum_attention_level)).unsqueeze(0)), + self.weight, 1, True) for outer in range(batch ** (self.sum_attention_level - 1))) From ed891964604612c8baa9ca0e820fe9d973d2e475 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 10:44:13 +0100 Subject: [PATCH 04/34] feat(model): add omnidirectional, pyramidal attention --- src/model.py | 45 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/src/model.py b/src/model.py index 8157f33..c9cf53f 100644 --- a/src/model.py +++ b/src/model.py @@ -4,6 +4,7 @@ import numpy as np import revlib import torch +import torch.nn.functional import torch.utils.data from deepspeed.runtime import lr_schedules from torch.nn import functional as F @@ -236,11 +237,13 @@ def __init__(self, ctx: Context): for i in range(1, 1 + ctx.model.depth * 2, 2) for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / ctx.model.momentumnet_beta ** i, - not ctx.model.weight_sharing), + not ctx.model.weight_sharing, + i + 1), MomentumNetSide(ctx.model.momentumnet_beta ** i), attn.momentum((1 - ctx.model.momentumnet_beta) / ctx.model.momentumnet_beta ** (i + 1), - not ctx.model.weight_sharing), + not ctx.model.weight_sharing, + i + 1), MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device) self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) @@ -284,6 +287,7 @@ def get_moe_param(in_features: int, out_features: int, groups: int, experts: int class FeedForward(torch.nn.Module): def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float): super(FeedForward, self).__init__() + self.ctx = ctx self.divisor = lambda: base.divisor self.init_scale = init_scale self.caching = ctx.eval.cache @@ -305,6 +309,29 @@ def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, featu self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features * feature_factor, self.groups2, self.experts2, self.expert_chunks, 1)) self.idx: int = 0 + self.depth: int = 0 + self.get_last: bool = True + + def _cut_off(self, inp: torch.Tensor) -> torch.Tensor: + if inp.size(2) == self.ctx.model.sequence_length: + return inp + + base_len = self.ctx.model.sequence_length * self.depth + max_len = base_len + self.ctx.model.sequence_length + if self.get_last: + return inp[base_len:max_len] + return inp[:max_len] + + def _pad(self, inp: torch.Tensor, out: torch.Tensor): + if inp.size(2) == out.size(2): + return inp + + batch, features, sequence = inp.size() + if self.get_last: + return torch.cat([torch.zeros((batch, features, self.ctx.model.sequence_length * self.depth)), out.size(), + torch.zeros((batch, features, + sequence - self.ctx.model.sequence_length * (self.depth + 1)))], 2) + return torch.cat([out.size(), torch.zeros((batch, features, sequence - out.size()))], 2) def _ff(self, inp: torch.Tensor) -> torch.Tensor: return linear_attention(inp, self.w0, self.groups0, self.experts0, self.w1, self.w2, self.groups2, @@ -312,11 +339,12 @@ def _ff(self, inp: torch.Tensor) -> torch.Tensor: self.bottleneck_group, self.training, self.norm_power, self.jitter_epsilon) def forward(self, inp: torch.Tensor) -> torch.Tensor: - return self._ff(inp) * self.init_scale + return self._pad(inp, self._ff(self._cut_off(inp))) * self.init_scale - def momentum(self, init_scale: float, deep: bool): + def momentum(self, init_scale: float, deep: bool, depth: int): out = copy.deepcopy(self) if deep else copy.copy(self) out.init_scale = init_scale + out.depth = depth return out @@ -347,3 +375,12 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: for inner in range(self.sum_attention_level)).unsqueeze(0)), self.weight, 1, True) for outer in range(batch ** (self.sum_attention_level - 1))) + + +class OmnidirectionalAttention(FFTAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.get_last = False + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + return super().forward(inp) From 21d80f9ded59478f984362b2eebda7b1d50cf8c2 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 11:09:58 +0100 Subject: [PATCH 05/34] feat(model): allow selection of attentio --- src/dataclass.py | 1 + src/model.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/dataclass.py b/src/dataclass.py index 3912ce6..495f449 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -18,6 +18,7 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): class Model(DataClass): + attention: str = "OmnidirectionalAttention" weight_sharing: bool = False checkpoint_path: str = "checkpoint.torch" steps_per_checkpoint: int = 0 # 0 -> disabled diff --git a/src/model.py b/src/model.py index c9cf53f..514b507 100644 --- a/src/model.py +++ b/src/model.py @@ -232,7 +232,12 @@ def __init__(self, ctx: Context): self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device)) ff = FeedForward(self, ctx, 1, 1) - attn = FFTAttention(self, ctx, 1) + + modules = [mod.__name__ for mod in attention_modules] + if ctx.model.attention not in modules: + raise ValueError(f"{ctx.model.attention} is not a known type of attention. You can pick any of the" + f" following: {modules}") + attn = attention_modules[modules.index(ctx.model.attention)](self, ctx, 1) self.stem = revlib.ReversibleSequential(*[c for i in range(1, 1 + ctx.model.depth * 2, 2) for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / @@ -384,3 +389,5 @@ def __init__(self, *args, **kwargs): def forward(self, inp: torch.Tensor) -> torch.Tensor: return super().forward(inp) + +attention_modules = [FeedForward, FFTAttention, SumAttention, OmnidirectionalAttention] \ No newline at end of file From 59f55fd5eed227eb8fca9778fe08a163f31e98d9 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 11:40:46 +0100 Subject: [PATCH 06/34] fix(model): pad input with zeros if omni --- src/model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/model.py b/src/model.py index 514b507..731b14a 100644 --- a/src/model.py +++ b/src/model.py @@ -238,6 +238,7 @@ def __init__(self, ctx: Context): raise ValueError(f"{ctx.model.attention} is not a known type of attention. You can pick any of the" f" following: {modules}") attn = attention_modules[modules.index(ctx.model.attention)](self, ctx, 1) + self.expand_sequence = attn.get_last | ff.get_last self.stem = revlib.ReversibleSequential(*[c for i in range(1, 1 + ctx.model.depth * 2, 2) for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / @@ -255,7 +256,11 @@ def __init__(self, ctx: Context): torch.nn.init.zeros_(self.output.weight.data) def forward(self, inp: torch.Tensor): - return self.output(self.stem(self.embedding(inp).transpose(1, 2))) + inp = self.embedding(inp).transpose(1, 2) + if self.expand_sequence: + batch, features, sequence = inp.size() + inp = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)))], 2) + return self.output(self.stem(inp)) def reset_cache(self): for mod in self.stem.modules(): @@ -362,7 +367,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: out = torch.view_as_real(torch.fft.rfft(inp, 2 * sequence)) out = out.transpose(2, 3).reshape(batch, features * 2, sequence + 1) out = self._ff(out) - out = out.view(batch, features, 2, sequence + 1).transpose(2, 3) + out = out.view(batch, features, 2, sequence + 1).transpose(2, 3).contiguous() out = torch.view_as_complex(out) return torch.fft.irfft(out, 2 * sequence)[:, :, :sequence] @@ -387,7 +392,5 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.get_last = False - def forward(self, inp: torch.Tensor) -> torch.Tensor: - return super().forward(inp) -attention_modules = [FeedForward, FFTAttention, SumAttention, OmnidirectionalAttention] \ No newline at end of file +attention_modules = [FeedForward, FFTAttention, SumAttention, OmnidirectionalAttention] From 3ae35bf005d5281405433a933a02f92f527fdd8e Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 12:09:32 +0100 Subject: [PATCH 07/34] fix(dataset): use custom multiprocessing to avoid data replication --- src/dataclass.py | 1 - src/dataset.py | 48 ++++++++++++++++++++++++++++++----------- src/executable/train.py | 1 - src/model.py | 5 +++++ src/utils/setup.py | 1 + 5 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/dataclass.py b/src/dataclass.py index 495f449..3bc4d28 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -52,7 +52,6 @@ class Dataset(DataClass): file_name: str = "out.tensor" classes: int = 256 num_workers: int = 4 - pin_memory: bool = False prefetch_factor: int = 256 # 256 (Prefetch) * 8 (Long) * 2048 (GPT context) * 256 (High Batch) = 1GiB RAM diff --git a/src/dataset.py b/src/dataset.py index 9e51bc9..16e7366 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,3 +1,5 @@ +import multiprocessing +import random import typing import torch @@ -13,26 +15,46 @@ def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typin return dat[:, :-1], dat[:, 1:] -class Dataset(torch.utils.data.Dataset): - def __init__(self, ctx: Context): - self.data = torch.load(ctx.dataset.file_name) - batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) - item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) - self.batch_index = batch_index + item_index - self.length = self.data.size(0) - ctx.model.batch_size * ctx.model.sequence_length +class Dataset: + def __init__(self, ctx: Context, length: int, queue: multiprocessing.Queue): + self.length = length + self.batch = ctx.optimizer.gradient_accumulation_steps + self.queue = queue def __len__(self): return self.length - def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: - return get_sample(self.data, self.batch_index, idx) + def __iter__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: + yield next(self) + def __next__(self): + return self.queue.get() -def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: + +def _get_process_fn(ctx: Context, queue: multiprocessing.Queue) -> typing.Tuple[typing.Callable[[int], None], int]: + data = torch.load(ctx.dataset.file_name) + batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) + item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) + batch_index = batch_index + item_index + length = data.size(0) - ctx.model.batch_size * ctx.model.sequence_length + + def _fn(idx): + random.seed(idx) + while True: + queue.put(get_sample(data, batch_index, random.randint(0, length))) + + return _fn, length + + +def get_dataset(ctx: Context) -> Dataset: if ctx.dataset.prefetch_factor < ctx.dataset.num_workers: print(f"Warning: prefetch_factor ({ctx.dataset.prefetch_factor}) < num_workers ({ctx.dataset.num_workers})." f"Some workers will be idle at all times. Reducing num_workers ({ctx.dataset.num_workers}) to " f"prefetch_factor ({ctx.dataset.prefetch_factor}).") - return torch.utils.data.DataLoader(Dataset(ctx), ctx.optimizer.gradient_accumulation_steps, True, - num_workers=min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor), - pin_memory=ctx.dataset.pin_memory, prefetch_factor=ctx.dataset.prefetch_factor) + queue = multiprocessing.Queue(ctx.dataset.prefetch_factor) + proc_fn, length = _get_process_fn(ctx, queue) + procs = [multiprocessing.Process(target=proc_fn, args=(idx,), daemon=True) + for idx in range(min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor))] + for p in procs: + p.start() + return Dataset(ctx, length, queue) diff --git a/src/executable/train.py b/src/executable/train.py index 0a6d9b4..ce45d72 100644 --- a/src/executable/train.py +++ b/src/executable/train.py @@ -13,7 +13,6 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): data = get_dataset(ctx) data_len = len(data) - data = iter(data) mod = get_model(ctx, load_model, next(data)[0]) wandb.watch(mod, log=ctx.log.wandb.model_log_type, log_freq=ctx.log.wandb.log_frequency) diff --git a/src/model.py b/src/model.py index 731b14a..fa79825 100644 --- a/src/model.py +++ b/src/model.py @@ -225,7 +225,9 @@ def forward(self, inp: torch.Tensor): class LinearAttention(torch.nn.Module): def __init__(self, ctx: Context): super(LinearAttention, self).__init__() + print("Enter Linear attention") self.embedding = torch.nn.Embedding(ctx.dataset.classes, ctx.model.features * 2).to(ctx.model.device) + print("EMbedding") orthonormal(self.embedding.weight, ctx.model.input_embedding_std * 2 ** -0.5) pos_embd = torch.arange(0, ctx.model.sequence_length).unsqueeze(0) + 1 @@ -239,6 +241,7 @@ def __init__(self, ctx: Context): f" following: {modules}") attn = attention_modules[modules.index(ctx.model.attention)](self, ctx, 1) self.expand_sequence = attn.get_last | ff.get_last + print("Attn/ff") self.stem = revlib.ReversibleSequential(*[c for i in range(1, 1 + ctx.model.depth * 2, 2) for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / @@ -252,7 +255,9 @@ def __init__(self, ctx: Context): i + 1), MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device) + print("Stem") self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) + print("Out") torch.nn.init.zeros_(self.output.weight.data) def forward(self, inp: torch.Tensor): diff --git a/src/utils/setup.py b/src/utils/setup.py index c7d4e62..7c3a3a3 100644 --- a/src/utils/setup.py +++ b/src/utils/setup.py @@ -40,6 +40,7 @@ def setup_torch(seed: int): def get_model(ctx: Context, load_model: bool, data: typing.Optional[torch.Tensor] = None) -> Trainer: + pritn("Get model") mod = Trainer(ctx, LinearAttention(ctx).to(dtype=torch.float16 if ctx.model.float16 else torch.float), data if data is None else None) From 0eac9038d50a9ba1a9500c057cccab0642fed5ad Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 12:10:43 +0100 Subject: [PATCH 08/34] style(model): remove debug prints --- src/model.py | 5 ----- src/utils/setup.py | 1 - 2 files changed, 6 deletions(-) diff --git a/src/model.py b/src/model.py index fa79825..731b14a 100644 --- a/src/model.py +++ b/src/model.py @@ -225,9 +225,7 @@ def forward(self, inp: torch.Tensor): class LinearAttention(torch.nn.Module): def __init__(self, ctx: Context): super(LinearAttention, self).__init__() - print("Enter Linear attention") self.embedding = torch.nn.Embedding(ctx.dataset.classes, ctx.model.features * 2).to(ctx.model.device) - print("EMbedding") orthonormal(self.embedding.weight, ctx.model.input_embedding_std * 2 ** -0.5) pos_embd = torch.arange(0, ctx.model.sequence_length).unsqueeze(0) + 1 @@ -241,7 +239,6 @@ def __init__(self, ctx: Context): f" following: {modules}") attn = attention_modules[modules.index(ctx.model.attention)](self, ctx, 1) self.expand_sequence = attn.get_last | ff.get_last - print("Attn/ff") self.stem = revlib.ReversibleSequential(*[c for i in range(1, 1 + ctx.model.depth * 2, 2) for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / @@ -255,9 +252,7 @@ def __init__(self, ctx: Context): i + 1), MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device) - print("Stem") self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) - print("Out") torch.nn.init.zeros_(self.output.weight.data) def forward(self, inp: torch.Tensor): diff --git a/src/utils/setup.py b/src/utils/setup.py index 7c3a3a3..c7d4e62 100644 --- a/src/utils/setup.py +++ b/src/utils/setup.py @@ -40,7 +40,6 @@ def setup_torch(seed: int): def get_model(ctx: Context, load_model: bool, data: typing.Optional[torch.Tensor] = None) -> Trainer: - pritn("Get model") mod = Trainer(ctx, LinearAttention(ctx).to(dtype=torch.float16 if ctx.model.float16 else torch.float), data if data is None else None) From 20ce54b4d6ec6fdbccf2de518604e59465956d39 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 12:52:47 +0100 Subject: [PATCH 09/34] fix(dataset): manually slice pytorch data loader --- src/dataclass.py | 4 ++- src/dataset.py | 60 ++++++++++++++++++----------------------- src/executable/train.py | 4 ++- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/src/dataclass.py b/src/dataclass.py index 3bc4d28..77d59a5 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -1,6 +1,7 @@ import pathlib -import torch import typing + +import torch import yaml @@ -52,6 +53,7 @@ class Dataset(DataClass): file_name: str = "out.tensor" classes: int = 256 num_workers: int = 4 + pin_memory: bool = False prefetch_factor: int = 256 # 256 (Prefetch) * 8 (Long) * 2048 (GPT context) * 256 (High Batch) = 1GiB RAM diff --git a/src/dataset.py b/src/dataset.py index 16e7366..d813d18 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,5 +1,3 @@ -import multiprocessing -import random import typing import torch @@ -15,46 +13,40 @@ def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typin return dat[:, :-1], dat[:, 1:] -class Dataset: - def __init__(self, ctx: Context, length: int, queue: multiprocessing.Queue): - self.length = length - self.batch = ctx.optimizer.gradient_accumulation_steps - self.queue = queue +class Dataset(torch.utils.data.Dataset): + def __init__(self, ctx: Context, workers: int): + self.ctx = ctx + self.workers = workers + self.data = torch.empty((1,)) + data = torch.load(self.ctx.dataset.file_name) + self.length = data.size(0) - self.ctx.model.batch_size * self.ctx.model.sequence_length + + batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) + item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) + self.batch_index = batch_index + item_index + self.worker_id: int = 0 + self.slice_size = self.length // self.workers def __len__(self): return self.length - def __iter__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: - yield next(self) - - def __next__(self): - return self.queue.get() - - -def _get_process_fn(ctx: Context, queue: multiprocessing.Queue) -> typing.Tuple[typing.Callable[[int], None], int]: - data = torch.load(ctx.dataset.file_name) - batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) - item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) - batch_index = batch_index + item_index - length = data.size(0) - ctx.model.batch_size * ctx.model.sequence_length - - def _fn(idx): - random.seed(idx) - while True: - queue.put(get_sample(data, batch_index, random.randint(0, length))) + def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: + return get_sample(self.data, self.batch_index, idx % self.slice_size) - return _fn, length + def set_id(self, worker_id: int): + self.worker_id = worker_id + data = torch.load(self.dataset.file_name) + self.data = data[self.slice_size * worker_id: self.slice_size * (worker_id + 1)] -def get_dataset(ctx: Context) -> Dataset: +def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: if ctx.dataset.prefetch_factor < ctx.dataset.num_workers: print(f"Warning: prefetch_factor ({ctx.dataset.prefetch_factor}) < num_workers ({ctx.dataset.num_workers})." f"Some workers will be idle at all times. Reducing num_workers ({ctx.dataset.num_workers}) to " f"prefetch_factor ({ctx.dataset.prefetch_factor}).") - queue = multiprocessing.Queue(ctx.dataset.prefetch_factor) - proc_fn, length = _get_process_fn(ctx, queue) - procs = [multiprocessing.Process(target=proc_fn, args=(idx,), daemon=True) - for idx in range(min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor))] - for p in procs: - p.start() - return Dataset(ctx, length, queue) + workers = min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor) + dset = Dataset(ctx, workers) + return torch.utils.data.DataLoader(dset, ctx.optimizer.gradient_accumulation_steps, True, + num_workers=workers, pin_memory=ctx.dataset.pin_memory, + prefetch_factor=ctx.dataset.prefetch_factor, + worker_init_fn=dset.set_id) diff --git a/src/executable/train.py b/src/executable/train.py index ce45d72..27a25b2 100644 --- a/src/executable/train.py +++ b/src/executable/train.py @@ -13,7 +13,9 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): data = get_dataset(ctx) data_len = len(data) - mod = get_model(ctx, load_model, next(data)[0]) + data = iter(data) + next_data = next(data)[0] + mod = get_model(ctx, load_model, next_data) wandb.watch(mod, log=ctx.log.wandb.model_log_type, log_freq=ctx.log.wandb.log_frequency) log = WandbLog(ctx, data_len) From f073f7ce17f3f9d2770cc5fc5a79763a4d06591d Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 13:44:03 +0100 Subject: [PATCH 10/34] fix(dataset): manually slice dataset --- src/dataclass.py | 4 ++-- src/dataset.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/dataclass.py b/src/dataclass.py index 77d59a5..a5165bc 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -19,7 +19,7 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): class Model(DataClass): - attention: str = "OmnidirectionalAttention" + attention: str = "FFTAttention" weight_sharing: bool = False checkpoint_path: str = "checkpoint.torch" steps_per_checkpoint: int = 0 # 0 -> disabled @@ -134,7 +134,7 @@ class Optimizer(DataClass): statistics_compute_steps: int = 1 block_size: int = 128 best_effort_shape_interpretation: bool = True - graft_type: str = 'adagrad' # 'Adagrad' or 'SGD' + graft_type: str = 'SGD' # 'Adagrad' or 'SGD' nesterov: bool = True no_preconditioning_for_layers_with_dim_gt: int = 8192 diff --git a/src/dataset.py b/src/dataset.py index d813d18..6429447 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -14,29 +14,32 @@ def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typin class Dataset(torch.utils.data.Dataset): + data: torch.Tensor + def __init__(self, ctx: Context, workers: int): self.ctx = ctx self.workers = workers - self.data = torch.empty((1,)) data = torch.load(self.ctx.dataset.file_name) - self.length = data.size(0) - self.ctx.model.batch_size * self.ctx.model.sequence_length + self.offset = self.ctx.model.batch_size * self.ctx.model.sequence_length + self.length = data.size(0) - self.offset batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) self.batch_index = batch_index + item_index self.worker_id: int = 0 - self.slice_size = self.length // self.workers + self.slice_size = data.size(0) // self.workers def __len__(self): return self.length def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: - return get_sample(self.data, self.batch_index, idx % self.slice_size) + return get_sample(self.data, self.batch_index, idx % (self.slice_size - self.offset)) + @torch.no_grad() def set_id(self, worker_id: int): self.worker_id = worker_id - data = torch.load(self.dataset.file_name) - self.data = data[self.slice_size * worker_id: self.slice_size * (worker_id + 1)] + data = torch.load(self.ctx.dataset.file_name) + self.data = data[self.slice_size * worker_id: self.slice_size * (worker_id + 1)].detach() def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: @@ -46,7 +49,7 @@ def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: f"prefetch_factor ({ctx.dataset.prefetch_factor}).") workers = min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor) dset = Dataset(ctx, workers) - return torch.utils.data.DataLoader(dset, ctx.optimizer.gradient_accumulation_steps, True, + return torch.utils.data.DataLoader(dset, ctx.optimizer.gradient_accumulation_steps, False, num_workers=workers, pin_memory=ctx.dataset.pin_memory, prefetch_factor=ctx.dataset.prefetch_factor, worker_init_fn=dset.set_id) From e8383e7e1534ea8359dadecf96bc74cd5ca7f55d Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 13 Nov 2021 14:00:22 +0100 Subject: [PATCH 11/34] perf(dataset): implement sampling in multiprocessing --- src/dataset.py | 60 ++++++++++++++++++++++------------------- src/executable/train.py | 1 - 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 6429447..714a6f0 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,3 +1,5 @@ +import multiprocessing +import random import typing import torch @@ -13,43 +15,47 @@ def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typin return dat[:, :-1], dat[:, 1:] -class Dataset(torch.utils.data.Dataset): - data: torch.Tensor - - def __init__(self, ctx: Context, workers: int): +class Dataset: + def __init__(self, ctx: Context, queue: multiprocessing.Queue, length: int): self.ctx = ctx - self.workers = workers - data = torch.load(self.ctx.dataset.file_name) - self.offset = self.ctx.model.batch_size * self.ctx.model.sequence_length - self.length = data.size(0) - self.offset - - batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) - item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) - self.batch_index = batch_index + item_index - self.worker_id: int = 0 - self.slice_size = data.size(0) // self.workers + self.length = length + self.queue = queue def __len__(self): return self.length - def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: - return get_sample(self.data, self.batch_index, idx % (self.slice_size - self.offset)) + def __iter__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: + yield next(self) + + def __next__(self): + items = [self.queue.get() for _ in range(self.ctx.optimizer.gradient_accumulation_steps)] + return torch.stack([itm[0] for itm in items], 0), torch.stack([itm[1] for itm in items], 0) + + +def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_count: int): + data = torch.load(ctx.dataset.file_name) + data_len = data.size(0) // worker_count + data = data[data_len * idx:data_len * (idx + 1)] + batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) + item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) + batch_index = batch_index + item_index + length = data.size(0) - ctx.model.batch_size * ctx.model.sequence_length - @torch.no_grad() - def set_id(self, worker_id: int): - self.worker_id = worker_id - data = torch.load(self.ctx.dataset.file_name) - self.data = data[self.slice_size * worker_id: self.slice_size * (worker_id + 1)].detach() + random.seed(idx) + while True: + queue.put(get_sample(data, batch_index, random.randint(0, length))) -def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: +def get_dataset(ctx: Context) -> Dataset: if ctx.dataset.prefetch_factor < ctx.dataset.num_workers: print(f"Warning: prefetch_factor ({ctx.dataset.prefetch_factor}) < num_workers ({ctx.dataset.num_workers})." f"Some workers will be idle at all times. Reducing num_workers ({ctx.dataset.num_workers}) to " f"prefetch_factor ({ctx.dataset.prefetch_factor}).") + queue = multiprocessing.Queue(ctx.dataset.prefetch_factor) workers = min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor) - dset = Dataset(ctx, workers) - return torch.utils.data.DataLoader(dset, ctx.optimizer.gradient_accumulation_steps, False, - num_workers=workers, pin_memory=ctx.dataset.pin_memory, - prefetch_factor=ctx.dataset.prefetch_factor, - worker_init_fn=dset.set_id) + procs = [multiprocessing.Process(target=_process_fn, args=(ctx, queue, idx, workers), daemon=True) for idx in + range(workers)] + for p in procs: + p.start() + data = torch.load(ctx.dataset.file_name) + return Dataset(ctx, queue, data.size(0)) diff --git a/src/executable/train.py b/src/executable/train.py index 27a25b2..e51ad92 100644 --- a/src/executable/train.py +++ b/src/executable/train.py @@ -13,7 +13,6 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): data = get_dataset(ctx) data_len = len(data) - data = iter(data) next_data = next(data)[0] mod = get_model(ctx, load_model, next_data) wandb.watch(mod, log=ctx.log.wandb.model_log_type, log_freq=ctx.log.wandb.log_frequency) From 2734094608ca7e45aca9bd2c0a036ff4d617e25d Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 00:14:48 +0100 Subject: [PATCH 12/34] fix(model): take mean of states in omninet case --- src/model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/model.py b/src/model.py index 731b14a..e8361ed 100644 --- a/src/model.py +++ b/src/model.py @@ -256,11 +256,17 @@ def __init__(self, ctx: Context): torch.nn.init.zeros_(self.output.weight.data) def forward(self, inp: torch.Tensor): - inp = self.embedding(inp).transpose(1, 2) + out = inp = self.embedding(inp).transpose(1, 2) if self.expand_sequence: batch, features, sequence = inp.size() - inp = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)))], 2) - return self.output(self.stem(inp)) + out = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)), device=inp.device, + dtype=inp.dtype)], 2) + out = self.output(self.stem(out)) + + if self.expand_sequence: + batch, features, sequence = inp.size() + inp = out.view(batch, features // 2, -1, sequence).mean(2) + return inp def reset_cache(self): for mod in self.stem.modules(): From d8db018f38ee2964ad7a44958e702a39f75631e9 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 00:16:07 +0100 Subject: [PATCH 13/34] style(model): increase size/use omninet --- configs/small.yaml | 4 ++-- src/dataclass.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 9a7c80d..5fe4b60 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,8 +1,8 @@ model: - depth: 32 + depth: 16 conv_kernel_size: 11 weight_shared_blocks: 1 - batch_size: 1024 + batch_size: 8 feed_forward_intermediate_factor: 0.125 optimizer: beta2: 0.95 diff --git a/src/dataclass.py b/src/dataclass.py index a5165bc..29d8893 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -19,7 +19,7 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): class Model(DataClass): - attention: str = "FFTAttention" + attention: str = "OmnidirectionalAttention" weight_sharing: bool = False checkpoint_path: str = "checkpoint.torch" steps_per_checkpoint: int = 0 # 0 -> disabled From 0003ec4e5bdf45c3318ee6a33bf4ee41b97ddd8d Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 12:18:39 +0100 Subject: [PATCH 14/34] fix(model): remove deepspeed/add windows support --- configs/small.yaml | 2 +- src/model.py | 33 +++++++++++++++++---------------- src/optimizers/build.py | 3 +-- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 5fe4b60..43d2d8b 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,5 +1,5 @@ model: - depth: 16 + depth: 4 conv_kernel_size: 11 weight_shared_blocks: 1 batch_size: 8 diff --git a/src/model.py b/src/model.py index e8361ed..627d5c3 100644 --- a/src/model.py +++ b/src/model.py @@ -6,8 +6,8 @@ import torch import torch.nn.functional import torch.utils.data -from deepspeed.runtime import lr_schedules from torch.nn import functional as F +from torch.optim.lr_scheduler import OneCycleLR from src.dataclass import Context from src.optimizers.build import build_optimizer @@ -151,20 +151,20 @@ def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[t self.ctx = ctx self.model = torch.jit.trace(model, data) if data else model self.optimizer = build_optimizer(ctx, self.model.parameters()) - self.scheduler = lr_schedules.OneCycle(self.optimizer, - ctx.optimizer.one_cycle.cycle_min_lr, - ctx.optimizer.one_cycle.cycle_max_lr, - ctx.optimizer.one_cycle.decay_lr_rate, - ctx.optimizer.one_cycle.cycle_first_step_size, - ctx.optimizer.one_cycle.cycle_second_step_size, - ctx.optimizer.one_cycle.cycle_first_stair_count, - ctx.optimizer.one_cycle.cycle_second_stair_count, - ctx.optimizer.one_cycle.decay_step_size, - ctx.optimizer.one_cycle.cycle_momentum, - ctx.optimizer.one_cycle.cycle_min_mom, - ctx.optimizer.one_cycle.cycle_max_mom, - ctx.optimizer.one_cycle.decay_mom_rate, - ctx.optimizer.one_cycle.last_batch_iteration) + self.scheduler = OneCycleLR(self.optimizer, + ctx.optimizer.one_cycle.cycle_min_lr, + ctx.optimizer.one_cycle.cycle_max_lr, + ctx.optimizer.one_cycle.decay_lr_rate, + ctx.optimizer.one_cycle.cycle_first_step_size, + ctx.optimizer.one_cycle.cycle_second_step_size, + ctx.optimizer.one_cycle.cycle_first_stair_count, + ctx.optimizer.one_cycle.cycle_second_stair_count, + ctx.optimizer.one_cycle.decay_step_size, + ctx.optimizer.one_cycle.cycle_momentum, + ctx.optimizer.one_cycle.cycle_min_mom, + ctx.optimizer.one_cycle.cycle_max_mom, + ctx.optimizer.one_cycle.decay_mom_rate, + ctx.optimizer.one_cycle.last_batch_iteration) @torch.no_grad() def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: @@ -251,7 +251,8 @@ def __init__(self, ctx: Context): not ctx.model.weight_sharing, i + 1), MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], - target_device=ctx.model.device) + target_device=ctx.model.device, + memory_mode=revlib.MemoryModes.autograd_graph) self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) torch.nn.init.zeros_(self.output.weight.data) diff --git a/src/optimizers/build.py b/src/optimizers/build.py index 008a1a8..4b06422 100644 --- a/src/optimizers/build.py +++ b/src/optimizers/build.py @@ -1,14 +1,13 @@ import inspect import typing -import deepspeed.ops.adam import torch from src.dataclass import Context from src.optimizers import shampoo OWN_OPTIMIZER = {'Shampoo': shampoo.Shampoo} -LIB_OPTIMIZER = {'DeepSpeedCPUAdam': deepspeed.ops.adam.DeepSpeedCPUAdam} +LIB_OPTIMIZER = {} def build_optimizer(ctx: Context, parameters: typing.Iterable[torch.nn.Parameter]): From bd437d21f67249882f121941bd7476f08102ac0f Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 12:43:25 +0100 Subject: [PATCH 15/34] fix(model): use pytorch onecycle --- configs/small.yaml | 2 +- src/dataclass.py | 26 +++++++++++++------------- src/model.py | 27 +++++++++++++-------------- 3 files changed, 27 insertions(+), 28 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 43d2d8b..5fe4b60 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,5 +1,5 @@ model: - depth: 4 + depth: 16 conv_kernel_size: 11 weight_shared_blocks: 1 batch_size: 8 diff --git a/src/dataclass.py b/src/dataclass.py index 29d8893..b51e1ea 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -87,19 +87,19 @@ class Zero(DataClass): class OneCycle(DataClass): - cycle_min_lr: float = 3e-4 # Base learning rate used at the start and end of cycle. - cycle_max_lr: float = 1e-3 # Learning rate used in the middle of the cycle. Can be smaller than cycle_min_lr - decay_lr_rate: float = 1e-4 # Decay rate for learning rate. - cycle_first_step_size: int = 2048 # Number of training iterations in the increasing half of a cycle. - cycle_second_step_size: typing.Optional[int] = None # steps in second phase. None -> cycle_first_step_size - cycle_first_stair_count: int = 0 # Number of stairs in first phase. 0 means staircase disabled - cycle_second_stair_count: typing.Optional[int] = None # Number of stairs in second phase - decay_step_size: int = 2 # Every how many steps to decay lr. 0 -> no decay - cycle_momentum: bool = True # Whether to cycle `momentum` inversely to learning rate. - cycle_min_mom: float = 0.8 # Initial momentum which is the lower boundary in the cycle for each parameter group. - cycle_max_mom: float = 0.9 # Upper momentum boundaries in the cycle for each parameter group. - decay_mom_rate: float = 0 # Decay rate for momentum - last_batch_iteration: int = -1 # The index of the last batch. This parameter is used when resuming a training job. + max_lr: float = 1e-3 + total_steps: typing.Optional[int] = 10 ** 3 + epochs: typing.Optional[int] = None + steps_per_epoch: typing.Optional[int] = None + pct_start: float = 0.3 + anneal_strategy: str = 'cos' + cycle_momentum: bool = True + base_momentum: float = 0.85 + max_momentum: float = 0.95 + div_factor: float = 25. + final_div_factor: float = 1e4 + three_phase: bool = False + last_epoch: int = -1 class AdaptiveGradientClipping(DataClass): diff --git a/src/model.py b/src/model.py index 627d5c3..2db0327 100644 --- a/src/model.py +++ b/src/model.py @@ -152,20 +152,19 @@ def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[t self.model = torch.jit.trace(model, data) if data else model self.optimizer = build_optimizer(ctx, self.model.parameters()) self.scheduler = OneCycleLR(self.optimizer, - ctx.optimizer.one_cycle.cycle_min_lr, - ctx.optimizer.one_cycle.cycle_max_lr, - ctx.optimizer.one_cycle.decay_lr_rate, - ctx.optimizer.one_cycle.cycle_first_step_size, - ctx.optimizer.one_cycle.cycle_second_step_size, - ctx.optimizer.one_cycle.cycle_first_stair_count, - ctx.optimizer.one_cycle.cycle_second_stair_count, - ctx.optimizer.one_cycle.decay_step_size, + ctx.optimizer.one_cycle.max_lr, + ctx.optimizer.one_cycle.total_steps, + ctx.optimizer.one_cycle.epochs, + ctx.optimizer.one_cycle.steps_per_epoch, + ctx.optimizer.one_cycle.pct_start, + ctx.optimizer.one_cycle.anneal_strategy, ctx.optimizer.one_cycle.cycle_momentum, - ctx.optimizer.one_cycle.cycle_min_mom, - ctx.optimizer.one_cycle.cycle_max_mom, - ctx.optimizer.one_cycle.decay_mom_rate, - ctx.optimizer.one_cycle.last_batch_iteration) - + ctx.optimizer.one_cycle.base_momentum, + ctx.optimizer.one_cycle.max_momentum, + ctx.optimizer.one_cycle.div_factor, + ctx.optimizer.one_cycle.final_div_factor, + ctx.optimizer.one_cycle.three_phase, + ctx.optimizer.one_cycle.last_epoch) @torch.no_grad() def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: return inp.to(device=self.ctx.model.device, non_blocking=True).detach() @@ -252,7 +251,7 @@ def __init__(self, ctx: Context): i + 1), MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device, - memory_mode=revlib.MemoryModes.autograd_graph) + memory_mode=revlib.MemoryModes.autograd_function) self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) torch.nn.init.zeros_(self.output.weight.data) From e7497fe8588ae21d867f4d6aa85fff59f3824df8 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 12:50:17 +0100 Subject: [PATCH 16/34] feat(modeol): add input dropout --- src/dataclass.py | 1 + src/dataset.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/dataclass.py b/src/dataclass.py index b51e1ea..251aa7a 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -53,6 +53,7 @@ class Dataset(DataClass): file_name: str = "out.tensor" classes: int = 256 num_workers: int = 4 + dropout: float = 0.3 pin_memory: bool = False prefetch_factor: int = 256 # 256 (Prefetch) * 8 (Long) * 2048 (GPT context) * 256 (High Batch) = 1GiB RAM diff --git a/src/dataset.py b/src/dataset.py index 714a6f0..909db2b 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -9,10 +9,12 @@ @torch.jit.script -def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: +def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int, + drop_probability:float) -> typing.Tuple[torch.Tensor, torch.Tensor]: dat = data[batch_index + idx] dat = dat.to(dtype=torch.long, non_blocking=True) - return dat[:, :-1], dat[:, 1:] + inp = dat * torch.rand_like(dat) > drop_probability + return inp, dat class Dataset: @@ -37,13 +39,13 @@ def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_cou data_len = data.size(0) // worker_count data = data[data_len * idx:data_len * (idx + 1)] batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) - item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) + item_index = torch.arange(0, ctx.model.sequence_length).view(1, -1) batch_index = batch_index + item_index length = data.size(0) - ctx.model.batch_size * ctx.model.sequence_length random.seed(idx) while True: - queue.put(get_sample(data, batch_index, random.randint(0, length))) + queue.put(get_sample(data, batch_index, random.randint(0, length), ctx.dataset.dropout)) def get_dataset(ctx: Context) -> Dataset: From 2e121787f930403dcc2bed2dac6bcc78876b35fe Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 12:51:33 +0100 Subject: [PATCH 17/34] fix(dataset): don't use rand_like(long_tensor) --- src/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 909db2b..9279337 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -10,10 +10,10 @@ @torch.jit.script def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int, - drop_probability:float) -> typing.Tuple[torch.Tensor, torch.Tensor]: + drop_probability: float) -> typing.Tuple[torch.Tensor, torch.Tensor]: dat = data[batch_index + idx] dat = dat.to(dtype=torch.long, non_blocking=True) - inp = dat * torch.rand_like(dat) > drop_probability + inp = dat * torch.rand(dat.size(), device=dat.device) > drop_probability return inp, dat From fcf771684820c4f4ad3af0013bd5a266d9fb8923 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 12:52:20 +0100 Subject: [PATCH 18/34] fix(dataset): specify mul/mask order --- src/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataset.py b/src/dataset.py index 9279337..9e7806e 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -13,7 +13,7 @@ def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int, drop_probability: float) -> typing.Tuple[torch.Tensor, torch.Tensor]: dat = data[batch_index + idx] dat = dat.to(dtype=torch.long, non_blocking=True) - inp = dat * torch.rand(dat.size(), device=dat.device) > drop_probability + inp = dat * (torch.rand(dat.size(), device=dat.device) > drop_probability) return inp, dat From 4cd79eac951ede245e59164bfe26191472f3f89c Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 13:19:43 +0100 Subject: [PATCH 19/34] fix(model): first reduce, then output --- configs/small.yaml | 8 ++++---- src/model.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 5fe4b60..678852d 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -2,15 +2,15 @@ model: depth: 16 conv_kernel_size: 11 weight_shared_blocks: 1 - batch_size: 8 + batch_size: 1 + features: 1024 feed_forward_intermediate_factor: 0.125 optimizer: beta2: 0.95 gradient_accumulation_steps: 1 one_cycle: - cycle_first_step_size: 8192 - cycle_second_step_size: null - cycle_max_lr: 0.01 + total_steps: 8192 + max_lr: 0.01 log: loss_steps_per_print: 8 dataset: diff --git a/src/model.py b/src/model.py index 2db0327..1b11120 100644 --- a/src/model.py +++ b/src/model.py @@ -165,6 +165,7 @@ def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[t ctx.optimizer.one_cycle.final_div_factor, ctx.optimizer.one_cycle.three_phase, ctx.optimizer.one_cycle.last_epoch) + @torch.no_grad() def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: return inp.to(device=self.ctx.model.device, non_blocking=True).detach() @@ -261,12 +262,12 @@ def forward(self, inp: torch.Tensor): batch, features, sequence = inp.size() out = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)), device=inp.device, dtype=inp.dtype)], 2) - out = self.output(self.stem(out)) + out = self.stem(out) if self.expand_sequence: batch, features, sequence = inp.size() - inp = out.view(batch, features // 2, -1, sequence).mean(2) - return inp + inp = out.view(batch, features, -1, sequence).mean(2) + return self.output(inp) def reset_cache(self): for mod in self.stem.modules(): From 9202f571aaee8b6ff05049b2e97f3f4a41f6854f Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 13:23:11 +0100 Subject: [PATCH 20/34] perf(model): increase features to better utilize gpu --- configs/small.yaml | 5 +++-- src/dataclass.py | 2 +- src/model.py | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 678852d..43ad312 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,9 +1,10 @@ model: - depth: 16 + attention: OmnidirectionalAttention + depth: 8 conv_kernel_size: 11 weight_shared_blocks: 1 batch_size: 1 - features: 1024 + features: 2048 feed_forward_intermediate_factor: 0.125 optimizer: beta2: 0.95 diff --git a/src/dataclass.py b/src/dataclass.py index 251aa7a..19cdbe7 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -19,7 +19,7 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): class Model(DataClass): - attention: str = "OmnidirectionalAttention" + attention: str = "FFTAttention" weight_sharing: bool = False checkpoint_path: str = "checkpoint.torch" steps_per_checkpoint: int = 0 # 0 -> disabled diff --git a/src/model.py b/src/model.py index 1b11120..5a1f850 100644 --- a/src/model.py +++ b/src/model.py @@ -263,7 +263,6 @@ def forward(self, inp: torch.Tensor): out = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)), device=inp.device, dtype=inp.dtype)], 2) out = self.stem(out) - if self.expand_sequence: batch, features, sequence = inp.size() inp = out.view(batch, features, -1, sequence).mean(2) From 439b7cd3777e85b05dee4c3e2014b1900d5e789a Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 13:53:02 +0100 Subject: [PATCH 21/34] style(model): add accuracy --- src/executable/train.py | 13 ++++++++++--- src/model.py | 20 +++++++++++++------ src/utils/formatting.py | 43 +++++++++++++++++++++++++++-------------- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/src/executable/train.py b/src/executable/train.py index e51ad92..d7a593f 100644 --- a/src/executable/train.py +++ b/src/executable/train.py @@ -20,12 +20,16 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): log = WandbLog(ctx, data_len) mean_loss = torch.zeros([], device=ctx.model.device, dtype=torch.float16 if ctx.model.float16 else torch.float) mean_max_loss = mean_loss.clone() + mean_acc = mean_loss.clone() + mean_max_acc = mean_loss.clone() i = 0 while True: i += 1 - mean_loss += mod.accumulated_step(next(data)) + lss, acc = mod.accumulated_step(next(data)) + mean_loss += lss + mean_acc += acc if ctx.optimizer.sharpness_aware_minimization.enabled: with torch.no_grad(): for p in mod.gradients(): @@ -35,7 +39,10 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): p.add_(p.grad) p.prev_step = p.grad p.grad = None - mean_max_loss += mod.accumulated_step(next(data)) + + lss, acc = mod.accumulated_step(next(data)) + mean_max_loss += lss + mean_max_acc += acc mod.optimizer.step() if ctx.optimizer.sharpness_aware_minimization.enabled: with torch.no_grad(): @@ -50,7 +57,7 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): p['betas'] = p['betas'][0], mod.ctx.optimizer.beta2 with torch.no_grad(): if mod.ctx.log.loss_steps_per_print and i % mod.ctx.log.loss_steps_per_print == 0: - log(mean_loss, mean_max_loss, + log(mean_loss, mean_max_loss, mean_acc, mean_max_acc, mod.optimizer.param_groups[0]['lr'], mod.optimizer.param_groups[0]['betas']) mean_loss.zero_() mean_max_loss.zero_() diff --git a/src/model.py b/src/model.py index 5a1f850..531a4f2 100644 --- a/src/model.py +++ b/src/model.py @@ -170,10 +170,13 @@ def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[t def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: return inp.to(device=self.ctx.model.device, non_blocking=True).detach() - def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: - loss = F.cross_entropy(self.model(self._to_device_detach(src)), self._to_device_detach(tgt)) + def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + out = self.model(self._to_device_detach(src)) + tgt = self._to_device_detach(tgt) + loss = F.cross_entropy(out, tgt) loss.backward() - return loss.detach() + with torch.inference_mode(): + return loss.detach(), (out == tgt).sum() / tgt.size() @torch.no_grad() def _clip_gradient(self): @@ -183,10 +186,15 @@ def _clip_gradient(self): grad_scale = (p_norm / g_norm * self.ctx.optimizer.agc.gradient_clipping).clamp(max=1) p.grad.data.copy_(p.grad * grad_scale) - def accumulated_step(self, data: torch.Tensor) -> torch.Tensor: - loss = sum(self._forward_backward(s, t) for s, t in zip(*data)) + def accumulated_step(self, data: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + loss = 0 + accuracy = 0 + for src, tgt in zip(*data): + lss, acc = self._forward_backward(src, tgt) + loss += lss + accuracy += acc self._clip_gradient() - return loss + return loss / data.size(0), accuracy / data.size(0) @torch.no_grad() def zero_grad(self): diff --git a/src/utils/formatting.py b/src/utils/formatting.py index e1a1255..e3a6e7e 100644 --- a/src/utils/formatting.py +++ b/src/utils/formatting.py @@ -36,41 +36,56 @@ class WandbLog: def __init__(self, ctx: Context, steps: int): self.mean_loss = 0 self.mean_max_loss = 0 + self.mean_acc = 0 + self.mean_max_acc = 0 self.start_time = time.time() self.ctx = ctx self.idx = 0 self.prev = 0 self.steps = steps - def __call__(self, current_loss: torch.Tensor, max_loss: torch.Tensor, learning_rate: float, + def normalize(self, var: torch.Tensor, attribute: str) -> float: + attr = getattr(self, attribute) + curr_var = var.item() / self.ctx.log.loss_steps_per_print / self.ctx.optimizer.gradient_accumulation_steps + setattr(self, attribute, (attr * self.prev + curr_var * self.idx) / (self.prev + self.idx)) # LWMA + return curr_var + + def __call__(self, current_loss: torch.Tensor, max_loss: torch.Tensor, + current_acc: torch.Tensor, max_acc: torch.Tensor, learning_rate: float, betas: typing.Tuple[float, float]): - grad_accum = self.ctx.optimizer.gradient_accumulation_steps - curr_loss = current_loss.item() / self.ctx.log.loss_steps_per_print / grad_accum - curr_max_loss = max_loss.item() / self.ctx.log.loss_steps_per_print / grad_accum self.idx += 1 - self.mean_loss = (self.mean_loss * self.prev + curr_loss * self.idx) / (self.prev + self.idx) # LWMA - mean_max = self.mean_max_loss = (self.mean_max_loss * self.prev + max_loss * self.idx) / (self.prev + self.idx) + current_loss = self.normalize(current_loss, "mean_loss") + current_acc = self.normalize(current_acc, "mean_acc") + if self.ctx.optimizer.sharpness_aware_minimization: + max_loss = self.normalize(max_loss, "mean_max_loss") + max_acc = self.normalize(max_acc, "mean_max_acc") + else: + max_loss = max_acc = self.mean_max_loss = self.mean_max_acc = None self.prev += self.idx rate = self.ctx.log.loss_steps_per_print * self.idx / (time.time() - self.start_time) - tokens_per_day = grad_accum * 3600 * 24 * rate * self.ctx.model.batch_size * self.ctx.model.sequence_length + tokens_per_day = 3600 * 24 * rate * self.ctx.model.batch_size * self.ctx.model.sequence_length + tokens_per_day *= self.ctx.optimizer.gradient_accumulation_steps pretty_print(f"[{self.idx * self.ctx.log.loss_steps_per_print:{len(str(self.steps))}d}/{self.steps}]", - f"Loss: {curr_loss:7.4f} -", + f"Loss: {current_loss:7.4f} -", f"Mean: {self.mean_loss:7.4f} |", + f"Acc: {current_acc:7.4f} -", + f"Mean: {self.mean_acc:7.4f} |", f"LR: {learning_rate:.6f} -", f"Beta1: {betas[0]:.3f} -", f"Beta2: {betas[1]:.3f} |", f"Batch/s: {rate:6.3f} -", f"Tokens/day: {tokens_per_day:11,.0f}") - if not self.ctx.optimizer.sharpness_aware_minimization.enabled: - curr_max_loss = None - mean_max = None - wandb.log({"Loss/Current": curr_loss, + wandb.log({"Loss/Current": current_loss, "Loss/Mean": self.mean_loss, - "Loss/Current Max": curr_max_loss, - "Loss/Mean Max": mean_max, + "Loss/Current Max": max_loss, + "Loss/Mean Max": self.mean_max_loss, + "Accuracy/Current": current_acc, + "Accuracy/Mean": self.mean_acc, + "Accuracy/Current Max": max_acc, + "Accuracy/Mean Max": self.mean_max_acc, "Speed/Batches per Second": rate, "Speed/Tokens per Day": tokens_per_day, "Optimizer/Learning Rate": learning_rate, From f6f39718a8e9465d0f73b55071547e8ac7669908 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 14:00:16 +0100 Subject: [PATCH 22/34] fix(train): cast accuracy to float --- configs/small.yaml | 6 +++--- src/model.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 43ad312..d1a3f0e 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -10,9 +10,9 @@ optimizer: beta2: 0.95 gradient_accumulation_steps: 1 one_cycle: - total_steps: 8192 + total_steps: 1024 max_lr: 0.01 log: - loss_steps_per_print: 8 + loss_steps_per_print: 4 dataset: - num_workers: 12 + num_workers: 16 diff --git a/src/model.py b/src/model.py index 531a4f2..52ee29c 100644 --- a/src/model.py +++ b/src/model.py @@ -176,7 +176,7 @@ def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Tupl loss = F.cross_entropy(out, tgt) loss.backward() with torch.inference_mode(): - return loss.detach(), (out == tgt).sum() / tgt.size() + return loss.detach(), (out == tgt).sum().float() / tgt.numel() @torch.no_grad() def _clip_gradient(self): @@ -186,7 +186,8 @@ def _clip_gradient(self): grad_scale = (p_norm / g_norm * self.ctx.optimizer.agc.gradient_clipping).clamp(max=1) p.grad.data.copy_(p.grad * grad_scale) - def accumulated_step(self, data: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + def accumulated_step(self, data: typing.Tuple[torch.Tensor, torch.Tensor] + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: loss = 0 accuracy = 0 for src, tgt in zip(*data): @@ -194,7 +195,7 @@ def accumulated_step(self, data: torch.Tensor) -> typing.Tuple[torch.Tensor, tor loss += lss accuracy += acc self._clip_gradient() - return loss / data.size(0), accuracy / data.size(0) + return loss / data[0].size(0), accuracy / data[0].size(0) @torch.no_grad() def zero_grad(self): From 6cea58a53cad6377f7ea7d9f4dc1573978cc3a17 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 14:06:15 +0100 Subject: [PATCH 23/34] fix(train): argmax accuracy accross correct dimension --- src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index 52ee29c..0c731f0 100644 --- a/src/model.py +++ b/src/model.py @@ -176,7 +176,7 @@ def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Tupl loss = F.cross_entropy(out, tgt) loss.backward() with torch.inference_mode(): - return loss.detach(), (out == tgt).sum().float() / tgt.numel() + return loss.detach(), (torch.argmax(out, 1) == tgt).sum().float() / tgt.numel() @torch.no_grad() def _clip_gradient(self): From 0269ce33b068c3248135961fedb4b840b86973c7 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 14:08:29 +0100 Subject: [PATCH 24/34] fix(train): zero accuracy after log --- src/executable/train.py | 2 ++ src/model.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/executable/train.py b/src/executable/train.py index d7a593f..29422eb 100644 --- a/src/executable/train.py +++ b/src/executable/train.py @@ -61,6 +61,8 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): mod.optimizer.param_groups[0]['lr'], mod.optimizer.param_groups[0]['betas']) mean_loss.zero_() mean_max_loss.zero_() + mean_acc.zero_() + mean_max_acc.zero_() if mod.ctx.model.steps_per_checkpoint and i % mod.ctx.model.steps_per_checkpoint == 0: mod.save() if steps and i > steps: diff --git a/src/model.py b/src/model.py index 0c731f0..bee2977 100644 --- a/src/model.py +++ b/src/model.py @@ -195,7 +195,7 @@ def accumulated_step(self, data: typing.Tuple[torch.Tensor, torch.Tensor] loss += lss accuracy += acc self._clip_gradient() - return loss / data[0].size(0), accuracy / data[0].size(0) + return loss, accuracy @torch.no_grad() def zero_grad(self): From 4d6634760d96779ff52a40f33cbefc32117bc6d1 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 18:39:02 +0100 Subject: [PATCH 25/34] feat(model): add omninet to all attention modules --- configs/small.yaml | 7 ++-- src/dataclass.py | 3 ++ src/model.py | 85 +++++++++++++++++++++++++++------------------- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index d1a3f0e..4f88766 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,16 +1,17 @@ model: - attention: OmnidirectionalAttention + attention: FFTAttention + omnidirectional: yes depth: 8 conv_kernel_size: 11 weight_shared_blocks: 1 batch_size: 1 - features: 2048 + features: 1024 feed_forward_intermediate_factor: 0.125 optimizer: beta2: 0.95 gradient_accumulation_steps: 1 one_cycle: - total_steps: 1024 + total_steps: 1024000 max_lr: 0.01 log: loss_steps_per_print: 4 diff --git a/src/dataclass.py b/src/dataclass.py index 19cdbe7..4ae61d9 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -19,6 +19,7 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): class Model(DataClass): + omnidirectional: bool = False attention: str = "FFTAttention" weight_sharing: bool = False checkpoint_path: str = "checkpoint.torch" @@ -117,6 +118,8 @@ class SharpnessAwareMinimization(DataClass): class Optimizer(DataClass): type: str = "AdamW" + final_step: int = 2 ** 14 + warmup_end: int = 2 ** 10 gradient_accumulation_steps: int = 1 one_cycle: OneCycle = OneCycle() beta2: float = 0.95 # beta1 is controlled by one_cycle diff --git a/src/model.py b/src/model.py index bee2977..44d59fd 100644 --- a/src/model.py +++ b/src/model.py @@ -7,7 +7,7 @@ import torch.nn.functional import torch.utils.data from torch.nn import functional as F -from torch.optim.lr_scheduler import OneCycleLR +from torch.optim.lr_scheduler import LambdaLR from src.dataclass import Context from src.optimizers.build import build_optimizer @@ -145,26 +145,25 @@ def conv_weight(in_features: int, out_features: int, kernel_size: int, groups: i return orthonormal(torch.nn.Conv1d(in_features, out_features, (kernel_size,), groups=groups).weight, 1 / std) +def get_lr_scheduler_fn(ctx: Context) -> typing.Callable[[int], float]: + def _fn(step: int) -> float: + final_lr = 1 - 2 / (ctx.optimizer.final_step - ctx.optimizer.warmup_end) + lr = step + lr /= max(step, ctx.optimizer.warmup_end) + lr *= final_lr ** max(step - ctx.optimizer.warmup_end, 0) + # lr *= ctx.optimizer.learning_rate # It's a multiplier for the initial learning rate, not the LR itself + return lr + + return _fn + + class Trainer(torch.nn.Module): def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[torch.Tensor]): super(Trainer, self).__init__() self.ctx = ctx self.model = torch.jit.trace(model, data) if data else model self.optimizer = build_optimizer(ctx, self.model.parameters()) - self.scheduler = OneCycleLR(self.optimizer, - ctx.optimizer.one_cycle.max_lr, - ctx.optimizer.one_cycle.total_steps, - ctx.optimizer.one_cycle.epochs, - ctx.optimizer.one_cycle.steps_per_epoch, - ctx.optimizer.one_cycle.pct_start, - ctx.optimizer.one_cycle.anneal_strategy, - ctx.optimizer.one_cycle.cycle_momentum, - ctx.optimizer.one_cycle.base_momentum, - ctx.optimizer.one_cycle.max_momentum, - ctx.optimizer.one_cycle.div_factor, - ctx.optimizer.one_cycle.final_div_factor, - ctx.optimizer.one_cycle.three_phase, - ctx.optimizer.one_cycle.last_epoch) + self.scheduler = LambdaLR(self.optimizer, get_lr_scheduler_fn(ctx)) @torch.no_grad() def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: @@ -326,13 +325,14 @@ def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, featu self.jitter_epsilon = ctx.model.moe_jitter_epsilon self.expert_chunks = ctx.model.expert_chunks intermediate = int(ctx.model.features * ctx.model.feed_forward_intermediate_factor) - self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features * feature_factor, intermediate * 3, - self.groups0, self.experts0, self.expert_chunks, - ctx.model.activation_std)) - self.w1 = conv_weight(intermediate, intermediate * 3, ctx.model.conv_kernel_size, ctx.model.bottleneck_group, - ctx.model.activation_std) - self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features * feature_factor, self.groups2, - self.experts2, self.expert_chunks, 1)) + if feature_factor: + self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features * feature_factor, intermediate * 3, + self.groups0, self.experts0, self.expert_chunks, + ctx.model.activation_std)) + self.w1 = conv_weight(intermediate, intermediate * 3, ctx.model.conv_kernel_size, + ctx.model.bottleneck_group, ctx.model.activation_std) + self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features * feature_factor, + self.groups2, self.experts2, self.expert_chunks, 1)) self.idx: int = 0 self.depth: int = 0 self.get_last: bool = True @@ -363,8 +363,11 @@ def _ff(self, inp: torch.Tensor) -> torch.Tensor: self.experts2, self.bottleneck_group, self.training, self.norm_power, self.jitter_epsilon) + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + return self._ff(inp) + def forward(self, inp: torch.Tensor) -> torch.Tensor: - return self._pad(inp, self._ff(self._cut_off(inp))) * self.init_scale + return self._pad(inp, self._inner_forward(self._cut_off(inp))) * self.init_scale def momentum(self, init_scale: float, deep: bool, depth: int): out = copy.deepcopy(self) if deep else copy.copy(self) @@ -373,27 +376,33 @@ def momentum(self, init_scale: float, deep: bool, depth: int): return out -class FFTAttention(FeedForward): +class AttentionBase(FeedForward): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float): + super(AttentionBase, self).__init__(base, ctx, init_scale, feature_factor) + self.get_last = not ctx.model.omnidirectional + + +class FFTAttention(AttentionBase): def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): super(FFTAttention, self).__init__(base, ctx, init_scale, 2) - def forward(self, inp: torch.Tensor) -> torch.Tensor: + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: batch, features, sequence = inp.size() - out = torch.view_as_real(torch.fft.rfft(inp, 2 * sequence)) + out = torch.view_as_real(torch.fft.rfft(inp, 2 * sequence, norm="ortho")) out = out.transpose(2, 3).reshape(batch, features * 2, sequence + 1) out = self._ff(out) out = out.view(batch, features, 2, sequence + 1).transpose(2, 3).contiguous() out = torch.view_as_complex(out) - return torch.fft.irfft(out, 2 * sequence)[:, :, :sequence] + return torch.fft.irfft(out, 2 * sequence, norm="ortho")[:, :, :sequence] -class SumAttention(FeedForward): +class SumAttention(AttentionBase): def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): super(SumAttention, self).__init__(base, ctx, init_scale, ctx.model.sum_attention_level) self.sum_attention_level = ctx.model.sum_attention_level self.weight = conv_weight(ctx.model.features, ctx.model.features, 3, 1, 1) - def forward(self, inp: torch.Tensor) -> torch.Tensor: + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: out = self._ff(inp).chunk(self.sum_attention_level, 1) batch, features, seq = out[0].size() return sum(conv(torch.relu(out[0] + sum(out[inner + 1][outer // batch ** inner % batch] @@ -402,10 +411,18 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: for outer in range(batch ** (self.sum_attention_level - 1))) -class OmnidirectionalAttention(FFTAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.get_last = False +class SqueezeExcitation(AttentionBase): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(SqueezeExcitation, self).__init__(base, ctx, init_scale, 0) + self.weight0 = orthonormal([ctx.model.features * 2, ctx.model.features * 2 * 3], 1) + self.weight1 = orthonormal([ctx.model.features * 2, ctx.model.features], 1) + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + out = torch.cat([inp.mean(2), inp.max(2).values], 1) + out = out.mm(self.weight0).chunk(3, 1) + out = TripleNorm.apply(*out, self.ctx.model.norm_power) + out = out.mm(self.weight1) + return out.unsqueeze(2) -attention_modules = [FeedForward, FFTAttention, SumAttention, OmnidirectionalAttention] +attention_modules = [FeedForward, FFTAttention, SumAttention, SqueezeExcitation] From 6c5fc7423b63fb69c92f4dbbc91ea39b6be62a45 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 18:56:04 +0100 Subject: [PATCH 26/34] fix(model): put padding in correct device --- configs/small.yaml | 2 +- src/model.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 4f88766..d1d0de3 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,5 +1,5 @@ model: - attention: FFTAttention + attention: SqueezeExcitation omnidirectional: yes depth: 8 conv_kernel_size: 11 diff --git a/src/model.py b/src/model.py index 44d59fd..c24a10d 100644 --- a/src/model.py +++ b/src/model.py @@ -353,10 +353,13 @@ def _pad(self, inp: torch.Tensor, out: torch.Tensor): batch, features, sequence = inp.size() if self.get_last: - return torch.cat([torch.zeros((batch, features, self.ctx.model.sequence_length * self.depth)), out.size(), + return torch.cat([torch.zeros((batch, features, self.ctx.model.sequence_length * self.depth), + device=out.device, dtype=out.dtype), out, torch.zeros((batch, features, - sequence - self.ctx.model.sequence_length * (self.depth + 1)))], 2) - return torch.cat([out.size(), torch.zeros((batch, features, sequence - out.size()))], 2) + sequence - self.ctx.model.sequence_length * (self.depth + 1)), + device=out.device, dtype=out.dtype)], 2) + return torch.cat([out, torch.zeros((batch, features, sequence - out.size(2)), device=out.device, + dtype=out.dtype)], 2) def _ff(self, inp: torch.Tensor) -> torch.Tensor: return linear_attention(inp, self.w0, self.groups0, self.experts0, self.w1, self.w2, self.groups2, From a138f8e03d165ce6df8846e3da0c18a8f1ca330a Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 19:16:15 +0100 Subject: [PATCH 27/34] fix(model): invert omninet expansion --- configs/small.yaml | 6 ++++-- src/model.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index d1d0de3..2029ee5 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,10 +1,12 @@ model: attention: SqueezeExcitation - omnidirectional: yes - depth: 8 + sequence_length: 2048 + omnidirectional: no + depth: 16 conv_kernel_size: 11 weight_shared_blocks: 1 batch_size: 1 + offloading: yes features: 1024 feed_forward_intermediate_factor: 0.125 optimizer: diff --git a/src/model.py b/src/model.py index c24a10d..1feba8c 100644 --- a/src/model.py +++ b/src/model.py @@ -246,7 +246,7 @@ def __init__(self, ctx: Context): raise ValueError(f"{ctx.model.attention} is not a known type of attention. You can pick any of the" f" following: {modules}") attn = attention_modules[modules.index(ctx.model.attention)](self, ctx, 1) - self.expand_sequence = attn.get_last | ff.get_last + self.expand_sequence = (not attn.get_last) | (not ff.get_last) self.stem = revlib.ReversibleSequential(*[c for i in range(1, 1 + ctx.model.depth * 2, 2) for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / @@ -425,7 +425,7 @@ def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: out = out.mm(self.weight0).chunk(3, 1) out = TripleNorm.apply(*out, self.ctx.model.norm_power) out = out.mm(self.weight1) - return out.unsqueeze(2) + return out.unsqueeze(2) * inp attention_modules = [FeedForward, FFTAttention, SumAttention, SqueezeExcitation] From 7e28e9791d96bb2d101e6e7b71282ea8bb11ef43 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 19:36:56 +0100 Subject: [PATCH 28/34] style(dataclass): remove weight decay by default --- src/dataclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dataclass.py b/src/dataclass.py index 4ae61d9..a412173 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -124,7 +124,7 @@ class Optimizer(DataClass): one_cycle: OneCycle = OneCycle() beta2: float = 0.95 # beta1 is controlled by one_cycle eps: float = 1e-8 - weight_decay: float = 0.01 + weight_decay: float = 0. zero: Zero = Zero() agc = AdaptiveGradientClipping() sharpness_aware_minimization: SharpnessAwareMinimization = SharpnessAwareMinimization() From d2fcbdc9f29ed9e4620ad6bc94e9d25923bd1612 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 19:52:40 +0100 Subject: [PATCH 29/34] perf(model): only backprop masked tokens --- configs/small.yaml | 5 +---- src/dataset.py | 12 ++++-------- src/model.py | 20 +++++++++++--------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 2029ee5..f402e18 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -7,14 +7,11 @@ model: weight_shared_blocks: 1 batch_size: 1 offloading: yes - features: 1024 + features: 2048 feed_forward_intermediate_factor: 0.125 optimizer: beta2: 0.95 gradient_accumulation_steps: 1 - one_cycle: - total_steps: 1024000 - max_lr: 0.01 log: loss_steps_per_print: 4 dataset: diff --git a/src/dataset.py b/src/dataset.py index 9e7806e..f444104 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -9,12 +9,8 @@ @torch.jit.script -def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int, - drop_probability: float) -> typing.Tuple[torch.Tensor, torch.Tensor]: - dat = data[batch_index + idx] - dat = dat.to(dtype=torch.long, non_blocking=True) - inp = dat * (torch.rand(dat.size(), device=dat.device) > drop_probability) - return inp, dat +def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> torch.Tensor: + return data[batch_index + idx].to(dtype=torch.long, non_blocking=True) class Dataset: @@ -31,7 +27,7 @@ def __iter__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: def __next__(self): items = [self.queue.get() for _ in range(self.ctx.optimizer.gradient_accumulation_steps)] - return torch.stack([itm[0] for itm in items], 0), torch.stack([itm[1] for itm in items], 0) + return torch.stack([itm for itm in items], 0) def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_count: int): @@ -45,7 +41,7 @@ def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_cou random.seed(idx) while True: - queue.put(get_sample(data, batch_index, random.randint(0, length), ctx.dataset.dropout)) + queue.put(get_sample(data, batch_index, random.randint(0, length))) def get_dataset(ctx: Context) -> Dataset: diff --git a/src/model.py b/src/model.py index 1feba8c..8a10246 100644 --- a/src/model.py +++ b/src/model.py @@ -169,13 +169,16 @@ def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[t def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: return inp.to(device=self.ctx.model.device, non_blocking=True).detach() - def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: - out = self.model(self._to_device_detach(src)) - tgt = self._to_device_detach(tgt) - loss = F.cross_entropy(out, tgt) + def _forward_backward(self, src: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + src = self._to_device_detach(src) + msk = torch.rand(src.size(), dtype=torch.float32, device=src.device) > self.ctx.dataset.dropout + out = self.model(src * msk) + msk = 1 - msk + masked = msk.sum() + loss = (F.cross_entropy(out, src, reduction="none") * msk).sum() / masked loss.backward() with torch.inference_mode(): - return loss.detach(), (torch.argmax(out, 1) == tgt).sum().float() / tgt.numel() + return loss.detach(), (torch.argmax(out, 1) == src).mul(msk).sum().float() / masked @torch.no_grad() def _clip_gradient(self): @@ -185,12 +188,11 @@ def _clip_gradient(self): grad_scale = (p_norm / g_norm * self.ctx.optimizer.agc.gradient_clipping).clamp(max=1) p.grad.data.copy_(p.grad * grad_scale) - def accumulated_step(self, data: typing.Tuple[torch.Tensor, torch.Tensor] - ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + def accumulated_step(self, data: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: loss = 0 accuracy = 0 - for src, tgt in zip(*data): - lss, acc = self._forward_backward(src, tgt) + for src in data: + lss, acc = self._forward_backward(src) loss += lss accuracy += acc self._clip_gradient() From 28d253331be101545e3459728784bc49912e9ee8 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 14 Nov 2021 20:09:41 +0100 Subject: [PATCH 30/34] style(model): show parameters in representation --- configs/small.yaml | 2 +- src/model.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index f402e18..4353171 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -5,7 +5,7 @@ model: depth: 16 conv_kernel_size: 11 weight_shared_blocks: 1 - batch_size: 1 + batch_size: 16 offloading: yes features: 2048 feed_forward_intermediate_factor: 0.125 diff --git a/src/model.py b/src/model.py index 8a10246..e7fb52e 100644 --- a/src/model.py +++ b/src/model.py @@ -173,7 +173,7 @@ def _forward_backward(self, src: torch.Tensor) -> typing.Tuple[torch.Tensor, tor src = self._to_device_detach(src) msk = torch.rand(src.size(), dtype=torch.float32, device=src.device) > self.ctx.dataset.dropout out = self.model(src * msk) - msk = 1 - msk + msk = ~msk masked = msk.sum() loss = (F.cross_entropy(out, src, reduction="none") * msk).sum() / masked loss.backward() @@ -339,6 +339,10 @@ def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, featu self.depth: int = 0 self.get_last: bool = True + def __repr__(self): + extra = '\n '.join([f'{name}: {param.size()}' for name, param in self.named_parameters()]) + return f"{self._get_name()}(\n {extra}\n)" + def _cut_off(self, inp: torch.Tensor) -> torch.Tensor: if inp.size(2) == self.ctx.model.sequence_length: return inp From 91c33e5469be66b2175e939a97756f3e3e208669 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 15 Nov 2021 10:33:10 +0100 Subject: [PATCH 31/34] fix(model): use previous output when calculating logits --- configs/small.yaml | 4 ++-- src/model.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 4353171..4e01d06 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,11 +1,11 @@ model: - attention: SqueezeExcitation + attention: FFTAttention sequence_length: 2048 omnidirectional: no depth: 16 conv_kernel_size: 11 weight_shared_blocks: 1 - batch_size: 16 + batch_size: 8 offloading: yes features: 2048 feed_forward_intermediate_factor: 0.125 diff --git a/src/model.py b/src/model.py index e7fb52e..ca928b9 100644 --- a/src/model.py +++ b/src/model.py @@ -263,20 +263,19 @@ def __init__(self, ctx: Context): MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device, memory_mode=revlib.MemoryModes.autograd_function) - self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) - torch.nn.init.zeros_(self.output.weight.data) + self.output = torch.nn.Conv1d(ctx.model.features, ctx.dataset.classes, (1,)).to(ctx.model.device) + torch.nn.init.orthogonal_(self.output.weight.data) def forward(self, inp: torch.Tensor): out = inp = self.embedding(inp).transpose(1, 2) + batch, features, sequence = inp.size() if self.expand_sequence: - batch, features, sequence = inp.size() out = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)), device=inp.device, dtype=inp.dtype)], 2) out = self.stem(out) if self.expand_sequence: - batch, features, sequence = inp.size() - inp = out.view(batch, features, -1, sequence).mean(2) - return self.output(inp) + out = out.view(batch, features, -1, sequence).mean(2) + return self.output(out) def reset_cache(self): for mod in self.stem.modules(): From 0e1cfa39586c844d5ceeb87679062fe1900696fd Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 15 Nov 2021 10:33:50 +0100 Subject: [PATCH 32/34] fix(model): take both momentumnet sides --- src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index ca928b9..8fa703b 100644 --- a/src/model.py +++ b/src/model.py @@ -263,7 +263,7 @@ def __init__(self, ctx: Context): MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], target_device=ctx.model.device, memory_mode=revlib.MemoryModes.autograd_function) - self.output = torch.nn.Conv1d(ctx.model.features, ctx.dataset.classes, (1,)).to(ctx.model.device) + self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) torch.nn.init.orthogonal_(self.output.weight.data) def forward(self, inp: torch.Tensor): From 0ba2df1e042159a7ba7aaf960393e286c7d6c313 Mon Sep 17 00:00:00 2001 From: JackMcCoy Date: Tue, 16 Nov 2021 12:40:34 -0500 Subject: [PATCH 33/34] import fix/ style edits --- src/dataset.py | 4 ++-- src/model.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index f444104..10bec2e 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -22,12 +22,12 @@ def __init__(self, ctx: Context, queue: multiprocessing.Queue, length: int): def __len__(self): return self.length - def __iter__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: + def __iter__(self) -> typing.Tuple[torch.Tensor, torch.Tensor]: yield next(self) def __next__(self): items = [self.queue.get() for _ in range(self.ctx.optimizer.gradient_accumulation_steps)] - return torch.stack([itm for itm in items], 0) + return torch.stack(items, 0) def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_count: int): diff --git a/src/model.py b/src/model.py index e7fb52e..bcf0393 100644 --- a/src/model.py +++ b/src/model.py @@ -4,7 +4,6 @@ import numpy as np import revlib import torch -import torch.nn.functional import torch.utils.data from torch.nn import functional as F from torch.optim.lr_scheduler import LambdaLR From efca91f071862e2a916252df5f2ba3ef1a427798 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Tue, 16 Nov 2021 21:53:31 +0100 Subject: [PATCH 34/34] feat(model): add multi-head self-attention --- configs/small.yaml | 4 ++-- src/model.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/configs/small.yaml b/configs/small.yaml index 4e01d06..a325023 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,8 +1,8 @@ model: - attention: FFTAttention + attention: SelfAttention sequence_length: 2048 omnidirectional: no - depth: 16 + depth: 1 conv_kernel_size: 11 weight_shared_blocks: 1 batch_size: 8 diff --git a/src/model.py b/src/model.py index 962a8ea..0ad45e6 100644 --- a/src/model.py +++ b/src/model.py @@ -271,6 +271,7 @@ def forward(self, inp: torch.Tensor): if self.expand_sequence: out = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)), device=inp.device, dtype=inp.dtype)], 2) + out = self.stem(out) if self.expand_sequence: out = out.view(batch, features, -1, sequence).mean(2) @@ -309,7 +310,7 @@ def get_moe_param(in_features: int, out_features: int, groups: int, experts: int class FeedForward(torch.nn.Module): - def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float = 1): super(FeedForward, self).__init__() self.ctx = ctx self.divisor = lambda: base.divisor @@ -432,4 +433,15 @@ def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: return out.unsqueeze(2) * inp -attention_modules = [FeedForward, FFTAttention, SumAttention, SqueezeExcitation] +class SelfAttention(AttentionBase): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(SelfAttention, self).__init__(base, ctx, init_scale, 0) + self.mha = torch.nn.MultiheadAttention(ctx.model.features, 4) + self.get_last = not ctx.model.omnidirectional + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + inp = inp.permute(2, 0, 1) + return self.mha(inp, inp, inp)[0].permute(1, 2, 0) + + +attention_modules = [FeedForward, FFTAttention, SumAttention, SqueezeExcitation, SelfAttention]