Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions src/scripts/get_schuerch_lmdb_train_test_valid_split.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
{
"cells": [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have this as a python script rather than a Notebook, it would make it way easier to track changes

{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from collections import defaultdict\n",
"import numpy as np\n",
"\n",
"import pandas as pd\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total patients: 67\n",
"Train patients: 46 (74 files)\n",
"Val patients: 10 (16 files)\n",
"Test patients: 11 (19 files)\n"
]
}
],
"source": [
"# Get all files\n",
"files = os.listdir('/home/fabian/raid5/schuerch_dataset/preprocessed/full')\n",
"files = [f for f in files if f.endswith('.hdf')]\n",
"\n",
"# Extract unique patient numbers and group files\n",
"patient_groups = defaultdict(list)\n",
"for file in files:\n",
" # Extract patient number (e.g., '006' from 'reg006_A.hdf')\n",
" patient_num = file.split('_')[0][3:] # removes 'reg' prefix\n",
" patient_groups[patient_num].append(file)\n",
"\n",
"# Get unique patient numbers\n",
"unique_patients = list(patient_groups.keys())\n",
"n_patients = len(unique_patients)\n",
"\n",
"# Calculate split sizes\n",
"n_train = int(n_patients * 0.7)\n",
"n_val = int(n_patients * 0.15)\n",
"n_test = n_patients - n_train - n_val\n",
"\n",
"# Randomly split patients\n",
"np.random.seed(42) # for reproducibility\n",
"patients_shuffled = np.random.permutation(unique_patients)\n",
"train_patients = patients_shuffled[:n_train]\n",
"val_patients = patients_shuffled[n_train:n_train+n_val]\n",
"test_patients = patients_shuffled[n_train+n_val:]\n",
"\n",
"# Create the final splits\n",
"train_files = []\n",
"val_files = []\n",
"test_files = []\n",
"\n",
"for patient in train_patients:\n",
" train_files.extend(patient_groups[patient])\n",
"for patient in val_patients:\n",
" val_files.extend(patient_groups[patient])\n",
"for patient in test_patients:\n",
" test_files.extend(patient_groups[patient])\n",
"\n",
"# Print summary\n",
"print(f\"Total patients: {n_patients}\")\n",
"print(f\"Train patients: {len(train_patients)} ({len(train_files)} files)\")\n",
"print(f\"Val patients: {len(val_patients)} ({len(val_files)} files)\")\n",
"print(f\"Test patients: {len(test_patients)} ({len(test_files)} files)\")\n",
"\n",
"# Save splits to files\n",
"# with open('train_split.txt', 'w') as f:\n",
"# f.write('\\n'.join(sorted(train_files)))\n",
"# with open('val_split.txt', 'w') as f:\n",
"# f.write('\\n'.join(sorted(val_files)))\n",
"# with open('test_split.txt', 'w') as f:\n",
"# f.write('\\n'.join(sorted(test_files)))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create a list of (filename, split) tuples\n",
"split_entries = []\n",
"\n",
"for file in train_files:\n",
" split_entries.append((file, 'train'))\n",
"for file in val_files:\n",
" split_entries.append((file, 'valid'))\n",
"for file in test_files:\n",
" split_entries.append((file, 'test'))\n",
"\n",
"# Create DataFrame and save to CSV\n",
"df = pd.DataFrame(split_entries, columns=['sample_name', 'train_test_val_split'])\n",
"df = df.sort_values('sample_name') # Optional: sort by filename\n",
"\n",
"# Save to CSV\n",
"# df.to_csv('/home/fabian/raid5/schuerch_dataset/splits/schuerch_dataset_split.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All unique semantic mask values found: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]\n"
]
}
],
"source": [
"import lmdb\n",
"import pickle\n",
"import numpy as np\n",
"\n",
"# Open the LMDB environment\n",
"env = lmdb.open(\"/home/fabian/raid5/schuerch_dataset/schuerch_dataset_lmdb/lmdb/\", readonly=True)\n",
"\n",
"# Keep track of all unique semantic mask values\n",
"all_unique_values = set()\n",
"\n",
"with env.begin() as txn:\n",
" cursor = txn.cursor()\n",
" for key, value in cursor:\n",
" # Deserialize the data\n",
" tile_dict = pickle.loads(value)\n",
" # Get semantic mask and find unique values\n",
" semantic_mask = tile_dict['semantic_mask']\n",
" unique_values = np.unique(semantic_mask)\n",
" all_unique_values.update(unique_values.tolist())\n",
"\n",
"print(\"All unique semantic mask values found:\", sorted(all_unique_values))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(all_unique_values)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "bio_bench",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
39 changes: 39 additions & 0 deletions src/scripts/make_arctique_train_test_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import os
from pathlib import Path
import shutil
import pandas as pd

orig_train_path = Path("/fast/AG_Kainmueller/data/patho_foundation_model_bench_data/arctique_dataset/original_data//v1-0/train") #Path("C:/Users/cwinklm/Documents/Data/v_review_sample10/train")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have this in a run() function that is called with "if name == "main":"
Also, would be nice to use argparse to get the arguments like paths etc, this would make the script reusable if the data changes location, or if we want to adapt it to another dataset

target_folder = Path("/fast/AG_Kainmueller/data/patho_foundation_model_bench_data/arctique_dataset/arctique") #Path("C:/Users/cwinklm/Documents/Data/v_review_sample10/train_test_val")
os.makedirs(target_folder, exist_ok=True)
#for split_name in ["train", "test", "val"]:
os.makedirs(target_folder.joinpath("images"), exist_ok=True)
os.makedirs(target_folder.joinpath("masks", "instance"), exist_ok=True)
os.makedirs(target_folder.joinpath("masks", "semantic"), exist_ok=True)


all_samples = [int(n.split("_")[1].split(".")[0]) for n in os.listdir(orig_train_path.joinpath("images"))]
np.random.shuffle(all_samples)

train_percent = 0.7
n_train = int(len(all_samples)*train_percent)
n_test_val = len(all_samples) - n_train

train_samples = all_samples[:n_train]
val_samples = all_samples[n_train:n_train+n_test_val//2]
test_samples = all_samples[n_train+n_test_val//2:]
labels = ["train"]*len(train_samples) + ["val"]*len(val_samples) + ["test"]*len(test_samples)

split_dict = pd.DataFrame({"sample_name":all_samples, "train_test_val_split":labels})
split_dict.to_csv(target_folder.joinpath("train_test_val_split.csv"), index=False)

for sample_idx, sample_name in enumerate(all_samples):
shutil.copy(orig_train_path.joinpath("images", f"img_{sample_name}.png"),
target_folder.joinpath("images", f"img_{sample_name}.png"))

shutil.copy(orig_train_path.joinpath("masks", "semantic", f"{sample_name}.tif"),
target_folder.joinpath("masks", "semantic", f"{sample_name}.png"))

shutil.copy(orig_train_path.joinpath("masks", "instance", f"{sample_name}.tif"),
target_folder.joinpath("masks", "instance", f"{sample_name}.png"))
27 changes: 27 additions & 0 deletions src/scripts/pannuke_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import pandas as pd
from bio_image_datasets.pannuke_dataset import PanNukeDataset

if __name__ == "__main__":
dataset = PanNukeDataset(local_path='/fast/AG_Kainmueller/data/pannuke_cp')
samples_names = dataset.get_sample_names()

# Create a list of (filename, split) tuples
split_entries = []

for sample_name in samples_names:
if 'fold3' in sample_name:
split_entries.append((sample_name, 'train'))
elif 'fold2' in sample_name:
split_entries.append((sample_name, 'test'))
else:
split_entries.append((sample_name, 'valid'))

print('n samples:', len(split_entries))

# Create DataFrame and save to CSV
df = pd.DataFrame(split_entries, columns=['sample_name', 'train_test_val_split'])
df = df.sort_values('sample_name') # Optional: sort by filename

# Save to CSV
df.to_csv('/fast/AG_Kainmueller/data/patho_foundation_model_bench_data/pannuke/train_test_val_split.csv', index=False)