@@ -201,15 +201,35 @@ def from_dataset(
201201 >>> ds = from_dataset(new({'a': 1, 'b': 2, 'c': 3, 'd': 4}).filter(lambda x: x%2))
202202 >>> dict(ds)
203203 {'a': 1, 'c': 3}
204+
205+ # Works with concatenated datasets and duplicated keys
206+ >>> ds = new({'a': 1, 'b': 2})
207+ >>> ds = concatenate(ds, ds)
208+ >>> ds
209+ DictDataset(len=2)
210+ MapDataset(_pickle.loads)
211+ DictDataset(len=2)
212+ MapDataset(_pickle.loads)
213+ ConcatenateDataset()
214+ >>> from_dataset(ds)
215+ ListDataset(len=4)
216+ MapDataset(_pickle.loads)
217+
204218 """
205219 try :
206220 items = list (examples .items ())
207221 except ItemsNotDefined :
208222 return from_list (list (examples ),
209223 immutable_warranty = immutable_warranty , name = name )
210224 else :
211- return from_dict (dict (items ),
212- immutable_warranty = immutable_warranty , name = name )
225+ new = dict (items )
226+ if len (new ) == len (items ):
227+ return from_dict (new ,
228+ immutable_warranty = immutable_warranty , name = name )
229+ else :
230+ # Duplicates in keys
231+ return from_list (list (map (operator .itemgetter (1 ), items )),
232+ immutable_warranty = immutable_warranty , name = name )
213233
214234
215235def concatenate (* datasets ):
@@ -417,7 +437,10 @@ def copy(self, freeze: bool = False) -> 'Dataset':
417437 Returns:
418438 A copy of this dataset
419439 """
420- raise NotImplementedError
440+ raise NotImplementedError (
441+ f'copy is not implemented for { self .__class__ } .\n '
442+ f'self: \n { repr (self )} '
443+ )
421444
422445 def __iter__ (self , with_key = False ):
423446 if with_key :
@@ -2973,6 +2996,7 @@ def __init__(self, *input_datasets):
29732996 ]
29742997 raise AssertionError (
29752998 f'Expect that all input_datasets have the same keys. '
2999+ f'Missing: { lengths } of { len (keys )} \n '
29763000 f'Missing keys: '
29773001 f'{ missing_keys } \n { self .input_datasets } '
29783002 )
@@ -3067,8 +3091,8 @@ class ItemsDataset(Dataset):
30673091 >>> ds_nokeys_rng = ds_plain.shuffle(True, rng=np.random.RandomState(0)) # No keys
30683092 >>> list(ds_nokeys.map(lambda x: x + 10).items())
30693093 [('a', 11), ('b', 12), ('c', 13)]
3070- >>> list(ds_nokeys.concatenate(ds_plain).items())
3071- [('a', 1), (' b', 2 ), ('c', 3 ), ('a', 1), ('b', 2), ('c', 3 )]
3094+ >>> list(ds_nokeys.map(lambda x: x + 10). concatenate(ds_plain).filter(lambda x: x in [1, 12, 13] ).items())
3095+ [('b', 12 ), ('c', 13 ), ('a', 1)]
30723096 >>> list(ds_nokeys_rng.intersperse(ds_nokeys_rng).items())
30733097 [('c', 3), ('a', 1), ('c', 3), ('c', 3), ('b', 2), ('b', 2)]
30743098 >>> list(ds_plain.key_zip(ds_plain).items())
0 commit comments