Skip to content

Commit 3d6fd01

Browse files
authored
Fix: Improve RangeDim default value logic and bound validation [#2560] (#2561)
* Improve RangeDim default value handling and validation * Add unit test for RangeDim merge (__ior__) and validation logic
1 parent c6ffb42 commit 3d6fd01

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

coremltools/converters/mil/input_types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,22 @@ def __repr__(self):
425425
def __str__(self):
426426
return 'RangeDim(lower_bound={}, upper_bound={}, default={}, symbol="{}")'.format(
427427
self.lower_bound, self.upper_bound, self.default, self.symbol)
428+
429+
def __ior__(self, other):
430+
if not isinstance(other, RangeDim):
431+
return NotImplemented
432+
433+
self.lower_bound = min(self.lower_bound, other.lower_bound)
434+
435+
if self.upper_bound == -1 or other.upper_bound == -1:
436+
self.upper_bound = -1
437+
else:
438+
self.upper_bound = max(self.upper_bound, other.upper_bound)
439+
440+
# Adjust default to fit in new bounds if needed
441+
self.default = max(self.lower_bound, min(self.default, self.upper_bound))
442+
443+
return self
428444

429445

430446
class Shape:
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
from coremltools.converters.mil.input_types import RangeDim
3+
4+
def test_rangedim_default_within_bounds():
5+
dim = RangeDim(lower_bound=0, upper_bound=10, default=5)
6+
assert dim.default == 5
7+
8+
def test_rangedim_default_falls_back_to_lower_bound():
9+
dim = RangeDim(lower_bound=1, upper_bound=5)
10+
assert dim.default == 1
11+
12+
def test_rangedim_raises_if_default_below_lower():
13+
with pytest.raises(ValueError, match=r"less than minimum value"):
14+
RangeDim(lower_bound=3, upper_bound=10, default=2)
15+
16+
def test_rangedim_raises_if_default_above_upper():
17+
with pytest.raises(ValueError, match=r"greater than maximum value"):
18+
RangeDim(lower_bound=0, upper_bound=5, default=6)
19+
20+
def test_rangedim_ior_merges_bounds_and_adjusts_default():
21+
dim1 = RangeDim(lower_bound=0, upper_bound=10, default=5)
22+
dim2 = RangeDim(lower_bound=2, upper_bound=8, default=3)
23+
dim1 |= dim2
24+
assert dim1.lower_bound == 0 # keep this unless your __ior__ updates it
25+
assert dim1.upper_bound == 10 # same here unless logic changes
26+
assert dim1.default >= dim1.lower_bound
27+
assert dim1.default <= dim1.upper_bound

0 commit comments

Comments
 (0)