Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
62ce27f
update
bertranMiquel May 13, 2024
e8a35dc
First ring-based commit
bertranMiquel Jun 13, 2024
1385b4e
Merge remote-tracking branch 'upstream/main'
bertranMiquel Jun 13, 2024
82c6b25
Readme update
bertranMiquel Jun 13, 2024
4be4aa6
mol ring lifting correctation
bertranMiquel Jun 13, 2024
02ea2bf
Correct more errors
bertranMiquel Jun 13, 2024
a83f484
Fix lifting details
bertranMiquel Jun 13, 2024
cd936f7
Hope it is the last one
bertranMiquel Jun 13, 2024
431e246
Deleting other test since they are using other test dataset
bertranMiquel Jun 13, 2024
ae365b5
git ignore modification
bertranMiquel Jun 13, 2024
b201b54
minor modifications
bertranMiquel Jun 14, 2024
917799b
Update ring_lifting.py
bertranMiquel Jun 14, 2024
a7534a8
Ring lifting modifications
bertranMiquel Jun 15, 2024
b884a92
Update dependencies
bertranMiquel Jun 18, 2024
1da64be
Updates due to warnings
bertranMiquel Jun 20, 2024
2f35444
Reload tests
bertranMiquel Jun 20, 2024
9ab1d8a
Trail spaces
bertranMiquel Jun 20, 2024
59735bd
Changes in test manual data.
bertranMiquel Jun 20, 2024
2e30876
Remove space
bertranMiquel Jun 20, 2024
28e7b74
Refine test.
bertranMiquel Jun 20, 2024
704e2e1
Remove notebook
bertranMiquel Jun 20, 2024
a07528f
correct torch in test
bertranMiquel Jun 20, 2024
db73c5e
Adding attributes functions
bertranMiquel Jul 1, 2024
0bbc8cd
Adding attributes
bertranMiquel Jul 1, 2024
0783baf
Load UniProt data
bertranMiquel Jul 3, 2024
29608e8
Trying lifting stuff
bertranMiquel Jul 4, 2024
456dddc
Lifting try
bertranMiquel Jul 5, 2024
06b2084
All developed
bertranMiquel Jul 8, 2024
7a7d2df
Comments updated
bertranMiquel Jul 8, 2024
cf7ffb7
Update readme
bertranMiquel Jul 8, 2024
530a198
Remove random package
bertranMiquel Jul 8, 2024
e1f83a3
update size parameter
bertranMiquel Jul 8, 2024
44e083d
Change size parameter
bertranMiquel Jul 8, 2024
c406216
Change attributes to mass function name
bertranMiquel Jul 8, 2024
7c46cb9
Update import class
bertranMiquel Jul 8, 2024
9cc415a
Update import class
bertranMiquel Jul 8, 2024
d1f49f6
Add parameter to test
bertranMiquel Jul 8, 2024
0fbae67
Finish test data
bertranMiquel Jul 8, 2024
932b9b3
Updata load manual data
bertranMiquel Jul 9, 2024
fbecc31
Merge branch 'uniprot_knn' of github.com:bertranMiquel/challenge-icml…
bertranMiquel Jul 9, 2024
48aeaa4
Pointcloud to graph lifting. First try.
bertranMiquel Jul 9, 2024
78c91fe
Rectify test data import
bertranMiquel Jul 9, 2024
f924f79
Rectify tests
bertranMiquel Jul 9, 2024
9ea1b6b
Rectify test results
bertranMiquel Jul 9, 2024
84cfe5b
clone vectors in angle
bertranMiquel Jul 9, 2024
fbf6a7b
Minor corrections
bertranMiquel Jul 9, 2024
c85b1b8
Correct last assert
bertranMiquel Jul 9, 2024
bd7d0d8
clone to copy adjust
bertranMiquel Jul 9, 2024
860c0e2
ruff import adjust
bertranMiquel Jul 9, 2024
30025fb
Add feature_lifting
bertranMiquel Jul 9, 2024
dea738e
Remove cell
bertranMiquel Jul 9, 2024
ebbe527
Last try
bertranMiquel Jul 9, 2024
4ff6054
Clean tutorials
bertranMiquel Jul 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions configs/datasets/UniProt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
data_domain: pointcloud
data_type: UniProt
data_name: UniProt
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Some parameters to do the query
query: "length:[95 TO 155]" # number of residues per protein
format: "tsv"
fields: "accession,length"
size: 20 # number of proteins to load

