|
| 1 | +from typing import Callable |
| 2 | +from torch.utils.data import Dataset |
| 3 | +import pandas as pd |
| 4 | +import os |
| 5 | +import numpy as np |
| 6 | +from Bio import SeqIO |
| 7 | + |
| 8 | + |
| 9 | +class ClusterDataset(Dataset): |
| 10 | + def __init__( |
| 11 | + self, |
| 12 | + dataset_path: str, |
| 13 | + cluster_table_path: str, |
| 14 | + size_to_sample_prob: Callable = lambda x: x, |
| 15 | + seed: int = 42, |
| 16 | + ) -> None: |
| 17 | + super().__init__() |
| 18 | + self.dataset_path = dataset_path |
| 19 | + self.cluster_table_path = cluster_table_path |
| 20 | + self.cluster_to_seqs = {} |
| 21 | + self.cluster_table = pd.read_csv( |
| 22 | + cluster_table_path, dtype={'cluster_name': str, 'cluster_size': int} |
| 23 | + ) |
| 24 | + self.cluster_table['sample_prob'] = self.cluster_table['cluster_size'].apply(size_to_sample_prob) |
| 25 | + self.cluster_table['sample_prob'] /= self.cluster_table['sample_prob'].sum() |
| 26 | + self.generator = np.random.default_rng(seed) |
| 27 | + |
| 28 | + def __len__(self) -> int: |
| 29 | + return len(self.cluster_table) |
| 30 | + |
| 31 | + def get_cluster_seqs(self, cluster_path: str) -> list: |
| 32 | + if cluster_path not in self.cluster_to_seqs: |
| 33 | + self.cluster_to_seqs[cluster_path] = [ |
| 34 | + str(x.seq) for x in SeqIO.parse(cluster_path, 'fasta') |
| 35 | + ] |
| 36 | + return self.cluster_to_seqs[cluster_path] |
| 37 | + |
| 38 | + def __iter__(self): |
| 39 | + for _ in range(len(self)): |
| 40 | + cluster_name = self.cluster_table.sample( |
| 41 | + n=1, weights='sample_prob', random_state=self.generator |
| 42 | + )[['cluster_name']].values[0][0] |
| 43 | + # Now we map cluster_name to the folder it is in |
| 44 | + if cluster_name == "unk": |
| 45 | + cluster_path = os.path.join(self.dataset_path, "unk", "unk.fasta") |
| 46 | + else: |
| 47 | + cluster_dir = f"{int(cluster_name) // 1000}000" |
| 48 | + cluster_path = os.path.join(self.dataset_path, cluster_dir, f"{cluster_name}.fasta") |
| 49 | + seqs = self.get_cluster_seqs(cluster_path) |
| 50 | + yield seqs[self.generator.integers(len(seqs))] |
0 commit comments