diff --git a/traittypes/traittypes.py b/traittypes/traittypes.py index 516a91b..9de6e97 100644 --- a/traittypes/traittypes.py +++ b/traittypes/traittypes.py @@ -187,7 +187,8 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs): import pandas as pd kwargs['klass'] = pd.DataFrame super(DataFrame, self).__init__( - default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs) + default_value=default_value, allow_none=allow_none, **kwargs) + self.tag(dtype=dtype) class Series(PandasType): @@ -202,7 +203,8 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs): import pandas as pd kwargs['klass'] = pd.Series super(Series, self).__init__( - default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs) + default_value=default_value, allow_none=allow_none, **kwargs) + self.tag(dtype=dtype) self.dtype = dtype @@ -266,7 +268,8 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs): import xarray as xr kwargs['klass'] = xr.Dataset super(Dataset, self).__init__( - default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs) + default_value=default_value, allow_none=allow_none, **kwargs) + self.tag(dtype=dtype) class DataArray(XarrayType): @@ -281,5 +284,6 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs): import xarray as xr kwargs['klass'] = xr.DataArray super(DataArray, self).__init__( - default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs) + default_value=default_value, allow_none=allow_none, **kwargs) + self.tag(dtype=dtype) self.dtype = dtype