@@ -31,7 +31,7 @@ def __init__(
3131 unlist_data : Any ,
3232 partitioning : Partitioning ,
3333 element_type : Any = None ,
34- element_metadata : dict = None ,
34+ element_metadata : Optional [ dict ] = None ,
3535 metadata : Optional [dict ] = None ,
3636 validate : bool = True ,
3737 ):
@@ -65,7 +65,7 @@ class for the type of elements.
6565 if validate :
6666 _validate_data_and_partitions (self ._unlist_data , self ._partitioning )
6767
68- def _define_output (self , in_place : bool = False ) -> "Partitioning " :
68+ def _define_output (self , in_place : bool = False ) -> "CompressedList " :
6969 if in_place is True :
7070 return self
7171 else :
@@ -207,7 +207,7 @@ def get_names(self) -> Optional[ut.NamedList]:
207207 """Get the names of list elements."""
208208 return self ._partitioning .get_names ()
209209
210- def set_names (self , names : Sequence [str ], in_place : bool = False ) -> "CompressedList" :
210+ def set_names (self , names : List [str ], in_place : bool = False ) -> "CompressedList" :
211211 """Set the names of list elements.
212212
213213 names:
@@ -401,7 +401,7 @@ def __getitem__(self, key: Union[int, str, slice]) -> Any:
401401 """
402402 # string keys (names)
403403 if isinstance (key , str ):
404- if key not in self .names :
404+ if key not in list ( self .get_names ()) :
405405 raise KeyError (f"No element named '{ key } '." )
406406 key = list (self .names ).index (key )
407407
@@ -422,14 +422,14 @@ def __getitem__(self, key: Union[int, str, slice]) -> Any:
422422 for i in indices :
423423 start , end = self ._partitioning .get_partition_range (i )
424424 result .append (self .extract_range (start , end ))
425-
426- # Create a new CompressedList from the result
427- return self . __class__ .from_list (
425+
426+ current_class_const = type ( self )
427+ return current_class_const .from_list (
428428 result , names = [self .names [i ] for i in indices ] if self .names [0 ] is not None else None
429429 )
430430
431431 else :
432- raise TypeError ("Index must be int, str, or slice." )
432+ raise TypeError ("'key' must be int, str, or slice." )
433433
434434 ##################################
435435 ######>> abstract methods <<######
@@ -460,8 +460,8 @@ def extract_range(self, start: int, end: int) -> Any:
460460
461461 @classmethod
462462 def from_list (
463- cls , lst : List [Any ], names : Optional [Sequence [str ]] = None , metadata : dict = None
464- ) -> "CompressedList[Any] " :
463+ cls , lst : List [Any ], names : Optional [Sequence [str ]] = None , metadata : Optional [ dict ] = None
464+ ) -> "CompressedList" :
465465 """Create a CompressedList from a regular list.
466466
467467 This method must be implemented by subclasses to handle
@@ -519,7 +519,7 @@ def unlist(self, use_names: bool = True) -> Any:
519519 """
520520 return self ._unlist_data
521521
522- def relist (self , unlist_data : Any ) -> "CompressedList[Any] " :
522+ def relist (self , unlist_data : Any ) -> "CompressedList" :
523523 """Create a new `CompressedList` with the same partitioning but different data.
524524
525525 Args:
@@ -531,15 +531,16 @@ def relist(self, unlist_data: Any) -> "CompressedList[Any]":
531531 """
532532 _validate_data_and_partitions (unlist_data , self ._partitioning )
533533
534- return self .__class__ (
534+ current_class_const = type (self )
535+ return current_class_const (
535536 unlist_data ,
536537 self ._partitioning .copy (),
537538 element_type = self ._element_type ,
538539 element_metadata = self ._element_metadata .copy (),
539540 metadata = self ._metadata .copy (),
540541 )
541542
542- def extract_subset (self , indices : Sequence [int ]) -> "CompressedList[Any] " :
543+ def extract_subset (self , indices : Sequence [int ]) -> "CompressedList" :
543544 """Extract a subset of elements by indices.
544545
545546 Args:
@@ -555,8 +556,8 @@ def extract_subset(self, indices: Sequence[int]) -> "CompressedList[Any]":
555556 raise IndexError (f"Index { i } out of range" )
556557
557558 # Extract element lengths and names
558- new_lengths = [ self .get_element_lengths ()[ i ] for i in indices ]
559- new_names = [ self .names [ i ] for i in indices ] if self .names [ 0 ] is not None else None
559+ new_lengths = ut . subset_sequence ( self .get_element_lengths (), indices )
560+ new_names = ut . subset_sequence ( self .names , indices ) if self .names is not None else None
560561
561562 # Create new partitioning
562563 new_partitioning = Partitioning .from_lengths (new_lengths , new_names )
@@ -573,16 +574,16 @@ def extract_subset(self, indices: Sequence[int]) -> "CompressedList[Any]":
573574 if isinstance (self ._unlist_data , np .ndarray ):
574575 new_data = np .concatenate (new_data )
575576
576- # Create new compressed list
577- return self . __class__ (
577+ current_class_const = type ( self )
578+ return current_class_const (
578579 new_data ,
579580 new_partitioning ,
580581 element_type = self ._element_type ,
581582 element_metadata = {k : v for k , v in self ._element_metadata .items () if k in indices },
582583 metadata = self ._metadata .copy (),
583584 )
584585
585- def lapply (self , func : Callable ) -> "CompressedList[Any] " :
586+ def lapply (self , func : Callable ) -> "CompressedList" :
586587 """Apply a function to each element.
587588
588589 Args:
@@ -593,4 +594,6 @@ def lapply(self, func: Callable) -> "CompressedList[Any]":
593594 A new CompressedList with the results.
594595 """
595596 result = [func (elem ) for elem in self ]
596- return self .__class__ .from_list (result , self .names , self ._metadata )
597+
598+ current_class_const = type (self )
599+ return current_class_const .from_list (result , self .names , self ._metadata )
0 commit comments