|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView |
15 | | -from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union |
| 14 | +from collections.abc import Iterable, KeysView, Mapping |
| 15 | +from typing import TYPE_CHECKING, Any, Optional, SupportsIndex, TypeVar, Union |
16 | 16 |
|
17 | 17 | from lightning_utilities.core.apply_func import apply_to_collection |
18 | 18 | from torch import Tensor |
19 | | -from typing_extensions import Self, overload |
| 19 | +from typing_extensions import overload |
20 | 20 |
|
21 | 21 | import lightning.pytorch as pl |
22 | 22 | from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE |
|
29 | 29 |
|
30 | 30 | warning_cache = WarningCache() |
31 | 31 |
|
32 | | - |
33 | 32 | if TYPE_CHECKING: |
34 | | - from _typeshed import SupportsRichComparison |
| 33 | + pass |
35 | 34 |
|
36 | 35 |
|
37 | 36 | class _LoggerConnector: |
@@ -309,207 +308,73 @@ class _ListMap(list[_T]): |
309 | 308 |
|
310 | 309 | """ |
311 | 310 |
|
312 | | - _dict: dict[str, int] |
| 311 | + _dict_map: dict[str, int] |
313 | 312 |
|
314 | | - def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None): |
| 313 | + def __init__(self, __iterable: Optional[Union[dict[str, _T], Iterable[_T]]] = None): |
315 | 314 | if isinstance(__iterable, Mapping): |
316 | 315 | # super inits list with values |
317 | 316 | if any(not isinstance(x, str) for x in __iterable): |
318 | 317 | raise TypeError("When providing a Mapping, all keys must be of type str.") |
319 | 318 | super().__init__(__iterable.values()) |
320 | | - _dict = dict(zip(__iterable.keys(), range(len(__iterable)))) |
| 319 | + self._dict_map = {key: idx for idx, key in enumerate(__iterable)} |
321 | 320 | else: |
322 | | - default_dict: dict[str, int] = {} |
323 | | - if isinstance(__iterable, _ListMap): |
324 | | - default_dict = __iterable._dict.copy() |
325 | | - super().__init__(() if __iterable is None else __iterable) |
326 | | - _dict = default_dict |
327 | | - self._dict = _dict |
| 321 | + super().__init__(__iterable or ()) |
| 322 | + self._dict_map = {} |
328 | 323 |
|
329 | 324 | def __eq__(self, other: Any) -> bool: |
330 | 325 | list_eq = super().__eq__(other) |
331 | 326 | if isinstance(other, _ListMap): |
332 | | - return list_eq and self._dict == other._dict |
| 327 | + list_eq &= other._dict_map == self._dict_map |
333 | 328 | return list_eq |
334 | 329 |
|
335 | | - def copy(self) -> "_ListMap": |
336 | | - new_listmap = _ListMap(self) |
337 | | - new_listmap._dict = self._dict.copy() |
338 | | - return new_listmap |
339 | | - |
340 | | - def extend(self, __iterable: Iterable[_T]) -> None: |
341 | | - if isinstance(__iterable, _ListMap): |
342 | | - offset = len(self) |
343 | | - for key, idx in __iterable._dict.items(): |
344 | | - self._dict[key] = idx + offset |
345 | | - super().extend(__iterable) |
| 330 | + def pop(self, index: SupportsIndex = -1, /): |
| 331 | + if self._dict_map: |
| 332 | + index_int = index.__index__() |
| 333 | + if index_int < 0: |
| 334 | + index_int += len(self) |
| 335 | + |
| 336 | + for key, idx in list(self._dict_map.items()): |
| 337 | + if idx == index_int: |
| 338 | + self._dict_map.pop(key) |
| 339 | + elif idx > index_int: |
| 340 | + self._dict_map[key] -= 1 |
| 341 | + return super().pop(index) |
| 342 | + |
| 343 | + def insert(self, index: SupportsIndex, __object: _T, /) -> None: |
| 344 | + if self._dict_map: |
| 345 | + index_int = index.__index__() |
| 346 | + if index_int < 0: |
| 347 | + index_int += len(self) |
| 348 | + |
| 349 | + for key, idx in list(self._dict_map.items()): |
| 350 | + if idx >= index_int: |
| 351 | + self._dict_map[key] += 1 |
346 | 352 |
|
347 | | - @overload |
348 | | - def pop(self, key: SupportsIndex = -1, /) -> _T: ... |
349 | | - |
350 | | - @overload |
351 | | - def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ... |
352 | | - |
353 | | - @overload |
354 | | - def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ... |
355 | | - |
356 | | - def pop(self, key: Union[SupportsIndex, str] = -1, default: Any = None) -> _T: |
357 | | - if isinstance(key, int): |
358 | | - ret = super().pop(key) |
359 | | - for str_key, idx in list(self._dict.items()): |
360 | | - if idx == key: |
361 | | - self._dict.pop(str_key) |
362 | | - elif idx > key: |
363 | | - self._dict[str_key] = idx - 1 |
364 | | - return ret |
365 | | - if isinstance(key, str): |
366 | | - if key not in self._dict: |
367 | | - return default |
368 | | - return self.pop(self._dict[key]) |
369 | | - raise TypeError("Key must be int or str") |
370 | | - |
371 | | - def insert(self, index: SupportsIndex, __object: _T) -> None: |
372 | | - idx_int = int(index) |
373 | | - # Check for negative indices |
374 | | - if idx_int < 0: |
375 | | - idx_int += len(self) |
376 | | - for key, idx in self._dict.items(): |
377 | | - if idx >= idx_int: |
378 | | - self._dict[key] = idx + 1 |
379 | 353 | return super().insert(index, __object) |
380 | 354 |
|
381 | | - def remove(self, __object: _T) -> None: |
382 | | - idx = self.index(__object) |
383 | | - name = None |
384 | | - for key, val in self._dict.items(): |
385 | | - if val == idx: |
386 | | - name = key |
387 | | - elif val > idx: |
388 | | - self._dict[key] = val - 1 |
389 | | - if name: |
390 | | - self._dict.pop(name, None) |
391 | | - super().remove(__object) |
392 | | - |
393 | | - def sort( |
394 | | - self, |
395 | | - *, |
396 | | - key: Optional[Callable[[_T], "SupportsRichComparison"]] = None, |
397 | | - reverse: bool = False, |
398 | | - ) -> None: |
399 | | - # Create a mapping from item to its name(s) |
400 | | - item_to_names: dict[_T, list[str]] = {} |
401 | | - for name, idx in self._dict.items(): |
402 | | - item = self[idx] |
403 | | - item_to_names.setdefault(item, []).append(name) |
404 | | - # Sort the list |
405 | | - super().sort(key=key, reverse=reverse) |
406 | | - # Update _dict with new indices |
407 | | - new_dict: dict[str, int] = {} |
408 | | - for idx, item in enumerate(self): |
409 | | - if item in item_to_names: |
410 | | - for name in item_to_names[item]: |
411 | | - new_dict[name] = idx |
412 | | - self._dict = new_dict |
413 | | - |
414 | 355 | @overload |
415 | 356 | def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ... |
416 | 357 |
|
417 | 358 | @overload |
418 | 359 | def __getitem__(self, key: slice, /) -> list[_T]: ... |
419 | 360 |
|
420 | 361 | def __getitem__(self, key: Union[SupportsIndex, str, slice], /) -> Union[_T, list[_T]]: |
421 | | - if isinstance(key, str): |
422 | | - return self[self._dict[key]] |
| 362 | + if self._dict_map and isinstance(key, str): |
| 363 | + return self[self._dict_map[key]] |
423 | 364 | return super().__getitem__(key) |
424 | 365 |
|
425 | | - def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": # type: ignore[override] |
426 | | - new_listmap = self.copy() |
427 | | - new_listmap += other |
428 | | - return new_listmap |
429 | | - |
430 | | - def __iadd__(self, other: Iterable[_T]) -> Self: # type: ignore[override] |
431 | | - if isinstance(other, _ListMap): |
432 | | - offset = len(self) |
433 | | - for key, idx in other._dict.items(): |
434 | | - # notes: if there are duplicate keys, the ones from other will overwrite self |
435 | | - self._dict[key] = idx + offset |
436 | | - |
437 | | - return super().__iadd__(other) |
438 | | - |
439 | | - @overload |
440 | | - def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ... |
441 | | - |
442 | | - @overload |
443 | | - def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ... |
444 | | - |
445 | | - def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Any, /) -> None: |
446 | | - if isinstance(key, str): |
447 | | - # replace or insert by name |
448 | | - if key in self._dict: |
449 | | - self[self._dict[key]] = value |
450 | | - else: |
451 | | - self.append(value) |
452 | | - self._dict[key] = len(self) - 1 |
453 | | - return None |
454 | | - return super().__setitem__(key, value) |
455 | | - |
456 | 366 | def __contains__(self, item: Union[object, str]) -> bool: |
457 | 367 | if isinstance(item, str): |
458 | | - return item in self._dict |
| 368 | + return item in self._dict_map |
459 | 369 | return super().__contains__(item) |
460 | 370 |
|
461 | | - # --- Dict-like interface --- |
462 | | - |
463 | | - def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None: |
464 | | - index: Union[SupportsIndex, slice] |
465 | | - if isinstance(key, str): |
466 | | - if key not in self._dict: |
467 | | - raise KeyError(f"Key '{key}' not found.") |
468 | | - index = self._dict[key] |
469 | | - else: |
470 | | - index = key |
471 | | - |
472 | | - if isinstance(index, (int, slice)): |
473 | | - super().__delitem__(index) |
474 | | - for _key in index.indices(len(self)) if isinstance(index, slice) else [index]: |
475 | | - # update indices in the dict |
476 | | - for str_key, idx in list(self._dict.items()): |
477 | | - if idx == _key: |
478 | | - self._dict.pop(str_key) |
479 | | - elif idx > _key: |
480 | | - self._dict[str_key] = idx - 1 |
481 | | - else: |
482 | | - raise TypeError("Key must be int or str") |
483 | | - |
484 | | - def keys(self) -> KeysView[str]: |
485 | | - return self._dict.keys() |
486 | | - |
487 | | - def values(self) -> ValuesView[_T]: |
488 | | - return {k: self[v] for k, v in self._dict.items()}.values() |
489 | | - |
490 | | - def items(self) -> ItemsView[str, _T]: |
491 | | - return {k: self[v] for k, v in self._dict.items()}.items() |
492 | | - |
493 | | - @overload |
494 | | - def get(self, __key: str) -> Optional[_T]: ... |
495 | | - |
496 | | - @overload |
497 | | - def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ... |
498 | | - |
499 | | - def get(self, __key: str, default: Optional[_PT] = None) -> Optional[Union[_T, _PT]]: |
500 | | - if __key in self._dict: |
501 | | - return self[self._dict[__key]] |
502 | | - return default |
503 | | - |
504 | 371 | def __repr__(self) -> str: |
505 | 372 | ret = super().__repr__() |
506 | | - return f"{type(self).__name__}({ret}, keys={list(self._dict.keys())})" |
507 | | - |
508 | | - def reverse(self) -> None: |
509 | | - for key, idx in self._dict.items(): |
510 | | - self._dict[key] = len(self) - 1 - idx |
511 | | - return super().reverse() |
| 373 | + return f"{type(self).__name__}({ret}, keys={list(self._dict_map.keys())})" |
512 | 374 |
|
513 | 375 | def clear(self) -> None: |
514 | | - self._dict.clear() |
| 376 | + self._dict_map.clear() |
515 | 377 | return super().clear() |
| 378 | + |
| 379 | + def keys(self) -> KeysView[str]: |
| 380 | + return self._dict_map.keys() |
0 commit comments