Source code for promptbind.utils.inference_pdb_utils
# preprocess pdb file
from Bio.PDB import PDBParser
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB.PDBIO import Select
import numpy as np
import esm
from tqdm import tqdm
import torch
three_to_one = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H',
'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q',
'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
[docs]
def get_protein_structure(res_list):
# protein feature extraction code from https://github.com/drorlab/gvp-pytorch
# ensure all res contains N, CA, C and O
res_list = [res for res in res_list if (('N' in res) and ('CA' in res) and ('C' in res) and ('O' in res))]
# construct the input for ProteinGraphDataset
# which requires name, seq, and a list of shape N * 4 * 3
structure = {}
structure['name'] = "placeholder"
structure['seq'] = "".join([three_to_one.get(res.resname) for res in res_list])
coords = []
for res in res_list:
res_coords = []
for atom in [res['N'], res['CA'], res['C'], res['O']]:
res_coords.append(list(atom.coord))
coords.append(res_coords)
structure['coords'] = coords
return structure
[docs]
def get_clean_res_list(res_list, verbose=False, ensure_ca_exist=False, bfactor_cutoff=None):
clean_res_list = []
for res in res_list:
hetero, resid, insertion = res.full_id[-1]
if hetero == ' ':
if res.resname not in three_to_one:
if verbose:
print(res, "has non-standard resname")
continue
if (not ensure_ca_exist) or ('CA' in res):
if bfactor_cutoff is not None:
ca_bfactor = float(res['CA'].bfactor)
if ca_bfactor < bfactor_cutoff:
continue
clean_res_list.append(res)
else:
if verbose:
print(res, res.full_id, "is hetero")
return clean_res_list
[docs]
def extract_protein_structure(path):
parser = PDBParser(QUIET=True)
s = parser.get_structure("x", path)
res_list = get_clean_res_list(s.get_residues(), verbose=False, ensure_ca_exist=True)
sturcture = get_protein_structure(res_list)
return sturcture
[docs]
def extract_esm_feature(protein):
device = "cuda" if torch.cuda.is_available() else "cpu"
letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19,
'N': 2, 'Y': 18, 'M': 12}
num_to_letter = {v:k for k, v in letter_to_num.items()}
# Load ESM-2 model with different sizes
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
model.to(device)
batch_converter = alphabet.get_batch_converter()
model.eval() # disables dropout for deterministic results
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
("protein1", protein['seq']),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
# batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
batch_tokens = batch_tokens.to(device)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33])
token_representations = results["representations"][33][0][1: -1]
assert token_representations.shape[0] == len(protein['seq'])
return token_representations