Skip to content

Commit 86ca8f5

Browse files
authored
Create cluster_dataset.py (#63)
1 parent 0d4cba5 commit 86ca8f5

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Callable
2+
from torch.utils.data import Dataset
3+
import pandas as pd
4+
import os
5+
import numpy as np
6+
from Bio import SeqIO
7+
8+
9+
class ClusterDataset(Dataset):
10+
def __init__(
11+
self,
12+
dataset_path: str,
13+
cluster_table_path: str,
14+
size_to_sample_prob: Callable = lambda x: x,
15+
seed: int = 42,
16+
) -> None:
17+
super().__init__()
18+
self.dataset_path = dataset_path
19+
self.cluster_table_path = cluster_table_path
20+
self.cluster_to_seqs = {}
21+
self.cluster_table = pd.read_csv(
22+
cluster_table_path, dtype={'cluster_name': str, 'cluster_size': int}
23+
)
24+
self.cluster_table['sample_prob'] = self.cluster_table['cluster_size'].apply(size_to_sample_prob)
25+
self.cluster_table['sample_prob'] /= self.cluster_table['sample_prob'].sum()
26+
self.generator = np.random.default_rng(seed)
27+
28+
def __len__(self) -> int:
29+
return len(self.cluster_table)
30+
31+
def get_cluster_seqs(self, cluster_path: str) -> list:
32+
if cluster_path not in self.cluster_to_seqs:
33+
self.cluster_to_seqs[cluster_path] = [
34+
str(x.seq) for x in SeqIO.parse(cluster_path, 'fasta')
35+
]
36+
return self.cluster_to_seqs[cluster_path]
37+
38+
def __iter__(self):
39+
for _ in range(len(self)):
40+
cluster_name = self.cluster_table.sample(
41+
n=1, weights='sample_prob', random_state=self.generator
42+
)[['cluster_name']].values[0][0]
43+
# Now we map cluster_name to the folder it is in
44+
if cluster_name == "unk":
45+
cluster_path = os.path.join(self.dataset_path, "unk", "unk.fasta")
46+
else:
47+
cluster_dir = f"{int(cluster_name) // 1000}000"
48+
cluster_path = os.path.join(self.dataset_path, cluster_dir, f"{cluster_name}.fasta")
49+
seqs = self.get_cluster_seqs(cluster_path)
50+
yield seqs[self.generator.integers(len(seqs))]

0 commit comments

Comments
 (0)