Source code for miniworld.utils.My_mistake
import torch
from miniworld.utils.chemical import *
[docs]
def MSA_change_idx(protein_msa, UNK_IDX = 20, GAP_IDX = 21):
# change 20 -> 21
query = protein_msa.query_sequence
msa_ = protein_msa.msa[1:]
msa_[msa_ == UNK_IDX] = GAP_IDX
msa_ = torch.cat([query.unsqueeze(0), msa_], dim=0)
protein_msa.query_sequence = query
protein_msa.msa = msa_
return protein_msa
[docs]
def MSA_change_idx_v2(protein_msa, chain_break, UNK_IDX = 20, GAP_IDX=21):
# change 20 -> 21
query = protein_msa.query_sequence
msa_ = protein_msa.msa[1:]
new_msa = []
for chain_idx, (chain_start, chain_end) in chain_break.items():
msa_chain = msa_[:,chain_start:chain_end+1] # (N, L_chain)
all_20_mask = msa_chain == UNK_IDX # (N', L_chain)
all_20_mask = all_20_mask.all(dim=1) # (N')
not_all_20_mask = ~all_20_mask # (N2)
msa_to_change = msa_chain[not_all_20_mask] # (N', L_chain)
msa_to_change[msa_to_change == UNK_IDX] = GAP_IDX
msa_chain[not_all_20_mask] = msa_to_change
new_msa.append(msa_chain)
new_msa = torch.cat(new_msa, dim=1)
msa_ = torch.cat([query.unsqueeze(0), new_msa], dim=0)
protein_msa.query_sequence = query
protein_msa.msa = msa_
return protein_msa