Source code for promptbind.utils.inference_pdb_utils

# preprocess pdb file
from Bio.PDB import PDBParser
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB.PDBIO import Select

import numpy as np
import esm
from tqdm import tqdm
import torch

three_to_one = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 
                'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 
                'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}

[docs] def get_protein_structure(res_list): # protein feature extraction code from https://github.com/drorlab/gvp-pytorch # ensure all res contains N, CA, C and O res_list = [res for res in res_list if (('N' in res) and ('CA' in res) and ('C' in res) and ('O' in res))] # construct the input for ProteinGraphDataset # which requires name, seq, and a list of shape N * 4 * 3 structure = {} structure['name'] = "placeholder" structure['seq'] = "".join([three_to_one.get(res.resname) for res in res_list]) coords = [] for res in res_list: res_coords = [] for atom in [res['N'], res['CA'], res['C'], res['O']]: res_coords.append(list(atom.coord)) coords.append(res_coords) structure['coords'] = coords return structure
[docs] def get_clean_res_list(res_list, verbose=False, ensure_ca_exist=False, bfactor_cutoff=None): clean_res_list = [] for res in res_list: hetero, resid, insertion = res.full_id[-1] if hetero == ' ': if res.resname not in three_to_one: if verbose: print(res, "has non-standard resname") continue if (not ensure_ca_exist) or ('CA' in res): if bfactor_cutoff is not None: ca_bfactor = float(res['CA'].bfactor) if ca_bfactor < bfactor_cutoff: continue clean_res_list.append(res) else: if verbose: print(res, res.full_id, "is hetero") return clean_res_list
[docs] def extract_protein_structure(path): parser = PDBParser(QUIET=True) s = parser.get_structure("x", path) res_list = get_clean_res_list(s.get_residues(), verbose=False, ensure_ca_exist=True) sturcture = get_protein_structure(res_list) return sturcture
[docs] def extract_esm_feature(protein): device = "cuda" if torch.cuda.is_available() else "cpu" letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, 'N': 2, 'Y': 18, 'M': 12} num_to_letter = {v:k for k, v in letter_to_num.items()} # Load ESM-2 model with different sizes model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # model, alphabet = esm.pretrained.esm2_t6_8M_UR50D() model.to(device) batch_converter = alphabet.get_batch_converter() model.eval() # disables dropout for deterministic results # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) data = [ ("protein1", protein['seq']), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) # batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) batch_tokens = batch_tokens.to(device) with torch.no_grad(): results = model(batch_tokens, repr_layers=[33]) token_representations = results["representations"][33][0][1: -1] assert token_representations.shape[0] == len(protein['seq']) return token_representations