Skip to content

Commit b992c7e

Browse files
authored
[Data] Add shuffle for DataPipeline (#93)
1 parent 53cb0ee commit b992c7e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

deepray/datasets/datapipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, context: tf.distribute.InputContext = None, **kwargs):
3535
# self.conf = Foo(flags.FLAGS.conf_file).conf
3636
self.url = None
3737
self.prebatch_size = kwargs.get("prebatch_size", None)
38+
self.shuffle = kwargs.get("shuffle", False)
3839

3940
@abc.abstractmethod
4041
def __len__(self):
@@ -63,7 +64,7 @@ def parser(self, record):
6364

6465
@abc.abstractmethod
6566
def build_dataset(
66-
self, batch_size, input_file_pattern=None, is_training=True, epochs=1, shuffle=False, *args, **kwargs
67+
self, batch_size, input_file_pattern=None, is_training=True, epochs=1, *args, **kwargs
6768
):
6869
"""
6970
must be defined in subclass

0 commit comments

Comments
 (0)