Skip to content

feat(model): Add FFT, OmniNet, custom DataLoader, Windows-Support #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1efeab8
feat(model): add fft attention
ClashLuke Oct 31, 2021
0e7ca86
feat(model): add sum-based attention
ClashLuke Nov 9, 2021
3089fa6
fix(model): define f()
ClashLuke Nov 9, 2021
ed89196
feat(model): add omnidirectional, pyramidal attention
ClashLuke Nov 13, 2021
21d80f9
feat(model): allow selection of attentio
ClashLuke Nov 13, 2021
59f55fd
fix(model): pad input with zeros if omni
ClashLuke Nov 13, 2021
3ae35bf
fix(dataset): use custom multiprocessing to avoid data replication
ClashLuke Nov 13, 2021
0eac903
style(model): remove debug prints
ClashLuke Nov 13, 2021
20ce54b
fix(dataset): manually slice pytorch data loader
ClashLuke Nov 13, 2021
f073f7c
fix(dataset): manually slice dataset
ClashLuke Nov 13, 2021
e8383e7
perf(dataset): implement sampling in multiprocessing
ClashLuke Nov 13, 2021
2734094
fix(model): take mean of states in omninet case
ClashLuke Nov 13, 2021
d8db018
style(model): increase size/use omninet
ClashLuke Nov 13, 2021
0003ec4
fix(model): remove deepspeed/add windows support
ClashLuke Nov 14, 2021
bd437d2
fix(model): use pytorch onecycle
ClashLuke Nov 14, 2021
e7497fe
feat(modeol): add input dropout
ClashLuke Nov 14, 2021
2e12178
fix(dataset): don't use rand_like(long_tensor)
ClashLuke Nov 14, 2021
fcf7716
fix(dataset): specify mul/mask order
ClashLuke Nov 14, 2021
4cd79ea
fix(model): first reduce, then output
ClashLuke Nov 14, 2021
9202f57
perf(model): increase features to better utilize gpu
ClashLuke Nov 14, 2021
439b7cd
style(model): add accuracy
ClashLuke Nov 14, 2021
f6f3971
fix(train): cast accuracy to float
ClashLuke Nov 14, 2021
6cea58a
fix(train): argmax accuracy accross correct dimension
ClashLuke Nov 14, 2021
0269ce3
fix(train): zero accuracy after log
ClashLuke Nov 14, 2021
4d66347
feat(model): add omninet to all attention modules
ClashLuke Nov 14, 2021
6c5fc74
fix(model): put padding in correct device
ClashLuke Nov 14, 2021
a138f8e
fix(model): invert omninet expansion
ClashLuke Nov 14, 2021
7e28e97
style(dataclass): remove weight decay by default
ClashLuke Nov 14, 2021
d2fcbdc
perf(model): only backprop masked tokens
ClashLuke Nov 14, 2021
28d2533
style(model): show parameters in representation
ClashLuke Nov 14, 2021
91c33e5
fix(model): use previous output when calculating logits
ClashLuke Nov 15, 2021
0e1cfa3
fix(model): take both momentumnet sides
ClashLuke Nov 15, 2021
0ba2df1
import fix/ style edits
JackMcCoy Nov 16, 2021
a85e193
Merge remote-tracking branch 'origin/fft' into fft
ClashLuke Nov 16, 2021
efca91f
feat(model): add multi-head self-attention
ClashLuke Nov 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions configs/small.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
model:
depth: 32
attention: SelfAttention
sequence_length: 2048
omnidirectional: no
depth: 1
conv_kernel_size: 11
weight_shared_blocks: 1
batch_size: 1024
batch_size: 8
offloading: yes
features: 2048
feed_forward_intermediate_factor: 0.125
optimizer:
beta2: 0.95
gradient_accumulation_steps: 1
one_cycle:
cycle_first_step_size: 8192
cycle_second_step_size: null
cycle_max_lr: 0.01
log:
loss_steps_per_print: 8
loss_steps_per_print: 4
dataset:
num_workers: 12
num_workers: 16
36 changes: 21 additions & 15 deletions src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ def serialize(instance: typing.Union[DataClass, typing.Dict[str, typing.Any]]):


