Source code for miniworld.utils.ProteinClass

import numpy as np
import torch
import os
from collections import OrderedDict
import pickle
import string
import gzip

from miniworld.utils.chemical import AA2num, num2AA, aa2num, aa2long

[docs] class ProteinStructure(): """This class is used to store the protein structure information:: M : Number of model N : Number of residue xyz : (M, L, 14, 3) np.array or list or tensor. 14 is the number of heavy atom and # ... (중략) ... position_mask : (M, L) np.array or list or tensor, 1 for confident position and 0 for missing position. """ def __init__(self, xyz, chain_break, atom_mask, position_mask = None, has_multiple_chains=False, has_multiple_models=False, ): assert isinstance(xyz, (np.ndarray, list, torch.Tensor)), "xyz must be np.ndarray, list or torch.Tensor" assert isinstance(atom_mask, (np.ndarray, list, torch.Tensor)), "atom_mask must be np.ndarray, list or torch.Tensor" assert isinstance(position_mask, (np.ndarray, list, torch.Tensor)) or position_mask is None, "position_mask must be np.ndarray, list or torch.Tensor" if isinstance(xyz, (np.ndarray, list)): xyz = torch.tensor(xyz) if isinstance(atom_mask, (np.ndarray, list)): atom_mask = torch.tensor(atom_mask) if len(xyz.shape) == 3: xyz = torch.unsqueeze(xyz, 0) # (1, L, 14, 3) torch.Tensor if len(atom_mask.shape) == 2: atom_mask = torch.unsqueeze(atom_mask, 0) # (1, L, 14) torch.Tensor if position_mask is None: position_mask = torch.ones((xyz.shape[0], xyz.shape[1])) # (M, L) torch.Tensor self.xyz = xyz # (M, L, 14, 3) torch.Tensor self.atom_mask = atom_mask # (M, L, 14) torch.Tensor self.position_mask = position_mask # (M, L) torch.Tensor self.chain_break = chain_break # dict or OrderdDict with Ex) chain_break = {A: (0,197), B:(198,356), ... } self.has_multiple_chains = has_multiple_chains # True if this protein has multiple chains self.has_multiple_models = has_multiple_models # True if this protein has multiple models def __repr__(self): return f'xyz shape : {self.xyz.shape}, mask shape : {self.atom_mask.shape}, chain_break : {self.chain_break}'
[docs] def get_chain_length_by_ID(self,chain_id): if chain_id not in self.chain_break.keys(): raise KeyError("Chain ID {} is not in this structure".format(chain_id)) return self.chain_break[chain_id][1] - self.chain_break[chain_id][0]
[docs] def get_structure_by_model(self, model_id): assert self.has_multiple_models or model_id == 0, "This protein has only one model" model_xyz = self.xyz[model_id] model_atom_mask = self.atom_mask[model_id] model_position_mask = self.position_mask[model_id] chain_break = self.chain_break[model_id] return model_xyz, model_atom_mask, model_position_mask, chain_break
[docs] def get_chain_structure_by_ID(self, chain_id): assert self.has_multiple_chains or chain_id != "A", "This protein has only one chain" chain_start = self.chain_break[chain_id][0] chain_end = self.chain_break[chain_id][1] chain_xyz = self.xyz[:, chain_start:chain_end] chain_mask = self.atom_mask[:, chain_start:chain_end] return chain_xyz, chain_mask
[docs] class ProteinSequence(): """This class is used to store the protein sequence information:: M : Number of model L : Number of residue sequence : (M, L) torch.Tensor or str. I strongly recommend to use torch.Tensor masked_sequence : (M, L) torch.Tensor. chain_break : (M, ) list of dictionary or OrderDict. Ex) M=1, chain_break[0] = {A: (0,197), B:(198,356), ... } sequence = Full sequence without None masked_sequence = Masked sequence with None """ def __init__(self, sequence, masked_sequence = None, chain_break = None, ): assert isinstance(sequence, (str, torch.Tensor)), "sequence must be string or torch.Tensor" assert isinstance(masked_sequence, (str, torch.Tensor)) or masked_sequence is None, "maksed_sequence must be string or torch.Tensor" if isinstance(sequence, str): sequence = torch.tensor([AA2num[aa] for aa in sequence]) # (L, ) torch.Tensor sequence = torch.unsqueeze(sequence, 0) # (1, L) torch.Tensor self.sequence = sequence # str or torch.Tensor with (M, L) shape; M is num of Model and L is sequence length. if masked_sequence is None : masked_sequence = sequence self.masked_sequence = masked_sequence # str or torch.Tensor with (M, L) shape; M is num of Model and L is sequence length. if chain_break is not None : self.chain_break = chain_break # list of dictionary. Ex) chain_break[0] = {A: (0,197), B:(198,356), ... } else : self.chain_break = OrderedDict() self.chain_break["A"] = (0, self.get_sequence_length()-1)
[docs] def get_sequence_length(self): if isinstance(self.sequence, str): return len(self.sequence) elif isinstance(self.sequence, torch.Tensor): return self.sequence.shape[1]
def __repr__(self): if isinstance(self.sequence, str): return str(f'{self.sequence}') elif isinstance(self.sequence, torch.Tensor): return str(f'shape : {self.sequence.shape}') else : raise TypeError("sequence must be string or torch.Tensor")
[docs] def get_sequence_by_chain_ID(self, chain_id): if chain_id not in self.chain_list: raise KeyError("Chain ID {} is not in this sequence".format(chain_id)) chain_start = self.chain_break[chain_id][0] chain_end = self.chain_break[chain_id][1] return self.sequence[chain_start:chain_end+1]
[docs] class Protein(): def __init__(self, sequence, structure, occupancy = None, symmetry_related_info = None, ID="----", IS_LABEL=False, # If this protein from PDB, IS_LABEL is True source = "", ): assert isinstance(sequence, ProteinSequence) or isinstance(sequence, str) or sequence is None, f"sequence must be torch.tensor or string or None, but {type(sequence)} is given" assert isinstance(structure, ProteinStructure) or structure is None, f"structure must be ProteinStructure or None, but {type(structure)} is given" self.ID = ID self.sequence = sequence self.structure = structure self.occupancy = occupancy self.symmetry_related_info = self.get_symmetry_related_info(symmetry_related_info) # dictionary self.chain_break = structure.chain_break self.chain_list = list(self.chain_break.keys()) self.IS_LABEL = IS_LABEL self.source = source def __extract__(residue_idx_list): pass def __repr__(self): output = '\n' output += '-'*15 + ' Protein Class ' + '-'*15+'\n' output += f'ID : {self.ID}\nSource : {self.source}\nIS_LABEL : {self.IS_LABEL}\nSequence\t| {self.sequence.__repr__()}\nStructure\t| {self.structure.__repr__()}\n' if self.symmetry_related_info is not None : output += 'Symmetry Related Info\n' for key, value in self.symmetry_related_info.items(): for value_ in value: for key_in_value, value_in_value in value_.items(): if len(str(value_in_value)) < 8 : output += f'{key}.{key_in_value} : {value_in_value}\n' else : output += f'{key}.{key_in_value} : \n {value_in_value}\n' output += '-'*45 return output
[docs] def get_sequence_by_chain_ID(self, chain_id): if chain_id not in self.chain_list: raise KeyError("Chain ID {} is not in this protein".format(chain_id)) chain_start = self.chain_break[chain_id][0] chain_end = self.chain_break[chain_id][1] return self.sequence[chain_start:chain_end]
[docs] def get_structure_by_chain_ID(self, chain_id): if chain_id not in self.chain_list: raise KeyError("Chain ID {} is not in this protein".format(chain_id)) chain_start = self.chain_break[chain_id][0] chain_end = self.chain_break[chain_id][1] return self.structure[chain_start:chain_end]
[docs] def save(self,saving_directory): """ Save file using pickle """ if not os.path.exists(saving_directory): os.makedirs(saving_directory) if self.ID == "----": raise ValueError("ID must be changed") with open(os.path.join(saving_directory, self.ID+".pkl"), 'wb') as f: pickle.dump(self, f)
[docs] class ProteinTemplate(Protein): """ position : (M, L_template) np.array or list or tensor. L is the number of model and N_template is the number of residue in template. It have to be less than L. """ def __init__(self, sequence, structure, position = None, template_ID = None, f0d=None, f1d=None, ): super().__init__(sequence, structure) assert isinstance(position, torch.Tensor) or position is None, "position must be torch.Tensor" assert isinstance(f0d, torch.Tensor) or f0d is None, "f0d must be torch.Tensor or None" assert isinstance(f1d, torch.Tensor) or f1d is None, "f1d must be torch.Tensor or None" if position is None : # It means that all residue position is confident protein_length = sequence.get_sequence_length() position = torch.arange(protein_length) else: self.position = position self.template_ID = template_ID self.f0d = f0d self.f1d = f1d
[docs] @classmethod def from_dict(cls, dictionary): sequence = ProteinSequence(dictionary["sequence"]) dict_for_structure = {} dict_for_structure["xyz"] = dictionary["xyz"] dict_for_structure["atom_mask"] = dictionary["mask"] dict_for_structure["chain_break"] = sequence.chain_break structure = ProteinStructure(**dict_for_structure) new_dict = {} new_dict["sequence"] = sequence new_dict["structure"] = structure new_dict["position"] = dictionary["position"] new_dict["template_ID"] = dictionary["template_ID"] new_dict["f0d"] = dictionary["f0d"] new_dict["f1d"] = dictionary["f1d"] return cls(**new_dict)
[docs] @classmethod def from_dict2(cls, dictionary): sequence = dictionary["sequence"] structure = dictionary["structure"] new_dict = {} new_dict["sequence"] = sequence new_dict["structure"] = structure new_dict["position"] = dictionary["position"] new_dict["template_ID"] = dictionary["template_ID"] new_dict["f0d"] = dictionary["f0d"] new_dict["f1d"] = dictionary["f1d"] return cls(**new_dict)
def __repr__(self): fod_output = "None" if self.f0d is None else self.f0d.shape f1d_output = "None" if self.f1d is None else self.f1d.shape return f'Sequence : {self.sequence.__repr__()}, Structure : {self.structure.__repr__()}, Position : {self.position.shape}, f0d : {fod_output}, f1d : {f1d_output}'
[docs] class ProteinTemplates(): """ ProteinTemplates : list of ProteinTemplate """ def __init__(self, templates = [], template_pdb_list = None, hhr_file = None, hhr_file_path = None, templates_ID = None, ): assert isinstance(hhr_file, torch.Tensor) or hhr_file is None, "hhr_file must be torch.Tensor or None" assert isinstance(templates, list) or templates is None, "templates must be list" self.templates = templates self.templates_ID = templates_ID if len(self.templates) != 0: self.template_ID_list = [template.template_ID for template in templates] if template_pdb_list is not None: self.from_pdb_file(template_pdb_list) elif hhr_file is not None: hhr_hash = hhr_file.split("/")[-1].split(".")[:-1] hhr_hash = ".".join(hhr_hash) self.from_hhr_file(hhr_file, hhr_hash) elif hhr_file_path is not None: hhr_hash = hhr_file_path.split("/")[-1].split(".")[:-1] hhr_hash = ".".join(hhr_hash) hhr_file = torch.load(hhr_file_path) self.from_hhr_file(hhr_file, hhr_hash) def __repr__(self) -> str: output_str = "" for ii, template in enumerate(self.templates): if ii>3: output_str += "...\n" break output_str += template.__repr__() + "\n" return output_str
[docs] def from_hhr_file(self, hhr_file, hhr_hash): if len(self.templates) != 0: raise ValueError("templates must be empty") self.templates_ID = hhr_hash # hhr file is list of dictionary # [template 1 dict, template 2 dict, ...] templates = [] for template_dict in hhr_file: templates.append(ProteinTemplate.from_dict(template_dict)) self.templates = templates
[docs] def from_pdb_file(self, template_pdb_list): if len(self.templates) != 0: raise ValueError("templates must be empty") templates = [] for template_pdb in template_pdb_list: template, lddt = PDB_parsing(template_pdb, return_lddt = True) template_sequence = template.sequence template_structure = template.structure template_ID = template.ID template_dict= { "sequence" : template_sequence.sequence, "xyz" : template_structure.xyz, "mask" : template_structure.atom_mask, "chain_break" : template_structure.chain_break, "f0d" : torch.tensor([0.0]), "f1d" : lddt/100, "position" : None, "template_ID" : template_ID, } templates.append(ProteinTemplate.from_dict(template_dict)) self.templates = templates
# def get_template_by_ID(self, template_ID): # if template_ID not in self.template_list: # raise KeyError("Template ID {} is not in this protein".format(template_ID)) # return self.templates[template_ID]
[docs] def filter_template_by_sequence_identity(self, sequence_identity): pass
# filtered_templates = [] # for template in self.templates: # if template.sequence_identity >= sequence_identity: # filtered_templates.append(template) # self.templates = filtered_templates # self.template_ID_list = [template.template_ID for template in filtered_templates]
[docs] def save(self, saving_directory): if not os.path.exists(saving_directory): os.makedirs(saving_directory) with open(os.path.join(saving_directory, self.templates_ID+".pkl"), 'wb') as f: pickle.dump(self, f)
[docs] class ProteinMSA(): """Represents the Multiple Sequence Alignment (MSA) of a Protein:: msa_ID : str query_sequence : str or torch.tensor(1, L) or (L,) msa(_tensor) : torch.tensor(N, L); N is MSA depth insertion(_tensor) : torch.tensor(N, L) a3m_file_path : str for load this object from a3m file """ def __init__(self, msa_ID = None, query_sequence = None, msa_tensor = None, insertion_tensor = None, a3m_file_path = None # it is None there is no MSA. ): assert isinstance(msa_tensor, torch.Tensor) or msa_tensor is None, "msa_tensor must be torch.Tensor or None" assert isinstance(insertion_tensor, torch.Tensor) or insertion_tensor is None, "insertion_tensor must be torch.Tensor or None" if msa_ID is None : if msa_tensor is not None: msa_ID = msa_tensor.split("/")[-1].split(".")[:-1] msa_ID = ".".join(msa_ID) elif a3m_file_path is not None: msa_ID = a3m_file_path.split("/")[-1].split(".")[:-1] msa_ID = ".".join(msa_ID) if isinstance(query_sequence, str): query_sequence = torch.tensor([AA2num[aa] for aa in query_sequence]) self.msa_ID = msa_ID self.query_sequence = query_sequence self.chain_break = None if msa_tensor is not None: self.msa = msa_tensor if insertion_tensor is not None: self.insertion = insertion_tensor if a3m_file_path is not None: self.from_a3m_file_path(a3m_file_path) def __repr__(self): return f'MSA ID : {self.msa_ID}, Query Sequence : {self.query_sequence.shape}, MSA : {self.msa.shape}, Insertion : {self.insertion.shape}'
[docs] def from_msa_tensor(self, msa_tensor): self.msa = msa_tensor query_sequence = msa_tensor[0] self.query_sequence = query_sequence
[docs] def from_a3m_file_path(self, a3m_file_path, max_seq = 5000): # Code from RF2 parsers.py """ Input : a3m_file_path (.a3m or .a3m.gz) """ msa = [] insertion_list = [] table = str.maketrans(dict.fromkeys(string.ascii_lowercase)) #print(filename) if a3m_file_path.split('.')[-1] == 'gz': f = gzip.open(a3m_file_path, 'rt') else: f = open(a3m_file_path, 'r') # read file line by line chain_break = {} for line in f: if line[0] == '#': # skip # line continue # skip labels if line[0] == '>': continue if chain_break == {} and ':' in line: chains = line.strip().split(':') chain_start = 0 for chain_idx, chain in enumerate(chains): chain_end = chain_start + len(chain) - 1 chain_break[chain_idx] = (chain_start, chain_end) chain_start = chain_end + 1 self.chain_break = chain_break # remove : line = line.replace(':','') # remove right whitespaces line = line.rstrip() if len(line) == 0: continue # remove lowercase letters and append to MSA msa_i = line.translate(table) msas_i = msa_i.split('/') if (len(msa)==0): # first seq Ls = [len(x) for x in msas_i] msa = [[s] for s in msas_i] else: nchains = len(msas_i) isgood = all([ len(msa[i][0]) == len(msas_i[i]) for i in range(nchains) ]) if isgood: for i in range(nchains): msa[i].append(msas_i[i]) else: raise ValueError("Len error", a3m_file_path, len(msa[0]), self.msa_ID ) # sequence length L = sum(Ls) # 0 - match or gap; 1 - insertion lower_case = np.array([0 if c.isupper() or c=='-' else 1 for c in line]) insertion = np.zeros((L)) if np.sum(lower_case) > 0: # positions of insertions pos = np.where(lower_case==1)[0] # shift by occurrence lower_case = pos - np.arange(pos.shape[0]) # position of insertions in cleaned sequence # and their length pos,num = np.unique(lower_case, return_counts=True) # append to the matrix of insetions insertion[pos] = num insertion_list.append(insertion) # concatenate msa = [np.array([list(s) for s in t], dtype='|S1').view(np.uint8) for t in msa] msa = np.concatenate(msa,axis=-1) query_sequence = msa[0] # convert letters into numbers alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) for i in range(alphabet.shape[0]): msa[msa == alphabet[i]] = i # treat all unknown characters as gaps msa[msa > 20] = 20 insertion_list = np.array(insertion_list, dtype=np.uint8) query_sequence = torch.tensor(query_sequence) # (L, ) torch.Tensor msa = torch.tensor(msa) # (N, L) torch.Tensor insertion = torch.tensor(insertion_list) # (N, L) torch.Tensor self.msa = msa self.insertion = insertion self.query_sequence = query_sequence
[docs] def save(self, saving_directory): if not os.path.exists(saving_directory): os.makedirs(saving_directory) with open(os.path.join(saving_directory, self.msa_ID+"_msas.pkl"), 'wb') as f: pickle.dump(self, f)
[docs] def print_query(self): print("".join([num2AA[aa] for aa in self.query_sequence]))
[docs] def PDB_parsing(PDB_file, IS_LABEL = True, return_lddt = False): """ Input : PDB file path, IS_LABEL (bool) Output : Protein object This function handle multi-model PDB file as well as single-model PDB file. Also this function handle multiple chains. And model idx start from 0 in this function, but PDB model idx start from 1. So I subtract 1 from model idx. This function parse PDB file and return a object of Protein class. I assume that all chain break is same in each model. """ if "_unrelaxed" in PDB_file: # for AF output PDB_ID = PDB_file.split("/")[-1].split("_unrelaxed")[0] else: PDB_ID = PDB_file.split("/")[-1].split(".")[:-1] PDB_ID = ".".join(PDB_ID) lines = open(PDB_file, "r").readlines() lines = [line.strip() for line in lines] has_multiple_models= False for line in lines: if line.startswith("MODEL"): has_multiple_models = True break if has_multiple_models: models = {} for line in lines: if line.startswith("MODEL"): model_idx = int(line.split()[1]) - 1 models[model_idx] = [] elif line.startswith("ATOM"): models[model_idx].append(line) else : models = {0:[]} for line in lines: if line.startswith("ATOM"): models[0].append(line) xyz = [] mask = [] seq = [] lddt = [] chain_break = OrderedDict() for ii, model_idx in enumerate(models.keys()): xyz_model = [] mask_model = [] seq_model = [] lddt_model = [] if ii == 0 : before_chain = "-" chain_start = -1 for line in models[model_idx]: residue = line[17:20] residue_num = aa2num[residue] atom_list = aa2long[residue_num] # 0~13 : Heavy atom, 14~26 : Hydrogen atom (what is it?) Heavy_atom_list = atom_list[:14] # Hydrogen_atom_list = atom_list[14:] atom = line[12:16] chain = line[20:22] # for large PDB bfactor = float(line[60:66]) if atom == " N " or atom.strip() == "N": seq_model.append(residue_num) lddt_model.append(bfactor) # append xyz_model to empty 14,3 list if ii == 0 : chain_start += 1 if chain != before_chain: chain_key = chain.replace(" ", "") if len(chain_break) == 0: chain_break[chain_key] = [chain_start] else : before_chain_key = before_chain.replace(" ", "") chain_break[before_chain_key].append(chain_start-1) chain_break [chain_key] = [] chain_break[chain_key].append(chain_start) xyz_empty_array = [[0.0,0.0,0.0] for _ in range(14)] mask_empty_array = [0 for _ in range(14)] xyz_model.append(xyz_empty_array) mask_model.append(mask_empty_array) if atom in Heavy_atom_list: atom_idx = Heavy_atom_list.index(atom) x = float(line[30:38]) y = float(line[38:46]) z = float(line[46:54]) try: xyz_model[-1][atom_idx] = [x,y,z] except : breakpoint() mask_model[-1][atom_idx] = 1 if ii == 0 : before_chain = chain if ii == 0 : chain_break[chain_key].append(chain_start) #------------------------------------------------------------------------------------ # Remove Hexa Histine and Penta Histidine, 8 is a index of Histidine # TODO # PRECAUTION!!! If there is hexa-histidine on complex, this code doesn't work properly. # first_five = np.array(seq_model[:5]) # last_five = np.array(seq_model[-5:]) # if np.all(first_five == 8) : # if seq_model[5] != 8: # if 6th residue is not histidine # seq_model = seq_model[5:] # xyz_model = xyz_model[5:] # mask_model = mask_model[5:] # if ii == 0 : # chain_break = {key : [value - 5 for value in chain_break[key]] for key in chain_break.keys()} # else : # if 6th residue is histidine # seq_model = seq_model[6:] # xyz_model = xyz_model[6:] # mask_model = mask_model[6:] # if ii == 0 : # chain_break = {key : [value - 6 for value in chain_break[key]] for key in chain_break.keys()} # if np.all(last_five == 8) : # if seq_model[-6] != 8: # seq_model = seq_model[:-5] # xyz_model = xyz_model[:-5] # mask_model = mask_model[:-5] # if ii == 0 : # last_chain = list(chain_break.keys())[-1] # chain_break[last_chain][1] -= 5 # else : # seq_model = seq_model[:-6] # xyz_model = xyz_model[:-6] # mask_model = mask_model[:-6] # if ii == 0 : # last_chain = list(chain_break.keys())[-1] # chain_break[last_chain][1] -= 6 #------------------------------------------------------------------------------------ xyz.append(xyz_model) mask.append(mask_model) seq.append(seq_model) lddt.append(lddt_model) xyz = torch.tensor(xyz) # (M, L, 14, 3) mask = torch.tensor(mask) # (M, L, 14) seq = torch.tensor(seq) # (M, L) lddt = torch.tensor(lddt) # (M, L) has_multiple_models = True if xyz.shape[0] > 1 else False has_multiple_chains = True if len(chain_break.keys()) > 1 else False new_chain_break = OrderedDict() for key in chain_break.keys(): new_chain_break[key] = (chain_break[key][0], chain_break[key][1]) chain_break = new_chain_break Structure = ProteinStructure(xyz, atom_mask = mask, position_mask = None, chain_break = chain_break, has_multiple_models = has_multiple_models, has_multiple_chains = has_multiple_chains) Sequence = ProteinSequence(sequence = seq, masked_sequence = None, chain_break = chain_break) Protein_object = Protein(sequence = Sequence, structure = Structure, ID = PDB_ID, IS_LABEL=IS_LABEL) if return_lddt: return Protein_object, lddt else: return Protein_object