diff --git a/src/data/dataset_utils.py b/src/data/dataset_utils.py index 3d27f02..b7986c0 100644 --- a/src/data/dataset_utils.py +++ b/src/data/dataset_utils.py @@ -96,9 +96,9 @@ def __getitem__(self, index): if self.config["train"] is True: filename = "train/"+str(filepath)+".jpg" elif self.config["challenge"] is True: # check if this is actually present in the config file. If not, lets add it - (Rohan) - filename = "challenge/"+str(filepath)+".jpg" + filename = str(filepath) elif self.config["val"] is True: - filename = "val/"+str(filepath)+".jpg" + filename = str(filepath) else: filename = filepath.split('/')[-1].split('.')[0] @@ -114,7 +114,8 @@ def __getitem__(self, index): sample = {'img': img, 'prediction_label': pred_label, 'private_label': privacy_label, 'filepath': filepath, 'filename': filename} - + return sample + def __len__(self): return len(self.indicies) diff --git a/src/data/loaders.py b/src/data/loaders.py index ca175a9..b90f44c 100644 --- a/src/data/loaders.py +++ b/src/data/loaders.py @@ -31,6 +31,8 @@ def setup_data_pipeline(self): config = { "transforms": trainTransform, "train": False, + "val": False, + "challenge": True, "path": self.config["dataset_path"], "prediction_attribute": "data", "protected_attribute": self.config["protected_attribute"], @@ -42,11 +44,15 @@ def setup_data_pipeline(self): else: train_config = {"transforms": trainTransform, "train": True, + "val": False, + "challenge": False, "path": self.config["dataset_path"], "prediction_attribute": self.config["prediction_attribute"], "protected_attribute": self.config["protected_attribute"]} test_config = {"transforms": trainTransform, "train": False, + "val": True, + "challenge": False, "path": self.config["dataset_path"], "prediction_attribute": self.config["prediction_attribute"], "protected_attribute": self.config["protected_attribute"]}