Skip to content

Commit e89bb1f

Browse files
authored
Passing dtype to Pandas.Series (#53)
* Passing dtype to Pandas.Series * Fix test
1 parent 6f1a7b8 commit e89bb1f

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ jobs:
2424
with:
2525
environment-name: test-env
2626
create-args: >-
27-
python
27+
python=3.11
2828
pip
29-
pandas
30-
numpy
29+
pandas=1
30+
numpy=1
3131
xarray
3232
pytest
3333

traittypes/tests/test_traittypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class TestSeries(TestCase):
149149
def test_series_equal(self):
150150
notifications = []
151151
class Foo(HasTraits):
152-
bar = Series([1, 2])
152+
bar = Series([1, 2], dtype=np.int64)
153153
@observe('bar')
154154
def _(self, change):
155155
notifications.append(change)

traittypes/traittypes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,20 @@ def set(self, obj, value):
155155
not old_value.equals(new_value)):
156156
obj._notify_trait(self.name, old_value, new_value)
157157

158-
def __init__(self, default_value=Empty, allow_none=False, klass=None, **kwargs):
158+
def __init__(self, default_value=Empty, allow_none=False, klass=None, klass_kwargs=None, **kwargs):
159159
if klass is None:
160160
klass = self.klass
161+
if klass_kwargs is None:
162+
klass_kwargs = {}
161163
if (klass is not None) and inspect.isclass(klass):
162164
self.klass = klass
163165
else:
164166
raise TraitError('The klass attribute must be a class'
165167
' not: %r' % klass)
166168
if default_value is Empty:
167-
default_value = klass()
169+
default_value = klass(**klass_kwargs)
168170
elif default_value is not None and default_value is not Undefined:
169-
default_value = klass(default_value)
171+
default_value = klass(default_value, **klass_kwargs)
170172
super(PandasType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)
171173

172174
def make_dynamic_default(self):
@@ -198,12 +200,12 @@ class Series(PandasType):
198200
info_text = 'a pandas series'
199201
dtype = None
200202

201-
def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
203+
def __init__(self, default_value=Empty, allow_none=False, dtype=np.float64, **kwargs):
202204
if 'klass' not in kwargs and self.klass is None:
203205
import pandas as pd
204206
kwargs['klass'] = pd.Series
205207
super(Series, self).__init__(
206-
default_value=default_value, allow_none=allow_none, **kwargs)
208+
default_value=default_value, allow_none=allow_none, klass_kwargs={"dtype": dtype}, **kwargs)
207209
self.tag(dtype=dtype)
208210
self.dtype = dtype
209211

0 commit comments

Comments
 (0)