Skip to content

Commit 5d91c2e

Browse files
Merge pull request #113 from vivekvjyn/vivek-beattrack
Added TCN Carnatic beat tracking module
2 parents 7efdfa8 + d837be0 commit 5d91c2e

File tree

14 files changed

+549
-38
lines changed

14 files changed

+549
-38
lines changed

.github/environment-ci.yml

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,42 @@
11
name: compiam-dev
2+
23
channels:
34
- conda-forge
45
- defaults
6+
57
dependencies:
8+
- python=3.9
69
- pip
7-
- "attrs>=23.1.0"
8-
- "matplotlib>=3.0.0"
9-
- "numpy>=1.20.3,<=1.26.4"
10-
- "joblib>=1.2.0"
11-
- "pathlib~=1.0.1"
12-
- "tqdm>=4.66.1"
13-
- "IPython>=7.34.0"
14-
- "ipywidgets>=7.0.0,<8"
15-
- "Jinja2~=3.1.2"
16-
- "configobj~=5.0.6"
17-
- "seaborn"
18-
- "librosa>=0.10.1"
19-
- "scikit-learn==1.5.2"
20-
- "scikit-image~=0.24.0"
21-
- "hmmlearn==0.3.3"
22-
- "fastdtw~=0.3.4"
23-
#######
2410
- libvorbis
25-
- pytest>=7.4.3
26-
#######
11+
2712
- pip:
28-
- "keras<3.0.0"
29-
- "tensorflow>=2.12.0,<2.16"
30-
- "torch==2.0.0"
31-
- "torchaudio==2.0.1"
32-
- "essentia"
33-
- "soundfile>=0.12.1"
34-
- "opencv-python~=4.6.0"
35-
- "mirdata==0.3.9"
36-
- "compmusic==0.4"
37-
- "attrs>=23.1.0"
38-
- "black>=23.3.0"
39-
- "decorator>=5.1.1"
40-
- "future>=0.18.3"
41-
- "testcontainers>=3.7.1"
13+
- keras<3.0.0
14+
- tensorflow>=2.12.0,<2.16
15+
- torch==2.0.0
16+
- torchaudio==2.0.1
17+
- essentia
18+
- soundfile>=0.12.1
19+
- opencv-python~=4.6.0
20+
- mirdata==0.3.9
21+
- compmusic==0.4
22+
- attrs>=23.1.0
23+
- black>=23.3.0
24+
- decorator>=5.1.1
25+
- future>=0.18.3
26+
- testcontainers>=3.7.1
27+
- madmom @ git+https://github.com/vivekvjyn/madmom.git
28+
- matplotlib>=3.0.0
29+
- numpy>=1.20.3,<=1.23.5
30+
- joblib>=1.2.0
31+
- tqdm>=4.66.1
32+
- IPython>=7.34.0
33+
- ipywidgets>=7.0.0,<8
34+
- Jinja2~=3.1.2
35+
- configobj~=5.0.6
36+
- seaborn
37+
- librosa>=0.10.1
38+
- scikit-learn==1.5.2
39+
- scikit-image~=0.24.0
40+
- hmmlearn==0.3.3
41+
- fastdtw~=0.3.4
42+
- pytest>=7.4.3

