Skip to content
This repository was archived by the owner on Jun 19, 2025. It is now read-only.

Commit 8c8b80d

Browse files
authored
Merge pull request #3546 from dzubke/Iss-3511_split-sets
Fix #3511: split-sets on sample size
2 parents 385c8c7 + 6945663 commit 8c8b80d

File tree

2 files changed

+62
-12
lines changed

2 files changed

+62
-12
lines changed

bin/import_fisher.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import codecs
33
import fnmatch
44
import os
5+
import random
56
import subprocess
67
import sys
78
import unicodedata
@@ -236,14 +237,18 @@ def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
236237

237238

238239
def _split_sets(filelist):
239-
# We initially split the entire set into 80% train and 20% test, then
240-
# split the train set into 80% train and 20% validation.
240+
"""
241+
randomply split the datasets into train, validation, and test sets where the size of the
242+
validation and test sets are determined by the `get_sample_size` function.
243+
"""
244+
random.shuffle(filelist)
245+
sample_size = get_sample_size(len(filelist))
246+
241247
train_beg = 0
242-
train_end = int(0.8 * len(filelist))
248+
train_end = len(filelist) - 2 * sample_size
243249

244-
dev_beg = int(0.8 * train_end)
245-
dev_end = train_end
246-
train_end = dev_beg
250+
dev_beg = train_end
251+
dev_end = train_end + sample_size
247252

248253
test_beg = dev_end
249254
test_end = len(filelist)
@@ -255,5 +260,25 @@ def _split_sets(filelist):
255260
)
256261

257262

263+
def get_sample_size(population_size):
264+
"""calculates the sample size for a 99% confidence and 1% margin of error
265+
"""
266+
margin_of_error = 0.01
267+
fraction_picking = 0.50
268+
z_score = 2.58 # Corresponds to confidence level 99%
269+
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
270+
margin_of_error ** 2
271+
)
272+
sample_size = 0
273+
for train_size in range(population_size, 0, -1):
274+
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
275+
margin_of_error ** 2 * train_size
276+
)
277+
sample_size = int(numerator / denominator)
278+
if 2 * sample_size + train_size <= population_size:
279+
break
280+
return sample_size
281+
282+
258283
if __name__ == "__main__":
259284
_download_and_preprocess_data(sys.argv[1])

bin/import_swb.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import codecs
66
import fnmatch
77
import os
8+
import random
89
import subprocess
910
import sys
1011
import tarfile
@@ -290,14 +291,18 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
290291

291292

292293
def _split_sets(filelist):
293-
# We initially split the entire set into 80% train and 20% test, then
294-
# split the train set into 80% train and 20% validation.
294+
"""
295+
randomply split the datasets into train, validation, and test sets where the size of the
296+
validation and test sets are determined by the `get_sample_size` function.
297+
"""
298+
random.shuffle(filelist)
299+
sample_size = get_sample_size(len(filelist))
300+
295301
train_beg = 0
296-
train_end = int(0.8 * len(filelist))
302+
train_end = len(filelist) - 2 * sample_size
297303

298-
dev_beg = int(0.8 * train_end)
299-
dev_end = train_end
300-
train_end = dev_beg
304+
dev_beg = train_end
305+
dev_end = train_end + sample_size
301306

302307
test_beg = dev_end
303308
test_end = len(filelist)
@@ -309,6 +314,26 @@ def _split_sets(filelist):
309314
)
310315

311316

317+
def get_sample_size(population_size):
318+
"""calculates the sample size for a 99% confidence and 1% margin of error
319+
"""
320+
margin_of_error = 0.01
321+
fraction_picking = 0.50
322+
z_score = 2.58 # Corresponds to confidence level 99%
323+
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
324+
margin_of_error ** 2
325+
)
326+
sample_size = 0
327+
for train_size in range(population_size, 0, -1):
328+
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
329+
margin_of_error ** 2 * train_size
330+
)
331+
sample_size = int(numerator / denominator)
332+
if 2 * sample_size + train_size <= population_size:
333+
break
334+
return sample_size
335+
336+
312337
def _read_data_set(
313338
filelist,
314339
thread_count,

0 commit comments

Comments
 (0)