diff --git a/models.py b/models.py index cb2e90f..9616c8d 100644 --- a/models.py +++ b/models.py @@ -57,7 +57,7 @@ class CausalLinearAttentionAMP(CausalLinearAttention): def forward(self, queries, keys, values, query_mask=None, key_mask=None, cache=None): - self.feature_map.new_feature_map() + self.feature_map.new_feature_map(queries.device) Q = self.feature_map.forward_queries(queries) K = self.feature_map.forward_keys(keys)