3030_logger  =  logging .getLogger (__name__ )
3131
3232
33+ def  is_ndarray (obj : typing .Any ) ->  bool :
34+     """Check if object is an ndarray (wraps the linter warning).""" 
35+     return  isinstance (obj , (np .ndarray ))  # noqa: TID251 
36+ 
37+ 
3338@dataclasses .dataclass  
3439class  IOReaderData :
3540    """ 
@@ -58,10 +63,10 @@ def create(cls, other: typing.Any) -> "IOReaderData":
5863
5964        other should be such an instance. 
6065        """ 
61-         coords  =  other .coords 
62-         geoinfos  =  other .geoinfos 
63-         data  =  other .data 
64-         datetimes  =  other .datetimes 
66+         coords  =  np . asarray ( other .coords ) 
67+         geoinfos  =  np . asarray ( other .geoinfos ) 
68+         data  =  np . asarray ( other .data ) 
69+         datetimes  =  np . asarray ( other .datetimes ) 
6570
6671        n_datapoints  =  len (data )
6772
@@ -130,22 +135,22 @@ class OutputDataset:
130135    item_key : ItemKey 
131136
132137    # (datapoints, channels, ens) 
133-     data : zarr .Array   # wrong type => array like 
138+     data : zarr .Array  |   NDArray    # wrong type => array like 
134139
135140    # (datapoints,) 
136-     times : zarr .Array 
141+     times : zarr .Array   |   NDArray 
137142
138143    # (datapoints, 2) 
139-     coords : zarr .Array 
144+     coords : zarr .Array   |   NDArray 
140145
141146    # (datapoints, geoinfos) geoinfos are stream dependent => 0 for most gridded data 
142-     geoinfo : zarr .Array 
147+     geoinfo : zarr .Array   |   NDArray 
143148
144149    channels : list [str ]
145150    geoinfo_channels : list [str ]
146151
147152    @functools .cached_property  
148-     def  arrays (self ) ->  dict [str , zarr .Array ]:
153+     def  arrays (self ) ->  dict [str , zarr .Array   |   NDArray ]:
149154        """Iterate over the arrays and their names.""" 
150155        return  {
151156            "data" : self .data ,
@@ -236,7 +241,8 @@ def write_zarr(self, item: OutputItem):
236241        """Write one output item to the zarr store.""" 
237242        group  =  self ._get_group (item .key , create = True )
238243        for  dataset  in  item .datasets :
239-             self ._write_dataset (group , dataset )
244+             if  dataset  is  not None :
245+                 self ._write_dataset (group , dataset )
240246
241247    def  get_data (self , sample : int , stream : str , forecast_step : int ) ->  OutputItem :
242248        """Get datasets for the output item matching the arguments.""" 
@@ -285,6 +291,7 @@ def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset):
285291            self ._create_dataset (dataset_group , array_name , array )
286292
287293    def  _create_dataset (self , group : zarr .Group , name : str , array : NDArray ):
294+         assert  is_ndarray (array ), f"Expected ndarray but got: { type (array )}  
288295        if  array .size  ==  0 :  # sometimes for geoinfo 
289296            chunks  =  None 
290297        else :
@@ -394,20 +401,10 @@ def extract(self, key: ItemKey) -> OutputItem:
394401            target_data  =  np .zeros ((0 , len (self .target_channels [stream_idx ])), dtype = np .float32 )
395402            preds_data  =  np .zeros ((0 , len (self .target_channels [stream_idx ])), dtype = np .float32 )
396403        else :
397-             target_data  =  (
398-                 self .targets [offset_key .forecast_step ][stream_idx ][0 ][datapoints ]
399-                 .cpu ()
400-                 .detach ()
401-                 .numpy ()
402-             )
403-             preds_data  =  (
404-                 self .predictions [offset_key .forecast_step ][stream_idx ][0 ]
405-                 .transpose (1 , 0 )
406-                 .transpose (1 , 2 )[datapoints ]
407-                 .cpu ()
408-                 .detach ()
409-                 .numpy ()
410-             )
404+             target_data  =  self .targets [offset_key .forecast_step ][stream_idx ][0 ][datapoints ]
405+             preds_data  =  self .predictions [offset_key .forecast_step ][stream_idx ][0 ].transpose (
406+                 1 , 2 , 0 
407+             )[datapoints ]
411408
412409        data_coords  =  self ._extract_coordinates (stream_idx , offset_key , datapoints )
413410
@@ -423,6 +420,8 @@ def extract(self, key: ItemKey) -> OutputItem:
423420        else :
424421            source_dataset  =  None 
425422
423+         assert  is_ndarray (target_data ), f"Expected ndarray but got: { type (target_data )}  
424+         assert  is_ndarray (preds_data ), f"Expected ndarray but got: { type (preds_data )}  
426425        return  OutputItem (
427426            key = key ,
428427            source = source_dataset ,
@@ -501,10 +500,10 @@ def _extract_sources(self, sample, stream_idx, key):
501500        source_dataset  =  OutputDataset (
502501            "source" ,
503502            key ,
504-             source .data ,
505-             source .datetimes ,
506-             source .coords ,
507-             source .geoinfos ,
503+             np . asarray ( source .data ) ,
504+             np . asarray ( source .datetimes ) ,
505+             np . asarray ( source .coords ) ,
506+             np . asarray ( source .geoinfos ) ,
508507            channels ,
509508            geoinfo_channels ,
510509        )
0 commit comments