Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
## 🧱 Contributing
We welcome contributions to RingTool! If you have suggestions, bug reports, or feature requests, please open an issue or submit a pull request.

Some example contributions include:
* Adding new algorithms or models.
* Improving documentation or examples.
* Enhancing performance or usability.
* Fixing bugs or issues.
* Adding new datasets or benchmarks.
* Improving the configuration system.

### Add a new supervised model
1. Create a new file in the [`nets`](nets) directory, e.g., `nets/new_model.py`.
2. Append the new model registration to the [`constants/model.py`](constants/model.py) like below:

```python
from enum import Enum

from nets.inception_time import InceptionTime
from nets.mamba2 import RingToolMamba
from nets.resnet import ResNet1D
from nets.transformer import RingToolBERT

# Import your new model
from nets.new_model import NewModel


class SupportedSupervisedModels(Enum):
RESNET = "resnet"
INCEPTION_TIME = "inception_time"
TRANSFORMER = "transformer"
MAMBA2 = "mamba2"
NEW_MODEL = "new_model" # Add your new model here


MODEL_CLASSES = {
SupportedSupervisedModels.RESNET: ResNet1D,
SupportedSupervisedModels.INCEPTION_TIME: InceptionTime,
SupportedSupervisedModels.TRANSFORMER: RingToolBERT,
SupportedSupervisedModels.MAMBA2: RingToolMamba,
SupportedSupervisedModels.NEW_MODEL: NewModel, # Add your new model here
}
```
3. Add logic to the [`main.py`](main.py) to use the model in the following training and evaluation process.
4. Add the model to the configuration files in the [`config`](config) directory. You can refer to the existing models for examples.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,56 @@ python3 main.py --data-path <replace-with-your-data-path> --config config/superv
If you want to integrate Slack bot notifications, you can add the `--send-notification-slack` argument to the command. This will send notifications to a specified Slack channel when the training ends. See Also [How to slack training bot](notifications/README.md).



#### Physiological Range Filtering
The training pipeline includes physiological range filtering to improve model quality by removing samples with physiologically impossible values. Filtering is applied during training and validation phases, but not during testing to evaluate true model performance.

**Physiological ranges defined:**
- Heart rate (HR): 40-200 bpm
- Respiratory rate (RR): 4-30 breaths/min
- Blood oxygen saturation (SpO2): 75-100%
- Systolic blood pressure (SBP): 60-260 mmHg
- Diastolic blood pressure (DBP): 30-200 mmHg

The filtering is automatically applied in the standard training pipeline. Custom ranges can be specified via the `physiological_filter()` function in [`utils/utils.py`](utils/utils.py).

#### Detailed Prediction Pairs
To save detailed prediction pairs during training for subsequent analysis, use the `--save-predictions` flag:
```sh
python3 main.py --data-path <path> --config <config.json> --save-predictions
```

This generates CSV files in `predictions/<exp_name>/<fold>.csv` with the following format:
```csv
prediction,target,task,fold,exp_name
75.68,78.0,hr,Fold-1,resnet-ring1-hr-all-ir
82.31,80.0,hr,Fold-1,resnet-ring1-hr-all-ir
```

#### Batch Testing with Metadata
For batch testing all trained models and generating predictions with complete metadata (subject ID, scenario, timestamps), use:
```sh
python3 test_all_models_predictions.py
```

This script:
- Automatically finds all trained models in the `models/` directory
- Tests each model across all folds
- Generates detailed predictions with metadata:
```csv
prediction,target,subject_id,scenario,start_time,end_time,task,exp_name
76.68,101.73,00023,sitting,1742822979.47,1742823009.47,hr,resnet-ring1-hr-all-ir
```
- Saves results to `predictions/<exp_name>/<fold>.csv`

**Testing specific models:**
```sh
python3 test_all_models_predictions.py --models resnet-ring1-hr-all-ir inception-time-ring1-hr-all-irred
```

This is particularly useful for paper reproduction and detailed performance analysis across subjects and scenarios.


## 🧱 Contributing
We welcome contributions to RingTool! If you have suggestions, bug reports, or feature requests, please open an issue or submit a pull request.

Expand Down
2 changes: 1 addition & 1 deletion dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def welch_spectrum(x, fs, window='hann', nperseg=None, noverlap=None, nfft=None,

# Use appropriate nfft
if nfft is None:
nfft = max(256, 2 ** int(np.log2(len(x))))
nfft = np.maximum(256, 2 ** int(np.log2(len(x))))

try:
f, Pxx = welch(x, fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft)
Expand Down
20 changes: 19 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
setup_slack,
)
from trainer.load_trainer import load_trainer
from utils.utils import calculate_avg_metrics, save_metrics_to_csv
from utils.utils import calculate_avg_metrics, save_metrics_to_csv, save_prediction_pairs_detailed


def generate_split_config(mode: str, split: Dict) -> List[Dict]:
Expand Down Expand Up @@ -242,6 +242,21 @@ def supervised(config: Dict, data_path: str) -> List[Tuple[str, str, Dict]]:
all_preds_and_targets.append(preds_and_targets)

all_test_results.append((split_config["fold"], task, test_results))

# Save prediction pairs when --save-predictions flag is enabled
if config.get('_save_predictions_', False):
exp_name = config.get("exp_name", "unknown")
csv_path = os.path.join("predictions", exp_name, f"{split_config['fold']}.csv")
preds, targets = preds_and_targets
save_prediction_pairs_detailed(
preds=preds,
targets=targets,
save_path=csv_path,
metadata=None, # main.py doesn't collect metadata
task=task,
fold=split_config["fold"],
exp_name=exp_name
)

metrics = calculate_avg_metrics(all_preds_and_targets)
logging.critical(f"Average metrics across all tasks: "
Expand Down Expand Up @@ -339,6 +354,7 @@ def do_run_experiment(config: Dict, data_path: str, send_notification_slack=Fals
parser = argparse.ArgumentParser(description='RingTool.')
parser.add_argument('--data-path', type=str, default=None, help='Path to the data folder.')
parser.add_argument('--send-notification-slack', action="store_true", help='Send notification to slack.')
parser.add_argument('--save-predictions', action="store_true", help='Save detailed prediction pairs to predictions directory.')

# --- Group for mutually exclusive config options ---
group = parser.add_mutually_exclusive_group(required=False) # Make the group itself not strictly required initially
Expand All @@ -358,6 +374,7 @@ def do_run_experiment(config: Dict, data_path: str, send_notification_slack=Fals
batch_configs_dirs = args.batch_configs_dirs # This will be a list of paths or None
single_config_path = args.config
send_notification_slack = args.send_notification_slack
save_predictions = args.save_predictions

config_files_to_run = []

Expand Down Expand Up @@ -414,6 +431,7 @@ def do_run_experiment(config: Dict, data_path: str, send_notification_slack=Fals

# Add config path to config dict for potential logging inside do_run_experiment
config['_config_path_'] = config_file_path
config['_save_predictions_'] = save_predictions # Pass flag to experiment

do_run_experiment(config, data_path, send_notification_slack)

Expand Down
Loading