Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Protein initialisation #317 #318

Merged
merged 50 commits into from
May 11, 2023
Merged
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
176d884
add PSW to nonstandard residues
a-r-j Apr 17, 2023
fa89a37
improve insertion and non-standard residue handling
a-r-j Apr 17, 2023
9855b9b
refactor chain selection
a-r-j Apr 17, 2023
f143719
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
3f3b3d9
remove unused verbosity arg
a-r-j Apr 17, 2023
09f05e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
b7475df
fix chain selection in tests
a-r-j Apr 17, 2023
2e0a371
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j Apr 17, 2023
d2c1808
fix chain selection in tutorial notebook
a-r-j Apr 17, 2023
fc332c6
fix notebook chain selection
a-r-j Apr 17, 2023
4a67851
fix chain selection typehint
a-r-j Apr 17, 2023
5f648d2
Update changelog
a-r-j Apr 17, 2023
ab26d78
Add NLW to non-standard residues
a-r-j Apr 17, 2023
a449bba
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j Apr 17, 2023
afc0f8b
add .ent support
a-r-j Apr 20, 2023
258c94d
add entry for construction from dataframe
a-r-j Apr 20, 2023
c9856ae
add missing stage arg
a-r-j Apr 20, 2023
9e1191a
improve obsolete mapping retrieving to include entries with no replac…
a-r-j Apr 20, 2023
17c38ab
Merge branch 'master' into tensor_fixes
a-r-j Apr 20, 2023
7bf4ff3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 20, 2023
5af9e06
update changelog
a-r-j Apr 21, 2023
e00bdfb
add transforms to foldcomp datasets
a-r-j Apr 22, 2023
31018bc
fix jaxtyping syntax
a-r-j Apr 25, 2023
6e26455
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j Apr 25, 2023
3681714
Merge branch 'master' into tensor_fixes
a-r-j Apr 27, 2023
adbdbe1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2023
50ac31b
Update changelog
a-r-j Apr 27, 2023
088ae02
fix double application of transforms
a-r-j Apr 27, 2023
fb684af
improve foldcomp data loading performance
a-r-j May 1, 2023
a543a75
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j May 1, 2023
a00e2be
Merge branch 'master' into tensor_fixes
a-r-j May 1, 2023
ccf0437
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2023
7939a82
remove unused imports
a-r-j May 1, 2023
d72abf9
remove unused imports
a-r-j May 1, 2023
8b551c7
linting
a-r-j May 1, 2023
86bedcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2023
685d3db
Update changelog
a-r-j May 1, 2023
bebc3c4
add B factors to FC parsing output
a-r-j May 2, 2023
c973422
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j May 2, 2023
828af29
bugfix to alpha & kappa angle embedding
a-r-j May 7, 2023
c986df0
Merge branch 'master' into tensor_fixes
a-r-j May 7, 2023
6c48878
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2023
fc7657e
update changelog
a-r-j May 7, 2023
7192613
handle selenocysteine in sidechain torsion angle computation
a-r-j May 10, 2023
6a31729
Merge branch 'tensor_fixes' of https://www.github.com/a-r-j/graphein …
a-r-j May 10, 2023
84fc3e4
fix protein data object initialisation #317
a-r-j May 10, 2023
9dcc1c7
Merge branch 'master' into protein_obj
a-r-j May 10, 2023
f5d1f26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2023
6269d25
restore eq dunder
a-r-j May 10, 2023
d96d60f
update changelog
a-r-j May 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve foldcomp data loading performance
  • Loading branch information
a-r-j committed May 1, 2023
commit fb684af6e5c21a21341071f591665b3de32361dd
172 changes: 140 additions & 32 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,32 @@
import os
import random
import shutil
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Union

import pandas as pd
from biopandas.pdb import PandasPdb
import numpy as np

#import pandas as pd
import torch

#from biopandas.pdb import PandasPdb
from biotite.structure.io.pdb import PDBFile
from loguru import logger as log
from sklearn.model_selection import train_test_split
from torch_geometric import transforms as T
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from graphein.protein.resi_atoms import (
ATOM_NUMBERING,
RESI_THREE_TO_1,
STANDARD_AMINO_ACID_MAPPING_1_TO_3,
STANDARD_AMINO_ACIDS,
)
from graphein.protein.tensor import Protein
from graphein.protein.tensor.io import protein_to_pyg
from graphein.utils.dependencies import import_message

try:
Expand Down Expand Up @@ -67,6 +80,81 @@
GraphTransform = Callable[[Union[Data, Protein]], Union[Data, Protein]]


