forked from Planet-AI-GmbH/tfaip
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
101 lines (79 loc) · 3.31 KB
/
data.py
File metadata and controls
101 lines (79 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2021 The tfaip authors. All Rights Reserved.
#
# This file is part of tfaip.
#
# tfaip is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# tfaip is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# tfaip. If not, see http://www.gnu.org/licenses/.
# ==============================================================================
import logging
from dataclasses import dataclass, field
from typing import Type, Optional, Iterable, List
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from paiargparse import pai_meta, pai_dataclass
from tfaip import DataGeneratorParams, DataBaseParams
from tfaip import PipelineMode, Sample
from tfaip import TrainerPipelineParamsBase
from tfaip.data.data import DataBase
from tfaip.data.pipeline.datagenerator import DataGenerator
logger = logging.getLogger(__name__)
@pai_dataclass
@dataclass
class TutorialDataGeneratorParams(DataGeneratorParams):
dataset: str = field(default="mnist", metadata=pai_meta(help="The dataset to select (chose also fashion_mnist)."))
@staticmethod
def cls() -> Type["DataGenerator"]:
return TutorialDataGenerator
class TutorialDataGenerator(DataGenerator[TutorialDataGeneratorParams]):
def __init__(self, mode: PipelineMode, params: "TutorialDataGeneratorParams"):
super().__init__(mode, params)
dataset = getattr(keras.datasets, params.dataset)
train, test = dataset.load_data()
data = train if mode == PipelineMode.TRAINING else test
self.data = to_samples(data)
def __len__(self):
return len(self.data)
def generate(self) -> Iterable[Sample]:
return self.data
@pai_dataclass
@dataclass
class TutorialTrainerGeneratorParams(
TrainerPipelineParamsBase[TutorialDataGeneratorParams, TutorialDataGeneratorParams]
):
train_val: TutorialDataGeneratorParams = field(
default_factory=TutorialDataGeneratorParams, metadata=pai_meta(mode="flat")
)
def train_gen(self) -> TutorialDataGeneratorParams:
return self.train_val
def val_gen(self) -> Optional[TutorialDataGeneratorParams]:
return self.train_val
def to_samples(samples):
return [
Sample(inputs={"img": np.array(img).astype("float")}, targets={"gt": gt.reshape((1,))})
for img, gt in zip(*samples)
]
@pai_dataclass
@dataclass
class TutorialDataParams(DataBaseParams):
input_shape: List[int] = field(default_factory=lambda: [28, 28])
@staticmethod
def cls() -> Type["DataBase"]:
return TutorialData
class TutorialData(DataBase[TutorialDataParams]):
def _input_layer_specs(self):
# Shape and type of the input data for the graph
return {"img": tf.TensorSpec(shape=self.params.input_shape, dtype="int32")}
def _target_layer_specs(self):
# Shape and type of the target (ground truth) data for the graph
return {"gt": tf.TensorSpec(shape=[1], dtype="int32")}