class Model(DataClass):
omnidirectional: bool = False
attention: str = "FFTAttention"
weight_sharing: bool = False
checkpoint_path: str = "checkpoint.torch"
steps_per_checkpoint: int = 0 # 0 -> disabled
print_on_init: bool = True
features: int = 256
sum_attention_level: int = 0
momentumnet_beta: float = 0.99 # The higher this is, the more numerically stable. BUT also lower impact per layer
depth: int = 64
batch_size: int = 128
Expand Down Expand Up @@ -51,6 +54,7 @@ class Dataset(DataClass):
file_name: str = "out.tensor"
classes: int = 256
num_workers: int = 4
dropout: float = 0.3
pin_memory: bool = False
prefetch_factor: int = 256 # 256 (Prefetch) * 8 (Long) * 2048 (GPT context) * 256 (High Batch) = 1GiB RAM

Expand Down Expand Up @@ -85,19 +89,19 @@ class Zero(DataClass):


class OneCycle(DataClass):
cycle_min_lr: float = 3e-4 # Base learning rate used at the start and end of cycle.
cycle_max_lr: float = 1e-3 # Learning rate used in the middle of the cycle. Can be smaller than cycle_min_lr
decay_lr_rate: float = 1e-4 # Decay rate for learning rate.
cycle_first_step_size: int = 2048 # Number of training iterations in the increasing half of a cycle.
cycle_second_step_size: typing.Optional[int] = None # steps in second phase. None -> cycle_first_step_size
cycle_first_stair_count: int = 0 # Number of stairs in first phase. 0 means staircase disabled
cycle_second_stair_count: typing.Optional[int] = None # Number of stairs in second phase
decay_step_size: int = 2 # Every how many steps to decay lr. 0 -> no decay
cycle_momentum: bool = True # Whether to cycle `momentum` inversely to learning rate.
cycle_min_mom: float = 0.8 # Initial momentum which is the lower boundary in the cycle for each parameter group.
cycle_max_mom: float = 0.9 # Upper momentum boundaries in the cycle for each parameter group.
decay_mom_rate: float = 0 # Decay rate for momentum
last_batch_iteration: int = -1 # The index of the last batch. This parameter is used when resuming a training job.
max_lr: float = 1e-3
total_steps: typing.Optional[int] = 10 ** 3
epochs: typing.Optional[int] = None
steps_per_epoch: typing.Optional[int] = None
pct_start: float = 0.3
anneal_strategy: str = 'cos'
cycle_momentum: bool = True
base_momentum: float = 0.85
max_momentum: float = 0.95
div_factor: float = 25.
final_div_factor: float = 1e4
three_phase: bool = False
last_epoch: int = -1


class AdaptiveGradientClipping(DataClass):
Expand All @@ -114,11 +118,13 @@ class SharpnessAwareMinimization(DataClass):

class Optimizer(DataClass):
type: str = "AdamW"
final_step: int = 2 ** 14
warmup_end: int = 2 ** 10
gradient_accumulation_steps: int = 1
one_cycle: OneCycle = OneCycle()
beta2: float = 0.95 # beta1 is controlled by one_cycle
eps: float = 1e-8
weight_decay: float = 0.01
weight_decay: float = 0.
zero: Zero = Zero()
agc = AdaptiveGradientClipping()
sharpness_aware_minimization: SharpnessAwareMinimization = SharpnessAwareMinimization()
Expand All @@ -132,7 +138,7 @@ class Optimizer(DataClass):
statistics_compute_steps: int = 1
block_size: int = 128
best_effort_shape_interpretation: bool = True
graft_type: str = 'adagrad' # 'Adagrad' or 'SGD'
graft_type: str = 'SGD' # 'Adagrad' or 'SGD'
nesterov: bool = True
no_preconditioning_for_layers_with_dim_gt: int = 8192

Expand Down
55 changes: 38 additions & 17 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import multiprocessing
import random
import typing

import torch
Expand All @@ -7,32 +9,51 @@


@torch.jit.script
def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]:
dat = data[batch_index + idx]
dat = dat.to(dtype=torch.long, non_blocking=True)
return dat[:, :-1], dat[:, 1:]
def get_sample(data: torch.Tensor, batch_index: torch.Tensor, idx: int) -> torch.Tensor:
return data[batch_index + idx].to(dtype=torch.long, non_blocking=True)


class Dataset(torch.utils.data.Dataset):
def __init__(self, ctx: Context):
self.data = torch.load(ctx.dataset.file_name)
batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1)
item_index = torch.arange(0, ctx.model.sequence_length + 1).view(1, -1)
self.batch_index = batch_index + item_index
self.length = self.data.size(0) - ctx.model.batch_size * ctx.model.sequence_length
class Dataset:
def __init__(self, ctx: Context, queue: multiprocessing.Queue, length: int):
self.ctx = ctx
self.length = length
self.queue = queue

def __len__(self):
return self.length

