Skip to content

Commit 9996814

Browse files
gabrielfruetzy1gitNicolasHug
authored
Fix Type Generics for VisionDataset (#9381)
Co-authored-by: zy1git <zycoding1@gmail.com> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent b0e5fb4 commit 9996814

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchvision/datasets/vision.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Optional, TypeVar, Union
44

55
import torch.utils.data as data
66

77
from ..utils import _log_api_usage_once
88

9+
T_co = TypeVar("T_co", covariant=True)
910

10-
class VisionDataset(data.Dataset):
11+
12+
class VisionDataset(data.Dataset[T_co]):
1113
"""
1214
Base Class For making datasets which are compatible with torchvision.
1315
It is necessary to override the ``__getitem__`` and ``__len__`` method.
@@ -53,7 +55,7 @@ def __init__(
5355
transforms = StandardTransform(transform, target_transform)
5456
self.transforms = transforms
5557

56-
def __getitem__(self, index: int) -> Any:
58+
def __getitem__(self, index: int) -> T_co:
5759
"""
5860
Args:
5961
index (int): Index

0 commit comments

Comments
 (0)