diff --git a/configs/small.yaml b/configs/small.yaml index 9a7c80d..a325023 100644 --- a/configs/small.yaml +++ b/configs/small.yaml @@ -1,17 +1,18 @@ model: - depth: 32 + attention: SelfAttention + sequence_length: 2048 + omnidirectional: no + depth: 1 conv_kernel_size: 11 weight_shared_blocks: 1 - batch_size: 1024 + batch_size: 8 + offloading: yes + features: 2048 feed_forward_intermediate_factor: 0.125 optimizer: beta2: 0.95 gradient_accumulation_steps: 1 - one_cycle: - cycle_first_step_size: 8192 - cycle_second_step_size: null - cycle_max_lr: 0.01 log: - loss_steps_per_print: 8 + loss_steps_per_print: 4 dataset: - num_workers: 12 + num_workers: 16 diff --git a/src/dataclass.py b/src/dataclass.py index 9511263..a412173 100644 --- a/src/dataclass.py +++ b/src/dataclass.py @@ -19,11 +19,14 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]): class Model(DataClass): + omnidirectional: bool = False + attention: str = "FFTAttention" weight_sharing: bool = False checkpoint_path: str = "checkpoint.torch" steps_per_checkpoint: int = 0 # 0 -> disabled print_on_init: bool = True features: int = 256 + sum_attention_level: int = 0 momentumnet_beta: float = 0.99 # The higher this is, the more numerically stable. BUT also lower impact per layer depth: int = 64 batch_size: int = 128 @@ -51,6 +54,7 @@ class Dataset(DataClass): file_name: str = "out.tensor" classes: int = 256 num_workers: int = 4 + dropout: float = 0.3 pin_memory: bool = False prefetch_factor: int = 256 # 256 (Prefetch) * 8 (Long) * 2048 (GPT context) * 256 (High Batch) = 1GiB RAM @@ -85,19 +89,19 @@ class Zero(DataClass): class OneCycle(DataClass): - cycle_min_lr: float = 3e-4 # Base learning rate used at the start and end of cycle. - cycle_max_lr: float = 1e-3 # Learning rate used in the middle of the cycle. Can be smaller than cycle_min_lr - decay_lr_rate: float = 1e-4 # Decay rate for learning rate. - cycle_first_step_size: int = 2048 # Number of training iterations in the increasing half of a cycle. - cycle_second_step_size: typing.Optional[int] = None # steps in second phase. None -> cycle_first_step_size - cycle_first_stair_count: int = 0 # Number of stairs in first phase. 0 means staircase disabled - cycle_second_stair_count: typing.Optional[int] = None # Number of stairs in second phase - decay_step_size: int = 2 # Every how many steps to decay lr. 0 -> no decay - cycle_momentum: bool = True # Whether to cycle `momentum` inversely to learning rate. - cycle_min_mom: float = 0.8 # Initial momentum which is the lower boundary in the cycle for each parameter group. - cycle_max_mom: float = 0.9 # Upper momentum boundaries in the cycle for each parameter group. - decay_mom_rate: float = 0 # Decay rate for momentum - last_batch_iteration: int = -1 # The index of the last batch. This parameter is used when resuming a training job. + max_lr: float = 1e-3 + total_steps: typing.Optional[int] = 10 ** 3 + epochs: typing.Optional[int] = None + steps_per_epoch: typing.Optional[int] = None + pct_start: float = 0.3 + anneal_strategy: str = 'cos' + cycle_momentum: bool = True + base_momentum: float = 0.85 + max_momentum: float = 0.95 + div_factor: float = 25. + final_div_factor: float = 1e4 + three_phase: bool = False + last_epoch: int = -1 class AdaptiveGradientClipping(DataClass): @@ -114,11 +118,13 @@ class SharpnessAwareMinimization(DataClass): class Optimizer(DataClass): type: str = "AdamW" + final_step: int = 2 ** 14 + warmup_end: int = 2 ** 10 gradient_accumulation_steps: int = 1 one_cycle: OneCycle = OneCycle() beta2: float = 0.95 # beta1 is controlled by one_cycle eps: float = 1e-8 - weight_decay: float = 0.01 + weight_decay: float = 0. zero: Zero = Zero() agc = AdaptiveGradientClipping() sharpness_aware_minimization: SharpnessAwareMinimization = SharpnessAwareMinimization() @@ -132,7 +138,7 @@ class Optimizer(DataClass): statistics_compute_steps: int = 1 block_size: int = 128 best_effort_shape_interpretation: bool = True - graft_type: str = 'adagrad' # 'Adagrad' or 'SGD' + graft_type: str = 'SGD' # 'Adagrad' or 'SGD' nesterov: bool = True no_preconditioning_for_layers_with_dim_gt: int = 8192 diff --git a/src/dataset.py b/src/dataset.py index 9e51bc9..10bec2e 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,3 +1,5 @@ +import multiprocessing +import random import typing import torch @@ -7,32 +9,51 @@ @torch.jit.script -def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: - dat = data[batch_index + idx] - dat = dat.to(dtype=torch.long, non_blocking=True) - return dat[:, :-1], dat[:, 1:] +def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> torch.Tensor: + return data[batch_index + idx].to(dtype=torch.long, non_blocking=True) -class Dataset(torch.utils.data.Dataset): - def __init__(self, ctx: Context): - self.data = torch.load(ctx.dataset.file_name) - batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) - item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1) - self.batch_index = batch_index + item_index - self.length = self.data.size(0) - ctx.model.batch_size * ctx.model.sequence_length +class Dataset: + def __init__(self, ctx: Context, queue: multiprocessing.Queue, length: int): + self.ctx = ctx + self.length = length + self.queue = queue def __len__(self): return self.length - def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: - return get_sample(self.data, self.batch_index, idx) + def __iter__(self) -> typing.Tuple[torch.Tensor, torch.Tensor]: + yield next(self) + def __next__(self): + items = [self.queue.get() for _ in range(self.ctx.optimizer.gradient_accumulation_steps)] + return torch.stack(items, 0) -def get_dataset(ctx: Context) -> torch.utils.data.DataLoader: + +def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_count: int): + data = torch.load(ctx.dataset.file_name) + data_len = data.size(0) // worker_count + data = data[data_len * idx:data_len * (idx + 1)] + batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1) + item_index = torch.arange(0, ctx.model.sequence_length).view(1, -1) + batch_index = batch_index + item_index + length = data.size(0) - ctx.model.batch_size * ctx.model.sequence_length + + random.seed(idx) + while True: + queue.put(get_sample(data, batch_index, random.randint(0, length))) + + +def get_dataset(ctx: Context) -> Dataset: if ctx.dataset.prefetch_factor < ctx.dataset.num_workers: print(f"Warning: prefetch_factor ({ctx.dataset.prefetch_factor}) < num_workers ({ctx.dataset.num_workers})." f"Some workers will be idle at all times. Reducing num_workers ({ctx.dataset.num_workers}) to " f"prefetch_factor ({ctx.dataset.prefetch_factor}).") - return torch.utils.data.DataLoader(Dataset(ctx), ctx.optimizer.gradient_accumulation_steps, True, - num_workers=min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor), - pin_memory=ctx.dataset.pin_memory, prefetch_factor=ctx.dataset.prefetch_factor) + queue = multiprocessing.Queue(ctx.dataset.prefetch_factor) + workers = min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor) + procs = [multiprocessing.Process(target=_process_fn, args=(ctx, queue, idx, workers), daemon=True) for idx in + range(workers)] + for p in procs: + p.start() + data = torch.load(ctx.dataset.file_name) + return Dataset(ctx, queue, data.size(0)) diff --git a/src/executable/train.py b/src/executable/train.py index 0a6d9b4..29422eb 100644 --- a/src/executable/train.py +++ b/src/executable/train.py @@ -13,19 +13,23 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): data = get_dataset(ctx) data_len = len(data) - data = iter(data) - mod = get_model(ctx, load_model, next(data)[0]) + next_data = next(data)[0] + mod = get_model(ctx, load_model, next_data) wandb.watch(mod, log=ctx.log.wandb.model_log_type, log_freq=ctx.log.wandb.log_frequency) log = WandbLog(ctx, data_len) mean_loss = torch.zeros([], device=ctx.model.device, dtype=torch.float16 if ctx.model.float16 else torch.float) mean_max_loss = mean_loss.clone() + mean_acc = mean_loss.clone() + mean_max_acc = mean_loss.clone() i = 0 while True: i += 1 - mean_loss += mod.accumulated_step(next(data)) + lss, acc = mod.accumulated_step(next(data)) + mean_loss += lss + mean_acc += acc if ctx.optimizer.sharpness_aware_minimization.enabled: with torch.no_grad(): for p in mod.gradients(): @@ -35,7 +39,10 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): p.add_(p.grad) p.prev_step = p.grad p.grad = None - mean_max_loss += mod.accumulated_step(next(data)) + + lss, acc = mod.accumulated_step(next(data)) + mean_max_loss += lss + mean_max_acc += acc mod.optimizer.step() if ctx.optimizer.sharpness_aware_minimization.enabled: with torch.no_grad(): @@ -50,10 +57,12 @@ def train_model(ctx: Context, steps=None, load_model: bool = False): p['betas'] = p['betas'][0], mod.ctx.optimizer.beta2 with torch.no_grad(): if mod.ctx.log.loss_steps_per_print and i % mod.ctx.log.loss_steps_per_print == 0: - log(mean_loss, mean_max_loss, + log(mean_loss, mean_max_loss, mean_acc, mean_max_acc, mod.optimizer.param_groups[0]['lr'], mod.optimizer.param_groups[0]['betas']) mean_loss.zero_() mean_max_loss.zero_() + mean_acc.zero_() + mean_max_acc.zero_() if mod.ctx.model.steps_per_checkpoint and i % mod.ctx.model.steps_per_checkpoint == 0: mod.save() if steps and i > steps: diff --git a/src/model.py b/src/model.py index b57ff9a..0ad45e6 100644 --- a/src/model.py +++ b/src/model.py @@ -5,8 +5,8 @@ import revlib import torch import torch.utils.data -from deepspeed.runtime import lr_schedules from torch.nn import functional as F +from torch.optim.lr_scheduler import LambdaLR from src.dataclass import Context from src.optimizers.build import build_optimizer @@ -76,7 +76,7 @@ def backward(ctx, grad_outputs: torch.Tensor): def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: bool, - jitter_epsilon: float, feature_shuffle: torch.Tensor, groups: int, experts: int) -> torch.Tensor: + jitter_epsilon: float, groups: int, experts: int) -> torch.Tensor: *expert_weights, gate = expert_weights batch, features, sequence = inp.size() tokens = batch * sequence @@ -112,8 +112,6 @@ def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: boo # permute inp = inp.gather(0, expert_permutation.expand_as(inp)) - if feature_shuffle is not None: - inp = inp.gather(1, feature_shuffle.view(1, -1).expand_as(inp)) inp = inp.view(tokens // experts, experts * groups, features // groups) if len(expert_weights) == 1: inp = expert_matmul(inp, expert_weights[0]) @@ -126,78 +124,60 @@ def moe(inp: torch.Tensor, expert_weights: torch.nn.ParameterList, training: boo def moe_check(inp: torch.Tensor, w: torch.nn.ParameterList, training: bool, - jitter_epsilon: float, feature_shuffle: torch.Tensor, groups: int, experts: int) -> torch.Tensor: + jitter_epsilon: float, groups: int, experts: int) -> torch.Tensor: if experts > 0: - return moe(inp, w, training, jitter_epsilon, feature_shuffle, groups, experts) + return moe(inp, w, training, jitter_epsilon, groups, experts) return conv(inp, w[0], groups, False) -def linear_attention(inp: torch.Tensor, divisor: torch.Tensor, - w0: torch.nn.ParameterList, - feature_shuffle0: typing.Optional[torch.Tensor], groups0: int, experts0: int, - w1: torch.Tensor, - w2: torch.nn.ParameterList, - feature_shuffle2: typing.Optional[torch.Tensor], groups2: int, experts2: int, - input_cache: torch.Tensor, cumsum_cache: torch.Tensor, bottleneck_group: int, training: bool, - caching: bool, idx: int, norm_power: int, jitter_epsilon: float - ) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - kernel_size = w1.size(2) - pad = True - if not training and caching: - if idx - 1 > kernel_size and inp.size(2) == 1: - pad = False - inp = torch.cat([input_cache, inp], -1) - input_cache = inp[:, :, -kernel_size + 1:].detach() - inp = moe_check(inp, w0, training, jitter_epsilon, feature_shuffle0, groups0, experts0) - depth, scale, shift = inp.chunk(3, 1) - cum = depth.cumsum(-1) - if not training and caching: - cum = cum + cumsum_cache - scale = scale[:, :, -1:] - shift = shift[:, :, -1:] - cum = cum[:, :, -1:] - if idx - 1 > kernel_size: - cumsum_cache = cum.detach() - inp = TripleNorm.apply(cum / divisor, scale, shift, norm_power) - inp = conv(inp, w1, bottleneck_group, pad) +def linear_attention(inp: torch.Tensor, w0: torch.nn.ParameterList, groups0: int, experts0: int, w1: torch.Tensor, + w2: torch.nn.ParameterList, groups2: int, experts2: int, bottleneck_group: int, training: bool, + norm_power: int, jitter_epsilon: float) -> torch.Tensor: + inp = moe_check(inp, w0, training, jitter_epsilon, groups0, experts0) inp = TripleNorm.apply(*inp.chunk(3, 1), norm_power) - inp = moe_check(inp, w2, training, jitter_epsilon, feature_shuffle2, groups2, experts2) - return input_cache, cumsum_cache, inp + inp = conv(inp, w1, bottleneck_group, True) + inp = TripleNorm.apply(*inp.chunk(3, 1), norm_power) + return moe_check(inp, w2, training, jitter_epsilon, groups2, experts2) def conv_weight(in_features: int, out_features: int, kernel_size: int, groups: int, std: float): return orthonormal(torch.nn.Conv1d(in_features, out_features, (kernel_size,), groups=groups).weight, 1 / std) +def get_lr_scheduler_fn(ctx: Context) -> typing.Callable[[int], float]: + def _fn(step: int) -> float: + final_lr = 1 - 2 / (ctx.optimizer.final_step - ctx.optimizer.warmup_end) + lr = step + lr /= max(step, ctx.optimizer.warmup_end) + lr *= final_lr ** max(step - ctx.optimizer.warmup_end, 0) + # lr *= ctx.optimizer.learning_rate # It's a multiplier for the initial learning rate, not the LR itself + return lr + + return _fn + + class Trainer(torch.nn.Module): def __init__(self, ctx: Context, model: torch.nn.Module, data: typing.Optional[torch.Tensor]): super(Trainer, self).__init__() self.ctx = ctx self.model = torch.jit.trace(model, data) if data else model self.optimizer = build_optimizer(ctx, self.model.parameters()) - self.scheduler = lr_schedules.OneCycle(self.optimizer, - ctx.optimizer.one_cycle.cycle_min_lr, - ctx.optimizer.one_cycle.cycle_max_lr, - ctx.optimizer.one_cycle.decay_lr_rate, - ctx.optimizer.one_cycle.cycle_first_step_size, - ctx.optimizer.one_cycle.cycle_second_step_size, - ctx.optimizer.one_cycle.cycle_first_stair_count, - ctx.optimizer.one_cycle.cycle_second_stair_count, - ctx.optimizer.one_cycle.decay_step_size, - ctx.optimizer.one_cycle.cycle_momentum, - ctx.optimizer.one_cycle.cycle_min_mom, - ctx.optimizer.one_cycle.cycle_max_mom, - ctx.optimizer.one_cycle.decay_mom_rate, - ctx.optimizer.one_cycle.last_batch_iteration) + self.scheduler = LambdaLR(self.optimizer, get_lr_scheduler_fn(ctx)) @torch.no_grad() def _to_device_detach(self, inp: torch.Tensor) -> torch.Tensor: return inp.to(device=self.ctx.model.device, non_blocking=True).detach() - def _forward_backward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: - loss = F.cross_entropy(self.model(self._to_device_detach(src)), self._to_device_detach(tgt)) + def _forward_backward(self, src: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + src = self._to_device_detach(src) + msk = torch.rand(src.size(), dtype=torch.float32, device=src.device) > self.ctx.dataset.dropout + out = self.model(src * msk) + msk = ~msk + masked = msk.sum() + loss = (F.cross_entropy(out, src, reduction="none") * msk).sum() / masked loss.backward() - return loss.detach() + with torch.inference_mode(): + return loss.detach(), (torch.argmax(out, 1) == src).mul(msk).sum().float() / masked @torch.no_grad() def _clip_gradient(self): @@ -207,10 +187,15 @@ def _clip_gradient(self): grad_scale = (p_norm / g_norm * self.ctx.optimizer.agc.gradient_clipping).clamp(max=1) p.grad.data.copy_(p.grad * grad_scale) - def accumulated_step(self, data: torch.Tensor) -> torch.Tensor: - loss = sum(self._forward_backward(s, t) for s, t in zip(*data)) + def accumulated_step(self, data: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + loss = 0 + accuracy = 0 + for src in data: + lss, acc = self._forward_backward(src) + loss += lss + accuracy += acc self._clip_gradient() - return loss + return loss, accuracy @torch.no_grad() def zero_grad(self): @@ -255,22 +240,46 @@ def __init__(self, ctx: Context): pos_embd = torch.arange(0, ctx.model.sequence_length).unsqueeze(0) + 1 self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device)) - cell = LinearAttentionCell(self, ctx, 1) + ff = FeedForward(self, ctx, 1, 1) + + modules = [mod.__name__ for mod in attention_modules] + if ctx.model.attention not in modules: + raise ValueError(f"{ctx.model.attention} is not a known type of attention. You can pick any of the" + f" following: {modules}") + attn = attention_modules[modules.index(ctx.model.attention)](self, ctx, 1) + self.expand_sequence = (not attn.get_last) | (not ff.get_last) self.stem = revlib.ReversibleSequential(*[c - for i in range(1, 1 + ctx.model.depth) - for c in [cell.momentum((1 - ctx.model.momentumnet_beta) / - ctx.model.momentumnet_beta ** i, not ctx.model.weight_sharing), - MomentumNetSide(ctx.model.momentumnet_beta ** i)]], - target_device=ctx.model.device) + for i in range(1, 1 + ctx.model.depth * 2, 2) + for c in [ff.momentum((1 - ctx.model.momentumnet_beta) / + ctx.model.momentumnet_beta ** i, + not ctx.model.weight_sharing, + i + 1), + MomentumNetSide(ctx.model.momentumnet_beta ** i), + attn.momentum((1 - ctx.model.momentumnet_beta) / + ctx.model.momentumnet_beta ** (i + 1), + not ctx.model.weight_sharing, + i + 1), + MomentumNetSide(ctx.model.momentumnet_beta ** (i + 1))]], + target_device=ctx.model.device, + memory_mode=revlib.MemoryModes.autograd_function) self.output = torch.nn.Conv1d(ctx.model.features * 2, ctx.dataset.classes, (1,)).to(ctx.model.device) - torch.nn.init.zeros_(self.output.weight.data) + torch.nn.init.orthogonal_(self.output.weight.data) def forward(self, inp: torch.Tensor): - return self.output(self.stem(self.embedding(inp).transpose(1, 2))) + out = inp = self.embedding(inp).transpose(1, 2) + batch, features, sequence = inp.size() + if self.expand_sequence: + out = torch.cat([inp, torch.zeros((batch, features, sequence * len(self.stem.stem)), device=inp.device, + dtype=inp.dtype)], 2) + + out = self.stem(out) + if self.expand_sequence: + out = out.view(batch, features, -1, sequence).mean(2) + return self.output(out) def reset_cache(self): for mod in self.stem.modules(): - if isinstance(mod, LinearAttentionCell): + if isinstance(mod, FeedForward): mod.reset_cache() @@ -300,9 +309,10 @@ def get_moe_param(in_features: int, out_features: int, groups: int, experts: int return [torch.nn.Parameter(conv_weight(in_features, out_features, 1, groups, std))] -class LinearAttentionCell(torch.nn.Module): - def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): - super(LinearAttentionCell, self).__init__() +class FeedForward(torch.nn.Module): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float = 1): + super(FeedForward, self).__init__() + self.ctx = ctx self.divisor = lambda: base.divisor self.init_scale = init_scale self.caching = ctx.eval.cache @@ -316,50 +326,122 @@ def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): self.jitter_epsilon = ctx.model.moe_jitter_epsilon self.expert_chunks = ctx.model.expert_chunks intermediate = int(ctx.model.features * ctx.model.feed_forward_intermediate_factor) - self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features, intermediate * 3, self.groups0, - self.experts0, self.expert_chunks, ctx.model.activation_std)) - self.w1 = conv_weight(intermediate, intermediate * 3, ctx.model.conv_kernel_size, ctx.model.bottleneck_group, - ctx.model.activation_std) - self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features, self.groups2, - self.experts2, self.expert_chunks, 1)) + if feature_factor: + self.w0 = torch.nn.ParameterList(get_moe_param(ctx.model.features * feature_factor, intermediate * 3, + self.groups0, self.experts0, self.expert_chunks, + ctx.model.activation_std)) + self.w1 = conv_weight(intermediate, intermediate * 3, ctx.model.conv_kernel_size, + ctx.model.bottleneck_group, ctx.model.activation_std) + self.w2 = torch.nn.ParameterList(get_moe_param(intermediate, ctx.model.features * feature_factor, + self.groups2, self.experts2, self.expert_chunks, 1)) self.idx: int = 0 - self._input_cache = torch.zeros([]) - self._cumsum_cache = torch.zeros([]) - if ctx.model.feature_shuffle: - self.register_buffer("feature_shuffle0", torch.argsort(torch.randn(ctx.model.features)).view(1, -1, 1)) - self.register_buffer("feature_shuffle2", torch.argsort(torch.randn(intermediate)).view(1, -1, 1)) - else: - self.feature_shuffle0 = None - self.feature_shuffle2 = None + self.depth: int = 0 + self.get_last: bool = True - def reset_cache(self): - self._cumsum_cache = torch.zeros([]) - self._input_cache = torch.zeros([]) - self.idx = 0 + def __repr__(self): + extra = '\n '.join([f'{name}: {param.size()}' for name, param in self.named_parameters()]) + return f"{self._get_name()}(\n {extra}\n)" + + def _cut_off(self, inp: torch.Tensor) -> torch.Tensor: + if inp.size(2) == self.ctx.model.sequence_length: + return inp + + base_len = self.ctx.model.sequence_length * self.depth + max_len = base_len + self.ctx.model.sequence_length + if self.get_last: + return inp[base_len:max_len] + return inp[:max_len] + + def _pad(self, inp: torch.Tensor, out: torch.Tensor): + if inp.size(2) == out.size(2): + return inp + + batch, features, sequence = inp.size() + if self.get_last: + return torch.cat([torch.zeros((batch, features, self.ctx.model.sequence_length * self.depth), + device=out.device, dtype=out.dtype), out, + torch.zeros((batch, features, + sequence - self.ctx.model.sequence_length * (self.depth + 1)), + device=out.device, dtype=out.dtype)], 2) + return torch.cat([out, torch.zeros((batch, features, sequence - out.size(2)), device=out.device, + dtype=out.dtype)], 2) + + def _ff(self, inp: torch.Tensor) -> torch.Tensor: + return linear_attention(inp, self.w0, self.groups0, self.experts0, self.w1, self.w2, self.groups2, + self.experts2, + self.bottleneck_group, self.training, self.norm_power, self.jitter_epsilon) + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + return self._ff(inp) def forward(self, inp: torch.Tensor) -> torch.Tensor: - if self.training: - div = self.divisor() - elif self.caching: - self.idx += inp.size(2) - div = torch.LongTensor([self.idx]).to(inp.device) - else: - self.idx = inp.size(2) - div = torch.arange(self.idx, device=inp.device).view(1, 1, -1) + 1 - self._input_cache, self._cumsum_cache, out = linear_attention(inp, div, - self.w0, self.feature_shuffle0, self.groups0, - self.experts0, - self.w1, - self.w2, self.feature_shuffle2, self.groups2, - self.experts2, self._input_cache, - self._cumsum_cache, self.bottleneck_group, - self.training, self.caching, self.idx, - self.norm_power, self.jitter_epsilon - ) - out = out * self.init_scale - return out + return self._pad(inp, self._inner_forward(self._cut_off(inp))) * self.init_scale - def momentum(self, init_scale: float, deep: bool): + def momentum(self, init_scale: float, deep: bool, depth: int): out = copy.deepcopy(self) if deep else copy.copy(self) out.init_scale = init_scale + out.depth = depth return out + + +class AttentionBase(FeedForward): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float, feature_factor: float): + super(AttentionBase, self).__init__(base, ctx, init_scale, feature_factor) + self.get_last = not ctx.model.omnidirectional + + +class FFTAttention(AttentionBase): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(FFTAttention, self).__init__(base, ctx, init_scale, 2) + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + batch, features, sequence = inp.size() + out = torch.view_as_real(torch.fft.rfft(inp, 2 * sequence, norm="ortho")) + out = out.transpose(2, 3).reshape(batch, features * 2, sequence + 1) + out = self._ff(out) + out = out.view(batch, features, 2, sequence + 1).transpose(2, 3).contiguous() + out = torch.view_as_complex(out) + return torch.fft.irfft(out, 2 * sequence, norm="ortho")[:, :, :sequence] + + +class SumAttention(AttentionBase): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(SumAttention, self).__init__(base, ctx, init_scale, ctx.model.sum_attention_level) + self.sum_attention_level = ctx.model.sum_attention_level + self.weight = conv_weight(ctx.model.features, ctx.model.features, 3, 1, 1) + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + out = self._ff(inp).chunk(self.sum_attention_level, 1) + batch, features, seq = out[0].size() + return sum(conv(torch.relu(out[0] + sum(out[inner + 1][outer // batch ** inner % batch] + for inner in range(self.sum_attention_level)).unsqueeze(0)), + self.weight, 1, True) + for outer in range(batch ** (self.sum_attention_level - 1))) + + +class SqueezeExcitation(AttentionBase): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(SqueezeExcitation, self).__init__(base, ctx, init_scale, 0) + self.weight0 = orthonormal([ctx.model.features * 2, ctx.model.features * 2 * 3], 1) + self.weight1 = orthonormal([ctx.model.features * 2, ctx.model.features], 1) + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + out = torch.cat([inp.mean(2), inp.max(2).values], 1) + out = out.mm(self.weight0).chunk(3, 1) + out = TripleNorm.apply(*out, self.ctx.model.norm_power) + out = out.mm(self.weight1) + return out.unsqueeze(2) * inp + + +class SelfAttention(AttentionBase): + def __init__(self, base: LinearAttention, ctx: Context, init_scale: float): + super(SelfAttention, self).__init__(base, ctx, init_scale, 0) + self.mha = torch.nn.MultiheadAttention(ctx.model.features, 4) + self.get_last = not ctx.model.omnidirectional + + def _inner_forward(self, inp: torch.Tensor) -> torch.Tensor: + inp = inp.permute(2, 0, 1) + return self.mha(inp, inp, inp)[0].permute(1, 2, 0) + + +attention_modules = [FeedForward, FFTAttention, SumAttention, SqueezeExcitation, SelfAttention] diff --git a/src/optimizers/build.py b/src/optimizers/build.py index 008a1a8..4b06422 100644 --- a/src/optimizers/build.py +++ b/src/optimizers/build.py @@ -1,14 +1,13 @@ import inspect import typing -import deepspeed.ops.adam import torch from src.dataclass import Context from src.optimizers import shampoo OWN_OPTIMIZER = {'Shampoo': shampoo.Shampoo} -LIB_OPTIMIZER = {'DeepSpeedCPUAdam': deepspeed.ops.adam.DeepSpeedCPUAdam} +LIB_OPTIMIZER = {} def build_optimizer(ctx: Context, parameters: typing.Iterable[torch.nn.Parameter]): diff --git a/src/utils/formatting.py b/src/utils/formatting.py index e1a1255..e3a6e7e 100644 --- a/src/utils/formatting.py +++ b/src/utils/formatting.py @@ -36,41 +36,56 @@ class WandbLog: def __init__(self, ctx: Context, steps: int): self.mean_loss = 0 self.mean_max_loss = 0 + self.mean_acc = 0 + self.mean_max_acc = 0 self.start_time = time.time() self.ctx = ctx self.idx = 0 self.prev = 0 self.steps = steps - def __call__(self, current_loss: torch.Tensor, max_loss: torch.Tensor, learning_rate: float, + def normalize(self, var: torch.Tensor, attribute: str) -> float: + attr = getattr(self, attribute) + curr_var = var.item() / self.ctx.log.loss_steps_per_print / self.ctx.optimizer.gradient_accumulation_steps + setattr(self, attribute, (attr * self.prev + curr_var * self.idx) / (self.prev + self.idx)) # LWMA + return curr_var + + def __call__(self, current_loss: torch.Tensor, max_loss: torch.Tensor, + current_acc: torch.Tensor, max_acc: torch.Tensor, learning_rate: float, betas: typing.Tuple[float, float]): - grad_accum = self.ctx.optimizer.gradient_accumulation_steps - curr_loss = current_loss.item() / self.ctx.log.loss_steps_per_print / grad_accum - curr_max_loss = max_loss.item() / self.ctx.log.loss_steps_per_print / grad_accum self.idx += 1 - self.mean_loss = (self.mean_loss * self.prev + curr_loss * self.idx) / (self.prev + self.idx) # LWMA - mean_max = self.mean_max_loss = (self.mean_max_loss * self.prev + max_loss * self.idx) / (self.prev + self.idx) + current_loss = self.normalize(current_loss, "mean_loss") + current_acc = self.normalize(current_acc, "mean_acc") + if self.ctx.optimizer.sharpness_aware_minimization: + max_loss = self.normalize(max_loss, "mean_max_loss") + max_acc = self.normalize(max_acc, "mean_max_acc") + else: + max_loss = max_acc = self.mean_max_loss = self.mean_max_acc = None self.prev += self.idx rate = self.ctx.log.loss_steps_per_print * self.idx / (time.time() - self.start_time) - tokens_per_day = grad_accum * 3600 * 24 * rate * self.ctx.model.batch_size * self.ctx.model.sequence_length + tokens_per_day = 3600 * 24 * rate * self.ctx.model.batch_size * self.ctx.model.sequence_length + tokens_per_day *= self.ctx.optimizer.gradient_accumulation_steps pretty_print(f"[{self.idx * self.ctx.log.loss_steps_per_print:{len(str(self.steps))}d}/{self.steps}]", - f"Loss: {curr_loss:7.4f} -", + f"Loss: {current_loss:7.4f} -", f"Mean: {self.mean_loss:7.4f} |", + f"Acc: {current_acc:7.4f} -", + f"Mean: {self.mean_acc:7.4f} |", f"LR: {learning_rate:.6f} -", f"Beta1: {betas[0]:.3f} -", f"Beta2: {betas[1]:.3f} |", f"Batch/s: {rate:6.3f} -", f"Tokens/day: {tokens_per_day:11,.0f}") - if not self.ctx.optimizer.sharpness_aware_minimization.enabled: - curr_max_loss = None - mean_max = None - wandb.log({"Loss/Current": curr_loss, + wandb.log({"Loss/Current": current_loss, "Loss/Mean": self.mean_loss, - "Loss/Current Max": curr_max_loss, - "Loss/Mean Max": mean_max, + "Loss/Current Max": max_loss, + "Loss/Mean Max": self.mean_max_loss, + "Accuracy/Current": current_acc, + "Accuracy/Mean": self.mean_acc, + "Accuracy/Current Max": max_acc, + "Accuracy/Mean Max": self.mean_max_acc, "Speed/Batches per Second": rate, "Speed/Tokens per Day": tokens_per_day, "Optimizer/Learning Rate": learning_rate,