Skip to content

Commit 29e0da6

Browse files
committed
Add types
1 parent 114845f commit 29e0da6

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

xarray/structure/concat.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def concat(
4646
objs: Iterable[T_Dataset],
4747
dim: Hashable | T_Variable | T_DataArray | pd.Index | Any,
4848
data_vars: T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT,
49-
coords: ConcatOptions | list[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
49+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
5050
compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT,
5151
positions: Iterable[Iterable[int]] | None = None,
5252
fill_value: object = dtypes.NA,
@@ -61,7 +61,7 @@ def concat(
6161
objs: Iterable[T_DataArray],
6262
dim: Hashable | T_Variable | T_DataArray | pd.Index | Any,
6363
data_vars: T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT,
64-
coords: ConcatOptions | list[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
64+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
6565
compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT,
6666
positions: Iterable[Iterable[int]] | None = None,
6767
fill_value: object = dtypes.NA,
@@ -75,7 +75,7 @@ def concat(
7575
objs,
7676
dim,
7777
data_vars: T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT,
78-
coords: ConcatOptions | list[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
78+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
7979
compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT,
8080
positions=None,
8181
fill_value=dtypes.NA,
@@ -339,25 +339,25 @@ def _calc_concat_dim_index(
339339

340340

341341
def _calc_concat_over(
342-
datasets,
343-
dim,
344-
all_dims,
342+
datasets: list[T_Dataset],
343+
dim: Hashable,
344+
all_dims: set[Hashable],
345345
data_vars: T_DataVars | CombineKwargDefault,
346-
coords,
347-
compat,
348-
):
346+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
347+
compat: CompatOptions | CombineKwargDefault,
348+
) -> tuple[set[Hashable], dict[Hashable, bool], list[int], set[Hashable]]:
349349
"""
350350
Determine which dataset variables need to be concatenated in the result,
351351
"""
352352
# variables to be concatenated
353353
concat_over = set()
354354
# variables checked for equality
355-
equals: dict[Hashable, bool | None] = {}
355+
equals: dict[Hashable, bool] = {}
356356
# skip merging these variables.
357357
# if concatenating over a dimension 'x' that is associated with an index over 2 variables,
358358
# 'x' and 'y', then we assert join="equals" on `y` and don't need to merge it.
359359
# that assertion happens in the align step prior to this function being called
360-
skip_merge = set()
360+
skip_merge: set[Hashable] = set()
361361

362362
if dim in all_dims:
363363
concat_over_existing_dim = True
@@ -380,7 +380,10 @@ def _calc_concat_over(
380380
skip_merge.update(idx_vars.keys())
381381
concat_dim_lengths.append(ds.sizes.get(dim, 1))
382382

383-
def process_subset_opt(opt, subset: Literal["coords", "data_vars"]) -> None:
383+
def process_subset_opt(
384+
opt: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
385+
subset: Literal["coords", "data_vars"],
386+
) -> None:
384387
original = set(concat_over)
385388
compat_str = (
386389
compat._value if isinstance(compat, CombineKwargDefault) else compat
@@ -411,7 +414,7 @@ def process_subset_opt(opt, subset: Literal["coords", "data_vars"]) -> None:
411414
# all nonindexes that are not the same in each dataset
412415
for k in getattr(datasets[0], subset):
413416
if k not in concat_over:
414-
equals[k] = None
417+
equal = None
415418

416419
variables = [
417420
ds.variables[k] for ds in datasets if k in ds.variables
@@ -430,19 +433,19 @@ def process_subset_opt(opt, subset: Literal["coords", "data_vars"]) -> None:
430433

431434
# first check without comparing values i.e. no computes
432435
for var in variables[1:]:
433-
equals[k] = getattr(variables[0], compat_str)(
436+
equal = getattr(variables[0], compat_str)(
434437
var, equiv=lazy_array_equiv
435438
)
436-
if equals[k] is not True:
439+
if equal is not True:
437440
# exit early if we know these are not equal or that
438441
# equality cannot be determined i.e. one or all of
439442
# the variables wraps a numpy array
440443
break
441444

442-
if equals[k] is False:
445+
if equal is False:
443446
concat_over.add(k)
444447

445-
elif equals[k] is None:
448+
elif equal is None:
446449
# Compare the variable of all datasets vs. the one
447450
# of the first dataset. Perform the minimum amount of
448451
# loads in order to avoid multiple loads from disk
@@ -464,7 +467,10 @@ def process_subset_opt(opt, subset: Literal["coords", "data_vars"]) -> None:
464467
ds.variables[k].data = v.data
465468
break
466469
else:
467-
equals[k] = True
470+
equal = True
471+
if TYPE_CHECKING:
472+
assert equal is not None
473+
equals[k] = equal
468474

469475
elif opt == "all":
470476
concat_over.update(
@@ -564,7 +570,7 @@ def _dataset_concat(
564570
datasets: Iterable[T_Dataset],
565571
dim: str | T_Variable | T_DataArray | pd.Index,
566572
data_vars: T_DataVars | CombineKwargDefault,
567-
coords: str | list[Hashable] | CombineKwargDefault,
573+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
568574
compat: CompatOptions | CombineKwargDefault,
569575
positions: Iterable[Iterable[int]] | None,
570576
fill_value: Any,
@@ -807,7 +813,7 @@ def _dataarray_concat(
807813
arrays: Iterable[T_DataArray],
808814
dim: str | T_Variable | T_DataArray | pd.Index,
809815
data_vars: T_DataVars | CombineKwargDefault,
810-
coords: str | list[Hashable] | CombineKwargDefault,
816+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
811817
compat: CompatOptions | CombineKwargDefault,
812818
positions: Iterable[Iterable[int]] | None,
813819
fill_value: object,

0 commit comments

Comments
 (0)