From 75ff62e82f46b291f69e4d0d530b8fa88cf17518 Mon Sep 17 00:00:00 2001 From: Pascal Seeber Date: Thu, 3 Jul 2025 15:17:43 +0200 Subject: [PATCH 1/4] add audio tasks add audio training --- configs/audio_classification/hub_dataset.yml | 33 + configs/audio_classification/local.yml | 43 + configs/audio_detection/hub_dataset.yml | 47 + configs/audio_detection/local.yml | 43 + configs/audio_segmentation/hub_dataset.yml | 49 + configs/audio_segmentation/local.yml | 44 + src/autotrain/app/api_routes.py | 65 +- src/autotrain/app/colab.py | 4 +- src/autotrain/app/models.py | 186 ++++ src/autotrain/app/params.py | 130 ++- src/autotrain/app/templates/index.html | 19 +- src/autotrain/app/ui_routes.py | 62 ++ src/autotrain/app/utils.py | 4 +- src/autotrain/backends/base.py | 12 + src/autotrain/cli/autotrain.py | 6 + src/autotrain/cli/run_audio_classification.py | 108 +++ src/autotrain/cli/run_audio_detection.py | 106 +++ src/autotrain/cli/run_audio_segmentation.py | 106 +++ src/autotrain/commands.py | 143 ++- src/autotrain/dataset.py | 443 +++++++++ src/autotrain/preprocessor/audio.py | 841 ++++++++++++++++++ src/autotrain/preprocessor/vision.py | 12 +- src/autotrain/project.py | 280 ++++++ src/autotrain/tasks.py | 7 + .../trainers/audio_classification/__init__.py | 3 + .../trainers/audio_classification/__main__.py | 284 ++++++ .../trainers/audio_classification/dataset.py | 121 +++ .../trainers/audio_classification/params.py | 78 ++ .../trainers/audio_classification/utils.py | 216 +++++ .../trainers/audio_detection/__init__.py | 3 + .../trainers/audio_detection/__main__.py | 276 ++++++ .../trainers/audio_detection/dataset.py | 131 +++ .../trainers/audio_detection/params.py | 88 ++ .../trainers/audio_detection/utils.py | 209 +++++ .../trainers/audio_segmentation/__init__.py | 3 + .../trainers/audio_segmentation/__main__.py | 385 ++++++++ .../trainers/audio_segmentation/dataset.py | 101 +++ .../trainers/audio_segmentation/params.py | 84 ++ .../trainers/audio_segmentation/utils.py | 276 ++++++ .../trainers/image_classification/__main__.py | 5 + src/autotrain/trainers/tabular/utils.py | 9 +- src/autotrain/utils.py | 9 + 42 files changed, 5051 insertions(+), 23 deletions(-) create mode 100644 configs/audio_classification/hub_dataset.yml create mode 100644 configs/audio_classification/local.yml create mode 100644 configs/audio_detection/hub_dataset.yml create mode 100644 configs/audio_detection/local.yml create mode 100644 configs/audio_segmentation/hub_dataset.yml create mode 100644 configs/audio_segmentation/local.yml create mode 100644 src/autotrain/cli/run_audio_classification.py create mode 100644 src/autotrain/cli/run_audio_detection.py create mode 100644 src/autotrain/cli/run_audio_segmentation.py create mode 100644 src/autotrain/preprocessor/audio.py create mode 100644 src/autotrain/trainers/audio_classification/__init__.py create mode 100644 src/autotrain/trainers/audio_classification/__main__.py create mode 100644 src/autotrain/trainers/audio_classification/dataset.py create mode 100644 src/autotrain/trainers/audio_classification/params.py create mode 100644 src/autotrain/trainers/audio_classification/utils.py create mode 100644 src/autotrain/trainers/audio_detection/__init__.py create mode 100644 src/autotrain/trainers/audio_detection/__main__.py create mode 100644 src/autotrain/trainers/audio_detection/dataset.py create mode 100644 src/autotrain/trainers/audio_detection/params.py create mode 100644 src/autotrain/trainers/audio_detection/utils.py create mode 100644 src/autotrain/trainers/audio_segmentation/__init__.py create mode 100644 src/autotrain/trainers/audio_segmentation/__main__.py create mode 100644 src/autotrain/trainers/audio_segmentation/dataset.py create mode 100644 src/autotrain/trainers/audio_segmentation/params.py create mode 100644 src/autotrain/trainers/audio_segmentation/utils.py diff --git a/configs/audio_classification/hub_dataset.yml b/configs/audio_classification/hub_dataset.yml new file mode 100644 index 0000000000..9e7efd132c --- /dev/null +++ b/configs/audio_classification/hub_dataset.yml @@ -0,0 +1,33 @@ +task: audio-classification +base_model: facebook/wav2vec2-base +project_name: my-autotrain-audio-clf +log: tensorboard +backend: local + +data_path: superb +train_split: train +valid_split: validation + +column_mapping: + audio_column: audio + target_column: label + +parameters: + learning_rate: 3e-5 + epochs: 5 + batch_size: 8 + warmup_ratio: 0.1 + weight_decay: 0.01 + mixed_precision: fp16 + gradient_accumulation: 1 + auto_find_batch_size: false + push_to_hub: false + logging_steps: -1 + eval_strategy: epoch + save_total_limit: 1 + early_stopping_patience: 5 + early_stopping_threshold: 0.01 + max_length: 480000 # 30 seconds at 16kHz + sampling_rate: 16000 + feature_extractor_normalize: true + feature_extractor_return_attention_mask: true \ No newline at end of file diff --git a/configs/audio_classification/local.yml b/configs/audio_classification/local.yml new file mode 100644 index 0000000000..0200a4fc0d --- /dev/null +++ b/configs/audio_classification/local.yml @@ -0,0 +1,43 @@ +task: audio-classification +base_model: facebook/wav2vec2-base +project_name: my-autotrain-audio-clf-local +log: tensorboard +backend: local + +# Local data path - should contain audio files and CSV with labels +data_path: /path/to/audio/dataset.csv +train_split: train +valid_split: validation + +column_mapping: + audio_column: audio_path + target_column: label + +parameters: + learning_rate: 3e-5 + epochs: 5 + batch_size: 8 + warmup_ratio: 0.1 + weight_decay: 0.01 + mixed_precision: fp16 + gradient_accumulation: 1 + auto_find_batch_size: false + push_to_hub: false + logging_steps: -1 + eval_strategy: epoch + save_total_limit: 1 + early_stopping_patience: 5 + early_stopping_threshold: 0.01 + max_length: 480000 # 30 seconds at 16kHz + sampling_rate: 16000 + feature_extractor_normalize: true + feature_extractor_return_attention_mask: true + +# Note: For local audio classification: +# - audio_path column should contain paths to audio files (.wav, .mp3, .flac) +# - label column should contain class labels (strings or integers) +# - CSV format: audio_path,label +# Example: +# /path/to/audio1.wav,speech +# /path/to/audio2.wav,music +# /path/to/audio3.wav,noise \ No newline at end of file diff --git a/configs/audio_detection/hub_dataset.yml b/configs/audio_detection/hub_dataset.yml new file mode 100644 index 0000000000..e2040dbced --- /dev/null +++ b/configs/audio_detection/hub_dataset.yml @@ -0,0 +1,47 @@ +task: audio-detection +base_model: facebook/wav2vec2-base +project_name: my-autotrain-audio-detection-hub +log: tensorboard +backend: local + +# Hub dataset configuration +data_path: audiofolder/audio_detection_dataset +train_split: train +valid_split: validation + +column_mapping: + audio_column: audio + events_column: events + +parameters: + learning_rate: 3e-5 + epochs: 3 + batch_size: 8 + warmup_ratio: 0.1 + weight_decay: 0.01 + mixed_precision: fp16 + gradient_accumulation: 1 + auto_find_batch_size: false + push_to_hub: false + logging_steps: -1 + eval_strategy: epoch + save_total_limit: 1 + early_stopping_patience: 5 + early_stopping_threshold: 0.01 + max_length: 480000 # 30 seconds at 16kHz + sampling_rate: 16000 + event_overlap_threshold: 0.5 # IoU threshold for overlapping events + confidence_threshold: 0.1 # Minimum confidence threshold for event detection + +# Hub settings +hub: + username: ${HF_USERNAME} + token: ${HF_TOKEN} + push_to_hub: true + +# Note: For hub audio detection datasets: +# - The dataset should have 'audio' and 'events' columns +# - Events should be formatted as a list of dictionaries: +# [{"start": 0.0, "end": 2.5, "label": "speech"}, {"start": 2.5, "end": 3.0, "label": "silence"}] +# - Audio column should contain audio data (array or file paths) +# - Similar to object detection but for temporal events in audio \ No newline at end of file diff --git a/configs/audio_detection/local.yml b/configs/audio_detection/local.yml new file mode 100644 index 0000000000..8db2f96d8d --- /dev/null +++ b/configs/audio_detection/local.yml @@ -0,0 +1,43 @@ +task: audio-detection +base_model: facebook/wav2vec2-base +project_name: my-autotrain-audio-detection-local +log: tensorboard +backend: local + +# Local data path - should contain audio files and CSV with event annotations +data_path: /path/to/audio/dataset.csv +train_split: train +valid_split: validation + +column_mapping: + audio_column: audio_path + events_column: events + +parameters: + learning_rate: 3e-5 + epochs: 3 + batch_size: 8 + warmup_ratio: 0.1 + weight_decay: 0.01 + mixed_precision: fp16 + gradient_accumulation: 1 + auto_find_batch_size: false + push_to_hub: false + logging_steps: -1 + eval_strategy: epoch + save_total_limit: 1 + early_stopping_patience: 5 + early_stopping_threshold: 0.01 + max_length: 480000 # 30 seconds at 16kHz + sampling_rate: 16000 + event_overlap_threshold: 0.5 # IoU threshold for overlapping events + confidence_threshold: 0.1 # Minimum confidence threshold for event detection + +# Note: For local audio detection: +# - audio_path column should contain paths to audio files (.wav, .mp3, .flac) +# - events column should contain event annotations as JSON list +# - CSV format: audio_path,events +# Example: +# /path/to/audio1.wav,"[{""start"": 0.0, ""end"": 2.5, ""label"": ""speech""}, {""start"": 2.5, ""end"": 3.0, ""label"": ""silence""}]" +# /path/to/audio2.wav,"[{""start"": 1.0, ""end"": 4.0, ""label"": ""music""}, {""start"": 4.0, ""end"": 5.0, ""label"": ""noise""}]" +# /path/to/audio3.wav,"[{""start"": 0.5, ""end"": 3.5, ""label"": ""car_crash""}]" \ No newline at end of file diff --git a/configs/audio_segmentation/hub_dataset.yml b/configs/audio_segmentation/hub_dataset.yml new file mode 100644 index 0000000000..b6f2c2c968 --- /dev/null +++ b/configs/audio_segmentation/hub_dataset.yml @@ -0,0 +1,49 @@ +task: audio_segmentation +base_model: microsoft/speecht5_vc +project_name: autotrain-audio-segmentation-hub +log: tensorboard +backend: spaces-a10g-large + +# Hub dataset configuration +data_path: audiofolder/audio_segmentation_dataset +train_split: train +valid_split: validation +audio_column: audio +target_column: segments + +# Training parameters +epochs: 10 +batch_size: 16 +lr: 2e-5 +scheduler: cosine +optimizer: adamw_torch +weight_decay: 0.01 +warmup_ratio: 0.05 +gradient_accumulation: 2 +mixed_precision: fp16 +logging_steps: 25 +save_total_limit: 5 +eval_strategy: steps +early_stopping_patience: 5 +early_stopping_threshold: 0.005 + +# Audio specific parameters +max_length: 320000 # 20 seconds at 16kHz (shorter for better memory usage) +sampling_rate: 16000 +feature_extractor_normalize: true +feature_extractor_return_attention_mask: true + +# Segmentation specific parameters +segment_length: 3.0 # seconds (shorter segments for better granularity) +overlap_length: 0.3 # seconds +min_segment_length: 0.5 # seconds + +# Model parameters +seed: 42 +max_grad_norm: 1.0 +auto_find_batch_size: true +push_to_hub: true + +# Hub settings +token: ${HF_TOKEN} +username: ${HF_USERNAME} \ No newline at end of file diff --git a/configs/audio_segmentation/local.yml b/configs/audio_segmentation/local.yml new file mode 100644 index 0000000000..ca9e78f124 --- /dev/null +++ b/configs/audio_segmentation/local.yml @@ -0,0 +1,44 @@ +task: audio_segmentation +base_model: microsoft/speecht5_vc +project_name: autotrain-audio-segmentation-local +log: tensorboard +backend: local + +data_path: data/ +train_split: train +valid_split: validation +audio_column: audio_path +target_column: segments + +# Training parameters +epochs: 5 +batch_size: 8 +lr: 3e-5 +scheduler: linear +optimizer: adamw_torch +weight_decay: 0.01 +warmup_ratio: 0.1 +gradient_accumulation: 1 +mixed_precision: fp16 +logging_steps: 50 +save_total_limit: 3 +eval_strategy: epoch +early_stopping_patience: 3 +early_stopping_threshold: 0.01 + +# Audio specific parameters +max_length: 480000 # 30 seconds at 16kHz +sampling_rate: 16000 +feature_extractor_normalize: true +feature_extractor_return_attention_mask: true + +# Segmentation specific parameters +segment_length: 5.0 # seconds +overlap_length: 0.5 # seconds +min_segment_length: 1.0 # seconds + +# Model parameters +seed: 42 +max_grad_norm: 1.0 +auto_find_batch_size: false +push_to_hub: false \ No newline at end of file diff --git a/src/autotrain/app/api_routes.py b/src/autotrain/app/api_routes.py index 8563ab15b8..7a5b19e8b8 100644 --- a/src/autotrain/app/api_routes.py +++ b/src/autotrain/app/api_routes.py @@ -11,6 +11,9 @@ from autotrain.app.params import HIDDEN_PARAMS, PARAMS, AppParams from autotrain.app.utils import token_verification from autotrain.project import AutoTrainProject +from autotrain.trainers.audio_classification.params import AudioClassificationParams +from autotrain.trainers.audio_detection.params import AudioDetectionParams +from autotrain.trainers.audio_segmentation.params import AudioSegmentationParams from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -25,7 +28,7 @@ from autotrain.trainers.vlm.params import VLMTrainingParams -FIELDS_TO_EXCLUDE = HIDDEN_PARAMS + ["push_to_hub"] +FIELDS_TO_EXCLUDE = HIDDEN_PARAMS def create_api_base_model(base_class, class_name): @@ -108,10 +111,11 @@ def create_api_base_model(base_class, class_name): SentenceTransformersParamsAPI = create_api_base_model(SentenceTransformersParams, "SentenceTransformersParamsAPI") ImageRegressionParamsAPI = create_api_base_model(ImageRegressionParams, "ImageRegressionParamsAPI") VLMTrainingParamsAPI = create_api_base_model(VLMTrainingParams, "VLMTrainingParamsAPI") -ExtractiveQuestionAnsweringParamsAPI = create_api_base_model( - ExtractiveQuestionAnsweringParams, "ExtractiveQuestionAnsweringParamsAPI" -) +ExtractiveQuestionAnsweringParamsAPI = create_api_base_model(ExtractiveQuestionAnsweringParams, "ExtractiveQuestionAnsweringParamsAPI") ObjectDetectionParamsAPI = create_api_base_model(ObjectDetectionParams, "ObjectDetectionParamsAPI") +AudioClassificationParamsAPI = create_api_base_model(AudioClassificationParams, "AudioClassificationParamsAPI") +AudioSegmentationParamsAPI = create_api_base_model(AudioSegmentationParams, "AudioSegmentationParamsAPI") +AudioDetectionParamsAPI = create_api_base_model(AudioDetectionParams, "AudioDetectionParamsAPI") class LLMSFTColumnMapping(BaseModel): @@ -224,6 +228,21 @@ class ObjectDetectionColumnMapping(BaseModel): objects_column: str +class AudioClassificationColumnMapping(BaseModel): + audio_column: str + target_column: str + + +class AudioSegmentationColumnMapping(BaseModel): + audio_column: str + target_column: str + + +class AudioDetectionColumnMapping(BaseModel): + audio_column: str + events_column: str + + class APICreateProjectModel(BaseModel): """ APICreateProjectModel is a Pydantic model that defines the schema for creating a project. @@ -275,6 +294,8 @@ class APICreateProjectModel(BaseModel): "vlm:vqa", "extractive-question-answering", "image-object-detection", + "audio-classification", + "audio-segmentation", ] base_model: str hardware: Literal[ @@ -312,6 +333,9 @@ class APICreateProjectModel(BaseModel): VLMTrainingParamsAPI, ExtractiveQuestionAnsweringParamsAPI, ObjectDetectionParamsAPI, + AudioClassificationParamsAPI, + AudioSegmentationParamsAPI, + AudioDetectionParamsAPI, ] username: str column_mapping: Optional[ @@ -337,6 +361,9 @@ class APICreateProjectModel(BaseModel): VLMColumnMapping, ExtractiveQuestionAnsweringColumnMapping, ObjectDetectionColumnMapping, + AudioClassificationColumnMapping, + AudioSegmentationColumnMapping, + AudioDetectionColumnMapping, ] ] = None hub_dataset: str @@ -534,6 +561,30 @@ def validate_column_mapping(cls, values): if not values.get("column_mapping").get("objects_column"): raise ValueError("objects_column is required for image-object-detection") values["column_mapping"] = ObjectDetectionColumnMapping(**values["column_mapping"]) + elif values.get("task") == "audio-classification": + if not values.get("column_mapping"): + raise ValueError("column_mapping is required for audio-classification") + if not values.get("column_mapping").get("audio_column"): + raise ValueError("audio_column is required for audio-classification") + if not values.get("column_mapping").get("target_column"): + raise ValueError("target_column is required for audio-classification") + values["column_mapping"] = AudioClassificationColumnMapping(**values["column_mapping"]) + elif values.get("task") == "audio-segmentation": + if not values.get("column_mapping"): + raise ValueError("column_mapping is required for audio-segmentation") + if not values.get("column_mapping").get("audio_column"): + raise ValueError("audio_column is required for audio-segmentation") + if not values.get("column_mapping").get("target_column"): + raise ValueError("target_column is required for audio-segmentation") + values["column_mapping"] = AudioSegmentationColumnMapping(**values["column_mapping"]) + elif values.get("task") == "audio-detection": + if not values.get("column_mapping"): + raise ValueError("column_mapping is required for audio-detection") + if not values.get("column_mapping").get("audio_column"): + raise ValueError("audio_column is required for audio-detection") + if not values.get("column_mapping").get("events_column"): + raise ValueError("events_column is required for audio-detection") + values["column_mapping"] = AudioDetectionColumnMapping(**values["column_mapping"]) return values @model_validator(mode="before") @@ -573,6 +624,12 @@ def validate_params(cls, values): values["params"] = ExtractiveQuestionAnsweringParamsAPI(**values["params"]) elif values.get("task") == "image-object-detection": values["params"] = ObjectDetectionParamsAPI(**values["params"]) + elif values.get("task") == "audio-classification": + values["params"] = AudioClassificationParamsAPI(**values["params"]) + elif values.get("task") == "audio-segmentation": + values["params"] = AudioSegmentationParamsAPI(**values["params"]) + elif values.get("task") == "audio-detection": + values["params"] = AudioDetectionParamsAPI(**values["params"]) return values diff --git a/src/autotrain/app/colab.py b/src/autotrain/app/colab.py index 2193ba048f..a54a6d6465 100644 --- a/src/autotrain/app/colab.py +++ b/src/autotrain/app/colab.py @@ -68,7 +68,7 @@ def colab_app(): def _get_params(task, param_type): _p = get_task_params(task, param_type=param_type) - _p["push_to_hub"] = True + _p["push_to_hub"] = False _p = json.dumps(_p, indent=4) return _p @@ -342,7 +342,7 @@ def start_training(b): if chat_template is not None: params_val = {k: v for k, v in params_val.items() if k != "chat_template"} - push_to_hub = params_val.get("push_to_hub", True) + push_to_hub = params_val.get("push_to_hub", False) if "push_to_hub" in params_val: params_val = {k: v for k, v in params_val.items() if k != "push_to_hub"} diff --git a/src/autotrain/app/models.py b/src/autotrain/app/models.py index 1d1f658113..caba63ec57 100644 --- a/src/autotrain/app/models.py +++ b/src/autotrain/app/models.py @@ -333,6 +333,189 @@ def _fetch_vlm_models(): return hub_models +def _fetch_audio_classification_models(): + """ + Fetches and sorts audio classification models from the Hugging Face model hub. + + This function retrieves models for the task "audio-classification" + from the Hugging Face model hub, sorts them by the number of downloads. + Additionally, it fetches trending models based on the number + of likes in the past 7 days, sorts them, and places them at the beginning of the list + if they are not already included. + + Returns: + list: A sorted list of model identifiers from the Hugging Face model hub. + """ + hub_models = list( + list_models( + task="audio-classification", + library="transformers", + sort="downloads", + direction=-1, + limit=100, + full=False, + ) + ) + hub_models = get_sorted_models(hub_models) + + trending_models = list( + list_models( + task="audio-classification", + library="transformers", + sort="likes7d", + direction=-1, + limit=30, + full=False, + ) + ) + if len(trending_models) > 0: + trending_models = get_sorted_models(trending_models) + hub_models = [m for m in hub_models if m not in trending_models] + hub_models = trending_models + hub_models + + return hub_models + + +def _fetch_audio_segmentation_models(): + """ + Fetches and sorts audio segmentation models from the Hugging Face model hub. + + This function retrieves models suitable for audio segmentation tasks such as + speaker diarization, voice activity detection, and speech/music segmentation. + It includes audio classification models that can be fine-tuned for segmentation. + + Returns: + list: A sorted list of model identifiers from the Hugging Face model hub. + """ + # Get audio classification models (can be used for segmentation) + hub_models1 = list( + list_models( + task="audio-classification", + library="transformers", + sort="downloads", + direction=-1, + limit=50, + full=False, + ) + ) + + # Get automatic speech recognition models (useful for segmentation) + hub_models2 = list( + list_models( + task="automatic-speech-recognition", + library="transformers", + sort="downloads", + direction=-1, + limit=50, + full=False, + ) + ) + + hub_models = list(hub_models1) + list(hub_models2) + hub_models = get_sorted_models(hub_models) + + # Get trending models + trending_models1 = list( + list_models( + task="audio-classification", + library="transformers", + sort="likes7d", + direction=-1, + limit=15, + full=False, + ) + ) + + trending_models2 = list( + list_models( + task="automatic-speech-recognition", + library="transformers", + sort="likes7d", + direction=-1, + limit=15, + full=False, + ) + ) + + trending_models = list(trending_models1) + list(trending_models2) + if len(trending_models) > 0: + trending_models = get_sorted_models(trending_models) + hub_models = [m for m in hub_models if m not in trending_models] + hub_models = trending_models + hub_models + + return hub_models + + +def _fetch_audio_detection_models(): + """ + Fetches and sorts audio detection models from the Hugging Face model hub. + + This function retrieves models suitable for audio detection tasks such as + event detection, audio classification, and temporal audio analysis. + It includes audio classification models that can be fine-tuned for detection. + + Returns: + list: A sorted list of model identifiers from the Hugging Face model hub. + """ + # Get audio classification models (can be used for detection) + hub_models1 = list( + list_models( + task="audio-classification", + library="transformers", + sort="downloads", + direction=-1, + limit=50, + full=False, + ) + ) + + # Get automatic speech recognition models (useful for audio analysis) + hub_models2 = list( + list_models( + task="automatic-speech-recognition", + library="transformers", + sort="downloads", + direction=-1, + limit=30, + full=False, + ) + ) + + hub_models = list(hub_models1) + list(hub_models2) + hub_models = get_sorted_models(hub_models) + + # Get trending models + trending_models1 = list( + list_models( + task="audio-classification", + library="transformers", + sort="likes7d", + direction=-1, + limit=15, + full=False, + ) + ) + + trending_models2 = list( + list_models( + task="automatic-speech-recognition", + library="transformers", + sort="likes7d", + direction=-1, + limit=10, + full=False, + ) + ) + + trending_models = list(trending_models1) + list(trending_models2) + if len(trending_models) > 0: + trending_models = get_sorted_models(trending_models) + hub_models = [m for m in hub_models if m not in trending_models] + hub_models = trending_models + hub_models + + return hub_models + + def fetch_models(): _mc = collections.defaultdict(list) _mc["text-classification"] = _fetch_text_classification_models() @@ -346,6 +529,9 @@ def fetch_models(): _mc["sentence-transformers"] = _fetch_st_models() _mc["vlm"] = _fetch_vlm_models() _mc["extractive-qa"] = _fetch_text_classification_models() + _mc["audio-classification"] = _fetch_audio_classification_models() + _mc["audio-segmentation"] = _fetch_audio_segmentation_models() + _mc["audio-detection"] = _fetch_audio_detection_models() # tabular-classification _mc["tabular-classification"] = [ diff --git a/src/autotrain/app/params.py b/src/autotrain/app/params.py index a6f4addbc5..bebb1a5d84 100644 --- a/src/autotrain/app/params.py +++ b/src/autotrain/app/params.py @@ -2,6 +2,9 @@ from dataclasses import dataclass from typing import Optional +from autotrain.trainers.audio_classification.params import AudioClassificationParams +from autotrain.trainers.audio_detection.params import AudioDetectionParams +from autotrain.trainers.audio_segmentation.params import AudioSegmentationParams from autotrain.trainers.clm.params import LLMTrainingParams from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -67,7 +70,6 @@ "answer_column", ] - PARAMS = {} PARAMS["llm"] = LLMTrainingParams( target_modules="all-linear", @@ -135,6 +137,18 @@ max_seq_length=512, max_doc_stride=128, ).model_dump() +PARAMS["audio-classification"] = AudioClassificationParams( + mixed_precision="fp16", + log="tensorboard", +).model_dump() +PARAMS["audio-detection"] = AudioDetectionParams( + mixed_precision="fp16", + log="tensorboard", +).model_dump() +PARAMS["audio-segmentation"] = AudioSegmentationParams( + mixed_precision="fp16", + log="tensorboard", +).model_dump() @dataclass @@ -216,6 +230,12 @@ def munge(self): return self._munge_params_vlm() elif self.task == "extractive-qa": return self._munge_params_extractive_qa() + elif self.task == "audio-classification": + return self._munge_params_audio_clf() + elif self.task == "audio-detection": + return self._munge_params_audio_det() + elif self.task == "audio-segmentation": + return self._munge_params_audio_seg() else: raise ValueError(f"Unknown task: {self.task}") @@ -488,6 +508,54 @@ def _munge_params_tabular(self): return TabularParams(**_params) + def _munge_params_audio_clf(self): + _params = self._munge_common_params() + _params["model"] = self.base_model + if "log" not in _params: + _params["log"] = "tensorboard" + if not self.using_hub_dataset: + _params["audio_column"] = "autotrain_audio" + _params["target_column"] = "autotrain_label" + _params["valid_split"] = "validation" + else: + _params["audio_column"] = self.column_mapping.get("audio" if not self.api else "audio_column", "audio") + _params["target_column"] = self.column_mapping.get("label" if not self.api else "target_column", "label") + _params["train_split"] = self.train_split + _params["valid_split"] = self.valid_split + return AudioClassificationParams(**_params) + + def _munge_params_audio_det(self): + _params = self._munge_common_params() + _params["model"] = self.base_model + if "log" not in _params: + _params["log"] = "tensorboard" + if not self.using_hub_dataset: + _params["audio_column"] = "autotrain_audio" + _params["events_column"] = "autotrain_events" + _params["valid_split"] = "validation" + else: + _params["audio_column"] = self.column_mapping.get("audio" if not self.api else "audio_column", "audio") + _params["events_column"] = self.column_mapping.get("events" if not self.api else "events_column", "events") + _params["train_split"] = self.train_split + _params["valid_split"] = self.valid_split + return AudioDetectionParams(**_params) + + def _munge_params_audio_seg(self): + _params = self._munge_common_params() + _params["model"] = self.base_model + if "log" not in _params: + _params["log"] = "tensorboard" + if not self.using_hub_dataset: + _params["audio_column"] = "autotrain_audio" + _params["target_column"] = "autotrain_label" + _params["valid_split"] = "validation" + else: + _params["audio_column"] = self.column_mapping.get("audio" if not self.api else "audio_column", "audio") + _params["target_column"] = self.column_mapping.get("label" if not self.api else "target_column", "label") + _params["train_split"] = self.train_split + _params["valid_split"] = self.valid_split + return AudioSegmentationParams(**_params) + def get_task_params(task, param_type): """ @@ -735,5 +803,65 @@ def get_task_params(task, param_type): "early_stopping_threshold", ] task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} + if task == "audio-classification" and param_type == "basic": + more_hidden_params = [ + "warmup_ratio", + "weight_decay", + "max_grad_norm", + "seed", + "logging_steps", + "auto_find_batch_size", + "save_total_limit", + "eval_strategy", + "early_stopping_patience", + "early_stopping_threshold", + "feature_extractor_normalize", + "feature_extractor_return_attention_mask", + "gradient_accumulation", + "max_length", + "sampling_rate", + ] + task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} + if task == "audio-segmentation" and param_type == "basic": + more_hidden_params = [ + "warmup_ratio", + "weight_decay", + "max_grad_norm", + "seed", + "logging_steps", + "auto_find_batch_size", + "save_total_limit", + "eval_strategy", + "early_stopping_patience", + "early_stopping_threshold", + "feature_extractor_normalize", + "feature_extractor_return_attention_mask", + "gradient_accumulation", + "max_length", + "sampling_rate", + "segment_length", + "overlap_length", + "min_segment_length", + ] + task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} + if task == "audio-detection" and param_type == "basic": + more_hidden_params = [ + "warmup_ratio", + "weight_decay", + "max_grad_norm", + "seed", + "logging_steps", + "auto_find_batch_size", + "save_total_limit", + "eval_strategy", + "early_stopping_patience", + "early_stopping_threshold", + "gradient_accumulation", + "max_length", + "sampling_rate", + "event_overlap_threshold", + "confidence_threshold", + ] + task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params} return task_params diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 0ee5226c9d..c513389960 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -84,6 +84,18 @@ fields = ['image', 'label']; fieldNames = ['image', 'target']; break; + case 'audio-classification': + fields = ['audio', 'label']; + fieldNames = ['audio_path', 'intent']; + break; + case 'audio-segmentation': + fields = ['audio', 'label']; + fieldNames = ['audio_path', 'segments']; + break; + case 'audio-detection': + fields = ['audio', 'events']; + fieldNames = ['audio_path', 'events']; + break; case 'image-object-detection': fields = ['image', 'objects']; fieldNames = ['image', 'objects']; @@ -222,6 +234,11 @@ + + + + + @@ -678,7 +695,7 @@

Dataset Vi