Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion named_arrays/_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ def _getitem(
inputs = array.inputs
outputs = array.outputs

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, na.AbstractArray):
if isinstance(item, na.AbstractFunctionArray):
if not np.all(item.inputs == array.inputs):
Expand Down Expand Up @@ -991,6 +994,9 @@ def __setitem__(
value: float | u.Quantity | na.FunctionArray,
):

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, na.AbstractFunctionArray):
if not np.all(item.inputs == self.inputs):
raise ValueError("boolean advanced index does not have the same inputs as the array")
Expand All @@ -1011,7 +1017,7 @@ def __setitem__(
item_inputs[ax] = item_outputs[ax] = item_ax
else:
raise TypeError(
f"`item` must be an instance of `{dict.__name__}`, or `{na.AbstractFunctionArray.__name__}`, "
f"`item` must be an instance of `bool`, `{dict.__name__}`, or `{na.AbstractFunctionArray.__name__}`, "
f"got `{type(item)}`"
)

Expand Down
19 changes: 8 additions & 11 deletions named_arrays/_functions/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def test_interp_linear_identity(

with pytest.raises(NotImplementedError):
array.interp_linear(array.indices)

@pytest.mark.parametrize(
argnames='item',
argvalues=[
Expand All @@ -206,6 +207,7 @@ def test_interp_linear_identity(
outputs=na.ScalarArrayRange(0, 2, axis='y'),
)
),
True,
na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5,
na.FunctionArray(
inputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y),
Expand All @@ -225,9 +227,6 @@ def test_interp_linear_identity(
)
],
)



def test__getitem__(
self,
array: na.AbstractFunctionArray,
Expand All @@ -240,6 +239,9 @@ def test__getitem__(
array[item]
return

if isinstance(item, bool):
item_outputs = item_inputs = item

if isinstance(item, na.AbstractArray):
item = item.explicit
if isinstance(item, na.AbstractFunctionArray):
Expand All @@ -261,7 +263,7 @@ def test__getitem__(
item_outputs[ax] = item_ax.outputs
else:
if ax in array.axes_center:
#can't assume center ax is in both outputs and inputs
# can't assume center ax is in both outputs and inputs
if ax in array.inputs.shape:
item_inputs[ax] = item_ax
if ax in array.outputs.shape:
Expand Down Expand Up @@ -409,7 +411,6 @@ def test_ufunc_binary(
assert np.all(result[i] == result_out[i])
assert result_out[i] is out[i]


class TestMatmul(
named_arrays.tests.test_core.AbstractTestAbstractArray.TestMatmul
):
Expand Down Expand Up @@ -704,7 +705,6 @@ def test_arg_reduction_functions(

assert np.all(result.outputs == outputs_expected)


class TestFFTLikeFunctions(
named_arrays.tests.test_core.AbstractTestAbstractArray.TestArrayFunctions.TestFFTLikeFunctions,
):
Expand Down Expand Up @@ -864,7 +864,7 @@ def test_pcolormesh(

components = list(array.inputs.components.keys())[:2]

#probably a smarter way to deal with plotting broadcasting during testing
# probably a smarter way to deal with plotting broadcasting during testing
if len(array.axes) > 2:
array = array[dict(z=0)]

Expand Down Expand Up @@ -947,6 +947,7 @@ class TestFunctionArray(
dict(y=slice(None)),
dict(y=na.ScalarArrayRange(0, _num_y, axis='y')),
dict(x=na.ScalarArrayRange(0, _num_x, axis='x'), y=na.ScalarArrayRange(0, _num_y, axis='y')),
True,
na.FunctionArray(
inputs=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y),
outputs=na.ScalarArray.ones(shape=dict(y=_num_y), dtype=bool),
Expand Down Expand Up @@ -984,7 +985,6 @@ class TestMatmul(
pass



@pytest.mark.parametrize("type_array", [na.FunctionArray])
class TestFunctionArrayCreation(
named_arrays.tests.test_core.AbstractTestAbstractExplicitArrayCreation,
Expand Down Expand Up @@ -1125,6 +1125,3 @@ class TestMatmul(
AbstractTestAbstractFunctionArray.TestMatmul
):
pass



12 changes: 12 additions & 0 deletions named_arrays/_scalars/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ def _getitem(
item: dict[str, int | slice | AbstractScalarArray] | AbstractScalarArray,
):

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, AbstractScalarArray):

if not set(item.shape).issubset(self.axes):
Expand Down Expand Up @@ -1016,6 +1019,9 @@ def __setitem__(
else:
value = ScalarArray(value)

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, AbstractScalarArray):

item = item.explicit
Expand Down Expand Up @@ -1080,6 +1086,12 @@ def __setitem__(

self.ndarray_aligned(axes_self)[tuple(index)] = value

else:
raise TypeError(
f"`item` must be an instance of `bool`, `{na.AbstractArray.__name__}`, or {dict.__name__}, "
f"got `{type(item)}`"
)


@dataclasses.dataclass(eq=False, repr=False)
class AbstractImplicitScalarArray(
Expand Down
7 changes: 6 additions & 1 deletion named_arrays/_scalars/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def test_change_axis_index(self, array: na.ScalarArray, index: int):
dict(y=0),
dict(y=slice(0,1)),
dict(y=na.ScalarArray(np.array([0, 1]), axes=('y', ))),
True,
na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5,
]
)
Expand All @@ -448,7 +449,10 @@ def test__getitem__(
else:
if array.shape:
result = array[item]
item_expected = (Ellipsis, item.ndarray)
if isinstance(item, bool):
item_expected = item
else:
item_expected = (Ellipsis, item.ndarray)
else:
with pytest.raises(ValueError):
array[item]
Expand Down Expand Up @@ -1319,6 +1323,7 @@ class TestScalarArray(
dict(y=slice(None)),
dict(y=na.ScalarArrayRange(0, _num_y, axis='y')),
dict(x=na.ScalarArrayRange(0, _num_x, axis='x'), y=na.ScalarArrayRange(0, _num_y, axis='y')),
True,
na.ScalarArray.ones(shape=dict(y=_num_y), dtype=bool),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_shape_distribution(self, array: na.AbstractUncertainScalarArray):
)
)
),
True,
na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5,
na.UncertainScalarArray(
nominal=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y),
Expand All @@ -132,7 +133,10 @@ def test__getitem__(
):
super().test__getitem__(array=array, item=item)

if isinstance(item, na.AbstractArray):
if isinstance(item, bool):
item_nominal = item_distribution = item

elif isinstance(item, na.AbstractArray):

if not set(item.shape).issubset(array.shape_distribution):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -795,6 +799,7 @@ class TestUncertainScalarArray(
dict(y=slice(None)),
dict(y=na.ScalarArrayRange(0, _num_y, axis='y')),
dict(x=na.ScalarArrayRange(0, _num_x, axis='x'), y=na.ScalarArrayRange(0, _num_y, axis='y')),
True,
na.ScalarArray.ones(shape=dict(y=_num_y), dtype=bool),
],
)
Expand Down
8 changes: 7 additions & 1 deletion named_arrays/_scalars/uncertainties/uncertainties.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def _getitem(
nominal = na.as_named_array(array.nominal)
distribution = na.as_named_array(array.distribution)

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, na.AbstractArray):
item = item.explicit
if isinstance(item, AbstractUncertainScalarArray):
Expand Down Expand Up @@ -649,6 +652,9 @@ def __setitem__(
):
shape_self = self.shape

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, na.AbstractArray):

item = item.explicit
Expand Down Expand Up @@ -697,7 +703,7 @@ def __setitem__(

else:
raise TypeError(
f"`item` must be an instance of `{na.AbstractArray.__name__}` or {dict.__name__}, "
f"`item` must be an instance of `bool`, `{na.AbstractArray.__name__}`, or {dict.__name__}, "
f"got `{type(item)}`"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _cartesian2d_items() -> list[na.AbstractArray | dict[str, int, slice, na.Abs
)
)
),
True,
na.ScalarLinearSpace(0, 1, axis='y', num=_num_y) > 0.5,
na.UniformUncertainScalarArray(
nominal=na.ScalarLinearSpace(0, 1, axis='y', num=_num_y),
Expand Down Expand Up @@ -259,6 +260,7 @@ class TestCartesian2dVectorArray(
y=na.ScalarArrayRange(0, _num_y, axis='y'),
)
),
True,
na.ScalarArray.ones(shape=dict(y=_num_y), dtype=bool),
np.ones_like(na.Cartesian2dVectorArray(), dtype=bool, shape=dict(y=_num_y)),
],
Expand Down
3 changes: 3 additions & 0 deletions named_arrays/_vectors/tests/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def test__getitem__(
for c in components:
components_item[c][ax] = components_item_ax[c]

elif isinstance(item, bool):
components_item = array.type_explicit.from_scalar(item, like=array).components

else:
if not array.shape:
with pytest.raises(ValueError):
Expand Down
8 changes: 7 additions & 1 deletion named_arrays/_vectors/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def _getitem(
shape_array = array.shape
components = array.components

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, na.AbstractArray):
item = item.explicit
shape_item = item.shape
Expand Down Expand Up @@ -610,6 +613,9 @@ def __setitem__(
):
components_self = self.components

if isinstance(item, bool):
item = na.as_named_array(item)

if isinstance(item, na.AbstractArray):
if isinstance(item, na.AbstractVectorArray):
if item.type_abstract == self.type_abstract:
Expand Down Expand Up @@ -662,7 +668,7 @@ def __setitem__(

else:
raise TypeError(
f"`item` must be an instance of `{na.AbstractArray.__name__}` or {dict.__name__}, "
f"`item` must be an instance of `bool`, `{na.AbstractArray.__name__}`, or {dict.__name__}, "
f"got `{type(item)}`"
)

Expand Down
Loading