Source code for miniworld.utils.My_mistake

import torch
from miniworld.utils.chemical import *

[docs] def MSA_change_idx(protein_msa, UNK_IDX = 20, GAP_IDX = 21): # change 20 -> 21 query = protein_msa.query_sequence msa_ = protein_msa.msa[1:] msa_[msa_ == UNK_IDX] = GAP_IDX msa_ = torch.cat([query.unsqueeze(0), msa_], dim=0) protein_msa.query_sequence = query protein_msa.msa = msa_ return protein_msa
[docs] def MSA_change_idx_v2(protein_msa, chain_break, UNK_IDX = 20, GAP_IDX=21): # change 20 -> 21 query = protein_msa.query_sequence msa_ = protein_msa.msa[1:] new_msa = [] for chain_idx, (chain_start, chain_end) in chain_break.items(): msa_chain = msa_[:,chain_start:chain_end+1] # (N, L_chain) all_20_mask = msa_chain == UNK_IDX # (N', L_chain) all_20_mask = all_20_mask.all(dim=1) # (N') not_all_20_mask = ~all_20_mask # (N2) msa_to_change = msa_chain[not_all_20_mask] # (N', L_chain) msa_to_change[msa_to_change == UNK_IDX] = GAP_IDX msa_chain[not_all_20_mask] = msa_to_change new_msa.append(msa_chain) new_msa = torch.cat(new_msa, dim=1) msa_ = torch.cat([query.unsqueeze(0), new_msa], dim=0) protein_msa.query_sequence = query protein_msa.msa = msa_ return protein_msa