# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2024 DeepFold Team
from typing import List, Sequence
import numpy as np
from deepfold.common import protein
from deepfold.common import residue_constants as rc
from deepfold.data.search.mmcif import zero_center_atom_positions
from deepfold.data.search.parsers import parse_a3m, parse_hhr, parse_hmmsearch_sto
from deepfold.data.search.templates import TemplateHit, TemplateHitFeaturizer
from deepfold.utils.datetime_utils import datetime_from_string
[docs]
def create_sequence_features(sequence: str, domain_name: str) -> dict:
seqlen = len(sequence) # num residues
sequence_features = {}
sequence_features["aatype"] = rc.sequence_to_onehot(sequence=sequence, mapping=rc.restype_order_with_x, map_unknown_to_x=True)
sequence_features["between_segment_residues"] = np.zeros(shape=(seqlen), dtype=np.int32)
sequence_features["domain_name"] = np.array([domain_name.encode("utf-8")], dtype=np.object_)
sequence_features["residue_index"] = np.arange(seqlen, dtype=np.int32)
sequence_features["seq_length"] = np.full(shape=(seqlen), fill_value=seqlen, dtype=np.int32)
sequence_features["sequence"] = np.array([sequence.encode("utf-8")], dtype=np.object_)
return sequence_features
[docs]
def create_mmcif_features(
mmcif_dict: dict,
author_chain_id: str,
zero_center: bool = False,
) -> dict:
mmcif_features = {}
pdb_chain_id = mmcif_dict["pdb_id"] + author_chain_id
sequence = mmcif_dict["sequences"][author_chain_id]
sequence_features = create_sequence_features(sequence=sequence, domain_name=pdb_chain_id)
mmcif_features.update(sequence_features)
all_atom_positions = mmcif_dict["atoms"][author_chain_id]["all_atom_positions"]
all_atom_mask = mmcif_dict["atoms"][author_chain_id]["all_atom_mask"]
if zero_center:
all_atom_positions = zero_center_atom_positions(all_atom_positions=all_atom_positions, all_atom_mask=all_atom_mask)
mmcif_features["all_atom_positions"] = all_atom_positions.astype(np.float32)
mmcif_features["all_atom_mask"] = all_atom_mask.astype(np.float32)
mmcif_features["resolution"] = np.array([mmcif_dict["resolution"]], dtype=np.float32)
mmcif_features["release_date"] = np.array([mmcif_dict["release_date"].encode("utf-8")], dtype=np.object_)
mmcif_features["is_distillation"] = np.array(0.0, dtype=np.float32)
return mmcif_features
def _aatype_to_str_sequence(aatype: Sequence[int]) -> str:
return "".join([rc.restypes_with_x[aatype[i]] for i in range(len(aatype))])
[docs]
def create_protein_features(
protein_object: protein.Protein,
description: str,
is_distillation: bool = False,
) -> dict:
pdb_feats = {}
aatype = list(protein_object.aatype) # [NUM_RES]
sequence = _aatype_to_str_sequence(aatype)
pdb_feats.update(create_sequence_features(sequence=sequence, domain_name=description))
all_atom_positions = protein_object.atom_positions
all_atom_mask = protein_object.atom_mask
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
pdb_feats["resolution"] = np.array([0.0]).astype(np.float32)
pdb_feats["is_distillation"] = np.array(1.0 if is_distillation else 0.0).astype(np.float32)
return pdb_feats
[docs]
def create_pdb_features(
protein_object: protein.Protein,
description: str,
is_distillation: bool = True,
confidence_threshold: float = 50.0,
) -> dict:
pdb_feats = create_protein_features(protein_object, description, is_distillation=True)
if is_distillation:
high_confidence = protein_object.b_factors > confidence_threshold
high_confidence = np.any(high_confidence, axis=-1)
pdb_feats["all_atom_mask"] *= high_confidence[..., None]
return pdb_feats
[docs]
def create_template_features(
sequence: str, # Query sequence
template_hits: Sequence[TemplateHit],
template_hit_featurizer: TemplateHitFeaturizer,
max_release_date: str,
pdb_id: str | None = None, # Optional query pdb id
sort_by_sum_probs: bool = True,
shuffling_seed: int | None = None,
) -> dict:
query_release_date = datetime_from_string(max_release_date, r"%Y-%m-%d")
template_features = template_hit_featurizer.get_template_features(
query_sequence=sequence,
template_hits=list(template_hits),
max_template_date=query_release_date,
query_pdb_id=pdb_id,
sort_by_sum_probs=sort_by_sum_probs,
shuffling_seed=shuffling_seed,
)
return template_features
[docs]
def create_template_features_from_hhr_string(
sequence: str,
hhr_string: str,
template_hit_featurizer: TemplateHitFeaturizer,
release_date: str,
pdb_id: str | None = None,
shuffling_seed: int | None = None,
) -> dict:
template_hits = parse_hhr(hhr_string)
template_features = create_template_features(
sequence=sequence,
template_hits=template_hits,
template_hit_featurizer=template_hit_featurizer,
max_release_date=release_date,
pdb_id=pdb_id,
shuffling_seed=shuffling_seed,
)
return template_features
[docs]
def create_template_features_from_hmmsearch_sto_string(
sequence: str,
sto_string: str,
template_hit_featurizer: TemplateHitFeaturizer,
release_date: str,
pdb_id: str | None = None,
shuffling_seed: int | None = None,
) -> dict:
template_hits = parse_hmmsearch_sto(sequence, sto_string)
template_features = create_template_features(
sequence=sequence,
template_hits=template_hits,
template_hit_featurizer=template_hit_featurizer,
max_release_date=release_date,
pdb_id=pdb_id,
shuffling_seed=shuffling_seed,
)
return template_features
[docs]
def create_msa_features(
a3m_strings: List[str],
sequence: str | None = None,
use_identifiers: bool = False,
) -> dict:
msas = []
deletion_matrices = []
descriptions = []
for a3m_string in a3m_strings:
if not a3m_string:
continue
msa, deletion_matrix, desc = parse_a3m(a3m_string)
msas.append(msa)
deletion_matrices.append(deletion_matrix)
descriptions.append(desc)
if len(msas) == 0:
assert sequence is not None
msas.append([sequence])
deletion_matrices.append([[0 for _ in sequence]])
descriptions.append([""])
int_msa = []
deletion_matrix = []
identifiers = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f"MSA {msa_index} must contain at least one sequence.")
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append([rc.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
identifiers.append(descriptions[msa_index][sequence_index])
num_res = len(msas[0][0]) # First sequence must be the query sequence.
num_alignments = len(int_msa)
msa_features = {}
msa_features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
msa_features["msa"] = np.array(int_msa, dtype=np.int32)
msa_features["num_alignments"] = np.array([num_alignments] * num_res, dtype=np.int32)
if use_identifiers:
msa_features["msa_identifiers"] = np.array(identifiers, dtype=np.object_)
return msa_features