diff --git a/examples/benchmarking_tutorial.ipynb b/examples/benchmarking_tutorial.ipynb new file mode 100644 index 0000000..67dc521 --- /dev/null +++ b/examples/benchmarking_tutorial.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "50e8d179", + "metadata": {}, + "source": [ + "# Benchmarking aptamer evaluation algorithms\n", + "Step-by-step guide for using `AptaNet` for benchmarking." + ] + }, + { + "cell_type": "markdown", + "id": "40a64683", + "metadata": {}, + "source": [ + "## Overview\n", + "This notebook introduces the Benchmarking class, a utility for systematically comparing machine learning estimators on a given dataset using cross-validation. It is designed to streamline model evaluation across multiple metrics and provide results in a unified, interpretable format.\n", + "\n", + "The output is a summary table that makes it easy to compare different models and metrics at a glance." + ] + }, + { + "cell_type": "markdown", + "id": "fef8f486", + "metadata": {}, + "source": [ + "## Data preparation\n", + "To train the `AptaNetPipeline` and `AptaTransPipeline` the notebook uses the dataset used to train the `AptaTrans` algorithm, this dataset can be found in `pyaptamer/datasets/data/train_li2014`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a2f72544", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "# Data imports\n", + "import numpy as np\n", + "\n", + "from pyaptamer.datasets import load_csv_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "27e4d060", + "metadata": {}, + "outputs": [], + "source": [ + "# Load full dataset\n", + "df = load_csv_dataset(\"train_li2014\")\n", + "\n", + "# Separate features and label column\n", + "X_raw = df.drop(columns=[\"label\"])\n", + "y_raw = df[\"label\"]\n", + "\n", + "# Build combinations (aptamer, protein)\n", + "# assuming the first two columns are aptamer and protein\n", + "X = list(zip(X_raw.iloc[:, 0], X_raw.iloc[:, 1], strict=False))[:100]\n", + "\n", + "# Binary labels\n", + "y = np.where(y_raw == \"positive\", 1, 0)[:100]" + ] + }, + { + "cell_type": "markdown", + "id": "113d9522", + "metadata": {}, + "source": [ + "## Different workflows\n", + "Benchmarking offers 2 main workflows, both depending on how you want to use `cv` (cross validation) in your benchmarking experiment:\n", + "1. Using normal k-fold cross-validation\n", + "2. Using `PredefinedSplit` to create a fixed train/test split" + ] + }, + { + "cell_type": "markdown", + "id": "18b3abd2", + "metadata": {}, + "source": [ + "### 1. Using normal k-fold cross validation for benchmarking" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9b8646b1", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score\n", + "from sklearn.model_selection import KFold\n", + "\n", + "from pyaptamer.aptanet import AptaNetPipeline\n", + "from pyaptamer.benchmarking import Benchmarking" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "adaccf95", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from pyaptamer.aptatrans import (\n", + " AptaTrans,\n", + " AptaTransPipeline,\n", + " EncoderPredictorConfig,\n", + ")\n", + "from pyaptamer.datasets import (\n", + " load_csv_dataset,\n", + ")\n", + "from pyaptamer.utils._base import filter_words\n", + "\n", + "# setup device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# auto-reloading external modules\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7235ca71", + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 16\n", + "TEST_SIZE = 0.05 # size of the test set for pretraining\n", + "RAMDOM_STATE = 42 # for reproducibility\n", + "\n", + "# embeddings for pretraining\n", + "# aptamers\n", + "N_APTA_VOCABS = 127\n", + "N_APTA_TARGET_VOCABS = 344\n", + "APTA_MAX_LEN = 275\n", + "# proteins\n", + "N_PROT_VOCABS = 715\n", + "N_PROT_TARGET_VOCABS = 585\n", + "PROT_MAX_LEN = 867" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "534124f9", + "metadata": {}, + "outputs": [], + "source": [ + "prot_words = load_csv_dataset(\n", + " name=\"protein_word_freq\", keep_default_na=False, na_values=[\"_\"]\n", + ") # dict for each protein word in ds gives freq\n", + "prot_words = prot_words.set_index(\"seq\")[\"freq\"].to_dict()\n", + "\n", + "filtered_prot_words = filter_words(prot_words)\n", + "\n", + "# (1.) load the api dataset for fine-tuning\n", + "# train_dataset = load_csv_dataset(name=\"train_li2014\")\n", + "# test_dataset = load_csv_dataset(name=\"test_li2014\")\n", + "\n", + "# # (2.) create the API dataset\n", + "# train_dataset = APIDataset(\n", + "# x_apta=train_dataset[\"aptamer\"].to_numpy(),\n", + "# x_prot=train_dataset[\"protein\"].to_numpy(),\n", + "# y=train_dataset[\"label\"].to_numpy(),\n", + "# apta_max_len=APTA_MAX_LEN,\n", + "# prot_max_len=PROT_MAX_LEN,\n", + "# prot_words=filtered_prot_words,\n", + "# )\n", + "# test_dataset = APIDataset(\n", + "# x_apta=test_dataset[\"aptamer\"].to_numpy(),\n", + "# x_prot=test_dataset[\"protein\"].to_numpy(),\n", + "# y=test_dataset[\"label\"].to_numpy(),\n", + "# apta_max_len=APTA_MAX_LEN,\n", + "# prot_max_len=PROT_MAX_LEN,\n", + "# prot_words=filtered_prot_words,\n", + "# split=\"test\",\n", + "# )\n", + "\n", + "# # (3.) create dataloaders\n", + "# train_dataloader = DataLoader(\n", + "# train_dataset,\n", + "# batch_size=BATCH_SIZE,\n", + "# shuffle=True,\n", + "# )\n", + "# test_dataloader = DataLoader(\n", + "# test_dataset,\n", + "# batch_size=BATCH_SIZE,\n", + "# shuffle=True,\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6440dc3b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\transformer.py:382: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "apta_embedding = EncoderPredictorConfig(\n", + " num_embeddings=N_APTA_VOCABS,\n", + " target_dim=N_APTA_TARGET_VOCABS,\n", + " max_len=APTA_MAX_LEN,\n", + ")\n", + "prot_embedding = EncoderPredictorConfig(\n", + " num_embeddings=N_PROT_VOCABS,\n", + " target_dim=N_PROT_TARGET_VOCABS,\n", + " max_len=PROT_MAX_LEN,\n", + ")\n", + "model = AptaTrans(\n", + " apta_embedding=apta_embedding,\n", + " prot_embedding=prot_embedding,\n", + " in_dim=128,\n", + " n_encoder_layers=1,\n", + " n_heads=1,\n", + " dropout=0.1,\n", + ").to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2a8f88ef", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " ----- Round: 1 -----\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "The size of tensor a (0) must match the size of tensor b (128) at non-singleton dimension 1", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 15\u001b[39m\n\u001b[32m 2\u001b[39m target_protein = (\n\u001b[32m 3\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mSTEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mRDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTR\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 5\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mQGVDDAFYTLVREIRKHKEKMSK\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 6\u001b[39m )\n\u001b[32m 8\u001b[39m pipeline = AptaTransPipeline(\n\u001b[32m 9\u001b[39m device=device,\n\u001b[32m 10\u001b[39m model=model,\n\u001b[32m (...)\u001b[39m\u001b[32m 13\u001b[39m n_iterations=\u001b[32m1\u001b[39m, \u001b[38;5;66;03m# higher is better but slower, suggested: 1000\u001b[39;00m\n\u001b[32m 14\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m15\u001b[39m candidates = \u001b[43mpipeline\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrecommend\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 16\u001b[39m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtarget_protein\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_candidates\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of candidates to generate\u001b[39;49;00m\n\u001b[32m 18\u001b[39m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 19\u001b[39m \u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[39m, in \u001b[36mcontext_decorator..decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\aptatrans\\_pipeline.py:315\u001b[39m, in \u001b[36mAptaTransPipeline.recommend\u001b[39m\u001b[34m(self, target, n_candidates, verbose)\u001b[39m\n\u001b[32m 313\u001b[39m candidates = \u001b[38;5;28mset\u001b[39m()\n\u001b[32m 314\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(candidates) < n_candidates:\n\u001b[32m--> \u001b[39m\u001b[32m315\u001b[39m candidate = \u001b[43mmcts\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m=\u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 316\u001b[39m candidates.add(\u001b[38;5;28mtuple\u001b[39m(candidate.values()))\n\u001b[32m 318\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m verbose:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\mcts\\_algorithm.py:302\u001b[39m, in \u001b[36mMCTS.run\u001b[39m\u001b[34m(self, verbose)\u001b[39m\n\u001b[32m 299\u001b[39m node = \u001b[38;5;28mself\u001b[39m._expansion(node=node)\n\u001b[32m 301\u001b[39m \u001b[38;5;66;03m# simulation\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m302\u001b[39m score = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_simulation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnode\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 304\u001b[39m \u001b[38;5;66;03m# backpropagation\u001b[39;00m\n\u001b[32m 305\u001b[39m node.backpropagate(score)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\mcts\\_algorithm.py:245\u001b[39m, in \u001b[36mMCTS._simulation\u001b[39m\u001b[34m(self, node)\u001b[39m\n\u001b[32m 242\u001b[39m sequence += random.choice(\u001b[38;5;28mself\u001b[39m.states)\n\u001b[32m 244\u001b[39m \u001b[38;5;66;03m# evaluate the candidate sequence with the goal function\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m245\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mexperiment\u001b[49m\u001b[43m.\u001b[49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43msequence\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\utils\\_contextlib.py:116\u001b[39m, in \u001b[36mcontext_decorator..decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 113\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m116\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\experiments\\_aptamer_aptatrans.py:115\u001b[39m, in \u001b[36mAptamerEvalAptaTrans.evaluate\u001b[39m\u001b[34m(self, aptamer_candidate, return_interaction_map)\u001b[39m\n\u001b[32m 104\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[32m 105\u001b[39m \u001b[38;5;28mself\u001b[39m.model.forward_imap(\n\u001b[32m 106\u001b[39m aptamer_candidate.to(\u001b[38;5;28mself\u001b[39m.device),\n\u001b[32m (...)\u001b[39m\u001b[32m 111\u001b[39m .numpy()\n\u001b[32m 112\u001b[39m )\n\u001b[32m 113\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 114\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.float64(\n\u001b[32m--> \u001b[39m\u001b[32m115\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 116\u001b[39m \u001b[43m \u001b[49m\u001b[43maptamer_candidate\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 117\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtarget_encoded\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 118\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m.item()\n\u001b[32m 119\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1749\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1760\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1761\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\aptatrans\\_model.py:376\u001b[39m, in \u001b[36mAptaTrans.forward\u001b[39m\u001b[34m(self, x_apta, x_prot)\u001b[39m\n\u001b[32m 359\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x_apta: Tensor, x_prot: Tensor) -> Tensor:\n\u001b[32m 360\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Forward pass.\u001b[39;00m\n\u001b[32m 361\u001b[39m \n\u001b[32m 362\u001b[39m \u001b[33;03m This methods performs a forward pass through the entire neural network, minus\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 374\u001b[39m \u001b[33;03m Output tensor of shape (batch__size, 1) containing the model's predictions.\u001b[39;00m\n\u001b[32m 375\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m376\u001b[39m out = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mforward_imap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_apta\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_prot\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 378\u001b[39m out = torch.squeeze(out, dim=\u001b[32m2\u001b[39m) \u001b[38;5;66;03m# remove extra dimension\u001b[39;00m\n\u001b[32m 379\u001b[39m out = \u001b[38;5;28mself\u001b[39m.gelu1(\u001b[38;5;28mself\u001b[39m.bn1(\u001b[38;5;28mself\u001b[39m.conv1(out)))\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\aptatrans\\_model.py:356\u001b[39m, in \u001b[36mAptaTrans.forward_imap\u001b[39m\u001b[34m(self, x_apta, x_prot)\u001b[39m\n\u001b[32m 339\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward_imap\u001b[39m(\u001b[38;5;28mself\u001b[39m, x_apta: Tensor, x_prot: Tensor) -> Tensor:\n\u001b[32m 340\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Forward pass to compute the interaction map.\u001b[39;00m\n\u001b[32m 341\u001b[39m \n\u001b[32m 342\u001b[39m \u001b[33;03m This methods performs a forward pass through the encoders, minus the token\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 354\u001b[39m \u001b[33;03m Interaction map tensor of shape (batch_size, 1, seq_len (s1), seq_len (s2)).\u001b[39;00m\n\u001b[32m 355\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m356\u001b[39m x_apta, x_prot = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mencoder_apta\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_apta\u001b[49m\u001b[43m)\u001b[49m, \u001b[38;5;28mself\u001b[39m.encoder_prot(x_prot)\n\u001b[32m 357\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.imap(x_apta, x_prot)\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1749\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1760\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1761\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\aptatrans\\_model.py:225\u001b[39m, in \u001b[36mAptaTrans._make_encoder..\u001b[39m\u001b[34m(x)\u001b[39m\n\u001b[32m 221\u001b[39m encoder_module = nn.ModuleList([embedding, pos_encoding, encoder])\n\u001b[32m 223\u001b[39m \u001b[38;5;66;03m# pass the the encoder a padding mask to ignore zero-padded tokens\u001b[39;00m\n\u001b[32m 224\u001b[39m encoder_module.forward = \u001b[38;5;28;01mlambda\u001b[39;00m x: encoder(\n\u001b[32m--> \u001b[39m\u001b[32m225\u001b[39m \u001b[43mpos_encoding\u001b[49m\u001b[43m(\u001b[49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[32m 226\u001b[39m src_key_padding_mask=(x == \u001b[32m0\u001b[39m), \u001b[38;5;66;03m# padding mask\u001b[39;00m\n\u001b[32m 227\u001b[39m )\n\u001b[32m 229\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (encoder_module, token_predictor)\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1749\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1760\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1761\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~\\pyaptamer\\pyaptamer\\aptatrans\\layers\\_encoder.py:85\u001b[39m, in \u001b[36mPositionalEncoding.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 68\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"Forward pass.\u001b[39;00m\n\u001b[32m 69\u001b[39m \n\u001b[32m 70\u001b[39m \u001b[33;03mParameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 79\u001b[39m \u001b[33;03m positional encodings applied.\u001b[39;00m\n\u001b[32m 80\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 81\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m x.shape[\u001b[32m1\u001b[39m] <= \u001b[38;5;28mself\u001b[39m.max_len, (\n\u001b[32m 82\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mInput sequence length \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mx.shape[\u001b[32m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m exceeds maximum length \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.max_len\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 83\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m85\u001b[39m out = \u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpe\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\n\u001b[32m 86\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.dropout:\n\u001b[32m 87\u001b[39m out = \u001b[38;5;28mself\u001b[39m.dropout(out)\n", + "\u001b[31mRuntimeError\u001b[39m: The size of tensor a (0) must match the size of tensor b (128) at non-singleton dimension 1" + ] + } + ], + "source": [ + "# specify the target protein sequence here\n", + "target_protein = (\n", + " \"STEYKLVVVGADGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAM\"\n", + " \"RDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTR\"\n", + " \"QGVDDAFYTLVREIRKHKEKMSK\"\n", + ")\n", + "\n", + "pipeline = AptaTransPipeline(\n", + " device=device,\n", + " model=model,\n", + " prot_words=prot_words,\n", + " depth=1, # depth of the search (i.e., length of generated candidates)\n", + " n_iterations=1, # higher is better but slower, suggested: 1000\n", + ")\n", + "candidates = pipeline.recommend(\n", + " target=target_protein,\n", + " n_candidates=1, # number of candidates to generate\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840249ec", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\sklearn\\utils\\validation.py:1406: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\sklearn\\utils\\validation.py:1406: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\sklearn\\utils\\validation.py:1406: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\sklearn\\utils\\validation.py:1406: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\sklearn\\utils\\validation.py:1406: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " train test\n", + "estimator metric \n", + "AptaNetPipeline accuracy_score 1.0 1.0\n" + ] + } + ], + "source": [ + "# Example estimator\n", + "aptanet_estimator = AptaNetPipeline(k=4)\n", + "aptatrans_estimator = AptaTransPipeline(\n", + " device=device,\n", + " model=model,\n", + " prot_words=prot_words,\n", + " depth=1,\n", + " n_iterations=1,\n", + ")\n", + "\n", + "# Define a 5-fold CV strategy\n", + "cv = KFold(n_splits=5, shuffle=True, random_state=42)\n", + "\n", + "# Run benchmarking with CV\n", + "bench = Benchmarking(\n", + " estimators=[aptanet_estimator, aptatrans_estimator],\n", + " metrics=[accuracy_score],\n", + " X=X,\n", + " y=y,\n", + " cv=cv,\n", + ")\n", + "results_cv = bench.run()\n", + "print(results_cv)" + ] + }, + { + "cell_type": "markdown", + "id": "87923ac7", + "metadata": {}, + "source": [ + "### 2. Using PredefinedSplit for benchmarking with a fixed train/test split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e504488", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n", + "c:\\Users\\satvm\\miniconda3\\envs\\pyaptamer-latest\\Lib\\site-packages\\sklearn\\utils\\validation.py:1406: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", + "c:\\Users\\satvm\\pyaptamer\\pyaptamer\\pseaac\\_features.py:199: UserWarning: Invalid amino acid(s) found in sequence. Replaced with 'N'.\n", + " seq = clean_protein_seq(protein_sequence)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " train test\n", + "estimator metric \n", + "AptaNetPipeline accuracy_score 1.0 1.0\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import PredefinedSplit\n", + "\n", + "# Define a custom train/test split\n", + "# Here, last 10 samples are used as test set\n", + "test_fold = np.ones(len(y)) * -1\n", + "test_fold[-10:] = 0\n", + "cv = PredefinedSplit(test_fold)\n", + "\n", + "# Run benchmarking with fixed split\n", + "bench_fixed = Benchmarking(\n", + " estimators=[aptanet_estimator, aptatrans_estimator],\n", + " metrics=[accuracy_score],\n", + " X=X,\n", + " y=y,\n", + " cv=cv,\n", + ")\n", + "results_fixed = bench_fixed.run()\n", + "print(results_fixed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f060c7a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyaptamer-latest", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyaptamer/aptatrans/_pipeline.py b/pyaptamer/aptatrans/_pipeline.py index a4e88ac..4469c6f 100644 --- a/pyaptamer/aptatrans/_pipeline.py +++ b/pyaptamer/aptatrans/_pipeline.py @@ -3,22 +3,25 @@ candidate aptamers recommendation. """ -__author__ = ["nennomp"] +__author__ = ["nennomp", "satvshr"] __all__ = ["AptaTransPipeline"] +import lightning as L +import numpy as np +import pandas as pd import torch +from skbase.base import BaseObject from torch import Tensor +from torch.utils.data import DataLoader -from pyaptamer.aptatrans import AptaTrans +from pyaptamer.aptatrans import AptaTrans, AptaTransLightning +from pyaptamer.datasets.dataclasses import APIDataset from pyaptamer.experiments import AptamerEvalAptaTrans from pyaptamer.mcts import MCTS -from pyaptamer.utils import ( - generate_nplets, -) from pyaptamer.utils._base import filter_words -class AptaTransPipeline: +class AptaTransPipeline(BaseObject): """AptaTrans pipeline for aptamer affinity prediction, by Shin et al. Algorithm as originally described in Shin et al [1]_. @@ -37,6 +40,7 @@ class AptaTransPipeline: The device on which to run the model. model : AptaTrans An instance of the AptaTrans() class. + # TODO: ask if apta_max_len and prot_max_len depend on it prot_words : dict[str, float] A dictionary mapping protein n-mer protein subsequences to a unique integer ID. Used to encode protein sequences into their numerical representions. The @@ -46,6 +50,18 @@ class AptaTransPipeline: The depth of the tree in the Monte Carlo Tree Search (MCTS) algorithm. n_iterations : int, optional, default=1000 The number of iterations for the MCTS algorithm. + batch_size : int, default=16 + Batch size for training and inference. + apta_max_len : int, default=275 + Maximum aptamer sequence length. + prot_max_len : int, default=867 + Maximum protein sequence length. + learning_rate : float, default=1e-5 + Learning rate for training the AptaTrans model. + weight_decay : float, default=1e-5 + Weight decay for training the AptaTrans model. + max_epochs : int, default=50 + Maximum number of epochs for training the AptaTrans model. Attributes ---------- @@ -57,11 +73,11 @@ class AptaTransPipeline: References ---------- .. [1] Shin, Incheol, et al. "AptaTrans: a deep neural network for predicting - aptamer-protein interaction using pretrained encoders." BMC bioinformatics 24.1 - (2023): 447. + aptamer-protein interaction using pretrained encoders." BMC bioinformatics 24.1 + (2023): 447. .. [2] Lee, Gwangho, et al. "Predicting aptamer sequences that interact with target - proteins using an aptamer-protein interaction classifier and a Monte Carlo tree - search approach." PloS one 16.6 (2021): e0253760. + proteins using an aptamer-protein interaction classifier and a Monte Carlo tree + search approach." PloS one 16.6 (2021): e0253760. Examples -------- @@ -90,54 +106,146 @@ def __init__( prot_words: dict[str, float], depth: int = 20, n_iterations: int = 1000, + batch_size: int = 16, + apta_max_len: int = 275, + prot_max_len: int = 867, + learning_rate: float = 1e-5, + weight_decay: float = 1e-5, + max_epochs: int = 50, ) -> None: super().__init__() self.device = device self.model = model.to(device) self.depth = depth + self.prot_words = prot_words self.n_iterations = n_iterations + self.batch_size = batch_size + self.apta_max_len = apta_max_len + self.prot_max_len = prot_max_len + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.max_epochs = max_epochs - self.apta_words, self.prot_words = self._init_words(prot_words) + self.prot_words_ = None - def _init_words( - self, - prot_words: dict[str, float], - ) -> tuple[dict[str, int], dict[str, int]]: - """Initialize aptamer and protein word vocabularies. + def _init_vocabularies(self): + """Initialize filtered protein words vocabulary.""" + if self.prot_words_ is None: + self.prot_words_ = filter_words(self.prot_words) + + def _init_aptamer_experiment(self, target: str) -> AptamerEvalAptaTrans: + """Initialize the aptamer recommendation experiment.""" + self._init_vocabularies() + experiment = AptamerEvalAptaTrans( + target=target, + model=self.model, + device=self.device, + prot_words=self.prot_words_, + ) + return experiment + + def _preprocess_data(self, dataset) -> DataLoader: + """Convert numpy arrays or pandas DataFrames into a PyTorch DataLoader.""" + self._init_vocabularies() + + dataset = APIDataset( + x_apta=dataset["aptamer"].to_numpy(), + x_prot=dataset["protein"].to_numpy(), + y=dataset["label"].to_numpy(), + apta_max_len=self.apta_max_len, + prot_max_len=self.prot_max_len, + prot_words=self.prot_words_, + ) + + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + return dataloader - For aptamers, creates a mapping between all possible 3-mer RNA subsequences and - integer indices. For proteins, load protein words mapped to their frequency, - filter out those with below-average frequency, and assign unique integer IDs. + def fit(self, X, y): + """Train the AptaTrans model using PyTorch Lightning. + + This method preprocesses the training dataset, converts it into a PyTorch + DataLoader, and trains the AptaTrans model using the AptaTransLightning + wrapper with the PyTorch Lightning Trainer. Parameters ---------- - prot_words : dict[str, float] - A dictionary containing protein 3-mer subsequences and their frequencies. + X : pandas.DataFrame or numpy.ndarray + Traning data that must include the following columns: + + - ``aptamer`` : str + The aptamer nucleotide sequences. + - ``protein`` : str + The target protein sequences. + + y : array-like + The ground truth interaction scores or binary labels. Returns ------- - tuple[dict[str, int], dict[str, int]] - A tuple of dictionaries mapping aptamer 3-mer subsequences to unique - indices and protein words to their frequencies, respectively. + AptaTransPipeline + The fitted AptaTransPipeline instance with an updated and trained model. """ - # generate all possible RNA triplets (5^3 -> 125 total) - apta_words = generate_nplets(letters=["A", "C", "G", "U", "N"], repeat=3) + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X, columns=["aptamer", "protein"]) - # filter out protein words with below average frequency and assign unique - # integer IDs - prot_words = filter_words(prot_words) + X["label"] = y - return (apta_words, prot_words) + dataloader = self._preprocess_data(X) - def _init_aptamer_experiment(self, target: str) -> AptamerEvalAptaTrans: - """Initialize the aptamer recommendation experiment.""" - experiment = AptamerEvalAptaTrans( - target=target, + model_lightning = AptaTransLightning( model=self.model, - device=self.device, - prot_words=self.prot_words, + lr=self.learning_rate, + weight_decay=self.weight_decay, ) - return experiment + + trainer = L.Trainer( + max_epochs=self.max_epochs, + log_every_n_steps=10, + ) + + trainer.fit(model_lightning, dataloader) + + # Keep trained model + self.model = model_lightning.model.to(self.device) + + return self + + def predict(self, X): + """Predict aptamer-protein interaction (API) score for aptamer–protein pairs. + + This method initializes a new aptamer experiment for each aptamer–protein pair + and predicts their interaction score using the AptaTrans deep neural network. + + Parameters + ---------- + X : pandas.DataFrame or numpy.ndarray, shape (n_samples, 2) + Input data containing aptamer–protein pairs. Must include: + + - ``aptamer`` : str + - ``protein`` : str + + Returns + ------- + np.ndarray, shape (n_samples,) + Predicted interaction scores. + """ + # Convert to list of pairs if given as numpy array or DataFrame + if isinstance(X, pd.DataFrame): + aptamers = X["aptamer"].to_numpy() + proteins = X["protein"].to_numpy() + elif isinstance(X, np.ndarray): + aptamers = X[:, 0] + proteins = X[:, 1] + else: + raise TypeError("X must be a pandas DataFrame or numpy.ndarray.") + + scores = [] + for aptamer, protein in zip(aptamers, proteins, strict=False): + experiment = self._init_aptamer_experiment(protein) + score = experiment.evaluate(aptamer) + scores.append(score) + + return np.array(scores) def get_interaction_map(self, candidate: str, target: str) -> Tensor: # TODO: to make the interaction map ready for plotting (at least if we were to @@ -163,28 +271,6 @@ def get_interaction_map(self, candidate: str, target: str) -> Tensor: experiment = self._init_aptamer_experiment(target) return experiment.evaluate(candidate, return_interaction_map=True) - def predict(self, candidate: str, target: str) -> Tensor: - """Predict aptamer-protein interaction (API) score for a given target protein. - - This methods initializes a new aptamer experiment for the given aptamer - candidate and target protein. Finally, it predict the interaction score using - the AptaTrans' deep neural network. - - Parameters - ---------- - candidate : str - The candidate aptamer sequence. - target : str - The target protein sequence. - - Returns - ------- - Tensor - A tensor containing the predicted interaction score. - """ - experiment = self._init_aptamer_experiment(target) - return experiment.evaluate(candidate) - @torch.no_grad() def recommend( self, diff --git a/pyaptamer/aptatrans/tests/test_aptatrans.py b/pyaptamer/aptatrans/tests/test_aptatrans.py index 73cf6e9..5b04372 100644 --- a/pyaptamer/aptatrans/tests/test_aptatrans.py +++ b/pyaptamer/aptatrans/tests/test_aptatrans.py @@ -219,16 +219,13 @@ def test_initialization(self, device, prot_words): assert pipeline.model.device.type == device.type # check word dictionaries - assert isinstance(pipeline.apta_words, dict) assert isinstance(pipeline.prot_words, dict) - # check aptamer words contain all possible triplets (should be 125) - assert len(pipeline.apta_words) == 125 - # check protein words filtering (only above average frequency) mean_freq = sum(prot_words.values()) / len(prot_words) expected_prot_count = sum(1 for freq in prot_words.values() if freq > mean_freq) - assert len(pipeline.prot_words) == expected_prot_count + pipeline._init_vocabularies() + assert len(pipeline.prot_words_) == expected_prot_count @pytest.mark.parametrize( "device, target", diff --git a/pyaptamer/utils/_aptanet_utils.py b/pyaptamer/utils/_aptanet_utils.py index 7c23852..9bea84a 100644 --- a/pyaptamer/utils/_aptanet_utils.py +++ b/pyaptamer/utils/_aptanet_utils.py @@ -94,3 +94,32 @@ def pairs_to_features(X, k=4): # Ensure float32 for PyTorch compatibility return np.vstack(feats).astype(np.float32) + + +def rna2dna(seq): + """ + Convert an RNA sequence to a DNA sequence. + + Nucleotides 'U' in the RNA sequence are replaced with 'T' in the DNA sequence. + Unknown nucleotides are replaced with 'N'. Other nucleotides ('A', 'C', 'G') + remain unchanged. + + Parameters + ---------- + seq : str + The RNA sequence to be converted. + + Returns + ------- + str + The converted DNA sequence. + """ + # Replace nucleotides 'U' with 'T' + result = seq.translate(str.maketrans("U", "T")) + + # Replace any unknown characters with 'N' + for char in result: + if char not in "ACGT": + result = result.replace(char, "N") + + return result