From 2e57d0e22a743af8f1ab6e79df7df104299ae952 Mon Sep 17 00:00:00 2001 From: Adam Page Date: Thu, 10 Oct 2024 21:46:33 +0000 Subject: [PATCH 1/3] feat: Allow providing custom label weights. feat: Filter out class labels below threshold when stratified. --- heartkit/datasets/lsad.py | 23 +++++++++++++++++++++++ heartkit/datasets/ptbxl.py | 22 ++++++++++++++++++++++ heartkit/defines.py | 2 +- heartkit/tasks/diagnostic/train.py | 6 +++++- heartkit/tasks/rhythm/train.py | 6 +++++- heartkit/tasks/segmentation/train.py | 6 +++++- 6 files changed, 61 insertions(+), 4 deletions(-) diff --git a/heartkit/datasets/lsad.py b/heartkit/datasets/lsad.py index 67518cd..e91e83e 100644 --- a/heartkit/datasets/lsad.py +++ b/heartkit/datasets/lsad.py @@ -422,6 +422,7 @@ def split_train_test_patients( test_size: float, label_map: dict[int, int] | None = None, label_type: str | None = None, + label_threshold: int | None = 2, ) -> list[list[int]]: """Perform train/test split on patients for given task. NOTE: We only perform inter-patient splits and not intra-patient. @@ -431,6 +432,7 @@ def split_train_test_patients( test_size (float): Test size label_map (dict[int, int], optional): Label map. Defaults to None. label_type (str, optional): Label type. Defaults to None. + label_threshold (int, optional): Label threshold. Defaults to 2. Returns: list[list[int]]: Train and test sets of patient ids @@ -440,16 +442,37 @@ def split_train_test_patients( patients_labels = self.get_patients_labels(patient_ids, label_map=label_map, label_type=label_type) # Select random label for stratification or -1 if no labels stratify = np.array([random.choice(x) if len(x) > 0 else -1 for x in patients_labels]) + + # Remove patients w/ label counts below threshold + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + if label_threshold is not None and class_counts < label_threshold: + stratify[stratify == label] = -1 + logger.warning(f"Removed class {label} w/ only {class_counts} samples") + # END IF + # END FOR + # Remove patients w/o labels neg_mask = stratify == -1 stratify = stratify[~neg_mask] patient_ids = patient_ids[~neg_mask] + num_neg = neg_mask.sum() if num_neg > 0: logger.debug(f"Removed {num_neg} patients w/ no target class") # END IF # END IF + # Get occurence of each class along with class index + if stratify is not None: + class_counts = np.zeros(len(label_map), dtype=np.int32) + logger.debug(f"[{self.name}] Stratify class counts:") + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + logger.debug(f"Class {label}: {class_counts}") + # END FOR + # END IF + return sklearn.model_selection.train_test_split( patient_ids, test_size=test_size, diff --git a/heartkit/datasets/ptbxl.py b/heartkit/datasets/ptbxl.py index 68d6764..50e9079 100644 --- a/heartkit/datasets/ptbxl.py +++ b/heartkit/datasets/ptbxl.py @@ -506,6 +506,7 @@ def split_train_test_patients( test_size: float, label_map: dict[int, int] | None = None, label_type: str | None = None, + label_threshold: int | None = 2, ) -> list[list[int]]: """Perform train/test split on patients for given task. NOTE: We only perform inter-patient splits and not intra-patient. @@ -515,6 +516,7 @@ def split_train_test_patients( test_size (float): Test size label_map (dict[int, int], optional): Label map. Defaults to None. label_type (str, optional): Label type. Defaults to None. + label_threshold (int, optional): Label threshold. Defaults to 2. Returns: list[list[int]]: Train and test sets of patient ids @@ -524,6 +526,16 @@ def split_train_test_patients( patients_labels = self.get_patients_labels(patient_ids, label_map=label_map, label_type=label_type) # Select random label for stratification or -1 if no labels stratify = np.array([random.choice(x) if len(x) > 0 else -1 for x in patients_labels]) + + # Remove patients w/ label counts below threshold + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + if label_threshold is not None and class_counts < label_threshold: + stratify[stratify == label] = -1 + logger.warning(f"Removed class {label} w/ only {class_counts} samples") + # END IF + # END FOR + # Remove patients w/o labels neg_mask = stratify == -1 stratify = stratify[~neg_mask] @@ -534,6 +546,16 @@ def split_train_test_patients( # END IF # END IF + # Get occurence of each class along with class index + if stratify is not None: + class_counts = np.zeros(len(label_map), dtype=np.int32) + logger.debug(f"[{self.name}] Stratify class counts:") + for i, label in enumerate(sorted(set(label_map.values()))): + class_counts = np.sum(stratify == label) + logger.debug(f"Class {label}: {class_counts}") + # END FOR + # END IF + return sklearn.model_selection.train_test_split( patient_ids, test_size=test_size, diff --git a/heartkit/defines.py b/heartkit/defines.py index e3fde95..ce0db77 100644 --- a/heartkit/defines.py +++ b/heartkit/defines.py @@ -118,7 +118,7 @@ class HKTaskParams(BaseModel, extra="allow"): steps_per_epoch: int = Field(10, description="Number of steps per epoch") val_steps_per_epoch: int = Field(10, description="Number of validation steps") val_metric: Literal["loss", "acc", "f1"] = Field("loss", description="Performance metric") - class_weights: Literal["balanced", "fixed"] = Field("fixed", description="Class weights") + class_weights: list[float] | str = Field("fixed", description="Class weights") # Evaluation arguments threshold: float | None = Field(None, description="Model output threshold") diff --git a/heartkit/tasks/diagnostic/train.py b/heartkit/tasks/diagnostic/train.py index 28f9c0b..6e3f456 100644 --- a/heartkit/tasks/diagnostic/train.py +++ b/heartkit/tasks/diagnostic/train.py @@ -58,7 +58,11 @@ def train(params: HKTaskParams): val_ds.save(str(params.val_file)) class_weights = 0.25 - if params.class_weights == "balanced": + if isinstance(params.class_weights, list): + class_weights = np.array(params.class_weights) + class_weights = class_weights / class_weights.sum() + class_weights = class_weights.tolist() + elif params.class_weights == "balanced": n_samples = np.sum(y_true) class_weights = n_samples / (params.num_classes * np.sum(y_true, axis=0)) # class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out diff --git a/heartkit/tasks/rhythm/train.py b/heartkit/tasks/rhythm/train.py index a63d0bf..35489ba 100644 --- a/heartkit/tasks/rhythm/train.py +++ b/heartkit/tasks/rhythm/train.py @@ -59,7 +59,11 @@ def train(params: HKTaskParams): val_ds.save(str(params.val_file)) class_weights = 0.25 - if params.class_weights == "balanced": + if isinstance(params.class_weights, list): + class_weights = np.array(params.class_weights) + class_weights = class_weights / class_weights.sum() + class_weights = class_weights.tolist() + elif params.class_weights == "balanced": class_weights = sklearn.utils.compute_class_weight("balanced", classes=np.array(classes), y=y_true) class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out class_weights = class_weights.tolist() diff --git a/heartkit/tasks/segmentation/train.py b/heartkit/tasks/segmentation/train.py index b3c50ae..ea69759 100644 --- a/heartkit/tasks/segmentation/train.py +++ b/heartkit/tasks/segmentation/train.py @@ -61,7 +61,11 @@ def train(params: HKTaskParams): val_ds.save(str(params.val_file)) class_weights = 0.25 - if params.class_weights == "balanced": + if isinstance(params.class_weights, list): + class_weights = np.array(params.class_weights) + class_weights = class_weights / class_weights.sum() + class_weights = class_weights.tolist() + elif params.class_weights == "balanced": class_weights = sklearn.utils.compute_class_weight("balanced", classes=np.array(classes), y=y_true) class_weights = (class_weights + class_weights.mean()) / 2 # Smooth out class_weights = class_weights.tolist() From eb293e234913aabc49fbe23d71322c1f73e2131f Mon Sep 17 00:00:00 2001 From: Adam Page Date: Fri, 22 Nov 2024 22:57:24 +0000 Subject: [PATCH 2/3] feat: Add vscode extensions to dev container. --- .devcontainer/devcontainer.json | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 17b6925..c519ffe 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -29,5 +29,22 @@ "LD_LIBRARY_PATH": "${containerEnv:LD_LIBRARY_PATH}:/usr/local/cuda/lib64", "PATH": "${containerEnv:PATH}:/usr/local/cuda/bin", "TF_FORCE_GPU_ALLOW_GROWTH": "true" + }, + + "customizations": { + "vscode": { + "extensions": [ + "GitHub.copilot", + "GitHub.copilot-chat", + "ms-toolsai.jupyter", + "ms-toolsai.jupyter-renderers", + "ms-toolsai.tensorboard", + "tamasfe.even-better-toml", + "mechatroner.rainbow-csv", + "ms-python.python", + "charliermarsh.ruff" + ] + } } + } From c7bf15b52237411d174f0ea0f24b2bc211b1449f Mon Sep 17 00:00:00 2001 From: Adam Page Date: Tue, 28 Jan 2025 15:29:47 +0000 Subject: [PATCH 3/3] docs: Update hardware demo links. --- docs/guides/index.md | 3 +-- docs/guides/train-arrhythmia-model.ipynb | 3 +-- heartkit/tasks/denoise/datasets.py | 2 ++ mkdocs.yml | 5 +---- notebooks/train-arrhythmia-model.ipynb | 3 +-- 5 files changed, 6 insertions(+), 10 deletions(-) diff --git a/docs/guides/index.md b/docs/guides/index.md index 44ba7ed..4cf36c3 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -16,5 +16,4 @@ This section contains guides to help with various aspects of HeartKit. The guide ## Hardware Guides -- **[Run simple demo on EVB]()**: Running a demo using Ambiq SoC as backend inference engine. -- **[HeartKit Tileio Demo](https://ambiqai.github.io/tileio-docs/demos/heartkit/)**: A guide to running a multi-headed model demo on Ambiq EVB. +Several guides are available for running HeartKit models on Ambiq evaluation boards (EVBs) with ECG/PPG sensors connected via the Tileio App. Please check out the [Tileio Demos Page](https://ambiqai.github.io/tileio-docs/demos/) for more information. diff --git a/docs/guides/train-arrhythmia-model.ipynb b/docs/guides/train-arrhythmia-model.ipynb index fbaef42..4af5d84 100644 --- a/docs/guides/train-arrhythmia-model.ipynb +++ b/docs/guides/train-arrhythmia-model.ipynb @@ -174,7 +174,6 @@ "## Preprocess pipeline\n", "\n", "We will preprocess the ECG signals by applying the following steps:\n", - "* Apply bandpass filter with cutoff frequencies of 1Hz and 30Hz\n", "* Apply Z-score normalization w/ epsilon to avoid division by zero\n", "\n", "The task accepts a list of preprocessing functions that will be applied to the input data. " @@ -917,7 +916,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/heartkit/tasks/denoise/datasets.py b/heartkit/tasks/denoise/datasets.py index faa4b8f..dd1cd66 100644 --- a/heartkit/tasks/denoise/datasets.py +++ b/heartkit/tasks/denoise/datasets.py @@ -43,6 +43,8 @@ def create_data_pipeline( drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, ) + # ds = ds.map(lambda x: preprocessor(x), num_parallel_calls=tf.data.AUTOTUNE) + # ds = ds.map(lambda x: (augmenter(x), x), num_parallel_calls=tf.data.AUTOTUNE) ds = ds.map(lambda x: (augmenter(x), x), num_parallel_calls=tf.data.AUTOTUNE) ds = ds.map(lambda x, y: (preprocessor(x), preprocessor(y)), num_parallel_calls=tf.data.AUTOTUNE) return ds.prefetch(tf.data.AUTOTUNE) diff --git a/mkdocs.yml b/mkdocs.yml index 598641e..71943ce 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -98,10 +98,7 @@ nav: - Train ECG Denoiser: guides/train-ecg-denoiser.ipynb - ECG Foundation Model: guides/ecg-foundation-model.ipynb - ECG Segmentation Model: guides/train-ecg-segmentation.ipynb - - Hardware Guides: - - EVB Setup: guides/evb-setup.md - - Rhythm Demo: guides/rhythm-demo.md - - HeartKit Tileio Demo →: https://ambiqai.github.io/tileio-docs/demos/heartkit/ + - Hardware Guides →: https://ambiqai.github.io/tileio-docs/demos/ - API: api/ diff --git a/notebooks/train-arrhythmia-model.ipynb b/notebooks/train-arrhythmia-model.ipynb index fbaef42..4af5d84 100644 --- a/notebooks/train-arrhythmia-model.ipynb +++ b/notebooks/train-arrhythmia-model.ipynb @@ -174,7 +174,6 @@ "## Preprocess pipeline\n", "\n", "We will preprocess the ECG signals by applying the following steps:\n", - "* Apply bandpass filter with cutoff frequencies of 1Hz and 30Hz\n", "* Apply Z-score normalization w/ epsilon to avoid division by zero\n", "\n", "The task accepts a list of preprocessing functions that will be applied to the input data. " @@ -917,7 +916,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.11.10" } }, "nbformat": 4,