|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import numpy as np |
| 4 | +from typing import Dict |
| 5 | +from tqdm import tqdm |
| 6 | +from compiam.exceptions import ModelNotTrainedError |
| 7 | + |
| 8 | +from compiam.utils.download import download_remote_model |
| 9 | +from compiam.utils import get_logger, WORKDIR |
| 10 | +from compiam.io import write_csv |
| 11 | + |
| 12 | +logger = get_logger(__name__) |
| 13 | + |
| 14 | +class TCNTracker(object): |
| 15 | + """TCN beat tracker tuned to Carnatic Music.""" |
| 16 | + def __init__(self, |
| 17 | + post_processor="joint", |
| 18 | + model_version=42, |
| 19 | + model_path=None, |
| 20 | + download_link=None, |
| 21 | + download_checksum=None, |
| 22 | + gpu=-1): |
| 23 | + """TCN beat tracker init method. |
| 24 | +
|
| 25 | + :param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'. |
| 26 | + :param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62. |
| 27 | + :param model_path: path to file to the model weights. |
| 28 | + :param download_link: link to the remote pre-trained model. |
| 29 | + :param download_checksum: checksum of the model file. |
| 30 | + """ |
| 31 | + ### IMPORTING OPTIONAL DEPENDENCIES |
| 32 | + try: |
| 33 | + global torch |
| 34 | + import torch |
| 35 | + except ImportError: |
| 36 | + raise ImportError( |
| 37 | + "Torch is required to use TCNTracker. " |
| 38 | + "Install compIAM with torch support: pip install 'compiam[torch]'" |
| 39 | + ) |
| 40 | + |
| 41 | + try: |
| 42 | + global madmom |
| 43 | + import madmom |
| 44 | + except ImportError: |
| 45 | + raise ImportError( |
| 46 | + "Madmom is required to use TCNTracker. " |
| 47 | + "Install compIAM with madmom support: pip install 'compiam[madmom]'" |
| 48 | + ) |
| 49 | +### |
| 50 | + global MultiTracker, PreProcessor, joint_tracker, sequential_tracker |
| 51 | + from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker |
| 52 | + from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor |
| 53 | + from compiam.rhythm.meter.tcn_carnatic.post import joint_tracker, sequential_tracker |
| 54 | + |
| 55 | + if post_processor not in ["beat", "joint", "sequential"]: |
| 56 | + raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.") |
| 57 | + if model_version not in [42, 52, 62]: |
| 58 | + raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.") |
| 59 | + |
| 60 | + self.gpu = gpu |
| 61 | + self.device = None |
| 62 | + self.select_gpu(gpu) |
| 63 | + |
| 64 | + self.model_path = model_path |
| 65 | + self.model_version = f'multitracker_{model_version}.pth' |
| 66 | + self.download_link = download_link |
| 67 | + self.download_checksum = download_checksum |
| 68 | + |
| 69 | + self.trained = False |
| 70 | + self.model = self._build_model() |
| 71 | + if self.model_path is not None: |
| 72 | + self.load_model(self.model_path) |
| 73 | + self.pad_frames = 2 |
| 74 | + |
| 75 | + self.post_processor = joint_tracker if post_processor == "joint" else \ |
| 76 | + sequential_tracker |
| 77 | + |
| 78 | + |
| 79 | + def _build_model(self): |
| 80 | + """Build the TCN model.""" |
| 81 | + model = MultiTracker().to(self.device) |
| 82 | + model.eval() |
| 83 | + return model |
| 84 | + |
| 85 | + |
| 86 | + def load_model(self, model_path): |
| 87 | + """Load pre-trained model weights.""" |
| 88 | + if not os.path.exists(os.path.join(model_path, self.model_version)): |
| 89 | + self.download_model(model_path) # Downloading model weights |
| 90 | + |
| 91 | + self.model.load_weights(os.path.join(model_path, self.model_version), self.device) |
| 92 | + |
| 93 | + self.model_path = model_path |
| 94 | + self.trained = True |
| 95 | + |
| 96 | + |
| 97 | + def download_model(self, model_path=None, force_overwrite=True): |
| 98 | + """Download pre-trained model.""" |
| 99 | + download_path = ( |
| 100 | + #os.sep + os.path.join(*model_path.split(os.sep)[:-2]) |
| 101 | + model_path |
| 102 | + if model_path is not None |
| 103 | + else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic") |
| 104 | + ) |
| 105 | + # Creating model folder to store the weights |
| 106 | + if not os.path.exists(download_path): |
| 107 | + os.makedirs(download_path) |
| 108 | + download_remote_model( |
| 109 | + self.download_link, |
| 110 | + self.download_checksum, |
| 111 | + download_path, |
| 112 | + force_overwrite=force_overwrite, |
| 113 | + ) |
| 114 | + |
| 115 | + def predict(self, input_data: str, sr: int = 44100, min_bpm=55, max_bpm=230, beats_per_bar=[3, 5, 7, 8]) -> Dict: |
| 116 | + """Run inference on input audio file. |
| 117 | +
|
| 118 | + :param input_data: path to audio file or numpy array like audio signal. |
| 119 | + :param sr: sampling rate of the input audio signal (default: 44100). |
| 120 | + :param min_bpm: minimum BPM for beat tracking (default: 55). |
| 121 | + :param max_bpm: maximum BPM for beat tracking (default: 230). |
| 122 | + :param beats_per_bar: list of possible beats per bar for downbeat tracking (default: [3, 5, 7, 8]). |
| 123 | +
|
| 124 | + :returns: a 2-D list with beats and beat positions. |
| 125 | + """ |
| 126 | + if self.trained is False: |
| 127 | + raise ModelNotTrainedError( |
| 128 | + """Model is not trained. Please load model before running inference! |
| 129 | + You can load the pre-trained instance with the load_model wrapper.""" |
| 130 | + ) |
| 131 | + |
| 132 | + features = self.preprocess_audio(input_data, sr) |
| 133 | + x = torch.from_numpy(features).to(self.device) |
| 134 | + output = self.model(x) |
| 135 | + beats_act = output["beats"].squeeze().detach().cpu().numpy() |
| 136 | + downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy() |
| 137 | + |
| 138 | + pred = self.post_processor(beats_act, downbeats_act, min_bpm=min_bpm, max_bpm=max_bpm, beats_per_bar=beats_per_bar) |
| 139 | + |
| 140 | + return pred |
| 141 | + |
| 142 | + def preprocess_audio(self, input_data: str, input_sr: int) -> np.ndarray: |
| 143 | + """Preprocess input audio file to extract features for inference. |
| 144 | + :param audio_path: Path to the input audio file. |
| 145 | + :param input_sr: Sampling rate of the input audio file. |
| 146 | +
|
| 147 | + :returns: Preprocessed features as a numpy array. |
| 148 | + """ |
| 149 | + if isinstance(input_data, str): |
| 150 | + if not os.path.exists(input_data): |
| 151 | + raise FileNotFoundError("Target audio not found.") |
| 152 | + audio, sr = madmom.io.audio.load_audio_file(input_data) |
| 153 | + if audio.shape[0] == 2: |
| 154 | + audio = audio.mean(axis=0) |
| 155 | + signal = madmom.audio.Signal(audio, sr, num_channels=1) |
| 156 | + elif isinstance(input_data, np.ndarray): |
| 157 | + audio = input_data |
| 158 | + if audio.shape[0] == 2: |
| 159 | + audio = audio.mean(axis=0) |
| 160 | + signal = madmom.audio.Signal(audio, input_sr, num_channels=1) |
| 161 | + sr = input_sr |
| 162 | + else: |
| 163 | + raise ValueError("Input must be path to audio signal or an audio array") |
| 164 | + |
| 165 | + x = PreProcessor(sample_rate=sr)(signal) |
| 166 | + |
| 167 | + pad_start = np.repeat(x[:1], self.pad_frames, axis=0) |
| 168 | + pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0) |
| 169 | + x_padded = np.concatenate((pad_start, x, pad_stop)) |
| 170 | + |
| 171 | + x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0) |
| 172 | + |
| 173 | + return x_final |
| 174 | + |
| 175 | + @staticmethod |
| 176 | + def save_pitch(data, output_path): |
| 177 | + """Calling the write_csv function in compiam.io to write the output beat track in a file |
| 178 | +
|
| 179 | + :param data: the data to write |
| 180 | + :param output_path: the path where the data is going to be stored |
| 181 | +
|
| 182 | + :returns: None |
| 183 | + """ |
| 184 | + return write_csv(data, output_path) |
| 185 | + |
| 186 | + |
| 187 | + def select_gpu(self, gpu="-1"): |
| 188 | + """Select the GPU to use for inference. |
| 189 | +
|
| 190 | + :param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc. |
| 191 | + :returns: None |
| 192 | + """ |
| 193 | + if int(gpu) == -1: |
| 194 | + self.device = torch.device("cpu") |
| 195 | + else: |
| 196 | + if torch.cuda.is_available(): |
| 197 | + self.device = torch.device("cuda:" + str(gpu)) |
| 198 | + elif torch.backends.mps.is_available(): |
| 199 | + self.device = torch.device("mps:" + str(gpu)) |
| 200 | + else: |
| 201 | + self.device = torch.device("cpu") |
| 202 | + logger.warning("No GPU available. Running on CPU.") |
| 203 | + self.gpu = gpu |
0 commit comments