diff --git a/geomdl/NURBS.py b/geomdl/NURBS.py index 2eafaa9e..8b3ce8c0 100644 --- a/geomdl/NURBS.py +++ b/geomdl/NURBS.py @@ -91,9 +91,11 @@ def __deepcopy__(self, memo): result.init_cache() return result - def init_cache(self): - self._cache['ctrlpts'] = self._init_array() - self._cache['weights'] = self._init_array() + def init_cache(self, ctrlpts=[], weights=[]): + self._cache['ctrlpts'] = self._array_type(iter(ctrlpts)) + self._cache['weights'] = self._array_type(iter(weights)) + self._cache['ctrlpts'].register_callback(lambda: setattr(self, '_control_points_valid', False)) + self._cache['weights'].register_callback(lambda: setattr(self, '_control_points_valid', False)) @property def ctrlptsw(self): @@ -107,6 +109,9 @@ def ctrlptsw(self): :getter: Gets the weighted control points :setter: Sets the weighted control points """ + if not self._control_points_valid: + ctrlptsw = compatibility.combine_ctrlpts_weights(self.ctrlpts, self.weights) + self.set_ctrlpts(ctrlptsw) return self._control_points @ctrlptsw.setter @@ -127,22 +132,17 @@ def ctrlpts(self): # Populate the cache, if necessary if not self._cache['ctrlpts']: c, w = compatibility.separate_ctrlpts_weights(self._control_points) - self._cache['ctrlpts'] = [crd for crd in c] - self._cache['weights'] = w + self.init_cache(c, w) return self._cache['ctrlpts'] @ctrlpts.setter def ctrlpts(self, value): # Check if we can retrieve the existing weights. If not, generate a weights vector of 1.0s. if not self.weights: - weights = [1.0 for _ in range(len(value))] - else: - weights = self.weights + self.weights[:] = [1.0 for _ in range(len(value))] # Generate weighted control points using the new control points - ctrlptsw = compatibility.combine_ctrlpts_weights(value, weights) - - # Set new weighted control points + ctrlptsw = compatibility.combine_ctrlpts_weights(value, self.weights) self.set_ctrlpts(ctrlptsw) @property @@ -159,8 +159,7 @@ def weights(self): # Populate the cache, if necessary if not self._cache['weights']: c, w = compatibility.separate_ctrlpts_weights(self._control_points) - self._cache['ctrlpts'] = [crd for crd in c] - self._cache['weights'] = w + self.init_cache(c, w) return self._cache['weights'] @weights.setter @@ -174,6 +173,12 @@ def weights(self, value): # Set new weighted control points self.set_ctrlpts(ctrlptsw) + def _check_variables(self): + super(Curve, self)._check_variables() + if not self._control_points_valid: + ctrlptsw = compatibility.combine_ctrlpts_weights(self.ctrlpts, self.weights) + self.set_ctrlpts(ctrlptsw) + def reset(self, **kwargs): """ Resets control points and/or evaluated points. @@ -189,10 +194,9 @@ def reset(self, **kwargs): # Call parent function super(Curve, self).reset(ctrlpts=reset_ctrlpts, evalpts=reset_evalpts) + # Delete the caches if reset_ctrlpts: - # Delete the caches - self._cache['ctrlpts'] = self._init_array() - self._cache['weights'][:] = self._init_array() + self.init_cache() @export @@ -330,8 +334,8 @@ def ctrlpts(self): """ if not self._cache['ctrlpts']: c, w = compatibility.separate_ctrlpts_weights(self._control_points) - self._cache['ctrlpts'] = [crd for crd in c] - self._cache['weights'] = w + self._cache['ctrlpts'] = self._array_type(iter(c)) + self._cache['weights'] = self._array_type(iter(w)) return self._cache['ctrlpts'] @ctrlpts.setter @@ -361,8 +365,8 @@ def weights(self): """ if not self._cache['weights']: c, w = compatibility.separate_ctrlpts_weights(self._control_points) - self._cache['ctrlpts'] = [crd for crd in c] - self._cache['weights'] = w + self._cache['ctrlpts'] = self._array_type(iter(c)) + self._cache['weights'] = self._array_type(iter(w)) return self._cache['weights'] @weights.setter @@ -518,8 +522,8 @@ def ctrlpts(self): """ if not self._cache['ctrlpts']: c, w = compatibility.separate_ctrlpts_weights(self._control_points) - self._cache['ctrlpts'] = [crd for crd in c] - self._cache['weights'] = w + self._cache['ctrlpts'] = self._array_type(iter(c)) + self._cache['weights'] = self._array_type(iter(w)) return self._cache['ctrlpts'] @ctrlpts.setter @@ -549,8 +553,8 @@ def weights(self): """ if not self._cache['weights']: c, w = compatibility.separate_ctrlpts_weights(self._control_points) - self._cache['ctrlpts'] = [crd for crd in c] - self._cache['weights'] = w + self._cache['ctrlpts'] = self._array_type(iter(c)) + self._cache['weights'] = self._array_type(iter(w)) return self._cache['weights'] @weights.setter diff --git a/geomdl/_collections.py b/geomdl/_collections.py new file mode 100644 index 00000000..cf50ed31 --- /dev/null +++ b/geomdl/_collections.py @@ -0,0 +1,55 @@ +# callback handlers for list modification +# https://stackoverflow.com/a/13259435/1162349 + +import sys + +_pyversion = sys.version_info[0] + +def callback_method(func): + def notify(self,*args,**kwargs): + for _,callback in self._callbacks: + callback() + return func(self,*args,**kwargs) + return notify + +class NotifyList(list): + extend = callback_method(list.extend) + append = callback_method(list.append) + remove = callback_method(list.remove) + pop = callback_method(list.pop) + __delitem__ = callback_method(list.__delitem__) + __setitem__ = callback_method(list.__setitem__) + __iadd__ = callback_method(list.__iadd__) + __imul__ = callback_method(list.__imul__) + + #Take care to return a new NotifyList if we slice it. + if _pyversion < 3: + __setslice__ = callback_method(list.__setslice__) + __delslice__ = callback_method(list.__delslice__) + def __getslice__(self,*args): + return self.__class__(list.__getslice__(self,*args)) + + def __getitem__(self,item): + if isinstance(item,slice): + return self.__class__(list.__getitem__(self,item)) + else: + return list.__getitem__(self,item) + + def __init__(self,*args): + list.__init__(self,*args) + self._callbacks = [] + self._callback_cntr = 0 + + def register_callback(self,cb): + self._callbacks.append((self._callback_cntr,cb)) + self._callback_cntr += 1 + return self._callback_cntr - 1 + + def unregister_callback(self,cbid): + for idx,(i,cb) in enumerate(self._callbacks): + if i == cbid: + self._callbacks.pop(idx) + return cb + else: + return None + diff --git a/geomdl/abstract.py b/geomdl/abstract.py index 29e4b62e..ebb0c9cc 100644 --- a/geomdl/abstract.py +++ b/geomdl/abstract.py @@ -11,6 +11,7 @@ from .six import add_metaclass from . import vis, helpers, knotvector, voxelize, utilities, tessellate from .base import GeomdlBase, GeomdlEvaluator, GeomdlError, GeomdlWarning, GeomdlTypeSequence +from ._collections import NotifyList @add_metaclass(abc.ABCMeta) @@ -34,9 +35,9 @@ class Geometry(GeomdlBase): # __slots__ = ('_iter_index', '_array_type', '_eval_points') def __init__(self, *args, **kwargs): - self._geometry_type = "default" if not hasattr(self, '_geometry_type') else self._geometry_type # geometry type super(Geometry, self).__init__(*args, **kwargs) - self._array_type = list if not hasattr(self, '_array_type') else self._array_type # array storage type + self._geometry_type = getattr(self, '_geometry_type', 'default') # geometry type + self._array_type = getattr(self, '_array_type', NotifyList) # array storage type self._eval_points = self._init_array() # evaluated points def __iter__(self): @@ -134,6 +135,7 @@ def __init__(self, **kwargs): self._knot_vector = [self._init_array() for _ in range(self._pdim)] # knot vector self._control_points = self._init_array() # control points self._control_points_size = [0 for _ in range(self._pdim)] # control points length + self._control_points_valid = False self._delta = [self._dinit for _ in range(self._pdim)] # evaluation delta self._bounding_box = self._init_array() # bounding box self._evaluator = None # evaluator instance @@ -465,7 +467,7 @@ def validate_and_clean(pts_in, check_for, dimension, pts_out, **kws): raise ValueError("Number of arguments after ctrlpts must be " + str(self._pdim)) # Keyword arguments - array_init = kwargs.get('array_init', [[] for _ in range(len(ctrlpts))]) + array_init = kwargs.get('array_init', self._array_type([] for _ in range(len(ctrlpts)))) array_check_for = kwargs.get('array_check_for', (list, tuple)) callback_func = kwargs.get('callback', validate_and_clean) self._dimension = kwargs.get('dimension', len(ctrlpts[0])) @@ -479,6 +481,7 @@ def validate_and_clean(pts_in, check_for, dimension, pts_out, **kws): # Set control points and sizes self._control_points = callback_func(ctrlpts, array_check_for, self._dimension, array_init, **kwargs) self._control_points_size = [int(arg) for arg in args] + self._control_points_valid = True @abc.abstractmethod def render(self, **kwargs): @@ -890,6 +893,7 @@ def reset(self, **kwargs): if reset_ctrlpts: self._control_points = self._init_array() self._bounding_box = self._init_array() + self._control_points_valid = False if reset_evalpts: self._eval_points = self._init_array() diff --git a/tests/test_curve.py b/tests/test_curve.py index b1228d06..e980571c 100644 --- a/tests/test_curve.py +++ b/tests/test_curve.py @@ -5,9 +5,11 @@ Requires "pytest" to run. """ +import math from pytest import fixture, mark from geomdl import BSpline +from geomdl import NURBS from geomdl import evaluators from geomdl import helpers from geomdl import convert @@ -234,6 +236,13 @@ def nurbs_curve(spline_curve): curve.weights = [0.5, 1.0, 0.75, 1.0, 0.25, 1.0] return curve +@fixture +def unit_circle_tri_ctrlpts(): + r = 1. + a, h = 3. * r / math.sqrt(3.), 1.5 * r + ctrlpts = [(0., -r), (-a,-r), (-a/2,-r+h), (0., 2*h-r), (a/2, -r+h), (a, -r), (0., -r)] + return ctrlpts + def test_nurbs_curve2d_weights(nurbs_curve): assert nurbs_curve.weights == [0.5, 1.0, 0.75, 1.0, 0.25, 1.0] @@ -252,6 +261,27 @@ def test_nurbs_curve2d_eval(nurbs_curve, param, res): assert abs(evalpt[1] - res[1]) < GEOMDL_DELTA +@mark.parametrize("param, res", [ + (0.0, (0.0, -1.0)), + (0.2, (-0.9571859726038534, -0.2894736842105261)), + (0.5, (1.1102230246251568e-16, 1.0)), + (0.95, (0.27544074447012257, -0.9613180515759312)) +]) +def test_nurbs_curve2d_slice_eval(unit_circle_tri_ctrlpts, param, res): + crv = NURBS.Curve() + crv.degree = 2 + crv.ctrlpts = unit_circle_tri_ctrlpts + crv.knotvector = [0.,0.,0., 1./3, 1./3, 2./3, 2./3, 1.,1.,1.] + crv.weights[1::2] = [0.5, 0.5, 0.5] + + evalpt = crv.evaluate_single(param) + + assert abs(evalpt[0] - res[0]) < GEOMDL_DELTA + assert abs(evalpt[1] - res[1]) < GEOMDL_DELTA + + +# TODO: derivative of a circle is a circle +@mark.xfail @mark.parametrize("param, order, res", [ (0.0, 1, ((5.0, 5.0), (90.9090, 90.9090))), (0.2, 2, ((13.8181, 11.5103), (40.0602, 17.3878), (104.4062, -29.3672))),