compiam/data.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,22 @@
125125
},
126126
},
127127
},
128+
"rhythm:tcn-carnatic": {
129+
"module_name": "compiam.rhythm.meter.tcn_carnatic",
130+
"class_name": "TCNTracker",
131+
"default_version": "v1",
132+
"kwargs": {
133+
"v1": {
134+
"model_path": os.path.join(
135+
"models",
136+
"rhythm",
137+
"tcn-carnatic"
138+
),
139+
"download_link": "https://zenodo.org/records/18449067/files/compIAM-TCNCarnatic.zip?download=1",
140+
"download_checksum": "995369933f2a344af0ffa57ea5c15e62",
141+
},
142+
},
143+
},
128144
"structure:dhrupad-bandish-segmentation": {
129145
"module_name": "compiam.structure.segmentation.dhrupad_bandish_segmentation",
130146
"class_name": "DhrupadBandishSegmentation",

compiam/rhythm/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
|--------------------------------|------------------------------------------------|-----------|
55
| Akshara Pulse Tracker Detector | Detect onsets of aksharas in tabla recordings | [1] |
66
| Mnemonic Stroke Transcription | Bol/Solkattu trasncription using HMM | [2] |
7+
| TCN Carnatic | Carnatic meter tracking using TCN | [3] |
78

89

910
[1] Originally implemented by Ajay Srinivasamurthy as part of PyCompMusic - https://github.com/MTG/pycompmusic
1011

11-
[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain.
12+
[2] Gupta, S., Srinivasamurthy, A., Kumar, M., Murthy, H., & Serra, X. (2015, October). Discovery of Syllabic Percussion Patterns in Tabla Solo Recordings. In Proceedings of the 16th International Society for Music Information Retrieval Conference (ISMIR 2015) (pp. 385–391). Malaga, Spain.

compiam/rhythm/meter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from compiam.data import models_dict
55

66
from compiam.rhythm.meter.akshara_pulse_tracker import AksharaPulseTracker
7-
7+
from compiam.rhythm.meter.tcn_carnatic import TCNTracker
88

99
# Show user the available tools
1010
def list_tools():
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import os
2+
import sys
3+
import numpy as np
4+
from typing import Dict
5+
from tqdm import tqdm
6+
from compiam.exceptions import ModelNotTrainedError
7+
8+
from compiam.utils.download import download_remote_model
9+
from compiam.utils import get_logger, WORKDIR
10+
from compiam.io import write_csv
11+
12+
logger = get_logger(__name__)
13+
14+
class TCNTracker(object):
15+
"""TCN beat tracker tuned to Carnatic Music."""
16+
def __init__(self,
17+
post_processor="joint",
18+
model_version=42,
19+
model_path=None,
20+
download_link=None,
21+
download_checksum=None,
22+
gpu=-1):
23+
"""TCN beat tracker init method.
24+
25+
:param post_processor: Post-processing method to use. Choose from 'joint', or 'sequential'.
26+
:param model_version: Version of the pre-trained model to use. Choose from 42, 52, or 62.
27+
:param model_path: path to file to the model weights.
28+
:param download_link: link to the remote pre-trained model.
29+
:param download_checksum: checksum of the model file.
30+
"""
31+
### IMPORTING OPTIONAL DEPENDENCIES
32+
try:
33+
global torch
34+
import torch
35+
except ImportError:
36+
raise ImportError(
37+
"Torch is required to use TCNTracker. "
38+
"Install compIAM with torch support: pip install 'compiam[torch]'"
39+
)
40+
41+
try:
42+
global madmom
43+
import madmom
44+
except ImportError:
45+
raise ImportError(
46+
"Madmom is required to use TCNTracker. "
47+
"Install compIAM with madmom support: pip install 'compiam[madmom]'"
48+
)
49+
###
50+
global MultiTracker, PreProcessor, joint_tracker, sequential_tracker
51+
from compiam.rhythm.meter.tcn_carnatic.model import MultiTracker
52+
from compiam.rhythm.meter.tcn_carnatic.pre import PreProcessor
53+
from compiam.rhythm.meter.tcn_carnatic.post import joint_tracker, sequential_tracker
54+
55+
if post_processor not in ["beat", "joint", "sequential"]:
56+
raise ValueError(f"Invalid post_processor: {post_processor}. Choose from 'joint', or 'sequential'.")
57+
if model_version not in [42, 52, 62]:
58+
raise ValueError(f"Invalid model_version: {model_version}. Choose from 42, 52, or 62.")
59+
60+
self.gpu = gpu
61+
self.device = None
62+
self.select_gpu(gpu)
63+
64+
self.model_path = model_path
65+
self.model_version = f'multitracker_{model_version}.pth'
66+
self.download_link = download_link
67+
self.download_checksum = download_checksum
68+
69+
self.trained = False
70+
self.model = self._build_model()
71+
if self.model_path is not None:
72+
self.load_model(self.model_path)
73+
self.pad_frames = 2
74+
75+
self.post_processor = joint_tracker if post_processor == "joint" else \
76+
sequential_tracker
77+
78+
79+
def _build_model(self):
80+
"""Build the TCN model."""
81+
model = MultiTracker().to(self.device)
82+
model.eval()
83+
return model
84+
85+
86+
def load_model(self, model_path):
87+
"""Load pre-trained model weights."""
88+
if not os.path.exists(os.path.join(model_path, self.model_version)):
89+
self.download_model(model_path) # Downloading model weights
90+
91+
self.model.load_weights(os.path.join(model_path, self.model_version), self.device)
92+
93+
self.model_path = model_path
94+
self.trained = True
95+
96+
97+
def download_model(self, model_path=None, force_overwrite=True):
98+
"""Download pre-trained model."""
99+
download_path = (
100+
#os.sep + os.path.join(*model_path.split(os.sep)[:-2])
101+
model_path
102+
if model_path is not None
103+
else os.path.join(WORKDIR, "models", "rhythm", "tcn-carnatic")
104+
)
105+
# Creating model folder to store the weights
106+
if not os.path.exists(download_path):
107+
os.makedirs(download_path)
108+
download_remote_model(
109+
self.download_link,
110+
self.download_checksum,
111+
download_path,
112+
force_overwrite=force_overwrite,
113+
)
114+
115+
def predict(self, input_data: str, sr: int = 44100, min_bpm=55, max_bpm=230, beats_per_bar=[3, 5, 7, 8]) -> Dict:
116+
"""Run inference on input audio file.
117+
118+
:param input_data: path to audio file or numpy array like audio signal.
119+
:param sr: sampling rate of the input audio signal (default: 44100).
120+
:param min_bpm: minimum BPM for beat tracking (default: 55).
121+
:param max_bpm: maximum BPM for beat tracking (default: 230).
122+
:param beats_per_bar: list of possible beats per bar for downbeat tracking (default: [3, 5, 7, 8]).
123+
124+
:returns: a 2-D list with beats and beat positions.
125+
"""
126+
if self.trained is False:
127+
raise ModelNotTrainedError(
128+
"""Model is not trained. Please load model before running inference!
129+
You can load the pre-trained instance with the load_model wrapper."""
130+
)
131+
132+
features = self.preprocess_audio(input_data, sr)
133+
x = torch.from_numpy(features).to(self.device)
134+
output = self.model(x)
135+
beats_act = output["beats"].squeeze().detach().cpu().numpy()
136+
downbeats_act = output["downbeats"].squeeze().detach().cpu().numpy()
137+
138+
pred = self.post_processor(beats_act, downbeats_act, min_bpm=min_bpm, max_bpm=max_bpm, beats_per_bar=beats_per_bar)
139+
140+
return pred
141+
142+
def preprocess_audio(self, input_data: str, input_sr: int) -> np.ndarray:
143+
"""Preprocess input audio file to extract features for inference.
144+
:param audio_path: Path to the input audio file.
145+
:param input_sr: Sampling rate of the input audio file.
146+
147+
:returns: Preprocessed features as a numpy array.
148+
"""
149+
if isinstance(input_data, str):
150+
if not os.path.exists(input_data):
151+
raise FileNotFoundError("Target audio not found.")
152+
audio, sr = madmom.io.audio.load_audio_file(input_data)
153+
if audio.shape[0] == 2:
154+
audio = audio.mean(axis=0)
155+
signal = madmom.audio.Signal(audio, sr, num_channels=1)
156+
elif isinstance(input_data, np.ndarray):
157+
audio = input_data
158+
if audio.shape[0] == 2:
159+
audio = audio.mean(axis=0)
160+
signal = madmom.audio.Signal(audio, input_sr, num_channels=1)
161+
sr = input_sr
162+
else:
163+
raise ValueError("Input must be path to audio signal or an audio array")
164+
165+
x = PreProcessor(sample_rate=sr)(signal)
166+
167+
pad_start = np.repeat(x[:1], self.pad_frames, axis=0)
168+
pad_stop = np.repeat(x[-1:], self.pad_frames, axis=0)
169+
x_padded = np.concatenate((pad_start, x, pad_stop))
170+
171+
x_final = np.expand_dims(np.expand_dims(x_padded, axis=0), axis=0)
172+
173+
return x_final
174+
175+
@staticmethod
176+
def save_pitch(data, output_path):
177+
"""Calling the write_csv function in compiam.io to write the output beat track in a file
178+
179+
:param data: the data to write
180+
:param output_path: the path where the data is going to be stored
181+
182+
:returns: None
183+
"""
184+
return write_csv(data, output_path)
185+
186+
187+
def select_gpu(self, gpu="-1"):
188+
"""Select the GPU to use for inference.
189+
190+
:param gpu: Id of the available GPU to use (-1 by default, to run on CPU), use string: '0', '1', etc.
191+
:returns: None
192+
"""
193+
if int(gpu) == -1:
194+
self.device = torch.device("cpu")
195+
else:
196+
if torch.cuda.is_available():
197+
self.device = torch.device("cuda:" + str(gpu))
198+
elif torch.backends.mps.is_available():
199+
self.device = torch.device("mps:" + str(gpu))
200+
else:
201+
self.device = torch.device("cpu")
202+
logger.warning("No GPU available. Running on CPU.")
203+
self.gpu = gpu

0 commit comments

Comments
 (0)