Skip to content

Commit 2b93d55

Browse files
committed
fix bug with scale
1 parent 85c83e3 commit 2b93d55

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

memory_efficient_attention_pytorch/memory_efficient_attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ def __init__(
151151
super().__init__()
152152
self.heads = heads
153153
self.causal = causal
154-
self.scale = dim_head ** -0.5
155154

156155
inner_dim = heads * dim_head
157156

@@ -184,8 +183,6 @@ def forward(
184183
q = self.to_q(x)
185184
k, v = self.to_kv(context).chunk(2, dim = -1)
186185

187-
q = q * self.scale
188-
189186
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
190187

191188
attn_fn = attention if not memory_efficient else memory_efficient_attention

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'memory-efficient-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.4',
6+
version = '0.0.5',
77
license='MIT',
88
description = 'Memory Efficient Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)