22# SPDX-License-Identifier: Apache-2.0
33
44"""You may copy this file as the starting point of your own model."""
5-
6- from collections .abc import Iterable
7- from logging import getLogger
85import os
96import sys
7+ from collections .abc import Iterable
8+ from logging import getLogger
109
11-
12- from openfl .federated import PyTorchDataLoader
1310import numpy as np
14- from openfl .federated .data .sources .torch .verifiable_map_style_image_folder import VerifiableImageFolder
15- from openfl .federated .data .sources .data_sources_json_parser import DataSourcesJsonParser
16- from openfl .utilities .path_check import is_directory_traversal
1711import torch
1812from torch .utils .data import random_split
1913from torchvision .transforms import ToTensor
2014
15+ from openfl .federated import PyTorchDataLoader
16+ from openfl .federated .data .sources .data_sources_json_parser import DataSourcesJsonParser
17+ from openfl .federated .data .sources .torch .verifiable_map_style_image_folder import VerifiableImageFolder
18+ from openfl .utilities .path_check import is_directory_traversal
2119
2220logger = getLogger (__name__ )
2321
2422
2523class PyTorchHistologyVerifiableDataLoader (PyTorchDataLoader ):
2624 """PyTorch data loader for Histology dataset."""
2725
28- def __init__ (self , data_path , batch_size , ** kwargs ):
26+ def __init__ (self , data_path = None , batch_size = 32 , ** kwargs ):
2927 """Instantiate the data object.
3028
3129 Args:
32- data_path: The file path to the data
30+ data_path: The file path to the data. If None, initialize for model creation only.
3331 batch_size: The batch size of the data loader
3432 **kwargs: Additional arguments, passed to super init
3533 and load_mnist_shard
@@ -61,17 +59,19 @@ def __init__(self, data_path, batch_size, **kwargs):
6159 else :
6260 logger .info ("The dataset is valid." )
6361
64- _ , num_classes , X_train , y_train , X_valid , y_valid = load_histology_shard (
65- verifible_dataset_info = verifible_dataset_info , verify_dataset_items = verify_dataset_items , ** kwargs )
62+ X_train , y_train , X_valid , y_valid = load_histology_shard (
63+ verifible_dataset_info = verifible_dataset_info ,
64+ verify_dataset_items = verify_dataset_items ,
65+ feature_shape = self .feature_shape ,
66+ num_classes = self .num_classes ,
67+ ** kwargs
68+ )
6669
6770 self .X_train = X_train
6871 self .y_train = y_train
6972 self .X_valid = X_valid
7073 self .y_valid = y_valid
7174
72- self .num_classes = num_classes
73-
74-
7575 def get_feature_shape (self ):
7676 """Returns the shape of an example feature array.
7777
@@ -101,7 +101,6 @@ def get_verifiable_dataset_info(self, data_path):
101101 Raises:
102102 SystemExit: If `data_path` is invalid or missing `datasources.json`.
103103 """
104- """Return the verifiable dataset info object for the given data sources."""
105104 if data_path and is_directory_traversal (data_path ):
106105 logger .error ("Data path is out of the openfl workspace scope." )
107106 if not os .path .isdir (data_path ):
@@ -152,7 +151,8 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
152151 n_train = int (train_split_ratio * len (dataset ))
153152 n_valid = len (dataset ) - n_train
154153 ds_train , ds_val = random_split (
155- dataset , lengths = [n_train , n_valid ], generator = torch .manual_seed (0 ))
154+ dataset , lengths = [n_train , n_valid ], generator = torch .manual_seed (0 )
155+ )
156156
157157 # create the shards
158158 X_train , y_train = list (zip (* ds_train ))
@@ -164,41 +164,40 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
164164 return (X_train , y_train ), (X_valid , y_valid )
165165
166166
167-
168- def load_histology_shard (verifible_dataset_info , verify_dataset_items ,
167+ def load_histology_shard (verifible_dataset_info , verify_dataset_items , feature_shape = None , num_classes = None ,
169168 categorical = False , channels_last = False , ** kwargs ):
170169 """
171170 Load the Histology dataset.
172171
173172 Args:
174- data_path (str): path to data directory
173+ verifible_dataset_info (VerifiableDatasetInfo): The verifiable dataset info object.
174+ verify_dataset_items (bool): True = verify the dataset items while loading data
175+ feature_shape (list, optional): The shape of input features.
176+ num_classes (int, optional): Number of classes.
175177 categorical (bool): True = convert the labels to one-hot encoded
176178 vectors (Default = True)
177179 channels_last (bool): True = The input images have the channels
178180 last (Default = True)
179181 **kwargs: Additional parameters to pass to the function
180182
181183 Returns:
182- list: The input shape
183- int: The number of classes
184184 numpy.ndarray: The training data
185185 numpy.ndarray: The training labels
186186 numpy.ndarray: The validation data
187187 numpy.ndarray: The validation labels
188188 """
189- img_rows , img_cols = 150 , 150
190- num_classes = 8
189+ img_rows , img_cols = feature_shape [1 ], feature_shape [2 ]
191190
192- (X_train , y_train ), (X_valid , y_valid ) = _load_raw_data (verifible_dataset_info , verify_dataset_items , ** kwargs )
191+ (X_train , y_train ), (X_valid , y_valid ) = _load_raw_data (
192+ verifible_dataset_info , verify_dataset_items , ** kwargs
193+ )
193194
194195 if channels_last :
195196 X_train = X_train .reshape (X_train .shape [0 ], img_rows , img_cols , 3 )
196197 X_valid = X_valid .reshape (X_valid .shape [0 ], img_rows , img_cols , 3 )
197- input_shape = (img_rows , img_cols , 3 )
198198 else :
199199 X_train = X_train .reshape (X_train .shape [0 ], 3 , img_rows , img_cols )
200200 X_valid = X_valid .reshape (X_valid .shape [0 ], 3 , img_rows , img_cols )
201- input_shape = (3 , img_rows , img_cols )
202201
203202 logger .info (f'Histology > X_train Shape : { X_train .shape } ' )
204203 logger .info (f'Histology > y_train Shape : { y_train .shape } ' )
@@ -210,4 +209,4 @@ def load_histology_shard(verifible_dataset_info, verify_dataset_items,
210209 y_train = np .eye (num_classes )[y_train ]
211210 y_valid = np .eye (num_classes )[y_valid ]
212211
213- return input_shape , num_classes , X_train , y_train , X_valid , y_valid
212+ return X_train , y_train , X_valid , y_valid
0 commit comments