ATOM_MAP = {'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE'],
'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1'],
'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2'],
'ALA': ['N', 'CA', 'C', 'O', 'CB'],
'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2'],
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD'],
'ARG': ['N',
'CA',
'C',
'O',
'CB',
'CG',
'CD',
'NE',
'CZ',
'NH1',
'NH2'],
'HIS': ['N',
'CA',
'C',
'O',
'CB',
'CG',
'ND1',
'CD2',
'CE1',
'NE2'],
'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2'],
'TYR': ['N',
'CA',
'C',
'O',
'CB',
'CG',
'CD1',
'CD2',
'CE1',
'CE2',
'CZ',
'OH'],
'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2'],
'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ'],
'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2'],
'PHE': ['N',
'CA',
'C',
'O',
'CB',
'CG',
'CD1',
'CD2',
'CE1',
'CE2',
'CZ'],
'GLY': ['N', 'CA', 'C', 'O'],
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG'],
'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2'],
'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2'],
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG'],
'TRP': ['N',
'CA',
'C',
'O',
'CB',
'CG',
'CD1',
'CD2',
'NE1',
'CE2',
'CE3',
'CZ2',
'CZ3',
'CH2']}


class FoldCompDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -112,17 +200,15 @@ def __init__(
self.use_graphein = use_graphein
self.transform = transform

_database_files = [
"$db",
"$db.dbtype",
"$db.index",
"$db.lookup",
"$db.source",
]
self.database_files = [
f.replace("$db", self.database) for f in _database_files
self._database_files = [
f"{self.database}",
f"{self.database}.dbtype",
f"{self.database}.index",
f"{self.database}.lookup",
f"{self.database}.source",
]
self._get_indices()

super().__init__(
root=self.root, transform=self.transform, pre_transform=None # type: ignore
)
Expand All @@ -146,7 +232,7 @@ def processed_file_names(self):
def download(self):
"""Downloads foldcomp database if not already downloaded."""

if not all(os.path.exists(self.root / f) for f in self.database_files):
if not all(os.path.exists(self.root / f) for f in self._database_files):
log.info(f"Downloading FoldComp dataset {self.database}...")
try:
foldcomp.setup(self.database)
Expand All @@ -156,7 +242,7 @@ def download(self):
log.info("Download complete.")
log.info("Moving files to raw directory...")

for f in self.database_files:
for f in self._database_files:
shutil.move(f, self.root)
else:
log.info(f"FoldComp database already downloaded: {self.root}.")
Expand Down Expand Up @@ -203,23 +289,47 @@ def process(self):
# Open the database
log.info("Opening database...")
if self.ids is not None:
self.db = foldcomp.open(self.root / self.database, ids=self.ids) # type: ignore
self.db = foldcomp.open(self.root / self.database, ids=self.ids, decompress=False) # type: ignore
else:
self.db = foldcomp.open(self.root / self.database) # type: ignore
self.db = foldcomp.open(self.root / self.database, decompress=False) # type: ignore

@staticmethod
def _parse_dataframe(pdb_string: str) -> pd.DataFrame:
"""Reads a PDB string into a Pandas dataframe."""
pdb: List[str] = pdb_string.split("\n")
return PandasPdb().read_pdb_from_list(pdb).df["ATOM"]

def process_pdb(self, pdb_string: str, name: str) -> Union[Protein, Data]:
"""Process a PDB string into a Graphein Protein object."""
df = self._parse_dataframe(pdb_string)
data = Protein().from_dataframe(df, id=name)
if not self.use_graphein:
data = data.to_data()
return data
def fc_to_pyg(data: Dict, name: Optional[str] = None) -> Protein:
# Map sequence to 3-letter codes
res = [STANDARD_AMINO_ACID_MAPPING_1_TO_3[r] for r in data["residues"]]
residue_type = torch.tensor(
[STANDARD_AMINO_ACIDS.index(res) for res in data["residues"]],
)

# Get residue numbers
res_num = [i for i, _ in enumerate(res)]

# Get list of atom types
atom_types = []
atom_counts = []
for r in res:
atom_types += ATOM_MAP[r]
atom_counts.append(len(ATOM_MAP[r]))
atom_types += ["OXT"]
atom_counts[-1] += 1

# Get atom indices
atom_idx = np.array([ATOM_NUMBERING[atm] for atm in atom_types])

# Initialize coordinates
coords = np.ones((len(res), 37, 3)) * 1e-5

res_idx = np.repeat(res_num, atom_counts)
coords[res_idx, atom_idx, :] = np.array(data["coordinates"])

return Protein(
coords=torch.from_numpy(coords).float(),
residues=res,
residue_id=[f"A:{m}:{str(n)}" for m, n in zip(res, res_num)],
chains=torch.zeros(len(res)),
residue_type=residue_type.long(),
id=name
)

def len(self) -> int:
"""Returns length of the dataset"""
Expand All @@ -230,12 +340,10 @@ def get(self, idx) -> Union[Data, Protein]:
ID or its index."""
if isinstance(idx, str):
idx = self.protein_to_idx[idx]
name, pdb = self.db[idx]

out = self.process_pdb(pdb, name)

# Apply transforms, if any
return self.transform(out) if self.transform is not None else out
name = self.idx_to_protein[idx]
data = foldcomp.get_data(self.db[idx])
return self.fc_to_pyg(data, name)


class FoldCompLightningDataModule(L.LightningDataModule):
Expand Down