Skip to content

Commit 811ffe7

Browse files
committed
refactor: remove methods
1 parent 56e5172 commit 811ffe7

File tree

2 files changed

+71
-336
lines changed

2 files changed

+71
-336
lines changed

src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 40 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from collections.abc import ItemsView, Iterable, KeysView, Mapping, ValuesView
15-
from typing import TYPE_CHECKING, Any, Callable, Optional, SupportsIndex, TypeVar, Union
14+
from collections.abc import Iterable, KeysView, Mapping
15+
from typing import TYPE_CHECKING, Any, Optional, SupportsIndex, TypeVar, Union
1616

1717
from lightning_utilities.core.apply_func import apply_to_collection
1818
from torch import Tensor
19-
from typing_extensions import Self, overload
19+
from typing_extensions import overload
2020

2121
import lightning.pytorch as pl
2222
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
@@ -29,9 +29,8 @@
2929

3030
warning_cache = WarningCache()
3131

32-
3332
if TYPE_CHECKING:
34-
from _typeshed import SupportsRichComparison
33+
pass
3534

3635

3736
class _LoggerConnector:
@@ -309,207 +308,73 @@ class _ListMap(list[_T]):
309308
310309
"""
311310

312-
_dict: dict[str, int]
311+
_dict_map: dict[str, int]
313312

314-
def __init__(self, __iterable: Optional[Union[Mapping[str, _T], Iterable[_T]]] = None):
313+
def __init__(self, __iterable: Optional[Union[dict[str, _T], Iterable[_T]]] = None):
315314
if isinstance(__iterable, Mapping):
316315
# super inits list with values
317316
if any(not isinstance(x, str) for x in __iterable):
318317
raise TypeError("When providing a Mapping, all keys must be of type str.")
319318
super().__init__(__iterable.values())
320-
_dict = dict(zip(__iterable.keys(), range(len(__iterable))))
319+
self._dict_map = {key: idx for idx, key in enumerate(__iterable)}
321320
else:
322-
default_dict: dict[str, int] = {}
323-
if isinstance(__iterable, _ListMap):
324-
default_dict = __iterable._dict.copy()
325-
super().__init__(() if __iterable is None else __iterable)
326-
_dict = default_dict
327-
self._dict = _dict
321+
super().__init__(__iterable or ())
322+
self._dict_map = {}
328323

329324
def __eq__(self, other: Any) -> bool:
330325
list_eq = super().__eq__(other)
331326
if isinstance(other, _ListMap):
332-
return list_eq and self._dict == other._dict
327+
list_eq &= other._dict_map == self._dict_map
333328
return list_eq
334329

335-
def copy(self) -> "_ListMap":
336-
new_listmap = _ListMap(self)
337-
new_listmap._dict = self._dict.copy()
338-
return new_listmap
339-
340-
def extend(self, __iterable: Iterable[_T]) -> None:
341-
if isinstance(__iterable, _ListMap):
342-
offset = len(self)
343-
for key, idx in __iterable._dict.items():
344-
self._dict[key] = idx + offset
345-
super().extend(__iterable)
330+
def pop(self, index: SupportsIndex = -1, /):
331+
if self._dict_map:
332+
index_int = index.__index__()
333+
if index_int < 0:
334+
index_int += len(self)
335+
336+
for key, idx in list(self._dict_map.items()):
337+
if idx == index_int:
338+
self._dict_map.pop(key)
339+
elif idx > index_int:
340+
self._dict_map[key] -= 1
341+
return super().pop(index)
342+
343+
def insert(self, index: SupportsIndex, __object: _T, /) -> None:
344+
if self._dict_map:
345+
index_int = index.__index__()
346+
if index_int < 0:
347+
index_int += len(self)
348+
349+
for key, idx in list(self._dict_map.items()):
350+
if idx >= index_int:
351+
self._dict_map[key] += 1
346352

347-
@overload
348-
def pop(self, key: SupportsIndex = -1, /) -> _T: ...
349-
350-
@overload
351-
def pop(self, key: Union[str, SupportsIndex], default: _T, /) -> _T: ...
352-
353-
@overload
354-
def pop(self, key: str, default: _PT, /) -> Union[_T, _PT]: ...
355-
356-
def pop(self, key: Union[SupportsIndex, str] = -1, default: Any = None) -> _T:
357-
if isinstance(key, int):
358-
ret = super().pop(key)
359-
for str_key, idx in list(self._dict.items()):
360-
if idx == key:
361-
self._dict.pop(str_key)
362-
elif idx > key:
363-
self._dict[str_key] = idx - 1
364-
return ret
365-
if isinstance(key, str):
366-
if key not in self._dict:
367-
return default
368-
return self.pop(self._dict[key])
369-
raise TypeError("Key must be int or str")
370-
371-
def insert(self, index: SupportsIndex, __object: _T) -> None:
372-
idx_int = int(index)
373-
# Check for negative indices
374-
if idx_int < 0:
375-
idx_int += len(self)
376-
for key, idx in self._dict.items():
377-
if idx >= idx_int:
378-
self._dict[key] = idx + 1
379353
return super().insert(index, __object)
380354

381-
def remove(self, __object: _T) -> None:
382-
idx = self.index(__object)
383-
name = None
384-
for key, val in self._dict.items():
385-
if val == idx:
386-
name = key
387-
elif val > idx:
388-
self._dict[key] = val - 1
389-
if name:
390-
self._dict.pop(name, None)
391-
super().remove(__object)
392-
393-
def sort(
394-
self,
395-
*,
396-
key: Optional[Callable[[_T], "SupportsRichComparison"]] = None,
397-
reverse: bool = False,
398-
) -> None:
399-
# Create a mapping from item to its name(s)
400-
item_to_names: dict[_T, list[str]] = {}
401-
for name, idx in self._dict.items():
402-
item = self[idx]
403-
item_to_names.setdefault(item, []).append(name)
404-
# Sort the list
405-
super().sort(key=key, reverse=reverse)
406-
# Update _dict with new indices
407-
new_dict: dict[str, int] = {}
408-
for idx, item in enumerate(self):
409-
if item in item_to_names:
410-
for name in item_to_names[item]:
411-
new_dict[name] = idx
412-
self._dict = new_dict
413-
414355
@overload
415356
def __getitem__(self, key: Union[SupportsIndex, str], /) -> _T: ...
416357

417358
@overload
418359
def __getitem__(self, key: slice, /) -> list[_T]: ...
419360

420361
def __getitem__(self, key: Union[SupportsIndex, str, slice], /) -> Union[_T, list[_T]]:
421-
if isinstance(key, str):
422-
return self[self._dict[key]]
362+
if self._dict_map and isinstance(key, str):
363+
return self[self._dict_map[key]]
423364
return super().__getitem__(key)
424365

425-
def __add__(self, other: Union[list[_T], "_ListMap[_T]"]) -> "_ListMap[_T]": # type: ignore[override]
426-
new_listmap = self.copy()
427-
new_listmap += other
428-
return new_listmap
429-
430-
def __iadd__(self, other: Iterable[_T]) -> Self: # type: ignore[override]
431-
if isinstance(other, _ListMap):
432-
offset = len(self)
433-
for key, idx in other._dict.items():
434-
# notes: if there are duplicate keys, the ones from other will overwrite self
435-
self._dict[key] = idx + offset
436-
437-
return super().__iadd__(other)
438-
439-
@overload
440-
def __setitem__(self, key: Union[SupportsIndex, str], value: _T, /) -> None: ...
441-
442-
@overload
443-
def __setitem__(self, key: slice, value: Iterable[_T], /) -> None: ...
444-
445-
def __setitem__(self, key: Union[SupportsIndex, str, slice], value: Any, /) -> None:
446-
if isinstance(key, str):
447-
# replace or insert by name
448-
if key in self._dict:
449-
self[self._dict[key]] = value
450-
else:
451-
self.append(value)
452-
self._dict[key] = len(self) - 1
453-
return None
454-
return super().__setitem__(key, value)
455-
456366
def __contains__(self, item: Union[object, str]) -> bool:
457367
if isinstance(item, str):
458-
return item in self._dict
368+
return item in self._dict_map
459369
return super().__contains__(item)
460370

461-
# --- Dict-like interface ---
462-
463-
def __delitem__(self, key: Union[SupportsIndex, slice, str]) -> None:
464-
index: Union[SupportsIndex, slice]
465-
if isinstance(key, str):
466-
if key not in self._dict:
467-
raise KeyError(f"Key '{key}' not found.")
468-
index = self._dict[key]
469-
else:
470-
index = key
471-
472-
if isinstance(index, (int, slice)):
473-
super().__delitem__(index)
474-
for _key in index.indices(len(self)) if isinstance(index, slice) else [index]:
475-
# update indices in the dict
476-
for str_key, idx in list(self._dict.items()):
477-
if idx == _key:
478-
self._dict.pop(str_key)
479-
elif idx > _key:
480-
self._dict[str_key] = idx - 1
481-
else:
482-
raise TypeError("Key must be int or str")
483-
484-
def keys(self) -> KeysView[str]:
485-
return self._dict.keys()
486-
487-
def values(self) -> ValuesView[_T]:
488-
return {k: self[v] for k, v in self._dict.items()}.values()
489-
490-
def items(self) -> ItemsView[str, _T]:
491-
return {k: self[v] for k, v in self._dict.items()}.items()
492-
493-
@overload
494-
def get(self, __key: str) -> Optional[_T]: ...
495-
496-
@overload
497-
def get(self, __key: str, default: _PT) -> Union[_T, _PT]: ...
498-
499-
def get(self, __key: str, default: Optional[_PT] = None) -> Optional[Union[_T, _PT]]:
500-
if __key in self._dict:
501-
return self[self._dict[__key]]
502-
return default
503-
504371
def __repr__(self) -> str:
505372
ret = super().__repr__()
506-
return f"{type(self).__name__}({ret}, keys={list(self._dict.keys())})"
507-
508-
def reverse(self) -> None:
509-
for key, idx in self._dict.items():
510-
self._dict[key] = len(self) - 1 - idx
511-
return super().reverse()
373+
return f"{type(self).__name__}({ret}, keys={list(self._dict_map.keys())})"
512374

513375
def clear(self) -> None:
514-
self._dict.clear()
376+
self._dict_map.clear()
515377
return super().clear()
378+
379+
def keys(self) -> KeysView[str]:
380+
return self._dict_map.keys()

0 commit comments

Comments
 (0)