Skip to content

Commit 885a947

Browse files
authored
Merge pull request #6 from greenc-FNAL/copilot/implement-variant-helper-in-pull-213
Integrate Variant helper from PR #245 for type-specific Python algorithm registration
2 parents fbffc68 + d2105cf commit 885a947

File tree

6 files changed

+111
-31
lines changed

6 files changed

+111
-31
lines changed

.github/copilot-instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ All Markdown files must strictly follow these markdownlint rules:
143143
- **C++ Driver**: Provides data streams (e.g., `test/python/driver.cpp`).
144144
- **Jsonnet Config**: Wires the graph (e.g., `test/python/pytypes.jsonnet`).
145145
- **Python Script**: Implements algorithms (e.g., `test/python/test_types.py`).
146-
- **Type Conversion**: `plugins/python/src/modulewrap.cpp` handles C++ $\leftrightarrow$ Python conversion.
146+
- **Type Conversion**: `plugins/python/src/modulewrap.cpp` handles C++ Python conversion.
147147
- **Mechanism**: Uses string comparison of type names (e.g., `"float64]]"`). This is brittle.
148148
- **Requirement**: Ensure converters exist for all types used in tests (e.g., `float`, `double`, `unsigned int`, and their vector equivalents).
149149
- **Warning**: Exact type matches are required. `numpy.float32` != `float`.

plugins/python/src/modulewrap.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include <stdexcept>
1010
#include <vector>
1111