def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor, torch.Tensor]:
return get_sample(self.data, self.batch_index, idx)
def __iter__(self) -> typing.Tuple[torch.Tensor, torch.Tensor]:
yield next(self)

def __next__(self):
items = [self.queue.get() for _ in range(self.ctx.optimizer.gradient_accumulation_steps)]
return torch.stack(items, 0)

def get_dataset(ctx: Context) -> torch.utils.data.DataLoader:

def _process_fn(ctx: Context, queue: multiprocessing.Queue, idx: int, worker_count: int):
data = torch.load(ctx.dataset.file_name)
data_len = data.size(0) // worker_count
data = data[data_len * idx:data_len * (idx + 1)]
batch_index = torch.arange(0, ctx.model.batch_size).view(-1, 1)
item_index = torch.arange(0, ctx.model.sequence_length).view(1, -1)
batch_index = batch_index + item_index
length = data.size(0) - ctx.model.batch_size * ctx.model.sequence_length

random.seed(idx)
while True:
queue.put(get_sample(data, batch_index, random.randint(0, length)))


def get_dataset(ctx: Context) -> Dataset:
if ctx.dataset.prefetch_factor < ctx.dataset.num_workers:
print(f"Warning: prefetch_factor ({ctx.dataset.prefetch_factor}) < num_workers ({ctx.dataset.num_workers})."
f"Some workers will be idle at all times. Reducing num_workers ({ctx.dataset.num_workers}) to "
f"prefetch_factor ({ctx.dataset.prefetch_factor}).")
return torch.utils.data.DataLoader(Dataset(ctx), ctx.optimizer.gradient_accumulation_steps, True,
num_workers=min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor),
pin_memory=ctx.dataset.pin_memory, prefetch_factor=ctx.dataset.prefetch_factor)
queue = multiprocessing.Queue(ctx.dataset.prefetch_factor)
workers = min(ctx.dataset.num_workers, ctx.dataset.prefetch_factor)
procs = [multiprocessing.Process(target=_process_fn, args=(ctx, queue, idx, workers), daemon=True) for idx in
range(workers)]
for p in procs:
p.start()
data = torch.load(ctx.dataset.file_name)
return Dataset(ctx, queue, data.size(0))
19 changes: 14 additions & 5 deletions src/executable/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ def train_model(ctx: Context, steps=None, load_model: bool = False):

data = get_dataset(ctx)
data_len = len(data)
data = iter(data)
mod = get_model(ctx, load_model, next(data)[0])
next_data = next(data)[0]
mod = get_model(ctx, load_model, next_data)
wandb.watch(mod, log=ctx.log.wandb.model_log_type, log_freq=ctx.log.wandb.log_frequency)

log = WandbLog(ctx, data_len)
mean_loss = torch.zeros([], device=ctx.model.device, dtype=torch.float16 if ctx.model.float16 else torch.float)
mean_max_loss = mean_loss.clone()
mean_acc = mean_loss.clone()
mean_max_acc = mean_loss.clone()

i = 0
while True:
i += 1

mean_loss += mod.accumulated_step(next(data))
lss, acc = mod.accumulated_step(next(data))
mean_loss += lss
mean_acc += acc
if ctx.optimizer.sharpness_aware_minimization.enabled:
with torch.no_grad():
for p in mod.gradients():
Expand All @@ -35,7 +39,10 @@ def train_model(ctx: Context, steps=None, load_model: bool = False):
p.add_(p.grad)
p.prev_step = p.grad
p.grad = None
mean_max_loss += mod.accumulated_step(next(data))

lss, acc = mod.accumulated_step(next(data))
mean_max_loss += lss
mean_max_acc += acc
mod.optimizer.step()
if ctx.optimizer.sharpness_aware_minimization.enabled:
with torch.no_grad():
Expand All @@ -50,10 +57,12 @@ def train_model(ctx: Context, steps=None, load_model: bool = False):
p['betas'] = p['betas'][0], mod.ctx.optimizer.beta2
with torch.no_grad():
if mod.ctx.log.loss_steps_per_print and i % mod.ctx.log.loss_steps_per_print == 0:
log(mean_loss, mean_max_loss,
log(mean_loss, mean_max_loss, mean_acc, mean_max_acc,
mod.optimizer.param_groups[0]['lr'], mod.optimizer.param_groups[0]['betas'])
mean_loss.zero_()
mean_max_loss.zero_()
mean_acc.zero_()
mean_max_acc.zero_()
if mod.ctx.model.steps_per_checkpoint and i % mod.ctx.model.steps_per_checkpoint == 0:
mod.save()
if steps and i > steps:
Expand Down
Loading