Skip to content

Commit e8771b3

Browse files
committed
Arm backend: add phi3-mini layer tests
- F32/F16 pass - INT8 xfails partially - BF16 passes for TOSA, runtime not supported for VGF Change-Id: I4d1df9faf44486d7de682cae4bd7dbd6bf2ff80d Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 2cb1ef5 commit e8771b3

File tree

3 files changed

+298
-1
lines changed

3 files changed

+298
-1
lines changed

backends/arm/MODELS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# The following file contains all models that have been confirmed to be functional and tested for the Arm backend :
1+
<!-- Copyright 2025-2026 Arm Limited and/or its affiliates. -->
2+
# The following file contains all models that have been confirmed to be functional and tested for the Arm backend:
23
- Conformer
34
- Deit Tiny
45
- DeepLab v3 (DL3)
@@ -12,6 +13,7 @@
1213
- Some popular torch.nn.modules models (NN modules)
1314
- Some popular torch ops (Torch Functions)
1415
- Neural Super Sampler (NSS)
16+
- Phi-3
1517
- ResNet 18
1618
- Wav2Letter (W2L)
1719
- Stable Diffusion:
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from transformers.models.phi3.configuration_phi3 import Phi3Config
7+
8+
9+
def get_phi3_test_config() -> Phi3Config:
10+
config = Phi3Config(
11+
vocab_size=128,
12+
hidden_size=32,
13+
intermediate_size=64,
14+
num_hidden_layers=2,
15+
num_attention_heads=4,
16+
num_key_value_heads=4,
17+
max_position_embeddings=32,
18+
original_max_position_embeddings=32,
19+
use_cache=False,
20+
tie_word_embeddings=False,
21+
)
22+
# Force eager attention path to keep the module exportable in tests.
23+
config._attn_implementation = "eager"
24+
return config
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable, Tuple
7+
8+
import pytest
9+
import torch
10+
from executorch.backends.arm._passes import (
11+
ConvertInt64ConstOpsToInt32Pass,
12+
ConvertInt64OutputOpsToInt32Pass,
13+
InsertInt32CastsAfterInt64PlaceholdersPass,
14+
)
15+
16+
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test.tester.test_pipeline import (
18+
TosaPipelineFP,
19+
TosaPipelineINT,
20+
VgfPipeline,
21+
)
22+
23+
pytest.importorskip("transformers.models.phi3")
24+
25+
from executorch.backends.arm.test.models.phi3.phi3_module_test_configs import (
26+
get_phi3_test_config,
27+
)
28+
from transformers.models.phi3.configuration_phi3 import Phi3Config # noqa: E402
29+
from transformers.models.phi3.modeling_phi3 import ( # noqa: E402
30+
Phi3Attention,
31+
Phi3DecoderLayer,
32+
Phi3MLP,
33+
Phi3RMSNorm,
34+
Phi3RotaryEmbedding,
35+
)
36+
37+
input_t1 = Tuple[torch.Tensor]
38+
input_t2 = Tuple[torch.Tensor, torch.Tensor]
39+
40+
41+
def _phi3_config() -> Phi3Config:
42+
return get_phi3_test_config()
43+
44+
45+
def _hidden_states(
46+
config: Phi3Config, dtype: torch.dtype, batch: int = 2, seq: int = 4
47+
) -> torch.Tensor:
48+
hidden_size = config.hidden_size
49+
if hidden_size is None:
50+
raise RuntimeError("Phi3Config hidden_size must be set for test inputs.")
51+
return torch.randn(batch, seq, hidden_size, dtype=dtype)
52+
53+
54+
def _position_ids(batch: int = 2, seq: int = 4) -> torch.Tensor:
55+
return torch.arange(seq, dtype=torch.long).unsqueeze(0).repeat(batch, 1)
56+
57+
58+
class Phi3AttentionModule(torch.nn.Module):
59+
def __init__(self, config: Phi3Config) -> None:
60+
super().__init__()
61+
self.attn = Phi3Attention(config, layer_idx=0)
62+
self.rotary = Phi3RotaryEmbedding(config)
63+
64+
def forward(
65+
self, hidden_states: torch.Tensor, position_ids: torch.Tensor
66+
) -> torch.Tensor:
67+
position_embeddings = self.rotary(hidden_states, position_ids)
68+
return self.attn(hidden_states, position_embeddings, None)[0]
69+
70+
71+
class Phi3DecoderLayerModule(torch.nn.Module):
72+
def __init__(self, config: Phi3Config) -> None:
73+
super().__init__()
74+
self.layer = Phi3DecoderLayer(config, layer_idx=0)
75+
self.rotary = Phi3RotaryEmbedding(config)
76+
77+
def forward(
78+
self, hidden_states: torch.Tensor, position_ids: torch.Tensor
79+
) -> torch.Tensor:
80+
position_embeddings = self.rotary(hidden_states, position_ids)
81+
output, _ = self.layer(hidden_states, position_embeddings=position_embeddings)
82+
return output
83+
84+
85+
def _module_cases() -> list[
86+
tuple[
87+
str,
88+
Callable[[Phi3Config], torch.nn.Module],
89+
Callable[[Phi3Config, torch.dtype], Tuple],
90+
]
91+
]:
92+
return [
93+
(
94+
"rms_norm",
95+
lambda cfg: Phi3RMSNorm(
96+
cfg.hidden_size,
97+
eps=float(cfg.rms_norm_eps) if cfg.rms_norm_eps is not None else 1e-6,
98+
),
99+
lambda cfg, dtype: (_hidden_states(cfg, dtype),),
100+
),
101+
(
102+
"mlp",
103+
lambda cfg: Phi3MLP(cfg),
104+
lambda cfg, dtype: (_hidden_states(cfg, dtype),),
105+
),
106+
(
107+
"attention",
108+
lambda cfg: Phi3AttentionModule(cfg),
109+
lambda cfg, dtype: (
110+
_hidden_states(cfg, dtype),
111+
_position_ids(seq=min(4, cfg.max_position_embeddings or 4)),
112+
),
113+
),
114+
(
115+
"decoder_layer",
116+
lambda cfg: Phi3DecoderLayerModule(cfg),
117+
lambda cfg, dtype: (
118+
_hidden_states(cfg, dtype),
119+
_position_ids(seq=min(4, cfg.max_position_embeddings or 4)),
120+
),
121+
),
122+
]
123+
124+
125+
def _module_cases_int() -> list[object]:
126+
xfail_reason = (
127+
"INT8 TOSA path delegates to executorch_call_delegate for attention and "
128+
"decoder_layer (check_count.exir fails)."
129+
)
130+
return [
131+
(
132+
"rms_norm",
133+
lambda cfg: Phi3RMSNorm(
134+
cfg.hidden_size,
135+
eps=float(cfg.rms_norm_eps) if cfg.rms_norm_eps is not None else 1e-6,
136+
),
137+
lambda cfg, dtype: (_hidden_states(cfg, dtype),),
138+
),
139+
(
140+
"mlp",
141+
lambda cfg: Phi3MLP(cfg),
142+
lambda cfg, dtype: (_hidden_states(cfg, dtype),),
143+
),
144+
pytest.param(
145+
"attention",
146+
lambda cfg: Phi3AttentionModule(cfg),
147+
lambda cfg, dtype: (
148+
_hidden_states(cfg, dtype),
149+
_position_ids(seq=min(4, cfg.max_position_embeddings or 4)),
150+
),
151+
marks=pytest.mark.xfail(strict=True, reason=xfail_reason),
152+
id="attention",
153+
),
154+
pytest.param(
155+
"decoder_layer",
156+
lambda cfg: Phi3DecoderLayerModule(cfg),
157+
lambda cfg, dtype: (
158+
_hidden_states(cfg, dtype),
159+
_position_ids(seq=min(4, cfg.max_position_embeddings or 4)),
160+
),
161+
marks=pytest.mark.xfail(strict=True, reason=xfail_reason),
162+
id="decoder_layer",
163+
),
164+
]
165+
166+
167+
def _dtype_cases() -> list:
168+
return [
169+
pytest.param(torch.float32, [], id="fp32"),
170+
pytest.param(
171+
torch.bfloat16,
172+
["bf16"],
173+
id="bf16",
174+
),
175+
pytest.param(
176+
torch.float16,
177+
[],
178+
id="fp16",
179+
),
180+
]
181+
182+
183+
def _vgf_dtype_cases() -> list:
184+
return [
185+
pytest.param(torch.float32, id="fp32"),
186+
pytest.param(
187+
torch.bfloat16,
188+
marks=pytest.mark.xfail(reason="BF16 runtime support not ready for VGF."),
189+
id="bf16",
190+
),
191+
]
192+
193+
194+
@pytest.mark.parametrize("dtype,tosa_extensions", _dtype_cases())
195+
@pytest.mark.parametrize("name,module_factory,input_factory", _module_cases())
196+
def test_phi3_tosa_FP_layers(
197+
dtype, tosa_extensions, name, module_factory, input_factory
198+
):
199+
config = _phi3_config()
200+
module = module_factory(config).to(dtype)
201+
inputs = input_factory(config, dtype)
202+
atol = 1e-02 if dtype == torch.bfloat16 else 1e-03
203+
rtol = 1e-02 if dtype == torch.bfloat16 else 1e-03
204+
205+
pipeline = TosaPipelineFP[input_t1 if len(inputs) == 1 else input_t2](
206+
module,
207+
inputs,
208+
aten_op=[],
209+
tosa_extensions=tosa_extensions or None,
210+
atol=atol,
211+
rtol=rtol,
212+
transform_passes=[
213+
ConvertInt64ConstOpsToInt32Pass(),
214+
ConvertInt64OutputOpsToInt32Pass(),
215+
InsertInt32CastsAfterInt64PlaceholdersPass(),
216+
],
217+
)
218+
pipeline.run()
219+
220+
221+
@pytest.mark.parametrize("name,module_factory,input_factory", _module_cases_int())
222+
def test_phi3_tosa_INT_layers(name, module_factory, input_factory):
223+
config = _phi3_config()
224+
module = module_factory(config)
225+
inputs = input_factory(config, torch.float32)
226+
227+
pipeline = TosaPipelineINT[input_t1 if len(inputs) == 1 else input_t2](
228+
module,
229+
inputs,
230+
aten_op=[],
231+
)
232+
pipeline.run()
233+
234+
235+
@common.SkipIfNoModelConverter
236+
@pytest.mark.parametrize("dtype", _vgf_dtype_cases())
237+
@pytest.mark.parametrize("name,module_factory,input_factory", _module_cases())
238+
def test_phi3_vgf_no_quant_layers(name, module_factory, input_factory, dtype):
239+
config = _phi3_config()
240+
module = module_factory(config).to(dtype)
241+
inputs = input_factory(config, dtype)
242+
243+
pipeline = VgfPipeline[input_t1 if len(inputs) == 1 else input_t2](
244+
module,
245+
inputs,
246+
aten_op=[],
247+
transform_passes=[
248+
ConvertInt64ConstOpsToInt32Pass(),
249+
ConvertInt64OutputOpsToInt32Pass(),
250+
InsertInt32CastsAfterInt64PlaceholdersPass(),
251+
],
252+
quantize=False,
253+
)
254+
pipeline.run()
255+
256+
257+
@common.SkipIfNoModelConverter
258+
@pytest.mark.parametrize("dtype", _vgf_dtype_cases())
259+
@pytest.mark.parametrize("name,module_factory,input_factory", _module_cases())
260+
def test_phi3_vgf_quant_layers(name, module_factory, input_factory, dtype):
261+
config = _phi3_config()
262+
module = module_factory(config).to(dtype)
263+
inputs = input_factory(config, dtype)
264+
265+
pipeline = VgfPipeline[input_t1 if len(inputs) == 1 else input_t2](
266+
module,
267+
inputs,
268+
aten_op=[],
269+
quantize=True,
270+
)
271+
pipeline.run()

0 commit comments

Comments
 (0)