# Dataset parameters
num_features: 20
num_classes: 1
task: regression
loss_type: mse
monitor_metric: mae
task_level: graph

12 changes: 12 additions & 0 deletions configs/datasets/manual_prot_pointcloud.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: manual_prot
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: node
6 changes: 6 additions & 0 deletions configs/models/graph/graphsage.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
in_channels_0: null # This will be set by the dataset
in_channels_1: null # This will be set by the dataset
in_channels_2: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
8 changes: 8 additions & 0 deletions configs/transforms/liftings/pointcloud2graph/knn_lifting.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
transform_type: 'lifting'
transform_name: "PointCloudKNNLifting"
max_cell_length: null
preserve_edge_attr: False
feature_lifting: ProjectionSum

k_value: 10
loop: False
296 changes: 296 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import random

import numpy as np
import requests
import rootutils
import torch
import torch_geometric
from Bio import PDB
from omegaconf import DictConfig

from modules.data.load.base import AbstractLoader
Expand All @@ -12,6 +16,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_manual_prot_pointcloud,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +209,294 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

#######################################################################
############## Auxiliar functions for loading UniProt data ############
#######################################################################

def fetch_uniprot_ids(self) -> list[dict]:
r"""Fetch UniProt IDs by its API under the parameters specified in the configuration file."""
query_url = "https://rest.uniprot.org/uniprotkb/search"
params = {
"query": self.parameters.query,
"format": self.parameters.format,
"fields": self.parameters.fields,
"size": self.parameters.size
}

response = requests.get(query_url, params=params)
if response.status_code != 200:
print(f"Failed to fetch data from UniProt. Status code: {response.status_code}")
return []

data = response.text.strip().split("\n")[1:]
proteins = [{"uniprot_id": row.split("\t")[0], "sequence_length": int(row.split("\t")[1])} for row in data]

# Ensure we have at least the required proteins to sample from
if len(proteins) >= self.parameters.size:
sampled_proteins = random.sample(proteins, self.parameters.size)
else:
print(f"Only found {len(proteins)} proteins within the specified length range. Returning all available proteins.")
sampled_proteins = proteins

# save sampled proteins to a csv file
# create directory if not exist
os.makedirs(self.data_dir, exist_ok=True)
with open(self.data_dir + "/uniprot_ids.csv", "w") as file:
for protein in sampled_proteins:
file.write(f"{protein}\n")

return sampled_proteins

def fetch_protein_mass(
self, uniprot_id : str
) -> float:
r"""Returns the mass of a protein given its UniProt ID.
This will be used as our target variable.

Parameters
----------
uniprot_id : str
The UniProt ID of the protein.

Returns
-------
float
The mass of the protein.
"""
url = f"https://www.ebi.ac.uk/proteins/api/proteins/{uniprot_id}"
response = requests.get(url, headers={"Accept": "application/json"})
if response.status_code == 200:
data = response.json()
return data.get("sequence", {}).get("mass")
return None

def fetch_alphafold_structure(
self, uniprot_id : str
) -> str:
r"""Fetches the AlphaFold structure for a given UniProt ID.
Not all the proteins have a structure available.
This ones will be descarded.

Parameters
----------
uniprot_id : str
The UniProt ID of the protein.

Returns
-------
str
The path to the downloaded PDB file.
"""
pdb_dir = self.data_dir + "/pdbs"
os.makedirs(pdb_dir, exist_ok=True)
file_path = os.path.join(pdb_dir, f"{uniprot_id}.pdb")

if os.path.exists(file_path):
print(f"PDB file for {uniprot_id} already exists.")
else:
url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.pdb"
response = requests.get(url)
if response.status_code == 200:
with open(file_path, "w") as file:
file.write(response.text)
print(f"PDB file for {uniprot_id} downloaded successfully.")
else:
print(f"Failed to fetch the structure for {uniprot_id}. Status code: {response.status_code}")
return None
return file_path

def parse_pdb(
self, file_path : str
) -> PDB.Structure:
r"""Parse a PDB file and return a BioPython structure object.

Parameters
----------
file_path : str
The path to the PDB file.

Returns
-------
PDB.Structure
The BioPython structure object.
"""

return PDB.PDBParser(QUIET=True).get_structure("alphafold_structure", file_path)

