Skip to content

Commit 1ac9ee5

Browse files
committed
test: add unit tests to execute transform flow e2e
1 parent 192dcdb commit 1ac9ee5

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ module-name = "cocoindex._engine"
2929
features = ["pyo3/extension-module"]
3030

3131
[project.optional-dependencies]
32-
test = ["pytest"]
33-
dev = ["ruff", "pre-commit"]
32+
test = [
33+
"pytest",
34+
"pytest-asyncio",
35+
]
36+
dev = [
37+
"ruff",
38+
"pre-commit",
39+
]
3440
embeddings = ["sentence-transformers>=3.3.1"]
3541

3642
[tool.mypy]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import typing
2+
from dataclasses import dataclass
3+
4+
import pytest
5+
6+
import cocoindex
7+
8+
9+
@dataclass
10+
class Child:
11+
value: int
12+
13+
14+
@dataclass
15+
class Parent:
16+
children: list[Child]
17+
18+
19+
# Fixture to initialize CocoIndex library
20+
@pytest.fixture(scope="session", autouse=True)
21+
def init_cocoindex() -> typing.Generator[None, None, None]:
22+
cocoindex.init()
23+
yield
24+
25+
26+
@cocoindex.op.function()
27+
def add_suffix(text: str) -> str:
28+
"""Append ' world' to the input text."""
29+
return f"{text} world"
30+
31+
32+
@cocoindex.transform_flow()
33+
def simple_transform(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[str]:
34+
"""Transform flow that applies add_suffix to input text."""
35+
return text.transform(add_suffix)
36+
37+
38+
@cocoindex.op.function()
39+
def extract_child_values(parent: Parent) -> list[int]:
40+
"""Extract values from each child in the Parent's children list."""
41+
return [child.value for child in parent.children]
42+
43+
44+
@cocoindex.transform_flow()
45+
def for_each_transform(
46+
data: cocoindex.DataSlice[Parent],
47+
) -> cocoindex.DataSlice[list[int]]:
48+
"""Transform flow that processes child rows to extract values."""
49+
return data.transform(extract_child_values)
50+
51+
52+
def test_simple_transform_flow() -> None:
53+
"""Test the simple transform flow."""
54+
input_text = "hello"
55+
result = simple_transform.eval(input_text)
56+
assert result == "hello world", f"Expected 'hello world', got {result}"
57+
58+
result = simple_transform.eval("")
59+
assert result == " world", f"Expected ' world', got {result}"
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_simple_transform_flow_async() -> None:
64+
"""Test the simple transform flow asynchronously."""
65+
input_text = "async"
66+
result = await simple_transform.eval_async(input_text)
67+
assert result == "async world", f"Expected 'async world', got {result}"
68+
69+
70+
def test_for_each_transform_flow() -> None:
71+
"""Test the complex transform flow with child rows."""
72+
input_data = Parent(children=[Child(1), Child(2), Child(3)])
73+
result = for_each_transform.eval(input_data)
74+
expected = [1, 2, 3]
75+
assert result == expected, f"Expected {expected}, got {result}"
76+
77+
input_data = Parent(children=[])
78+
result = for_each_transform.eval(input_data)
79+
assert result == [], f"Expected [], got {result}"
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_for_each_transform_flow_async() -> None:
84+
"""Test the complex transform flow asynchronously."""
85+
input_data = Parent(children=[Child(4), Child(5)])
86+
result = await for_each_transform.eval_async(input_data)
87+
expected = [4, 5]
88+
assert result == expected, f"Expected {expected}, got {result}"

0 commit comments

Comments
 (0)