12-
// static std::mutex g_py_mutex;
13-
1412
#ifdef PHLEX_HAVE_NUMPY
1513
#define NO_IMPORT_ARRAY
1614
#define PY_ARRAY_UNIQUE_SYMBOL phlex_ARRAY_API
@@ -111,7 +109,6 @@ namespace {
111109
static_assert(sizeof...(Args) == N, "Argument count mismatch");
112110

113111
PyGILRAII gil;
114-
// std::lock_guard<std::mutex> lock(g_py_mutex);
115112

116113
PyObject* result = PyObject_CallFunctionObjArgs(
117114
(PyObject*)m_callable, lifeline_transform(args.get())..., nullptr);
@@ -134,7 +131,6 @@ namespace {
134131
static_assert(sizeof...(Args) == N, "Argument count mismatch");
135132

136133
PyGILRAII gil;
137-
// std::lock_guard<std::mutex> lock(g_py_mutex);
138134

139135
PyObject* result =
140136
PyObject_CallFunctionObjArgs((PyObject*)m_callable, (PyObject*)args.get()..., nullptr);
@@ -372,7 +368,6 @@ namespace {
372368
static PyObjectPtr vint_to_py(std::shared_ptr<std::vector<int>> const& v)
373369
{
374370
PyGILRAII gil;
375-
// std::lock_guard<std::mutex> lock(g_py_mutex);
376371
if (!v)
377372
return PyObjectPtr();
378373
PyObject* list = PyList_New(v->size());
@@ -395,7 +390,6 @@ namespace {
395390
static PyObjectPtr vuint_to_py(std::shared_ptr<std::vector<unsigned int>> const& v)
396391
{
397392
PyGILRAII gil;
398-
// std::lock_guard<std::mutex> lock(g_py_mutex);
399393
if (!v)
400394
return PyObjectPtr();
401395
PyObject* list = PyList_New(v->size());
@@ -418,7 +412,6 @@ namespace {
418412
static PyObjectPtr vlong_to_py(std::shared_ptr<std::vector<long>> const& v)
419413
{
420414
PyGILRAII gil;
421-
// std::lock_guard<std::mutex> lock(g_py_mutex);
422415
if (!v)
423416
return PyObjectPtr();
424417
PyObject* list = PyList_New(v->size());
@@ -441,7 +434,6 @@ namespace {
441434
static PyObjectPtr vulong_to_py(std::shared_ptr<std::vector<unsigned long>> const& v)
442435
{
443436
PyGILRAII gil;
444-
// std::lock_guard<std::mutex> lock(g_py_mutex);
445437
if (!v)
446438
return PyObjectPtr();
447439
PyObject* list = PyList_New(v->size());
@@ -497,7 +489,6 @@ namespace {
497489
static std::shared_ptr<std::vector<int>> py_to_vint(PyObjectPtr pyobj)
498490
{
499491
PyGILRAII gil;
500-
// std::lock_guard<std::mutex> lock(g_py_mutex);
501492
auto vec = std::make_shared<std::vector<int>>();
502493
PyObject* obj = pyobj.get();
503494

@@ -537,7 +528,6 @@ namespace {
537528
static std::shared_ptr<std::vector<unsigned int>> py_to_vuint(PyObjectPtr pyobj)
538529
{
539530
PyGILRAII gil;
540-
// std::lock_guard<std::mutex> lock(g_py_mutex);
541531
auto vec = std::make_shared<std::vector<unsigned int>>();
542532
PyObject* obj = pyobj.get();
543533

@@ -577,7 +567,6 @@ namespace {
577567
static std::shared_ptr<std::vector<long>> py_to_vlong(PyObjectPtr pyobj)
578568
{
579569
PyGILRAII gil;
580-
// std::lock_guard<std::mutex> lock(g_py_mutex);
581570
auto vec = std::make_shared<std::vector<long>>();
582571
PyObject* obj = pyobj.get();
583572

@@ -617,7 +606,6 @@ namespace {
617606
static std::shared_ptr<std::vector<unsigned long>> py_to_vulong(PyObjectPtr pyobj)
618607
{
619608
PyGILRAII gil;
620-
// std::lock_guard<std::mutex> lock(g_py_mutex);
621609
auto vec = std::make_shared<std::vector<unsigned long>>();
622610
PyObject* obj = pyobj.get();
623611

@@ -657,7 +645,6 @@ namespace {
657645
static std::shared_ptr<std::vector<float>> py_to_vfloat(PyObjectPtr pyobj)
658646
{
659647
PyGILRAII gil;
660-
// std::lock_guard<std::mutex> lock(g_py_mutex);
661648
auto vec = std::make_shared<std::vector<float>>();
662649
PyObject* obj = pyobj.get();
663650

@@ -697,7 +684,6 @@ namespace {
697684
static std::shared_ptr<std::vector<double>> py_to_vdouble(PyObjectPtr pyobj)
698685
{
699686
PyGILRAII gil;
700-
// std::lock_guard<std::mutex> lock(g_py_mutex);
701687
auto vec = std::make_shared<std::vector<double>>();
702688
PyObject* obj = pyobj.get();
703689

@@ -863,8 +849,18 @@ static PyObject* parse_args(PyObject* args,
863849
return nullptr;
864850
}
865851

852+
// special case of Phlex Variant wrapper
853+
PyObject* wrapped_callable = PyObject_GetAttrString(callable, "phlex_callable");
854+
if (wrapped_callable) {
855+
// PyObject_GetAttrString returns a new reference, which we return
856+
callable = wrapped_callable;
857+
} else {
858+
// No wrapper, use the original callable with incremented reference count
859+
PyErr_Clear();
860+
Py_INCREF(callable);
861+
}
862+
866863
// no common errors detected; actual registration may have more checks
867-
Py_INCREF(callable);
868864
return callable;
869865
}
870866

test/python/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ except Exception:
137137
)
138138
list(APPEND ACTIVE_PY_CPHLEX_TESTS py:failure)
139139

140-
message(STATUS "Python_SITELIB: ${Python_SITELIB}")
141-
message(STATUS "Python_SITEARCH: ${Python_SITEARCH}")
142140
set(TEST_PYTHONPATH ${CMAKE_CURRENT_SOURCE_DIR})
143141
# Always add site-packages to PYTHONPATH for tests, as embedded python might
144142
# not find them especially in spack environments where they are in
@@ -154,7 +152,6 @@ except Exception:
154152
# Keep this for backward compatibility or if it adds something else
155153
endif()
156154
set(TEST_PYTHONPATH ${TEST_PYTHONPATH}:$ENV{PYTHONPATH})
157-
message(STATUS "TEST_PYTHONPATH: ${TEST_PYTHONPATH}")
158155

159156
set_tests_properties(
160157
${ACTIVE_PY_CPHLEX_TESTS}

test/python/adder.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,33 @@
44
real. It serves as a "Hello, World" equivalent for running Python code.
55
"""
66

7+
from typing import Protocol, TypeVar
78

8-
def add(i: int, j: int) -> int:
9+
from variant import Variant
10+
11+
12+
class AddableProtocol[T](Protocol):
13+
"""Typer bound for any types that can be added."""
14+
15+
def __add__(self, other: T) -> T: # noqa: D105
16+
...
17+
18+
19+
Addable = TypeVar('Addable', bound=AddableProtocol)
20+
21+
22+
def add(i: Addable, j: Addable) -> Addable:
923
"""Add the inputs together and return the sum total.
1024
1125
Use the standard `+` operator to add the two inputs together
1226
to arrive at their total.
1327
1428
Args:
15-
i (int): First input.
16-
j (int): Second input.
29+
i (Addable): First input.
30+
j (Addable): Second input.
1731
1832
Returns:
19-
int: Sum of the two inputs.
33+
Addable: Sum of the two inputs.
2034
2135
Examples:
2236
>>> add(1, 2)
@@ -40,4 +54,5 @@ def PHLEX_REGISTER_ALGORITHMS(m, config):
4054
Returns:
4155
None
4256
"""
43-
m.transform(add, input_family=config["input"], output_products=config["output"])
57+
int_adder = Variant(add, {"i": int, "j": int, "return": int}, "iadd")
58+
m.transform(int_adder, input_family=config["input"], output_products=config["output"])

test/python/variant.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Annotation helper for C++ typing variants.
2+
3+
Python algorithms are generic, like C++ templates, but the Phlex registration
4+
process requires a single unique signature. These helpers generate annotated
5+
functions for registration with the proper C++ types.
6+
"""
7+
8+
import copy
9+
from typing import Any, Callable
10+
11+
12+
class Variant:
13+
"""Wrapper to associate custom annotations with a callable.
14+
15+
This class wraps a callable and provides custom ``__annotations__`` and
16+
``__name__`` attributes, allowing the same underlying function or callable
17+
object to be registered multiple times with different type annotations.
18+
19+
By default, the provided callable is kept by reference, but can be cloned
20+
(e.g. for callable instances) if requested.
21+
22+
Phlex will recognize the "phlex_callable" data member, allowing an unwrap
23+
and thus saving an indirection. To detect performance degradation, the
24+
wrapper is not callable by default.
25+
26+
Attributes:
27+
phlex_callable (Callable): The underlying callable (public).
28+
__annotations__ (dict): Type information of arguments and return product.
29+
__name__ (str): The name associated with this variant.
30+
31+
Examples:
32+
>>> def add(i: Number, j: Number) -> Number:
33+
... return i + j
34+
...
35+
>>> int_adder = variant(add, {"i": int, "j": int, "return": int}, "iadd")
36+
"""
37+
38+
def __init__(
39+
self,
40+
f: Callable,
41+
annotations: dict[str, str | type | Any],
42+
name: str,
43+
clone: bool | str = False,
44+
allow_call: bool = False,
45+
):
46+
"""Annotate the callable F.
47+
48+
Args:
49+
f (Callable): Annotable function.
50+
annotations (dict): Type information of arguments and return product.
51+
name (str): Name to assign to this variant.
52+
clone (bool|str): If True (or "deep"), creates a shallow (deep) copy
53+
of the callable.
54+
allow_call (bool): Allow this wrapper to forward to the callable.
55+
"""
56+
if clone == 'deep':
57+
self.phlex_callable = copy.deepcopy(f)
58+
elif clone:
59+
self.phlex_callable = copy.copy(f)
60+
else:
61+
self.phlex_callable = f
62+
self.__annotations__ = annotations
63+
self.__name__ = name
64+
self._allow_call = allow_call
65+
66+
def __call__(self, *args, **kwargs):
67+
"""Raises an error if called directly.
68+
69+
Variant instances should not be called directly. The framework should
70+
extract ``phlex_callable`` instead and call that.
71+
72+
Raises:
73+
AssertionError: To indicate incorrect usage, unless overridden.
74+
"""
75+
assert self._allow_call, (
76+
f"Variant '{self.__name__}' was called directly. "
77+
f"The framework should extract phlex_callable instead."
78+
)
79+
return self.phlex_callable(*args, **kwargs) # type: ignore

test/python/verify_extended.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Observers to check for various types in tests."""
22

3-
import sys
4-
53

64
class VerifierInt:
75
"""Verify int values."""
@@ -42,7 +40,6 @@ def __init__(self, sum_total: int):
4240

4341
def __call__(self, value: "long") -> None: # type: ignore # noqa: F821
4442
"""Check if value matches expected sum."""
45-
print(f"VerifierLong: value={value}, expected={self._sum_total}")
4643
assert value == self._sum_total
4744

4845

@@ -57,7 +54,6 @@ def __init__(self, sum_total: int):
5754

5855
def __call__(self, value: "unsigned long") -> None: # type: ignore # noqa: F722
5956
"""Check if value matches expected sum."""
60-
print(f"VerifierULong: value={value}, expected={self._sum_total}")
6157
assert value == self._sum_total
6258

6359

@@ -72,7 +68,6 @@ def __init__(self, sum_total: float):
7268

7369
def __call__(self, value: "float") -> None:
7470
"""Check if value matches expected sum."""
75-
sys.stderr.write(f"VerifierFloat: value={value}, expected={self._sum_total}\n")
7671
assert abs(value - self._sum_total) < 1e-5
7772

7873

@@ -87,7 +82,6 @@ def __init__(self, sum_total: float):
8782

8883
def __call__(self, value: "double") -> None: # type: ignore # noqa: F821
8984
"""Check if value matches expected sum."""
90-
print(f"VerifierDouble: value={value}, expected={self._sum_total}")
9185
assert abs(value - self._sum_total) < 1e-5
9286

9387

@@ -102,7 +96,6 @@ def __init__(self, expected: bool):
10296

10397
def __call__(self, value: bool) -> None:
10498
"""Check if value matches expected."""
105-
print(f"VerifierBool: value={value}, expected={self._expected}")
10699
assert value == self._expected
107100

108101

0 commit comments

Comments
 (0)