def residue_mapping(
self, uniprot_ids : list[str]
) -> dict:
r"""Create a mapping of residue types to unique integers.
Each residue type will be represented as a one unique integer.
There are 20 standard amino acids, so we will have 20 unique integers (at maximum).

Parameters
----------
uniprot_ids : list[str]
The list of UniProt IDs to process.

Returns
-------
dict
The mapping of residue types to unique integers.
"""

residue_map = {}
residue_counter = 0

# First pass: determine unique residue types
for uniprot_id in uniprot_ids:
pdb_file = self.fetch_alphafold_structure(uniprot_id)
if pdb_file:
structure = self.parse_pdb(pdb_file)
residues = [residue for model in structure for chain in model for residue in chain]
for residue in residues:
residue_type = residue.get_resname()
if residue_type not in residue_map:
residue_map[residue_type] = residue_counter
residue_counter += 1
return residue_map

def calculate_residue_ca_distances_and_vectors(
self, structure : PDB.Structure
):
r"""Calculate the distances between the alpha carbon atoms of the residues.
Also, calculate the vectors between the alpha carbon and beta carbon atoms of each residue.

Parameters
----------
structure : PDB.Structure
The BioPython structure object.

Returns
-------
list
The list of residues.
dict
The dictionary of alpha carbon coordinates.
dict
The dictionary of beta carbon vectors.
np.ndarray
The matrix of distances between the residues.
"""

residues = [residue for model in structure for chain in model for residue in chain]
ca_coordinates = {}
cb_vectors = {}
residue_keys = []

for residue in residues:
if "CA" in residue:
ca_coord = residue["CA"].get_coord()
residue_type = residue.get_resname()
residue_number = residue.get_id()[1]
key = f"{residue_type}_{residue_number}"
ca_coordinates[key] = ca_coord
cb_vectors[key] = residue["CB"].get_coord() - ca_coord if "CB" in residue else None
residue_keys.append(key)

return ca_coordinates, cb_vectors, residue_keys

def save_point_cloud(self, ca_coordinates, cb_vectors, file_path):
data = []
for key, ca_coord in ca_coordinates.items():
cb_vector = cb_vectors[key] if key in cb_vectors else np.zeros(3)
if cb_vector is None:
cb_vector = np.zeros(3)
data.append({
"residue_id": key,
"x": ca_coord[0],
"y": ca_coord[1],
"z": ca_coord[2],
"cb_x": cb_vector[0],
"cb_y": cb_vector[1],
"cb_z": cb_vector[2]
})

# Save data
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "w") as file:
file.write("residue_id,x,y,z,cb_x,cb_y,cb_z\n")
for row in data:
file.write(f"{row['residue_id']},{row['x']},{row['y']},{row['z']},{row['cb_x']},{row['cb_y']},{row['cb_z']}\n")


def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load point cloud dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""

root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])

self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])
if self.parameters.data_name in ["UniProt"]:
datasets = []
protein_data = self.fetch_uniprot_ids()
uniprot_ids = [protein["uniprot_id"] for protein in protein_data]
residue_map = self.residue_mapping(uniprot_ids)

for uniprot_id in uniprot_ids:
pdb_file = self.fetch_alphafold_structure(uniprot_id)
y = self.fetch_protein_mass(uniprot_id)

if pdb_file and y:
structure = self.parse_pdb(pdb_file)
ca_coordinates, cb_vectors, residue_keys = self.calculate_residue_ca_distances_and_vectors(structure)
point_cloud_file = os.path.join(self.data_dir, "point_cloud", f"{uniprot_id}.csv")
self.save_point_cloud(ca_coordinates, cb_vectors, point_cloud_file)

# Create one-hot residues
one_hot_residues = []
for res_id in residue_keys:
res_type = res_id.split("_")[0]
one_hot = torch.zeros(len(residue_map))
one_hot[residue_map[res_type]] = 1
one_hot_residues.append(one_hot)

x = torch.stack(one_hot_residues)
pos_np = np.array([ca_coordinates[res_id] for res_id in residue_keys])
pos = torch.tensor(pos_np, dtype=torch.float)

node_attr = [None if cb_vectors[res_id] is None else cb_vectors[res_id] for res_id in residue_keys]

data = torch_geometric.data.Data(
x=x,
pos=pos,
node_attr=node_attr,
y=y,
uniprot_id=uniprot_id
)

datasets.append(data)

dataset = CustomDataset(datasets, self.data_dir)

elif self.parameters.data_name in ["manual_prot"]:
data = load_manual_prot_pointcloud()
dataset = CustomDataset([data], self.data_dir)
else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
)
return dataset

Loading