Skip to content

Conversation

ClashLuke
Copy link
Member

This PR also

  • adds a new custom sum-based attention
  • changes a bunch of parameter names
  • changes small.yaml to integrate omnidirectional attention
  • breaks up our linear attention module into one ff and one attention module
  • removes DeepSpeed's broken CPUAdam
  • enforces full attention while removing autoregressive attention

The idea of FFT-based attention comes from FNet, LMU and On Learning the Transformer Kernel, but is implemented differently to optimize the expressivity of our model.
OmniNet attends to all previous hidden states instead of only the current hidden state, bridging the gap between linear attention and full attention.
A custom data loader is required as PyTorch's data loader gives CPU-OOMs, has a broken shuffling function and requires >8GiB RAM to instantiate 12 empty classes. While this wasn't the case in PyTorch 1.9, it is in 1.10 on WSL.
As WSL cannot deallocate GPU memory, we had to support windows natively.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants