Skip to content

Commit 7992b16

Browse files
committed
Make key optional in ipex.llm.functional.rotary_embedding
1 parent 61f0bff commit 7992b16

File tree

4 files changed

+46
-27
lines changed

4 files changed

+46
-27
lines changed

intel_extension_for_pytorch/llm/functional/fusions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
def rotary_embedding(
1515
query: torch.Tensor,
16-
key: torch.Tensor,
16+
key: Optional[torch.Tensor],
1717
sin: torch.Tensor,
1818
cos: torch.Tensor,
1919
rotary_dim: int,
@@ -25,9 +25,10 @@ def rotary_embedding(
2525
on the `query ` or `key` before their multi-head attention computation.
2626
2727
Args:
28-
query, key (torch.Tensor) : inputs to be applied with position embeddings,
28+
query (torch.Tensor), key (Optional[torch.Tensor]): inputs to be applied with position embeddings,
2929
taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim]
3030
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
31+
`key` may be `None`, e.g. in case of cross-layer KV sharing.
3132
sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor
3233
generated to be applied on query/key.
3334
rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
@@ -42,7 +43,7 @@ def rotary_embedding(
4243
The according position_ids for the input. The shape should be [batch size, sequence length].
4344
4445
Return
45-
query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
46+
query (torch.Tensor), key (Optional[torch.Tensor]): [batch size, sequence length, num_head/num_kv_head, head_dim]
4647
or [num_tokens, num_head/num_kv_head, head_dim].
4748
4849
"""

intel_extension_for_pytorch/llm/modules/mha_fusion.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,15 @@ class RotaryEmbedding(nn.Module):
4949
5050
[Direct function call] This module also provides a `.apply_function` function call
5151
to be used on query and key at the same time without initializing the module
52-
(assume rotary embedding sin/cos values are provided).
52+
(assume rotary embedding sin/cos values are provided). `key` is optional for `.apply_function` call.
5353
5454
`apply_function()`
5555
5656
Args:
57-
query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of
57+
query (torch.Tensor), key (Optional[torch.Tensor]) : inputs to be applied with position embeddings, taking shape of
5858
[batch size, sequence length, num_head/num_kv_head, head_dim]
5959
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
60+
`key` may be None, e.g. in case of cross-layer KV sharing.
6061
sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.
6162
rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
6263
head_dim (int) : head dim from the input shape.
@@ -68,7 +69,7 @@ class RotaryEmbedding(nn.Module):
6869
for the input. The shape should be [batch size, sequence length].
6970
7071
Return:
71-
query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
72+
query (torch.Tensor), key (Optional[torch.Tensor]): [batch size, sequence length, num_head/num_kv_head, head_dim]
7273
or [num_tokens, num_head/num_kv_head, head_dim].
7374
7475
"""
@@ -137,14 +138,17 @@ def forward(
137138
def apply_function(
138139
cls,
139140
query: torch.Tensor,
140-
key: torch.Tensor,
141+
key: Optional[torch.Tensor],
141142
sin: torch.Tensor,
142143
cos: torch.Tensor,
143144
rotary_dim: int,
144145
rotary_half: bool,
145146
position_ids: torch.Tensor = None,
146147
):
147-
# query, key (in/out shape) torch.Tensor :
148+
# query: torch.Tensor with in/out shape:
149+
# 4D: [batch, seqlen, num_head/num_kv_head, head_dim]
150+
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
151+
# key (optional) None or torch.Tensor with in/out shape:
148152
# 4D: [batch, seqlen, num_head/num_kv_head, head_dim]
149153
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
150154
# sin, cos: torch.Tensor [num_tokens, rotary_dim]

intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,28 @@ def forward(
6565
def rotary_embedding(
6666
cls, query, key, sin, cos, rotary_dim, rotary_half, position_ids=None
6767
):
68-
# query, key (in/out shape) torch.Tensor :
68+
# query: torch.Tensor with in/out shape:
69+
# 4D: [bs, seqlen, num_head/num_kv_head, head_dim]
70+
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
71+
# key (optional) None or torch.Tensor with in/out shape:
6972
# 4D: [bs, seqlen, num_head/num_kv_head, head_dim]
7073
# 3D: [num_tokens, num_head/num_kv_head, head_dim]
7174
# sin, cos: torch.Tensor [num_tokens, rotary_dim]
7275
# position_ids (optional): torch.Tensor [bs, seqlen]
7376
head_dim = query.size(-1)
7477
num_head = query.size(-2)
75-
num_kv_head = key.size(-2)
78+
num_kv_head = key.size(-2) if key is not None else 0
7679
input_3d = False
7780
assert (
78-
key.dim() == query.dim() and query.dim() == 3 or query.dim() == 4
81+
(key is None or key.dim() == query.dim())
82+
and query.dim() == 3
83+
or query.dim() == 4
7984
), "rotary embedding query/key dim == 3 or 4"
8085

8186
if query.dim() == 3:
8287
input_3d = True
8388
query_ = query.unsqueeze(0)
84-
key_ = key.unsqueeze(0)
89+
key_ = key.unsqueeze(0) if key is not None else None
8590
else:
8691
query_ = query
8792
key_ = key
@@ -124,21 +129,26 @@ def rotary_embedding(
124129
rotary_dim,
125130
)
126131

127-
key_, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
128-
key_,
129-
sin_cos,
130-
position_ids,
131-
num_kv_head,
132-
head_dim,
133-
offset,
134-
rotary_dim,
135-
)
132+
if key is not None:
133+
key_, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
134+
key_,
135+
sin_cos,
136+
position_ids,
137+
num_kv_head,
138+
head_dim,
139+
offset,
140+
rotary_dim,
141+
)
136142
if input_3d:
137143
query_ = query_.view([-1, num_head, head_dim])
138-
key_ = key_.view([-1, num_kv_head, head_dim])
144+
if key_ is not None:
145+
key_ = key_.view([-1, num_kv_head, head_dim])
139146
# keep the inplace context as used in TGI
140147
query.copy_(query_)
141-
key.copy_(key_)
148+
149+
if key is not None:
150+
key.copy_(key_)
151+
142152
return query, key
143153

144154

tests/cpu/test_ipex_llm_module.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -884,23 +884,27 @@ def test_rotary_embedding_tgi(self):
884884
(1, 32, 128),
885885
(32, 32, 128),
886886
]
887-
for size in test_tensor_size:
887+
for size, use_key in itertools.product(test_tensor_size, [True, False]):
888888
q = torch.randn(size).float()
889-
k = torch.randn(size).float()
889+
k = torch.randn(size).float() if use_key else None
890890
rotary_dim = size[-1]
891891
seqlen = size[0]
892892
position_ids = torch.arange(size[0])
893893
sin, cos = get_sin_cos(position_ids, rotary_dim, 10000, seqlen, q.dtype)
894894

895895
ref_q = apply(q, cos, sin)
896-
ref_k = apply(k, cos, sin)
896+
ref_k = apply(k, cos, sin) if use_key else None
897897

898898
ipex_q, ipex_k = ipex.llm.functional.rotary_embedding(
899899
q, k, sin, cos, rotary_dim, True
900900
)
901901

902902
self.assertEqual(ipex_q, ref_q)
903-
self.assertEqual(ref_k, ipex_k)
903+
if use_key:
904+
self.assertEqual(ref_k, ipex_k)
905+
else:
906+
self.assertIsNone(ipex_k)
907+
self.assertIsNone(ref_k)
904908

905909
def test_add_layernorm(self):
906910
for add_back in [True, False]:

0 commit comments

Comments
 (0)