Skip to content

Commit bf789f4

Browse files
committed
Add PyTorch frontend support for aten::items
1 parent 36051ef commit bf789f4

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/frontend/pytorch/node_context.hpp"
6+
#include "utils.hpp"
7+
8+
namespace ov {
9+
namespace frontend {
10+
namespace pytorch {
11+
namespace op {
12+
13+
OutputVector translate_items(const NodeContext& context) {
14+
num_inputs_check(context, 1, 1);
15+
16+
auto input = context.get_input(0);
17+
auto producer = input.get_node_shared_ptr();
18+
19+
// Only support dict values created directly by prim::DictConstruct.
20+
if (auto dict_construct = cast_fw_node(producer, "prim::DictConstruct")) {
21+
const auto inputs = dict_construct->input_values();
22+
23+
// DictConstruct inputs must be [key1, value1, key2, value2, ...]
24+
PYTORCH_OP_CONVERSION_CHECK(inputs.size() % 2 == 0,
25+
"aten::items: prim::DictConstruct inputs number is not divisible by 2.");
26+
27+
OutputVector item_outputs;
28+
for (size_t i = 0; i < inputs.size(); i += 2) {
29+
auto key = inputs.at(i);
30+
auto value = inputs.at(i + 1);
31+
auto tuple = context.mark_node(make_list_construct({key, value}));
32+
item_outputs.push_back(tuple);
33+
}
34+
35+
return {context.mark_node(make_list_construct(item_outputs))};
36+
}
37+
38+
// Fail explicitly for unsupported Dict producers.
39+
PYTORCH_OP_CONVERSION_CHECK(false,
40+
"aten::items: only Dict produced by prim::DictConstruct is supported.");
41+
}
42+
43+
} // namespace op
44+
} // namespace pytorch
45+
} // namespace frontend
46+
} // namespace ov

src/frontends/pytorch/src/op_table.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ OP_CONVERTER(translate_int);
143143
OP_CONVERTER(translate_inverse);
144144
OP_CONVERTER(translate_istft);
145145
OP_CONVERTER(translate_is_nonzero);
146+
OP_CONVERTER(translate_items);
146147
OP_CONVERTER(translate_kthvalue);
147148
OP_CONVERTER(translate_layer_norm);
148149
OP_CONVERTER(translate_len);
@@ -574,6 +575,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
574575
{"aten::is_grad_enabled", op::return_false_scalar},
575576
{"aten::istft", op::translate_istft},
576577
{"aten::is_nonzero", op::translate_is_nonzero},
578+
{"aten::items", op::translate_items},
577579
{"aten::kthvalue", op::translate_kthvalue},
578580
{"aten::isfinite", op::translate_1to1_match_1_inputs<opset10::IsFinite>},
579581
{"aten::isinf", op::translate_1to1_match_1_inputs<opset10::IsInf>},
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (C) 2018-2026 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import torch
6+
7+
from pytorch_layer_test_class import PytorchLayerTest
8+
9+
10+
class aten_items_dict_input(torch.nn.Module):
11+
def forward(self, x):
12+
d = {0: x, 1: x + x, 2: 2 * x}
13+
return d.items()
14+
15+
16+
class TestItems(PytorchLayerTest):
17+
def _prepare_input(self):
18+
return (self.random.randn(2, 5, 3, 4),)
19+
20+
@pytest.mark.nightly
21+
@pytest.mark.precommit
22+
def test_items(self, ie_device, precision, ir_version):
23+
self._test(aten_items_dict_input(), "aten::items",
24+
ie_device, precision, ir_version, use_convert_model=True)

0 commit comments

Comments
 (0)