diff --git a/getting-started.md b/getting-started.md index 93f1a0e7..6d5bffec 100644 --- a/getting-started.md +++ b/getting-started.md @@ -58,6 +58,7 @@ Stamp currently supports the following feature extractors: - [mSTAR][mstar] - [MUSK][musk] - [PLIP][plip] + - [TICON][ticon] As some of the above require you to request access to the model on huggingface, @@ -158,6 +159,7 @@ meaning ignored that it was ignored during feature extraction. [EAGLE]: https://github.com/KatherLab/EAGLE [MADELEINE]: https://huggingface.co/MahmoodLab/madeleine [PRISM]: https://huggingface.co/paige-ai/Prism +[TICON]: https://cvlab-stonybrook.github.io/TICON/ "TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning" diff --git a/mcp/README.md b/mcp/README.md index 7821a368..0f62c1b8 100644 --- a/mcp/README.md +++ b/mcp/README.md @@ -1,23 +1,24 @@ # STAMP MCP Server -A FastMCP-based Model Context Protocol server wrapping [STAMP](https://github.com/KatherLab/STAMP)’s CLI, enabling seamless integration of STAMP preprocessing, training, encoding, evaluation, and inference into LLM-based pipelines. +A FastMCP-based Model Context Protocol server wrapping [STAMP](https://github.com/KatherLab/STAMP)'s tools, enabling seamless integration of STAMP preprocessing, training, encoding, evaluation, and inference into LLM-based pipelines. ## Overview This server lets LLM agents invoke STAMP tools via structured calls. It exposes the following tools: -- `preprocess_stamp(...)`: tile & extract WSI features -- `train_stamp(...)`: train weakly supervised models -- `crossval_stamp(...)`: k-fold cross‑validation -- `deploy_stamp(...)`: inference on held‑out data -- `encode_slides_stamp(...)`: slide-level feature encoding -- `encode_patients_stamp(...)`: patient-level feature encoding -- `heatmaps_stamp(...)`: model-based heatmap visualization -- `statistics_stamp(...)`: compute classification metrics -- `read_file(...)` & `list_files(...)`: safe disk access -- `check_available_devices()`: query Torch/Platform device availability - -Each tool serializes config into YAML, launches `stamp `, streams logs back, and returns stdout/stderr. +- `preprocess_stamp()`: tile & extract WSI features +- `train_stamp()`: train weakly supervised models +- `crossval_stamp()`: k-fold cross‑validation +- `deploy_stamp()`: inference on held‑out data +- `encode_slides_stamp()`: slide-level feature encoding +- `encode_patients_stamp()`: patient-level feature encoding +- `heatmaps_stamp()`: model-based heatmap visualization +- `statistics_stamp()`: compute classification metrics +- `read_file()` & `list_files()`: safe disk access +- `check_available_devices()`: query Torch/Platform device availability +- `analyze_csv()` & `list_column_values`: useful for clinical and slide tables + +Each tool serializes config into YAML and directly calls STAMP's internal `_run_cli()` function, streaming logs back in real-time and returning execution results. ## Installation To run the MCP server is as simple as intalling STAMP as it is explained in the main README.md file, but adding `--extra mcp` to the command. For a GPU repository installation it would be like this: diff --git a/mcp/server.py b/mcp/server.py index 28781b2a..c66dc4f9 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -1,16 +1,21 @@ +"""STAMP MCP Server""" + import asyncio import logging import os import platform -import subprocess import tempfile from pathlib import Path from typing import Annotated +import argparse import torch import yaml from fastmcp import Context, FastMCP from pydantic import Field +import pandas as pd +from stamp.__main__ import _run_cli + # Initialize the FastMCP server mcp = FastMCP("STAMP MCP Server") @@ -18,24 +23,47 @@ STAMP_LOGGER = logging.getLogger("stamp") # TODO: add proper filesystem management -base_dir = "./" -base = Path(base_dir).resolve() +# The idea would be to send thw safe workspace via HTTP Headers or roots +# if OpenAI Agents SDK already implemented it. +# Check docs for more info. +WORKSPACE_FOLDER = "./" # Folder where the agent can work on. +WORKSPACE_PATH = Path(WORKSPACE_FOLDER).resolve() +# List of additional allowed paths outside workspace +ALLOWED_EXTERNAL_PATHS = [ + "/mnt/bulk-curie/peter/fmbenchmark/images/tcga_crc", + "/mnt/bulk-curie/peter/fmbenchmark/20mag_experiments/features/tcga_crc/ctranspath/STAMP_raw_xiyuewang-ctranspath-7c998680", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/features/features-20x/virchow2/CPTAC-CCRCC/virchow2-stamp-maru-21-12-24", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-CCRCC/data", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-BRCA/features-STAMP/conch1_5-778e1572", + "/mnt/copernicus3/PATHOLOGY/others/public/CPTAC/CPTAC-BRCA/data", + # Add other specific paths you want to allow +] +MAX_ITEMS = 100 # Max amount of files listed with list_files tool. +# Big values could exceed LLM's context length. When it exceeds, values are summarized. class MCPLogHandler(logging.Handler): - def __init__(self, ctx): + def __init__(self, ctx, loop: asyncio.AbstractEventLoop): super().__init__() self.ctx = ctx + self.loop = loop + self.captured_logs = [] # FIXME: Implement so the agent can see the logs when finished. Logging is viewed by the user only. - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: msg = self.format(record) - # Fire-and-forget the coroutine - asyncio.create_task(self.ctx.log(msg)) + try: + self.captured_logs.append(msg) + # Thread-safe: schedule on the captured event loop + asyncio.run_coroutine_threadsafe(self.ctx.log(msg), self.loop) + # Alternatively: + # self.loop.call_soon_threadsafe(self.loop.create_task, self.ctx.log(msg)) + except Exception: + self.handleError(record) async def _run_stamp(mode, config, ctx): """ - Run the STAMP command as a subprocess and capture its console output. + Run the STAMP command directly by calling _run_cli() instead of subprocess. Args: mode (str): The mode to run the STAMP command in (e.g., "preprocess", "train"). @@ -50,20 +78,36 @@ async def _run_stamp(mode, config, ctx): yaml.dump(config, tmp_config) tmp_config_path = tmp_config.name - handler = MCPLogHandler(ctx) - handler.setLevel(logging.DEBUG) + # Set up logging handler to capture STAMP logs + loop = asyncio.get_running_loop() + handler = MCPLogHandler(ctx, loop) STAMP_LOGGER.addHandler(handler) - print("Running command...") - try: - cmd = ["stamp", "--config", tmp_config_path, mode] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - print("Result returned...") - print(f"Command completed successfully:\n{result.stdout}\n{result.stderr}") - return f"Command completed successfully:\n{result.stdout}\n{result.stderr}" - except subprocess.CalledProcessError as e: - return f"Command failed with error:\n{e.stdout}\n{e.stderr}" + await ctx.info(f"Starting STAMP {mode} tool...") + # Create argparse Namespace object to mimic command line arguments + args = argparse.Namespace(command=mode, config_file_path=Path(tmp_config_path)) + + # Call the STAMP CLI function directly + await asyncio.to_thread(_run_cli, args) + + # Get captured logs + captured_logs_text = ( + "\n".join(handler.captured_logs) + if handler.captured_logs + else "Tool completed successfully (no logs captured)" + ) + await ctx.info(f"STAMP {mode} completed successfully") + return f"Tool completed successfully:\n{captured_logs_text}" + + except Exception as e: + captured_logs_text = ( + "\n".join(handler.captured_logs) if handler.captured_logs else "" + ) + error_msg = f"Tool failed with error: {str(e)}\n{captured_logs_text}" + await ctx.error(f"STAMP {mode} failed: {str(e)}") + return error_msg + finally: os.remove(tmp_config_path) STAMP_LOGGER.removeHandler(handler) @@ -206,22 +250,11 @@ async def train_stamp( "in the slide table containing the feature file path relative to `feature_dir`" ), ] = "FILENAME", - bag_size: Annotated[ - int, - Field( - description="Amount of tiles to sample when training. " - "Reducing this value reduces memory usage, but it is not recommended as the model can miss" - "relevant regions of the slide. Default value works well on H&E tissue images." - ), - ] = 512, - batch_size: Annotated[ - int, Field(description="Amount of bags processed together.") - ] = 64, ) -> str: """ Train a model using clinical data and WSI-derived features via STAMP. Takes in a clinical table, slide associations, and extracted features - to train a model on a specified label. + to train a model on a specified label. Best option when an external cohort is available. Returns: str: message indicating the success or failure of the training operation, @@ -250,8 +283,6 @@ async def train_stamp( "categories": categories, "patient_label": patient_label, "filename_label": filename_label, - "bag_size": bag_size, - "batch_size": batch_size, } } return await _run_stamp(mode="train", config=config, ctx=ctx) @@ -306,22 +337,12 @@ async def crossval_stamp( description="Number of folds to split the data into for cross-validation" ), ] = 5, - bag_size: Annotated[ - int, - Field( - description="Amount of tiles to sample when training. " - "Reducing this value reduces memory usage, but it is not recommended as the model can miss" - "relevant regions of the slide. Default value works well on H&E tissue images." - ), - ] = 512, - batch_size: Annotated[ - int, Field(description="Amount of bags processed together.") - ] = 64, ) -> str: """ Perform cross-validation for model training using STAMP. Splits the data into folds and trains a model on each to assess generalization. Uses clinical data, features, and slide mappings. + Best option when only one cohort is available. Returns: str: A message indicating the success or failure of the cross-validation operation, along with @@ -353,10 +374,6 @@ async def crossval_stamp( "filename_label": filename_label, "n_splits": n_splits, }, - "advanced_config": { # Add advanced config for bag_size and batch_size - "bag_size": bag_size, - "batch_size": batch_size, - }, } return await _run_stamp(mode="crossval", config=config, ctx=ctx) @@ -486,7 +503,7 @@ async def statistics_stamp( output_dir="output/statistics", ground_truth_label="OUTCOME", true_class="Positive", - pred_csvs=["predictions/fold1.csv", "predictions/fold2.csv"] + pred_csvs=["/pathto/split-0/patient-preds.csv", "/pathto/split-1/patient-preds.csv"] ) "Command completed successfully: ..." """ @@ -517,24 +534,40 @@ async def heatmaps_stamp( str, Field(description="Path of the model to generate the heatmaps with.") ], slide_paths: Annotated[ - list[str] | None, + list[str], Field( - description="List of slide paths relative " - "to `wsi_dir` to generate heatmaps for. If not specified, heatmaps will be generated " - "for all slides in `wsi_dir`." + description="List of slide paths relative to `wsi_dir` to " + "generate heatmaps for. The slide paths HAVE to be specified relative to `wsi_dir`.", + min_length=1, ), - ] = None, + ], topk: Annotated[ int | None, Field(description="Number of top-scoring tiles to extract") ] = None, bottomk: Annotated[ int | None, Field(description="Number of bottom-scoring tiles to extract") ] = None, + device: Annotated[ + str | None, + Field( + description="The device to use for computation. " + "Possible options are 'cuda' for NVIDIA GPUs, 'cpu' for general-purpose " + "processors, and 'mps' for Apple Silicon GPUs. Default is detected automatically" + ), + ] = None, ) -> str: """ Generate heatmaps and tile scorings from WSIs using a trained model. - Produces visual explanations and optionally extracts top/bottom - scoring tiles. + + Creates visual attention maps showing which regions the model focuses on for predictions. + Works only with tile-level features. For each slide, generates: + - Overview plots with complete heatmaps and class overlays + - Raw data including thumbnails, class maps, and per-class heatmaps + - Individual tile extractions (top/bottom scoring if specified) + + Output structure: Each slide gets its own folder + (slide name without file extension)containing plots/, raw/, and tiles/ subdirectories. + Returns: str: A message indicating the success or failure of the heatmap generation operation, @@ -547,8 +580,8 @@ async def heatmaps_stamp( wsi_dir="input/slides", checkpoint_path="models/checkpoint.pth", slide_paths=["slide1.svs", "slide2.svs"], - topk=10, - bottomk=5 + topk=3, + bottomk=3 ) "Command completed successfully: ..." """ @@ -561,6 +594,7 @@ async def heatmaps_stamp( "slide_paths": slide_paths, "topk": topk, "bottomk": bottomk, + "device": device, } } return await _run_stamp(mode="heatmaps", config=config, ctx=ctx) @@ -697,14 +731,68 @@ async def encode_patients_stamp( def _resolve_path(subpath: str) -> Path: - requested = (base / subpath).resolve() - if base not in requested.parents and requested != base: - raise PermissionError(f"Access denied: {subpath}") - return requested + """ + Resolve path with security checks: + - Paths starting with /mnt/, /tmp/, /home/, etc. are treated as external absolute paths + - All other paths (including /tables, /data, etc.) are treated as workspace-relative + """ + requested = Path(subpath) + + # Check if it's a true external absolute path (starting with known system roots) + external_roots = [ + "/mnt/", + "/tmp/", + "/home/", + "/usr/", + "/var/", + "/opt/", + "/etc/", + "/root/", + "/boot/", + "/sys/", + "/proc/", + "/dev/", + ] + is_external_absolute = any(subpath.startswith(root) for root in external_roots) + + if is_external_absolute: + # This is a true external absolute path - check against allowed external paths + requested_resolved = requested.resolve() + + # Check if it's in allowed external paths + for allowed_path in ALLOWED_EXTERNAL_PATHS: + allowed_path = Path(allowed_path).resolve() + # Check both: exact match OR if allowed_path is a parent of requested + if ( + requested_resolved == allowed_path + or allowed_path in requested_resolved.parents + ): + return requested_resolved + + # If not in allowed external paths, raise error + raise PermissionError(f"Access denied to external absolute path: {subpath}") + + else: + # Treat as workspace-relative (including paths like /tables, /data, etc.) + # Remove leading slash if present to make it clearly relative + clean_path = subpath.lstrip("/") + requested_resolved = (WORKSPACE_PATH / clean_path).resolve() + + # Check if resolved path is within workspace + if ( + WORKSPACE_PATH in requested_resolved.parents + or requested_resolved == WORKSPACE_PATH + ): + return requested_resolved + + # If not within workspace, raise error + raise PermissionError( + f"Access denied: path {subpath} resolves outside workspace" + ) @mcp.tool -def read_file(path: str) -> str: +async def read_file(ctx: Context, path: str) -> str: """ Read the contents of a file inside the allowed folder. @@ -714,41 +802,290 @@ def read_file(path: str) -> str: Returns: str: Content of the file. """ + await ctx.info("Starting read_file tool...") safe_path = _resolve_path(path) with open(safe_path, "r", encoding="utf-8") as f: return f.read() @mcp.tool -def list_files(subdir: str = "") -> list: +async def list_files(ctx: Context, subdir: str = "") -> str: """ List all files and directories under the given subdirectory (default is root), recursively, - returning paths relative to the base directory. + returning paths relative to the base directory. If the list is too long, shows only directories + with file type summaries. If still too long, shows a truncated message. Args: subdir (str): Relative subdirectory path to list files from. Returns: - list: List of relative file paths found. + str: Formatted list of files/directories or summary information. """ - safe = _resolve_path(subdir) - if not safe.is_dir(): + await ctx.info("Starting list_files tool...") + subdir_path = _resolve_path(subdir) if subdir else WORKSPACE_PATH + if not subdir_path.is_dir(): raise FileNotFoundError(f"Subdirectory does not exist: {subdir}") - results = [] - base_len = len(str(base)) + 1 # To slice off base path + separator - for root, dirs, files in os.walk(safe): + + # Collect all files and directories + all_items = [] + directories = {} + base_len = len(str(WORKSPACE_PATH)) + 1 # To slice off base path + separator + + for root, dirs, files in os.walk(subdir_path): rel_root = str(root)[base_len:] # relative path under base_dir + + # Track file types in each directory + if rel_root not in directories: + directories[rel_root] = {"subdirs": [], "file_types": {}, "file_count": 0} + + # Add subdirectories for d in dirs: path = os.path.join(rel_root, d) - results.append(path + "/") + all_items.append(path + "/") + directories[rel_root]["subdirs"].append(d) + + # Add files and track their extensions for f in files: path = os.path.join(rel_root, f) - results.append(path) - return sorted(results) + all_items.append(path) + + # Track file extension + ext = Path(f).suffix.lower() or "no extension" + directories[rel_root]["file_types"][ext] = ( + directories[rel_root]["file_types"].get(ext, 0) + 1 + ) + directories[rel_root]["file_count"] += 1 + + # If the list is manageable, return the full list + if len(all_items) <= MAX_ITEMS: + return "\n".join(sorted(all_items)) + + # Try directory summary instead with sample files + dir_summary = [] + sample_files_per_dir = 5 # Show up to 5 sample files per directory + + for dir_path, info in sorted(directories.items()): + if not dir_path: # Root directory + dir_display = "/ (root)" + current_dir = WORKSPACE_PATH + else: + dir_display = f"{dir_path}/" + current_dir = WORKSPACE_PATH / dir_path + + # File type summary + if info["file_count"] > 0: + file_types = [] + for ext, count in sorted(info["file_types"].items()): + file_types.append(f"{count} {ext}") + file_summary = f" [{', '.join(file_types)}]" + else: + file_summary = " [empty]" + + # Subdirectory info + if info["subdirs"]: + subdir_info = f" (contains {len(info['subdirs'])} subdirs)" + else: + subdir_info = "" + + dir_summary.append(f"{dir_display}{file_summary}{subdir_info}") + + # Add sample files from this directory + if info["file_count"] > 0: + try: + # Get sample files from this specific directory (not recursive) + sample_files = [] + if current_dir.exists() and current_dir.is_dir(): + for item in sorted(current_dir.iterdir()): + if item.is_file() and len(sample_files) < sample_files_per_dir: + rel_path = str(item.relative_to(WORKSPACE_PATH)) + sample_files.append(f" • {rel_path}") + + if sample_files: + dir_summary.extend(sample_files) + if info["file_count"] > sample_files_per_dir: + dir_summary.append( + f" ... and {info['file_count'] - len(sample_files)} more files" + ) + except Exception: + # If we can't read the directory, just skip the sample files + pass + + # If directory summary is still too long, truncate + if len(dir_summary) > MAX_ITEMS: + total_dirs = len(directories) + total_files = sum(info["file_count"] for info in directories.values()) + + # Show first few directories and a summary + shown_dirs = dir_summary[: MAX_ITEMS // 2] + summary_text = ( + f"\n... (showing first {len(shown_dirs)} of {total_dirs} directories)\n\n" + f"SUMMARY:\n" + f"- Total directories: {total_dirs}\n" + f"- Total files: {total_files}\n" + f"- Directory '{subdir or '/'}' contains too many items to display completely.\n" + f"- Use a more specific subdirectory path to see detailed listings." + ) + + # Get overall file type statistics + all_extensions = {} + for info in directories.values(): + for ext, count in info["file_types"].items(): + all_extensions[ext] = all_extensions.get(ext, 0) + count + + if all_extensions: + ext_summary = [] + for ext, count in sorted( + all_extensions.items(), key=lambda x: x[1], reverse=True + )[:10]: + ext_summary.append(f" {ext}: {count} files") + summary_text += "\n\nTop file types:\n" + "\n".join(ext_summary) + if len(all_extensions) > 10: + summary_text += ( + f"\n ... and {len(all_extensions) - 10} more file types" + ) + + return "\n".join(shown_dirs) + summary_text + + # Return directory summary + header = f"Directory listing for '{subdir or '/'}' (showing directories with file type summaries):\n" + return header + "\n".join(dir_summary) + + +@mcp.tool +async def analyze_csv(ctx: Context, path: str) -> str: + """ + Analyze a CSV file and provide detailed information about its structure and contents. + + Args: + path (str): Relative path to the CSV file. + + Returns: + str: Detailed information about the CSV including dimensions, columns, and sample data. + """ + await ctx.info("Starting analyze_csv tool...") + safe_path = _resolve_path(path) + + if not safe_path.exists(): + raise FileNotFoundError(f"CSV file does not exist: {path}") + + if safe_path.suffix.lower() not in [".csv", ".tsv"]: + raise ValueError(f"File is not a CSV file: {path}") + + try: + # Read the CSV file + df = pd.read_csv(safe_path) + + # Get basic information + num_rows, num_columns = df.shape + column_names = df.columns.tolist() + + # Get first 3 rows as examples + sample_rows = df.head(3).to_string(index=True, max_cols=None) + + # Format the output + result = f"""CSV File Analysis: {path} + +Dimensions: +- Number of rows: {num_rows:,} +- Number of columns: {num_columns} + +Column Names: +{", ".join([f'"{col}"' for col in column_names])} + +First 3 rows (sample data): +{sample_rows} + +Data Types: +{df.dtypes.to_string()} + """ + + return result.strip() + + except pd.errors.EmptyDataError: + return f"CSV file is empty: {path}" + except pd.errors.ParserError as e: + return f"Error parsing CSV file {path}: {str(e)}" + except Exception as e: + return f"Error analyzing CSV file {path}: {str(e)}" + + +@mcp.tool +async def list_column_values(ctx: Context, path: str, column_name: str) -> str: + """ + List all unique values in a specific column of a CSV file. + + Args: + path (str): Relative path to the CSV file. + column_name (str): Name of the column to analyze. + + Returns: + str: Information about the unique values in the specified column. + """ + await ctx.info("Starting list_column_values tool...") + safe_path = _resolve_path(path) + + if not safe_path.exists(): + raise FileNotFoundError(f"CSV file does not exist: {path}") + + if safe_path.suffix.lower() not in [".csv", ".tsv"]: + raise ValueError(f"File is not a CSV file: {path}") + + try: + # Read the CSV file + df = pd.read_csv(safe_path) + + # Check if column exists + if column_name not in df.columns: + available_columns = ", ".join([f'"{col}"' for col in df.columns]) + return f"Column '{column_name}' not found in CSV file: {path}\nAvailable columns: {available_columns}" + + # Get unique values + unique_values = df[column_name].unique() + + # Count occurrences of each value + value_counts = df[column_name].value_counts().sort_index() + + # Handle missing values + null_count = df[column_name].isnull().sum() + + # Format the output + result = f"""Column Analysis for '{column_name}' in {path} + +Total rows: {len(df):,} +Unique values: {len(unique_values):,} +Missing/null values: {null_count:,} + +Value distribution: +{value_counts.to_string()} + """ + + # If there are many unique values, show a sample + if len(unique_values) > 20: + result += f""" + +First 20 unique values: +{", ".join([str(val) for val in unique_values[:20]])} +... and {len(unique_values) - 20} more values + """ + else: + result += f""" + +All unique values: +{", ".join([str(val) for val in unique_values if pd.notna(val)])} + """ + + return result.strip() + + except pd.errors.EmptyDataError: + return f"CSV file is empty: {path}" + except pd.errors.ParserError as e: + return f"Error parsing CSV file {path}: {str(e)}" + except Exception as e: + return f"Error analyzing column in CSV file {path}: {str(e)}" @mcp.tool -def check_available_devices() -> str: +async def check_available_devices(ctx: Context) -> str: """ Check which computation devices are available on the system. This includes checking for cuda (NVIDIA GPUs) and mps (Apple Silicon GPUs). @@ -756,6 +1093,7 @@ def check_available_devices() -> str: Returns: A string describing the available devices. """ + await ctx.info("Starting check_available_devices tool...") devices = [] # Check for CUDA availability diff --git a/src/stamp/heatmaps/config.py b/src/stamp/heatmaps/config.py index 98bc1744..1b8b199c 100644 --- a/src/stamp/heatmaps/config.py +++ b/src/stamp/heatmaps/config.py @@ -9,15 +9,21 @@ class HeatmapConfig(BaseModel): model_config = ConfigDict(extra="forbid") - output_dir: Path + output_dir: Path = Field(description="Directory to save heatmap outputs") - feature_dir: Path - wsi_dir: Path - checkpoint_path: Path + feature_dir: Path = Field(description="Directory containing extracted features") + wsi_dir: Path = Field(description="Directory containing whole slide images") + checkpoint_path: Path = Field(description="Path to model checkpoint file") - slide_paths: list[Path] | None = None + slide_paths: list[Path] | None = Field( + default=None, + description="Specific slide paths to process. If None, processes all slides in wsi_dir", + ) - device: str = "cuda" if torch.cuda.is_available() else "cpu" + device: str = Field( + default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu", + description="Device to use for computation", + ) opacity: float = Field( default=0.6, @@ -26,8 +32,19 @@ class HeatmapConfig(BaseModel): le=1, ) - topk: int = 0 - bottomk: int = 0 + topk: int = Field( + default=0, + description="Number of top patches to highlight. 0 means no highlighting.", + ge=0, + ) + + bottomk: int = Field( + default=0, + description="Number of bottom patches to highlight. 0 means no highlighting.", + ge=0, + ) - default_slide_mpp: SlideMPP | None = None - """MPP of the slide to use if none can be inferred from the WSI""" + default_slide_mpp: SlideMPP | None = Field( + default=None, + description="MPP of the slide to use if none can be inferred from the WSI", + ) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 2f121f27..6cabec64 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -37,7 +37,6 @@ ) _logger = logging.getLogger("stamp") -_logged_stamp_v1_warning = False __author__ = "Marko van Treeck, Minh Duc Nguyen" @@ -569,13 +568,9 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: == 224 ): # Historic STAMP format - # TODO: find a better way to get this warning just once - global _logged_stamp_v1_warning - if not _logged_stamp_v1_warning: - _logger.info( - f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" - ) - _logged_stamp_v1_warning = True + _logger.debug( + f"{feature_h5.filename}: tile stride is roughly 224, assuming coordinates have unit 256um/224px (historic STAMP format)" + ) tile_size_um = Microns(256.0) tile_size_px = TilePixels(224) coords_um = coords / 224 * 256 diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..ab3ff0d2 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -222,6 +222,11 @@ def extract_( extractor = plip() + case ExtractorName.TICON: + from stamp.preprocessing.extractor.ticon import ticon + + extractor = ticon() + case ExtractorName.EMPTY: from stamp.preprocessing.extractor.empty import empty diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 244d70dd..5eca41dd 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -28,6 +28,7 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + TICON = "ticon" EMPTY = "empty" diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py new file mode 100644 index 00000000..fb8f9b43 --- /dev/null +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -0,0 +1,730 @@ +import math +from collections.abc import Callable, Mapping +from functools import partial +from typing import Any + +import timm +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from jaxtyping import Float + +# from omegaconf import OmegaConf +from torch import Tensor +from torchvision import transforms + +from stamp.preprocessing.extractor import Extractor + +try: + import timm + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "h_optimus_1 dependencies not installed." + " Please reinstall stamp using `pip install 'stamp[h_optimus_1]'`" + ) from e + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(torch.empty(dim)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.gamma, self.init_values) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + mlp_ratio: int | float | None = (16 / 3), + bias: bool = True, + ) -> None: + super().__init__() + if hidden_features is None: + assert mlp_ratio is not None + hidden_features = int(in_features * mlp_ratio) + else: + assert mlp_ratio is None + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) + + def forward(self, x: Float[Tensor, "*b d"]) -> Float[Tensor, "*b d"]: + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=-1) + x = self.act(x1) * x2 + x = self.fc2(x) + return x + + +class ProjectionMlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.norm = nn.LayerNorm(out_features) + + def forward(self, x: Float[Tensor, "*b d"]) -> Float[Tensor, "*b d"]: + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.norm(x) + return x + + +def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2( + n + ) # In the paper, we only train models that have 2^a heads for some a. This function has + else: # some good properties that only occur when the input is a power of 2. To maintain that even + closest_power_of_2 = 2 ** math.floor( + math.log2(n) + ) # when the number of heads is not a power of 2, we use this workaround. + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def scaled_dot_product_attention_custom( + query, + key, + value, + attn_bias=None, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + # attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) # pyright: ignore[reportOptionalMemberAccess] + attn_bias.to(query.dtype) # pyright: ignore[reportOptionalMemberAccess] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) # pyright: ignore[reportOptionalMemberAccess] + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + context_dim: int | None = None, + # rope_kwargs: Mapping = {}, + ) -> None: + super().__init__() + self.num_heads = num_heads + context_dim = context_dim or dim + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + # self.rope = Rope(dim=head_dim, **rope_kwargs) + slopes = torch.Tensor(get_slopes(num_heads)) + self.slopes = slopes[ + None, :, None, None + ] # einops.rearrange(slopes, 'b -> 1 b 1 1') + + def forward( + self, + x: Float[Tensor, "b n_q d"], + coords: Float[Tensor, "b n_q 2"], + context: Float[Tensor, "b n_k d_k"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n_q d"]: + if context is None or context_coords is None: + context = x + context_coords = coords + b, n_q, d = x.shape + b, n_k, _ = context.shape + h = self.num_heads + + q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) + k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + + corrds_expanded = coords.unsqueeze(2).expand( + -1, -1, n_k, -1 + ) # (b, m, d) -> (b, m, 1, d) -> (b, m, n, d) + context_coords_expanded = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) + euclid_dist = torch.sqrt( + torch.sum((corrds_expanded - context_coords_expanded) ** 2, dim=-1) + ) + self.slopes = self.slopes.to(x.device) + attn_bias = (-1) * self.slopes * euclid_dist[:, None, :, :] + + # x = F.scaled_dot_product_attention(q, k, v) + x = scaled_dot_product_attention_custom(q, k, v, attn_bias=attn_bias) + x = x.transpose(1, 2).reshape([b, n_q, d]) + x = self.proj(x) + return x + + +class NaiveResidual(nn.Module): + def __init__( + self, + drop_prob: float | int, + norm: nn.Module, + fn: nn.Module, + gamma: nn.Parameter, + ): + super().__init__() + self.norm = norm + self.fn = fn + self.keep_prob = 1 - drop_prob + self.gamma = gamma + + def forward( + self, + x: Float[Tensor, "b n d"], + **kwargs: Float[Tensor, "b ..."] | None, + ) -> Float[Tensor, "b n d"]: + fn_out = self.fn(self.norm(x), **kwargs) + if self.gamma is not None: + if self.keep_prob == 1.0 or not self.training: + return x + self.gamma * fn_out + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[ + :, None, None + ] + return x + self.gamma * fn_out * mask / self.keep_prob + else: + if self.keep_prob == 1.0 or not self.training: + return x + fn_out + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[ + :, None, None + ] + return x + fn_out * mask / self.keep_prob + + +class EfficientResidual(NaiveResidual): + def forward( + self, + x: Float[Tensor, "b n d"], + **kwargs: Float[Tensor, "b ..."] | None, + ) -> Float[Tensor, "b n d"]: + if self.keep_prob == 1.0 or not self.training: + if self.gamma is not None: + return x + self.gamma * self.fn(self.norm(x), **kwargs) + else: + return x + self.fn(self.norm(x), **kwargs) + + b, _, _ = x.shape + n_keep = max(int(b * self.keep_prob), 1) + indices = torch.randperm(b, device=x.device)[:n_keep] + for k, v in kwargs.items(): + if v is not None: + kwargs[k] = v[indices] + if self.gamma is not None: + return torch.index_add( + x, + dim=0, + source=self.gamma * self.fn(self.norm(x[indices]), **kwargs), + index=indices, + alpha=b / n_keep, + ) + else: + return torch.index_add( + x, + dim=0, + source=self.fn(self.norm(x[indices]), **kwargs), + index=indices, + alpha=b / n_keep, + ) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + drop_path: float | int, + norm_layer: Callable[[int], nn.Module], + context_dim: int | None, + drop_path_type: str = "efficient", + layer_scale: int = True, + attn_kwargs: Mapping = {}, + ) -> None: + super().__init__() + residual_module = { + "naive": NaiveResidual, + "efficient": EfficientResidual, + }[drop_path_type] + + self.layer_scale = layer_scale + if layer_scale: + gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + else: + gamma1 = None + gamma2 = None + + self.residual1 = residual_module( + drop_path, + norm_layer(dim), + Attention( + dim, + context_dim=context_dim, + **attn_kwargs, + ), + gamma1, + ) + self.residual2 = residual_module( + drop_path, norm_layer(dim), Mlp(in_features=dim), gamma2 + ) + + def forward( + self, + x: Float[Tensor, "b n d"], + context: Float[Tensor, "b n_k d_k"] | None = None, + coords: Float[Tensor, "b n 2"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n d"]: + x = self.residual1( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + x = self.residual2(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, + embed_dim: int, + norm_layer: Callable[[int], nn.Module], + depth: int, + drop_path_rate: float | int, + context_dim: int | None = None, + block_kwargs: Mapping[str, Any] = {}, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_blocks = depth + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + drop_path=drop_path_rate, + norm_layer=norm_layer, + context_dim=context_dim, + **block_kwargs, + ) + for i in range(depth) + ], + ) + + def forward( + self, + x: Float[Tensor, "b n d"], + return_layers: set[int], + contexts: list[Float[Tensor, "b n_k d_k"]] | None = None, + coords: Float[Tensor, "b n 2"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> dict[int, Float[Tensor, "b n d"]]: + outputs = {} + if 0 in return_layers: + outputs[0] = x + for blk_idx, blk in enumerate(self.blocks): + context = contexts[blk_idx] if contexts is not None else None + x = blk( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + if blk_idx + 1 in return_layers: + outputs[blk_idx + 1] = x + return outputs + + +class EncoderDecoder(nn.Module): + def __init__( + self, + patch_size: int = 14, + in_dims: list = [], + tile_encoder_keys: list = [], + norm_layer_type: str = "LayerNorm", + transformers_kwargs: Mapping[str, Any] = {}, + encoder_kwargs: Mapping[str, Any] = {}, + decoder_kwargs: Mapping[str, Any] = {}, + norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, + final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, + out_layer: int = -1, + num_decoders: int = 1, + decoder_out_dims: list = [], + ): + super().__init__() + self.patch_size = patch_size + + norm_layer: Callable[[int], nn.Module] = partial( + getattr(torch.nn, norm_layer_type), **norm_layer_kwargs + ) + + self.encoder = Transformer( + **transformers_kwargs, + **encoder_kwargs, + norm_layer=norm_layer, + ) + + self.tile_encoder_keys = tile_encoder_keys + self.embed_dim = self.encoder.embed_dim + self.n_blocks = len(self.encoder.blocks) + self.out_layer = out_layer % (len(self.encoder.blocks) + 1) + self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) + self.num_decoders = num_decoders + self.decoder_out_dims = decoder_out_dims + + self.decoder_dict = nn.ModuleDict({}) + self.mask_dict = nn.ParameterDict({}) + self.input_proj_dict = nn.ModuleDict({}) + self.output_proj_dict = nn.ModuleDict({}) + + for i in range(len(in_dims)): + self.input_proj_dict[f"input_proj_{self.tile_encoder_keys[i]}"] = ( + ProjectionMlp( + in_features=in_dims[i], + hidden_features=self.encoder.embed_dim, + out_features=self.encoder.embed_dim, + ) + ) + + for i in range(self.num_decoders): + self.decoder_dict[f"decoder_{i}"] = nn.ModuleDict({}) + self.decoder_dict[f"decoder_{i}"]["transformer"] = Transformer( # pyright: ignore[reportIndexIssue] + **transformers_kwargs, + **decoder_kwargs, + context_dim=self.encoder.embed_dim, + norm_layer=norm_layer, + ) + + self.decoder_dict[f"decoder_{i}"]["norm"] = norm_layer( # pyright: ignore[reportIndexIssue] + self.decoder_dict[f"decoder_{i}"]["transformer"].embed_dim, # pyright: ignore[reportIndexIssue] + **final_norm_kwargs, + ) + self.mask_dict[f"mask_token_{i}"] = nn.Parameter( + torch.empty( + 1, + self.decoder_dict[f"decoder_{i}"]["transformer"].embed_dim, # pyright: ignore[reportIndexIssue] + ) + ) + + for i in range(len(self.decoder_out_dims)): + self.output_proj_dict[f"output_proj_{self.tile_encoder_keys[i]}"] = ( + ProjectionMlp( + in_features=self.encoder.embed_dim, + hidden_features=self.encoder.embed_dim, + out_features=self.decoder_out_dims[i], + ) + ) + + assert self.num_decoders <= 1 + + def init_weights(self): + for mask_key in self.mask_dict.keys(): + nn.init.normal_(self.mask_dict[mask_key], std=0.02) + self.apply(_init_weights) + return self + + def forward_features( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"] | None, + predict_coords: Float[Tensor, "b n 2"] | None, + enc_layer: int, + dec_layer: int | None, + tile_encoder_key: str | None, + ) -> tuple[Float[Tensor, "b n d"], dict | None]: + b, _, _ = x.shape + + # these are the layers we need + enc_layers = {enc_layer} + if dec_layer is not None: + enc_layers.add(len(self.encoder.blocks)) + + # encoder fwd + coords_enc = relative_coords + coords_dec = predict_coords + x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) + encoder_outputs = self.encoder(x, coords=coords_enc, return_layers=enc_layers) + encoder_outputs = {k: self.enc_norm(v) for k, v in encoder_outputs.items()} + + # decoder fwd + if dec_layer is not None: + dec_final_output = {} + assert self.num_decoders == 1 + for dec_index in range(self.num_decoders): + decoder_outputs = self.decoder_dict[ + f"decoder_{dec_index}" + ][ # pyright: ignore[reportIndexIssue] + "transformer" + ]( + self.mask_dict[f"mask_token_{dec_index}"][None].expand( + *coords_dec.shape[:2], # pyright: ignore[reportOptionalMemberAccess] + -1, # pyright: ignore[reportOptionalMemberAccess] + ), + contexts=[encoder_outputs[len(self.encoder.blocks)]] + * self.decoder_dict[f"decoder_{dec_index}"]["transformer"].n_blocks, # pyright: ignore[reportIndexIssue] + coords=coords_dec, + context_coords=coords_enc, + return_layers={dec_layer}, + ) + dec_output = self.decoder_dict[f"decoder_{dec_index}"]["norm"]( # pyright: ignore[reportIndexIssue] + decoder_outputs[dec_layer] + ) + + for out_index in range(len(self.decoder_out_dims)): + dec_final_output[self.tile_encoder_keys[out_index]] = ( + self.output_proj_dict[ + f"output_proj_{self.tile_encoder_keys[out_index]}" + ](dec_output) + ) + else: + dec_final_output = None + enc_output = encoder_outputs[enc_layer] + return (enc_output, dec_final_output) + + def forward( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"] | None = None, + tile_encoder_key: str | None = None, + ) -> Float[Tensor, "b n d"]: + # print("Input feature range", torch.min(x), torch.max(x)) + # print("Input coords range", torch.min(relative_coords), torch.max(relative_coords)) + enc_output, dec_output = self.forward_features( + x, + relative_coords=relative_coords, + predict_coords=None, + enc_layer=self.out_layer, + dec_layer=None, + tile_encoder_key=tile_encoder_key, + ) + + # print(torch.min(enc_output), torch.max(enc_output)) + return enc_output + + +# from https://github.com/facebookresearch/mae/blob/main/models_mae.py +def _init_weights(m: nn.Module, xavier_gain=1) -> None: + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight, gain=xavier_gain) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm | nn.RMSNorm) and m.elementwise_affine: + nn.init.constant_(m.weight, 1.0) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) # pyright: ignore[reportArgumentType] + if hasattr(m, "_device_weight_init"): + m._device_weight_init() # pyright: ignore[reportCallIssue] + + +def load_ticon(device: str = "cuda") -> nn.Module: + model_cfg = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + "conchv15", + "hoptimus1", + "uni2h", + "gigapath", + "virchow2", + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], + } + + ckpt = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + with torch.device("meta"): + model = EncoderDecoder(**model_cfg) + + model.to_empty(device=device) + model.init_weights() + + sd = torch.load(ckpt, map_location="cpu", weights_only=True) + sd = { + k.removeprefix("backbone."): v + for k, v in sd.items() + if k.startswith("backbone.") + } + + model.load_state_dict(sd, strict=False) + model.eval() + return model + + +class HOptimusTICON(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.device = device + + # ---------------------------- + # Stage 1: H-OptimUS + # ---------------------------- + self.tile_encoder = timm.create_model( + "hf-hub:bioptimus/H-optimus-1", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + ) + + # ---------------------------- + # Stage 2: TICON + # ---------------------------- + ticon_cfg = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + "conchv15", + "hoptimus1", + "uni2h", + "gigapath", + "virchow2", + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], + } + + with torch.device("meta"): + self.ticon = EncoderDecoder(**ticon_cfg) + + self.ticon.to_empty(device=device) + self.ticon.init_weights() + + ckpt = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + sd = torch.load(ckpt, map_location="cpu", weights_only=True) + sd = { + k.removeprefix("backbone."): v + for k, v in sd.items() + if k.startswith("backbone.") + } + self.ticon.load_state_dict(sd, strict=False) + + self.to(device) + self.eval() + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [B, 3, 224, 224] (CPU or CUDA) + """ + x = x.to(self.device, non_blocking=True) + + # H-Optimus_1 + emb = self.tile_encoder(x) # [B, 1536] + emb = emb.unsqueeze(1) # [B, 1, 1536] + # TICON + # single-tile → zero relative coords + coords = torch.zeros( + emb.size(0), + 1, + 2, + device=self.device, + dtype=torch.float32, + ) + + out = self.ticon( + x=emb, + relative_coords=coords, + tile_encoder_key="hoptimus1", + ) + + return out.squeeze(1) # [B, 1536] + + +def ticon(device: str = "cuda") -> Extractor[nn.Module]: + model = HOptimusTICON(torch.device(device)) + + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.707223, 0.578729, 0.703617), + std=(0.211883, 0.230117, 0.177517), + ), + ] + ) + + return Extractor( + model=model, + transform=transform, + identifier="ticon", + )