Source code for deepfold.data.multimer.input_features

import collections
import copy
import dataclasses
from typing import Iterable, List, Mapping, MutableMapping, Sequence

import numpy as np

from deepfold.common import protein
from deepfold.common import residue_constants as rc
from deepfold.data.multimer import msa_pairing
from deepfold.data.search.input_features import create_msa_features
from deepfold.data.search.msa_identifiers import get_identifiers

REQUIRED_FEATURES = frozenset(
    {
        "aatype",
        "all_atom_mask",
        "all_atom_positions",
        "all_chains_entity_ids",  #
        "all_crops_all_chains_mask",  #
        "all_crops_all_chains_positions",  #
        "all_crops_all_chains_residue_ids",  #
        "assembly_num_chains",
        "asym_id",
        "bert_mask",
        "cluster_bias_mask",
        "deletion_matrix",
        "deletion_mean",
        "entity_id",
        "entity_mask",
        "mem_peak",
        "msa",
        "msa_mask",
        "num_alignments",
        "num_templates",
        "queue_size",
        "residue_index",
        "resolution",
        "seq_length",
        "seq_mask",
        "sym_id",
        "template_aatype",
        "template_all_atom_mask",
        "template_all_atom_positions",
    }
)

MAX_TEMPLATES = 20
MSA_CROP_SIZE = 2048


