diff --git a/pb_chime5/core_chime6_rttm.py b/pb_chime5/core_chime6_rttm.py index 9f07773..7f59748 100644 --- a/pb_chime5/core_chime6_rttm.py +++ b/pb_chime5/core_chime6_rttm.py @@ -122,7 +122,8 @@ def enhance_session( session_ids, audio_dir, dataset_slice=False, - audio_dir_exist_ok=False + audio_dir_exist_ok=False, + is_chime=True, ): """ @@ -133,7 +134,9 @@ def enhance_session( audio_dir_exist_ok: When True: It is ok, when the audio dir exists and the files insinde may be overwritten. - + is_chime: + If true, map the session_id to the dataset name for the folder + naming. Otherwise keep the session_id for the folder name. Returns: @@ -149,9 +152,8 @@ def enhance_session( if dlp_mpi.IS_MASTER: audio_dir.mkdir(exist_ok=audio_dir_exist_ok) - - for dataset in set(mapping.session_to_dataset.values()): - (audio_dir / dataset).mkdir(exist_ok=audio_dir_exist_ok) + # for dataset in self.db.data['alias']: + # (audio_dir / dataset).mkdir(exist_ok=audio_dir_exist_ok) dlp_mpi.barrier() @@ -170,13 +172,17 @@ def enhance_session( x_hat = self.enhance_example(ex) example_id = ex["example_id"] session_id = ex["session_id"] - dataset = mapping.session_to_dataset[session_id] + if is_chime: + dataset = mapping.session_to_dataset[session_id] + else: + dataset = session_id if x_hat.ndim == 1: save_path = audio_dir / f'{dataset}' / f'{example_id}.wav' dump_audio( x_hat, save_path, + mkdir=True ) else: raise NotImplementedError(x_hat.shape) diff --git a/pb_chime5/database/chime5/rttm.py b/pb_chime5/database/chime5/rttm.py index edba1f1..2abe5d1 100644 --- a/pb_chime5/database/chime5/rttm.py +++ b/pb_chime5/database/chime5/rttm.py @@ -401,7 +401,10 @@ def __init__(self, rttm_path, audio_paths, alias=None): """ super().__init__() self._rttm_path = rttm_path + assert isinstance(audio_paths, dict), audio_paths self._audio_paths = audio_paths + if alias is None: + alias = [] self._alias = alias @cached_property @@ -465,6 +468,22 @@ def data(self): for speaker_id, speaker in session.items(): for start, end in speaker.intervals: example_id = self.example_id(session_id, speaker_id, start, end) + + try: + audio_path = self._audio_paths[session_id] + except KeyError as e: + raise ValueError( + 'self._audio_paths does not contain the session_id' + f'{session_id!r},\nit has only: ' + f'{list(self._audio_paths.keys())}' + ) + except Exception as e: + raise AssertionError( + 'Something went wrong.\n' + f'session_id: {session_id}\n' + f'self._audio_paths == {self._audio_paths!r}' + ) from e + datasets[session_id][example_id] = { 'example_id': example_id, 'start': start, @@ -472,7 +491,7 @@ def data(self): 'num_samples': end - start, 'session_id': session_id, 'speaker_id': speaker_id, - 'audio_path': self._audio_paths[session_id] + 'audio_path': audio_path, } return { 'datasets': datasets, diff --git a/pb_chime5/io/audiowrite.py b/pb_chime5/io/audiowrite.py index fbee41b..7ca535e 100644 --- a/pb_chime5/io/audiowrite.py +++ b/pb_chime5/io/audiowrite.py @@ -1,7 +1,10 @@ import io +from distutils.command.config import config + import numpy as np import threading from pathlib import Path +import contextlib import soundfile from scipy.io.wavfile import write as wav_write @@ -22,6 +25,7 @@ def dump_audio( start=None, normalize=True, format=None, + mkdir=False, ): """ If normalize is False and the dytpe is float, the values of obj should be in @@ -200,7 +204,16 @@ def dump_audio( # soundfile.write() - with soundfile.SoundFile(path, **sf_args) as f: + with contextlib.ExitStack() as exit_stack: + try: + f = exit_stack.enter_context(soundfile.SoundFile(path, **sf_args)) + except RuntimeError: + # Not sure, why this is a RuntimeError. Maybe a bug in SoundFile. + if mkdir: + Path(path).parent.mkdir(exist_ok=True, parents=True) + f = exit_stack.enter_context(soundfile.SoundFile(path, **sf_args)) + else: + raise if start is not None: f.seek(start) f.write(obj.T) diff --git a/pb_chime5/scripts/kaldi_run_rttm_libri_css.py b/pb_chime5/scripts/kaldi_run_rttm_libri_css.py new file mode 100644 index 0000000..f068168 --- /dev/null +++ b/pb_chime5/scripts/kaldi_run_rttm_libri_css.py @@ -0,0 +1,249 @@ +""" +[mpirun -np $(nproc --all)] python -m pb_chime5.scripts.kaldi_run_rttm_libri_css with storage_dir=<...> database_rttm=<...> [activity_rttm=<...>] [session_id=dev] [session_to_audio_paths=<...>.{yaml,json}] job_id=1 number_of_jobs=1 + + - [mpirun -np $(nproc --all)]: Use all loacl cores + - storage_dir: Where to store the files: + /audio//_- + - database_rttm: RTTM files that contains the utterance/sentence start and end times. + - activity_rttm: RTTM files that contains the word boundaries (default database_rttm) + - session_to_audio_paths: yaml or json file + Contains the mapping from session/file ID to to the actual files. + - job_id=1: Kaldi style parallel option + - number_of_jobs=1: Kaldi style parallel option +""" + +from pathlib import Path +import inspect + +import sacred +from sacred.commands import print_config +from sacred.observers import FileStorageObserver + +import dlp_mpi + +from pb_chime5.core_chime6_rttm import get_enhancer + +experiment = sacred.Experiment('Chime5 Array Enhancement') + + +@experiment.config +def config(): + locals().update({k: v.default for k, v in inspect.signature(get_enhancer).parameters.items()}) + + session_id = None # Default: All sessions from database_rttm + storage_dir: str = None + + database_rttm: str = None + activity_rttm: str = database_rttm + + # session_to_audio_paths: A file with a mapping from session ID (or file ID) used + # in rttm to the actual audio file. + # e.g. session_to_audio_paths = { + # session_1: multichannel.wav, + # session_1: [ch1.wav, ch2.wav, ...], + # } + # Supported are json and yaml. + session_to_audio_paths = None + + job_id = 1 + number_of_jobs = 1 + + assert storage_dir is not None, (storage_dir, 'overwrite the storage_dir from the command line') + assert database_rttm is not None, (database_rttm, 'overwrite the database_rttm from the command line') + assert activity_rttm is not None, (database_rttm, 'overwrite the activity_rttm from the command line') + + if dlp_mpi.IS_MASTER: + experiment.observers.append(FileStorageObserver.create(str( + Path(storage_dir).expanduser().resolve() / 'sacred' + ))) + +@experiment.named_config +def my_test_rttm(): + database_rttm = '/scratch/hpc-prf-nt1/cbj/net/vol/boeddeker/chime6/kaldi/egs/chime6/s5_track2_download/data/dev_beamformit_dereverb_stats_seg/rttm.U06' + + +# get_enhancer = experiment.capture(get_enhancer) + +@experiment.capture +def get_sessions(database_rttm): + database_rttm = Path(database_rttm) + database_rttm_text = database_rttm.read_text() + sessions: set = {line.split()[1] for line in database_rttm_text.splitlines()} + return sessions + +@experiment.capture +def get_enhancer( + database_rttm, + activity_rttm, + + # session_to_audio_paths: A file with a mapping from session ID (or file ID) used + # in rttm to the actual audio file. + # e.g. session_to_audio_paths = { + # session_1: multichannel.wav, + # session_1: [ch1.wav, ch2.wav, ...], + # } + # Supported are json and yaml. + session_to_audio_paths=None, + + # chime6_dir='/net/fastdb/chime6/CHiME6', + # multiarray='outer_array_mics', + context_samples=240000, + + wpe=True, + wpe_tabs=10, + wpe_delay=2, + wpe_iterations=3, + wpe_psd_context=0, + + activity_garbage_class=True, + + stft_size=1024, + stft_shift=256, + stft_fading=True, + + bss_iterations=20, + bss_iterations_post=1, + + bf_drop_context=True, + + bf='mvdrSouden_ban', + postfilter=None, +): + from pb_chime5.core_chime6_rttm import ( + Enhancer, + get_database, + WPE, + Activity, + GSS, + Beamformer, + RTTMDatabase, + ) + if session_to_audio_paths is None: + from paderbox.array import intervall as array_intervall + + sessions: set = get_sessions(database_rttm) + assert len(sessions) == 1, sessions + + files = list(Path(database_rttm).parent.glob('*.wav')) + if len(files) == 1: + # Assume multichannel file + files, = files + + session_to_audio_paths = { + sessions.pop(): files + } + elif isinstance(session_to_audio_paths, str): + import paderbox as pb + # Load file, load detects yaml, json, ... (pkl is not allowed -> unsafe) + file = session_to_audio_paths + session_to_audio_paths = pb.io.load(file) + assert isinstance(session_to_audio_paths, dict), (file, session_to_audio_paths) + elif isinstance(session_to_audio_paths, dict): + session_to_audio_paths = session_to_audio_paths + else: + raise NotImplementedError(type(session_to_audio_paths), session_to_audio_paths) + + assert wpe is True or wpe is False, wpe + + class MyRTTMDatabase(RTTMDatabase): + @staticmethod + def example_id(file_id, speaker_id, start, end): + # Don't use the strange CHiME-6 pattern for the example ID. + max_digits = len(str(16000 * 60 * 60 * 10)) # 10h + start = str(start).zfill(max_digits) + end = str(end).zfill(max_digits) + + # return f'{file_id}_{speaker_id}-{start}_{end}' + return f'{speaker_id}_{start}-{end}' + + db = MyRTTMDatabase( + database_rttm, + session_to_audio_paths, + # rttm, audio_paths, alias=alias + ) + + return Enhancer( + db=db, + context_samples=context_samples, + wpe_block=WPE( + taps=wpe_tabs, + delay=wpe_delay, + iterations=wpe_iterations, + psd_context=wpe_psd_context, + ) if wpe else None, + activity=Activity( + garbage_class=activity_garbage_class, + rttm=activity_rttm, + ), + gss_block=GSS( + iterations=bss_iterations, + iterations_post=bss_iterations_post, + verbose=False, + ), + bf_drop_context=bf_drop_context, + bf_block=Beamformer( + type=bf, + postfilter=postfilter, + ), + stft_size=stft_size, + stft_shift=stft_shift, + stft_fading=stft_fading, + ) + + +@experiment.main +def main(_run, storage_dir): + run(_run, storage_dir=storage_dir) + + +@experiment.command +def test_run(_run, storage_dir, test_run=True): + assert test_run is not False, test_run + run(_run, storage_dir=storage_dir, test_run=test_run) + + +@experiment.capture +def run(_run, storage_dir, job_id, number_of_jobs, session_id, test_run=False): + if dlp_mpi.IS_MASTER: + print_config(_run) + + assert job_id >= 1 and job_id <= number_of_jobs, (job_id, number_of_jobs) + + enhancer = get_enhancer() + + if test_run: + print('Database', enhancer.db) + + if test_run is False: + dataset_slice = slice(job_id - 1, None, number_of_jobs) + else: + dataset_slice = test_run + + if dlp_mpi.IS_MASTER: + print('Enhancer:', enhancer) + print(session_id) + + if session_id is None: + session_ids = sorted(get_sessions()) + elif isinstance(session_id, str): + session_ids = [session_id] + elif isinstance(session_id, (tuple, list)): + session_ids = session_id + else: + raise TypeError(type(session_id), session_id) + + for session_id in session_ids: + enhancer.enhance_session( + session_id, + Path(storage_dir) / 'audio', + dataset_slice=dataset_slice, + audio_dir_exist_ok=True, + is_chime=False, + ) + + if dlp_mpi.IS_MASTER: + print('Finished experiment dir:', storage_dir) + + +if __name__ == '__main__': + experiment.run_commandline()