Skip to content

Commit 24f6639

Browse files
Merge pull request #181 from clEsperanto/improve_pyopencl_interoperability
improve pyopencl interoperability + added test
2 parents 87fd2d1 + 6d4c65d commit 24f6639

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pyclesperanto_prototype/_tier0/_create.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def create_like(*args):
2929
dimensions = dimensions.shape
3030
elif isinstance(dimensions, np.ndarray):
3131
dimensions = dimensions.shape[::-1]
32+
elif hasattr(dimensions, "shape"):
33+
dimensions = dimensions.shape
3234
return create(dimensions)
3335

3436
def create_binary_like(*args):
@@ -37,6 +39,8 @@ def create_binary_like(*args):
3739
dimensions = dimensions.shape
3840
elif isinstance(dimensions, np.ndarray):
3941
dimensions = dimensions.shape[::-1]
42+
elif hasattr(dimensions, "shape"):
43+
dimensions = dimensions.shape
4044
return create(dimensions, np.uint8)
4145

4246
def create_labels_like(*args):
@@ -45,6 +49,8 @@ def create_labels_like(*args):
4549
dimensions = dimensions.shape
4650
elif isinstance(dimensions, np.ndarray):
4751
dimensions = dimensions.shape[::-1]
52+
elif hasattr(dimensions, "shape"):
53+
dimensions = dimensions.shape
4854
return create(dimensions, np.uint32)
4955

5056
def create_same_type_like(*args):
@@ -53,6 +59,8 @@ def create_same_type_like(*args):
5359
dimensions = dimensions.shape
5460
elif isinstance(dimensions, np.ndarray):
5561
dimensions = dimensions.shape[::-1]
62+
elif hasattr(dimensions, "shape"):
63+
dimensions = dimensions.shape
5664
return create(dimensions, dimensions.dtype)
5765

5866
def create_pointlist_from_labelmap(source, *args):

tests/test_pyopencl_compatibility.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,21 @@ def test_pyopencl_compatibility():
1515
cl_c = cl_a + cl_b
1616

1717
import pyclesperanto_prototype as cle
18-
cl_c = cl_a + cl_b
18+
cl_c = cl_a + cl_b
19+
20+
def test_semi_push():
21+
import numpy as np
22+
img = np.asarray([[1,2],[3,4], [5,6]])
23+
24+
import pyclesperanto_prototype as cle
25+
device = cle.get_device()
26+
27+
import pyopencl.array as cla
28+
pushed = cla.to_device(device.queue, img)
29+
30+
print(type(pushed))
31+
print(pushed.shape)
32+
blurred = cle.gaussian_blur(pushed, sigma_x=10, sigma_y=10, sigma_z=10)
33+
34+
assert np.array_equal(blurred.shape, pushed.shape)
35+

0 commit comments

Comments
 (0)