[docs] def convert_monomer_features( monomer_features: dict, ) -> dict: """Reshapes and modifies monomer features for multimer models.""" converted = {} unnecessary_leading_dim_feats = {"sequence", "domain_name", "num_alignments", "seq_length"} for feature_name, feature in monomer_features.items(): if feature_name in unnecessary_leading_dim_feats: # asarray ensures it's a np.ndarray. feature = np.asarray(feature[0], dtype=feature.dtype) elif feature_name == "aatype": # The multimer model performs the one-hot operation itself. feature = np.argmax(feature, axis=-1).astype(np.int32) elif feature_name == "template_aatype": feature = np.argmax(feature, axis=-1).astype(np.int32) new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE feature = np.take(new_order_list, feature, axis=0) elif feature_name == "template_all_atom_masks": feature_name = "template_all_atom_mask" converted[feature_name] = feature return converted
[docs] def int_id_to_str_id(num: int) -> str: """Encodes a number as a string, using reverse spreadsheet style naming. Args: num: A positive integer. Returns: A string that encodes the positive integer using reverse spreadsheet style, naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the usual way to encode chain IDs in mmCIF files. """ if num <= 0: raise ValueError(f"Only positive integers allowed, got {num}.") num = num - 1 # 1-based indexing. output = [] while num >= 0: output.append(chr(num % 26 + ord("A"))) num = num // 26 - 1 return "".join(output)
[docs] def process_single_chain( chain_features: dict, is_homomer_or_monomer: bool, a3m_strings_for_paring: Sequence[str] | None = None, use_identifier: bool = False, ) -> dict: """Process a single chain features.""" new_chain_features = copy.deepcopy(chain_features) if a3m_strings_for_paring is None: a3m_strings_for_paring = [""] if not is_homomer_or_monomer: all_seq_msa_features = create_all_seq_msa_features_from_a3m( a3m_strings_for_paring, sequence=chain_features["sequence"].item().decode("utf-8"), ) if use_identifier: all_msa_features = create_all_seq_msa_features(chain_features) for k, v in all_seq_msa_features.items(): all_seq_msa_features[k] = np.concatenate([all_msa_features[k], v], axis=0) new_chain_features.update(all_seq_msa_features) return new_chain_features
[docs] def add_assembly_features( all_chain_features: MutableMapping[str, dict], ) -> MutableMapping[str, dict]: """Add features to distinguish between chains. Args: all_chain_features: A dictionary which maps chain_id to a dictionary of features for each chain. Returns: all_chain_features: A dictionary which maps strings of the form `<seq_id>_<sym_id>` to the corresponding chain features. E.g. two chains from a homodimer would have keys A_1 and A_2. Two chains from a heterodimer would have keys A_1 and B_1. """ # Group the chains by sequence seq_to_entity_id = {} grouped_chains = collections.defaultdict(list) for chain_id, chain_features in all_chain_features.items(): seq = chain_features["sequence"].item().decode("utf-8") if seq not in seq_to_entity_id: seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 grouped_chains[seq_to_entity_id[seq]].append(chain_features) new_all_chain_features = {} chain_id = 1 for entity_id, group_chain_features in grouped_chains.items(): for sym_id, chain_features in enumerate(group_chain_features, start=1): new_all_chain_features[f"{int_id_to_str_id(entity_id)}_{sym_id}"] = chain_features seq_length = chain_features["seq_length"] chain_features["asym_id"] = (chain_id * np.ones(seq_length)).astype(np.int64) chain_features["sym_id"] = (sym_id * np.ones(seq_length)).astype(np.int64) chain_features["entity_id"] = (entity_id * np.ones(seq_length)).astype(np.int64) chain_id += 1 return new_all_chain_features
[docs] def pad_msa(example: dict, min_num_cluster) -> dict: example = dict(example) num_seq = example["msa"].shape[0] if num_seq < min_num_cluster: for feat in ("msa", "deletion_matrix", "bert_mask", "msa_mask"): example[feat] = np.pad(example[feat], ((0, min_num_cluster - num_seq), (0, 0))) example["cluster_bias_mask"] = np.pad(example["cluster_bias_mask"], ((0, min_num_cluster - num_seq),)) return example
[docs] def create_all_seq_msa_features_from_a3m( a3m_strings: Sequence[str], sequence: str | None = None, ) -> dict: """Get MSA features for paring.""" all_seq_features = create_msa_features( list(a3m_strings), sequence=sequence, use_identifiers=True, ) species_id = [get_identifiers(s) for s in all_seq_features["msa_identifiers"]] all_seq_features["msa_identifiers"] = np.array(species_id, dtype=np.object_) valid_feats = (*msa_pairing.MSA_FEATURES, "msa_identifiers") feats = {f"{k}_all_seq": v for k, v in all_seq_features.items() if k in valid_feats} return feats
[docs] def create_all_seq_msa_features( all_seq_features: dict, ) -> dict: species_id = [get_identifiers(s) for s in all_seq_features["msa_identifiers"]] all_seq_features["msa_identifiers"] = np.array(species_id, dtype=np.object_) valid_feats = (*msa_pairing.MSA_FEATURES, "msa_identifiers") feats = {f"{k}_all_seq": v for k, v in all_seq_features.items() if k in valid_feats} return feats
[docs] @dataclasses.dataclass(frozen=False) class ComplexInfo: descriptions: List[str] = dataclasses.field(default_factory=list) num_units: List[int] = dataclasses.field(default_factory=list) def __post_init__(self): assert len(self.descriptions) == len(self.num_units) assert all(n > 0 for n in self.num_units)
[docs] def create_multimer_features( paired_a3m_strings: List[str], sequence: str | None = None, ) -> dict: """Create multimer features from paired MSA strings.""" valid_feats = (*msa_pairing.MSA_FEATURES, "msa_identifiers") feats = { f"{k}_all_seq": v for k, v in create_msa_features( paired_a3m_strings, sequence=sequence, use_identifiers=True, ).items() if k in valid_feats } for i in range(len(feats["msa_identifiers_all_seq"])): feats["msa_identifiers_all_seq"][i] = f"pair:{i:d}" return feats
[docs] def process_multimer_features( complex: ComplexInfo, all_monomer_features: Mapping[str, dict], pair_with_identifier: bool = False, a3m_strings_with_identifiers: Mapping[str, str] | None = None, paired_a3m_strings: Mapping[str, str] = dict(), max_num_clusters: int = 508, ) -> dict: """Create a multimer input features.""" all_chain_features = {} is_homomer_or_monomer = len(complex.num_units) == 1 for cid, desc, num in zip(protein.PDB_CHAIN_IDS, complex.descriptions, complex.num_units): assert cid is not None chain_features = all_monomer_features[desc] if a3m_strings_with_identifiers is not None: # Process UniProt features: chain_features = process_single_chain( chain_features, is_homomer_or_monomer, a3m_strings_for_paring=[a3m_strings_with_identifiers[desc]], use_identifier=pair_with_identifier, ) else: chain_features = process_single_chain( chain_features=chain_features, is_homomer_or_monomer=is_homomer_or_monomer, a3m_strings_for_paring=None, use_identifier=pair_with_identifier, ) # Process custom paired MSA: paired_a3m_str = paired_a3m_strings.get(desc, "") multimer_features = create_multimer_features( [paired_a3m_str], sequence=chain_features["sequence"].item().decode("utf-8"), ) if is_homomer_or_monomer: chain_features.update(multimer_features) else: for k, v in multimer_features.items(): chain_features[k] = np.concatenate([chain_features[k], v], axis=0) # Convert monomer features to multimer features: chain_features = convert_monomer_features(chain_features) for i in range(num): chain_id = f"{cid}_{i+1}" all_chain_features[chain_id] = copy.deepcopy(chain_features) # Add assembly features: all_chain_features = add_assembly_features(all_chain_features) # Pair and merge features: example = pair_and_merge(all_chain_features) example = pad_msa(example, max_num_clusters) return example
def _is_homomer_or_monomer(chains: Iterable[dict]) -> bool: """Checks if a list of chains represents a homomer/monomer example.""" # NOTE: An entity_id of 0 indicates padding. num_unique_chains = len(np.unique(np.concatenate([np.unique(chain["entity_id"][chain["entity_id"] > 0]) for chain in chains]))) return num_unique_chains == 1
[docs] def pair_and_merge(all_chain_features: MutableMapping[str, dict]) -> dict: """Runs processing on features to augment, pair and merge. Args: all_chain_features: A MutableMap of dictionaries of features for each chain. Returns: A dictionary of features. """ process_unmerged_features(all_chain_features) np_chains_list = list(all_chain_features.values()) pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) if pair_msa_sequences: np_chains_list = msa_pairing.create_paired_features(chains=np_chains_list) np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list) np_chains_list = crop_chains( np_chains_list, msa_crop_size=MSA_CROP_SIZE, pair_msa_sequences=pair_msa_sequences, max_templates=MAX_TEMPLATES, ) # `merge_chain_features` crashes if there are additional features only present in one chain. common_features = set([*np_chains_list[0]]).intersection(*np_chains_list) np_chains_list = [{k: v for k, v in chain.items() if k in common_features} for chain in np_chains_list] np_example = msa_pairing.merge_chain_features( np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences, max_templates=MAX_TEMPLATES, ) np_example = process_final(np_example) return np_example
[docs] def crop_chains( chains_list: List[dict], msa_crop_size: int, pair_msa_sequences: bool, max_templates: int, ) -> List[dict]: """Crops the MSAs for a set of chains. Args: chains_list: A list of chains to be cropped. msa_crop_size: The total number of sequences to crop from the MSA. pair_msa_sequences: Whether we are operating in sequence-pairing mode. max_templates: The maximum templates to use per chain. Returns: The chains cropped. """ # Apply the cropping. cropped_chains = [] for chain in chains_list: cropped_chain = _crop_single_chain( chain, msa_crop_size=msa_crop_size, pair_msa_sequences=pair_msa_sequences, max_templates=max_templates, ) cropped_chains.append(cropped_chain) return cropped_chains
def _crop_single_chain(chain: dict, msa_crop_size: int, pair_msa_sequences: bool, max_templates: int) -> dict: """Crops msa sequences to `msa_crop_size`.""" msa_size = chain["num_alignments"] if pair_msa_sequences: msa_size_all_seq = chain["num_alignments_all_seq"] msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2) # We reduce the number of un-paired sequences, by the number of times a # sequence from this chain's MSA is included in the paired MSA. This keeps # the MSA size for each chain roughly constant. msa_all_seq = chain["msa_all_seq"][:msa_crop_size_all_seq, :] num_non_gapped_pairs = np.sum(np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, msa_crop_size_all_seq) # Restrict the unpaired crop size so that paired+unpaired sequences do not # exceed msa_seqs_per_chain for each chain. max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) msa_crop_size = np.minimum(msa_size, max_msa_crop_size) else: msa_crop_size = np.minimum(msa_size, msa_crop_size) include_templates = "template_aatype" in chain and max_templates if include_templates: num_templates = chain["template_aatype"].shape[0] templates_crop_size = np.minimum(num_templates, max_templates) for k in chain: k_split = k.split("_all_seq")[0] if k_split in msa_pairing.TEMPLATE_FEATURES: chain[k] = chain[k][:templates_crop_size, :] elif k_split in msa_pairing.MSA_FEATURES: if "_all_seq" in k and pair_msa_sequences: chain[k] = chain[k][:msa_crop_size_all_seq, :] else: chain[k] = chain[k][:msa_crop_size, :] chain["num_alignments"] = np.asarray(msa_crop_size, dtype=np.int32) if include_templates: chain["num_templates"] = np.asarray(templates_crop_size, dtype=np.int32) if pair_msa_sequences: chain["num_alignments_all_seq"] = np.asarray(msa_crop_size_all_seq, dtype=np.int32) return chain
[docs] def process_final(np_example: dict) -> dict: """Final processing steps in data pipeline, after merging and pairing.""" np_example = _correct_msa_restypes(np_example) np_example = _make_seq_mask(np_example) np_example = _make_msa_mask(np_example) np_example = _filter_features(np_example) return np_example
def _correct_msa_restypes(np_example): """Correct MSA restype to have the same order as residue_constants.""" new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE np_example["msa"] = np.take(new_order_list, np_example["msa"], axis=0) np_example["msa"] = np_example["msa"].astype(np.int32) return np_example def _make_seq_mask(np_example): np_example["seq_mask"] = (np_example["entity_id"] > 0).astype(np.float32) return np_example def _make_msa_mask(np_example): """Mask features are all ones, but will later be zero-padded.""" np_example["msa_mask"] = np.ones_like(np_example["msa"], dtype=np.float32) seq_mask = (np_example["entity_id"] > 0).astype(np.float32) np_example["msa_mask"] *= seq_mask[None] return np_example def _filter_features(np_example: dict) -> dict: """Filters features of example to only those requested.""" return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
[docs] def process_unmerged_features(all_chain_features: MutableMapping[str, dict]): """Postprocessing stage for per-chain features before merging.""" num_chains = len(all_chain_features) for chain_features in all_chain_features.values(): # Convert deletion matrices to float. chain_features["deletion_matrix"] = np.asarray(chain_features.pop("deletion_matrix_int"), dtype=np.float32) if "deletion_matrix_int_all_seq" in chain_features: chain_features["deletion_matrix_all_seq"] = np.asarray(chain_features.pop("deletion_matrix_int_all_seq"), dtype=np.float32) chain_features["deletion_mean"] = np.mean(chain_features["deletion_matrix"], axis=0) # Add all_atom_mask and dummy all_atom_positions based on aatype. all_atom_mask = rc.STANDARD_ATOM_MASK[chain_features["aatype"]] chain_features["all_atom_mask"] = all_atom_mask.astype(dtype=np.float32) chain_features["all_atom_positions"] = np.zeros(list(all_atom_mask.shape) + [3], dtype=np.float32) # Add assembly_num_chains. chain_features["assembly_num_chains"] = np.asarray(num_chains) # Add entity_mask. for chain_features in all_chain_features.values(): chain_features["entity_mask"] = (chain_features["entity_id"] != 0).astype(np.int32)