Source code for deepfold.data.pdbx_parsing

# Copyright 2024 DeepFold Team


"""Protein Data Bank."""


import collections
import functools
import gzip
import logging
import math
import os
import sys
import warnings
from dataclasses import dataclass, field
from io import BytesIO, StringIO
from itertools import product
from pathlib import Path
from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple

import numpy as np
import requests
from Bio.Data.PDBData import protein_letters_3to1_extended as protein_letters_3to1
from Bio.PDB.MMCIF2Dict import MMCIF2Dict
from Bio.PDB.PDBExceptions import PDBConstructionException, PDBConstructionWarning
from Bio.PDB.Structure import Structure
from Bio.PDB.StructureBuilder import StructureBuilder

from deepfold.common import residue_constants as rc
from deepfold.data.errors import PDBxConstructionError, PDBxConstructionWarning, PDBxError, PDBxWarning
from deepfold.data.monomer import build_atom_map
from deepfold.utils.file_utils import read_text

logger = logging.getLogger(__name__)


[docs] class CategoryNotFoundError(PDBxError): pass
UNASSIGNED = {".", "?"} LATIN = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" DEFAULT_ATOM_MAP = { ("MSE", "MET"): {"N": "N", "CA": "CA", "C": "C", "O": "O", "OXT": "OXT", "CB": "CB", "CG": "CG", "SE": "SD", "CE": "CE"}, ("UNK", "UNK"): {"N": "N", "CA": "CA", "C": "C", "O": "O", "OXT": "OXT", "CB": "CB"}, } # TODO: Implement ATOM_MAP
[docs] def loop_to_list(prefix: str, dic: Mapping[str, Sequence[str]]) -> Sequence[Mapping[str, str]]: cols = [] data = [] for key, value in dic.items(): if key.startswith(prefix): cols.append(key) data.append(value) if not all([len(xs) == len(data[0]) for xs in data]): raise PDBxConstructionError(f"Not all loops are the same length: {cols}") return [dict(zip(cols, xs)) for xs in zip(*data)]
[docs] def loop_to_dict(prefix: str, index: str, dic: Mapping[str, Sequence[str]]) -> Mapping[str, Mapping[str, str]]: entries = loop_to_list(prefix, dic) return {entry[index]: entry for entry in entries}
# _entity_poly_seq
[docs] @dataclass(frozen=True) class Monomer: entity_id: str num: int mon_id: str
[docs] @dataclass(frozen=True) class AtomSite: residue_name: str author_chain_id: str label_chain_id: str author_seq_num: str label_seq_num: int insertion_code: str hetatm_atom: str model_num: int
# Used to map SEQRES index to a residue in the structure.
[docs] @dataclass(frozen=True) class ResiduePosition: label_asym_id: str auth_asym_id: str label_seq_id: int auth_seq_id: int insertion_code: str
[docs] @dataclass(frozen=True) class ResidueAtPosition: position: ResiduePosition | None name: str is_missing: bool hetflag: str residue_index: int ordinal: int = 0 # For hetero-residues
# _pdbx_struct_mod_residue # NOTE: PDB_model_num not used
[docs] @dataclass(frozen=True) class ModResidue: label_asym_id: str label_comp_id: str label_seq_id: int insertion_code: str parent_comp_id: List[str] = field(default_factory=list)
# _pdbx_struct_oper_list
[docs] @dataclass(frozen=True) class Op: op_id: str is_identity: bool = True rot: np.ndarray = field(default=np.identity(3)) trans: np.ndarray = field(default=np.zeros(3))
[docs] def to_dict(self): return { "is_identity": self.is_identity, "rot": self.rot.tolist(), "trans": self.trans.tolist(), }
# _pdbx_struct_assembly_gen
[docs] @dataclass(frozen=True) class AssemblyGenerator: assembly_id: str asym_id_list: List[str] oper_sequence: List[str]
[docs] @dataclass(frozen=True) class PDBxHeader: # Required entry_id: str # _entry.id deposition_date: str # _pdbx_database_status.recvd_initial_deposition_date method: str # _exptl.method resolution: float | None = None
[docs] @dataclass(frozen=True) class PDBxObject: header: PDBxHeader structure: Structure chain_ids: List[str] modres: Dict[Tuple[str, str, str], ModResidue] label_to_auth: Dict[str, str] auth_to_label: Dict[str, str] chain_to_entity: Dict[str, str] # chain_to_seqres: Dict[str, str] chain_to_structure: Dict[str, Dict[int, ResidueAtPosition]] assemblies: Dict[int, List[AssemblyGenerator]] operations: Dict[int, Op]
[docs] @dataclass(frozen=True) class ParsingResult: mmcif_object: PDBxObject | None = None errors: List[PDBxError] = field(default_factory=list)
[docs] class PDBxParser: """Parse a PDBx/mmCIF record.""" def __init__(self) -> None: self.mmcif_dict: Dict[str, List[str]] # Header self.header = None # Structure self.structure: Structure = None # Maps mmCIF chain ids to chain ids used the authors. self.label_to_auth = None # Maps index into sequence to ResidueAtPosition. self.chain_to_structure = None # self.chain_to_seqres = None # self.entity_to_chains = None self.chain_to_entity = None self.valid_chains = None # self.modified_residues = None # Assembly self.assemblies: Dict[int, List[AssemblyGenerator]] = None self.operations: Dict[int, Op] = None
[docs] def parse(self, mmcif_string: str, catch_all_errors: bool = False) -> ParsingResult: errors = [] try: self.mmcif_dict = MMCIF2Dict(StringIO(mmcif_string)) self._parse() self._handle_missing_residues() self._process_modres() pdbx_object = PDBxObject( header=self.header, structure=self.structure, chain_ids=sorted(list(self.valid_chains.keys())), modres=self.modified_residues, label_to_auth=self.label_to_auth, auth_to_label={y: x for x, y in self.label_to_auth.items()}, chain_to_entity=self.chain_to_entity, chain_to_structure=self.chain_to_structure, assemblies=self.assemblies, operations=self.operations, ) return ParsingResult(mmcif_object=pdbx_object, errors=errors) except Exception as e: errors.append(e) if not catch_all_errors: raise return ParsingResult(errors=errors)
# NOTE: Do before run `parse()`
[docs] def inject(parents_table: Dict[str, Dict[str, Any]] | None = None): pass
def _parse(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PDBConstructionWarning) self.header = self._get_header(self.mmcif_dict) self._get_protein_chains() # valid_chains, entity_to_chains self.structure, self.label_to_auth, self.chain_to_structure = self._build_structure(self.mmcif_dict) self._get_mod_residue() # modified_residues self.assemblies, self.operations = self._get_assemblies(self.mmcif_dict, self.valid_chains) def _handle_missing_residues(self): missing_chains = [] for chain_id, seq_info in self.valid_chains.items(): try: current_mapping = self.chain_to_structure[chain_id] except KeyError as key: logger.info(f"Chain {chain_id} has no structure. Remove...") missing_chains.append(chain_id) continue # Add missing residue information to seq_to_structure_mappings. for monomer in seq_info: idx = monomer.num if idx not in current_mapping: current_mapping[idx] = ResidueAtPosition( position=None, name=monomer.mon_id, is_missing=True, hetflag=" ", residue_index=idx, ) for chain_id in missing_chains: del self.valid_chains[chain_id] # Filter modres records which are used. def _process_modres(self): modres_processed = {} for chain_id, seq_info in self.valid_chains.items(): seq_info = {m.num: m for m in seq_info} for value in self.modified_residues.values(): if (value.label_asym_id == chain_id) and (value.label_comp_id == seq_info[value.label_seq_id].mon_id): modres_processed[(chain_id, value.label_seq_id)] = value self.modified_residues = modres_processed @staticmethod def _get_header(mmcif_dict): header = {"entry_id": "", "deposition_date": "", "method": "", "resolution": 0.0} def update_header_entry(target_key, keys): md = mmcif_dict for key in keys: val = md.get(key) try: item = val[0] except (TypeError, IndexError): continue if item != "?" and item != ".": header[target_key] = item break update_header_entry("entry_id", ["_entry_id", "_exptl.entry_id", "_struct.entry_id"]) update_header_entry("deposition_date", ["_pdbx_database_status.recvd_initial_deposition_date"]) update_header_entry("method", ["_exptl.method"]) update_header_entry("resolution", ["_refine.ls_d_res_high", "_refine_hist.d_res_high", "_em_3d_reconstruction.resolution"]) # NOTE: NMR result has empty resolution. header["resolution"] = float(header["resolution"]) return PDBxHeader(**header) def _get_protein_chains(self, only_valid: bool = True): mmcif_dict = self.mmcif_dict # _entity_poly try: entity_poly_seqs = loop_to_list("_entity_poly_seq.", mmcif_dict) except KeyError: raise PDBxConstructionError("Cannot find `_entity_poly`") polymers = collections.defaultdict(list) for entity_poly_seq in entity_poly_seqs: polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append( Monomer( mon_id=entity_poly_seq["_entity_poly_seq.mon_id"], num=int(entity_poly_seq["_entity_poly_seq.num"]), entity_id=entity_poly_seq["_entity_poly_seq.entity_id"], ) ) # Get chemical compositions. chem_comps = loop_to_dict("_chem_comp.", "_chem_comp.id", mmcif_dict) # Get chains information for each entity. struct_asyms = loop_to_list("_struct_asym.", mmcif_dict) entity_to_chains = collections.defaultdict(list) for struct_asym in struct_asyms: chain_id = struct_asym["_struct_asym.id"] entity_id = struct_asym["_struct_asym.entity_id"] entity_to_chains[entity_id].append(chain_id) # Identify and return the valid portein chains. valid_chains = {} for entity_id, seq_info in polymers.items(): chain_ids = entity_to_chains[entity_id] # Reject polymers without any peptide-like componenets, such as DNA/RNA. if any(["peptide" in chem_comps[monomer.mon_id]["_chem_comp.type"].lower() for monomer in seq_info]): for chain_id in chain_ids: valid_chains[chain_id] = seq_info chain_to_entity = {} for k, v in entity_to_chains.items(): for x in v: if x in valid_chains: chain_to_entity[x] = k self.entity_to_chains = entity_to_chains self.chain_to_entity = chain_to_entity self.valid_chains = valid_chains def _get_mod_residue(self): mod_residues = loop_to_list("_pdbx_struct_mod_residue.", self.mmcif_dict) modified_residues = {} for modres in mod_residues: asym_id = modres["_pdbx_struct_mod_residue.label_asym_id"] seq_id = modres["_pdbx_struct_mod_residue.label_seq_id"] comp_id = modres["_pdbx_struct_mod_residue.label_comp_id"] parent_comp_id = modres["_pdbx_struct_mod_residue.parent_comp_id"] insertion_code = modres["_pdbx_struct_mod_residue.PDB_ins_code"] if asym_id in {".", "?"}: asym_id = " " if asym_id not in self.valid_chains: logger.debug(f"Skip {asym_id}:{seq_id}{insertion_code} {comp_id}...") continue seq_id = int(seq_id) if insertion_code in {".", "?"}: insertion_code = " " key = (asym_id, seq_id, insertion_code) if key in modified_residues: row = dict(vars(modified_residues[key])) row["parent_comp_id"].append(parent_comp_id) row = ModResidue(**row) else: row = ModResidue( label_asym_id=asym_id, label_comp_id=comp_id, label_seq_id=seq_id, insertion_code=insertion_code, parent_comp_id=[parent_comp_id], ) modified_residues[key] = row self.modified_residues = modified_residues @staticmethod def _build_structure(mmcif_dict): structure_builder = StructureBuilder() atom_serial_list = mmcif_dict["_atom_site.id"] atom_id_list = mmcif_dict["_atom_site.label_atom_id"] residue_id_list = mmcif_dict["_atom_site.label_comp_id"] try: element_list = mmcif_dict["_atom_site.type_symbol"] except KeyError: element_list = None auth_chain_id_list = mmcif_dict["_atom_site.auth_asym_id"] chain_id_list = mmcif_dict["_atom_site.label_asym_id"] x_list = [float(x) for x in mmcif_dict["_atom_site.Cartn_x"]] y_list = [float(x) for x in mmcif_dict["_atom_site.Cartn_y"]] z_list = [float(x) for x in mmcif_dict["_atom_site.Cartn_z"]] alt_list = mmcif_dict["_atom_site.label_alt_id"] icode_list = mmcif_dict["_atom_site.pdbx_PDB_ins_code"] b_factor_list = mmcif_dict["_atom_site.B_iso_or_equiv"] occupancy_list = mmcif_dict["_atom_site.occupancy"] fieldname_list = mmcif_dict["_atom_site.group_PDB"] try: serial_list = [int(n) for n in mmcif_dict["_atom_site.pdbx_PDB_model_num"]] except KeyError: # No model number column serial_list = None except ValueError: # Invalid model number (malformed file) raise PDBxConstructionError("Invalid model number") from None auth_seq_id_list = mmcif_dict["_atom_site.auth_seq_id"] seq_id_list = mmcif_dict["_atom_site.label_seq_id"] # Now loop over atoms and build the structure label_to_author_chain_id = {} seq_to_structure_mappings = {} current_chain_id = None current_residue_id = None current_resname = None structure_builder.init_structure("") structure_builder.init_seg(" ") current_model_id = -1 current_serial_id = -1 for i in range(len(atom_serial_list)): # set the line_counter for 'ATOM' lines only and not # as a global line counter found in the PDBParser() structure_builder.set_line_counter(i) # Try coercing serial to int, for compatibility with PDBParser # But do not quit if it fails. mmCIF format specs allow strings. try: serial = int(atom_serial_list[i]) except ValueError: serial = atom_serial_list[i] warnings.warn(f"Some atom serial numbers ({serial}) are not numerical", PDBConstructionWarning) x = x_list[i] y = y_list[i] z = z_list[i] resname = residue_id_list[i] chain_id = chain_id_list[i] auth_chain_id = auth_chain_id_list[i] altloc = alt_list[i] if altloc in UNASSIGNED: altloc = " " resseq = seq_id_list[i] if resseq == ".": # Non-existing residue ID try: msg_resseq = mmcif_dict["_atom_site.auth_seq_id"][i] msg = f"Non-existing residue ID in chain '{chain_id}', residue '{msg_resseq}'" except (KeyError, IndexError): msg = f"Non-existing residue ID in chain '{chain_id}'" warnings.warn(msg, PDBxConstructionWarning) continue int_resseq = int(resseq) icode = icode_list[i] if icode in UNASSIGNED: icode = " " name = atom_id_list[i] # occupancy & B factor try: tempfactor = float(b_factor_list[i]) except ValueError: raise PDBxConstructionError("Invalid or missing B factor") from None try: occupancy = float(occupancy_list[i]) except ValueError: raise PDBxConstructionError("Invalid or missing occupancy") from None fieldname = fieldname_list[i] if fieldname == "HETATM": if resname == "HOH" or resname == "WAT": hetatm_flag = "W" else: hetatm_flag = "H" # "H" else: hetatm_flag = " " # position = ResiduePosition( label_asym_id=chain_id, auth_asym_id=auth_chain_id, label_seq_id=int_resseq, auth_seq_id=int(auth_seq_id_list[i]), insertion_code=icode, ) current = seq_to_structure_mappings.get(chain_id, {}) current[int_resseq] = ResidueAtPosition( position=position, name=resname, is_missing=False, hetflag=hetatm_flag, residue_index=int_resseq, ) seq_to_structure_mappings[chain_id] = current resseq = (hetatm_flag, int_resseq, icode) if serial_list is not None: # model column exists; use it serial_id = serial_list[i] if current_serial_id != serial_id: # if serial changes, update it and start new model current_serial_id = serial_id current_model_id += 1 structure_builder.init_model(current_model_id, current_serial_id) current_chain_id = None current_residue_id = None current_resname = None else: # no explicit model column; initialize single model structure_builder.init_model(current_model_id) if current_chain_id != chain_id: current_chain_id = chain_id structure_builder.init_chain(current_chain_id) label_to_author_chain_id[chain_id] = auth_chain_id # Author chain id current_residue_id = None current_resname = None if current_residue_id != resseq or current_resname != resname: current_residue_id = resseq current_resname = resname structure_builder.init_residue(resname, hetatm_flag, int_resseq, icode) coord = np.array((x, y, z), "f") element = element_list[i].upper() if element_list else None structure_builder.init_atom( name, coord, tempfactor, occupancy, altloc, name, serial_number=serial, element=element, ) return structure_builder.get_structure(), label_to_author_chain_id, seq_to_structure_mappings @staticmethod def _get_assemblies(mmcif_dict, valid_chains): oper_list_id_list = mmcif_dict["_pdbx_struct_oper_list.id"] oper_list_type_list = mmcif_dict["_pdbx_struct_oper_list.type"] oper_list_r11_list = mmcif_dict["_pdbx_struct_oper_list.matrix[1][1]"] oper_list_r12_list = mmcif_dict["_pdbx_struct_oper_list.matrix[1][2]"] oper_list_r13_list = mmcif_dict["_pdbx_struct_oper_list.matrix[1][3]"] oper_list_r21_list = mmcif_dict["_pdbx_struct_oper_list.matrix[2][1]"] oper_list_r22_list = mmcif_dict["_pdbx_struct_oper_list.matrix[2][2]"] oper_list_r23_list = mmcif_dict["_pdbx_struct_oper_list.matrix[2][3]"] oper_list_r31_list = mmcif_dict["_pdbx_struct_oper_list.matrix[3][1]"] oper_list_r32_list = mmcif_dict["_pdbx_struct_oper_list.matrix[3][2]"] oper_list_r33_list = mmcif_dict["_pdbx_struct_oper_list.matrix[3][3]"] oper_list_t1_list = mmcif_dict["_pdbx_struct_oper_list.vector[1]"] oper_list_t2_list = mmcif_dict["_pdbx_struct_oper_list.vector[2]"] oper_list_t3_list = mmcif_dict["_pdbx_struct_oper_list.vector[3]"] oper_list = {} for op_id, type, r11, r12, r13, r21, r22, r23, r31, r32, r33, x, y, z in zip( oper_list_id_list, oper_list_type_list, oper_list_r11_list, oper_list_r12_list, oper_list_r13_list, oper_list_r21_list, oper_list_r22_list, oper_list_r23_list, oper_list_r31_list, oper_list_r32_list, oper_list_r33_list, oper_list_t1_list, oper_list_t2_list, oper_list_t3_list, ): # op_id = int(op_id) # str if type == "identity operation": op = Op(op_id=op_id, is_identity=True) else: r11, r12, r13, r21, r22, r23, r31, r32, r33 = map(float, [r11, r12, r13, r21, r22, r23, r31, r32, r33]) x, y, z = map(float, [x, y, z]) rot = np.array([[r11, r12, r13], [r21, r22, r23], [r31, r32, r33]]) vec = np.array([x, y, z]) op = Op(op_id=op_id, is_identity=False, rot=rot, trans=vec) oper_list[op_id] = op assembly_id_list = mmcif_dict["_pdbx_struct_assembly_gen.assembly_id"] oper_expression_list = mmcif_dict["_pdbx_struct_assembly_gen.oper_expression"] asym_id_list_list = mmcif_dict["_pdbx_struct_assembly_gen.asym_id_list"] generators = collections.defaultdict(list) for assem_id, expr, asym_ids in zip(assembly_id_list, oper_expression_list, asym_id_list_list): assem_id = int(assem_id) asym_ids = asym_ids.split(",") asym_ids = [aid for aid in asym_ids if aid in valid_chains] op_seq = _parse_oper_expr(expr) generators[assem_id].append( AssemblyGenerator( assembly_id=assem_id, asym_id_list=asym_ids, oper_sequence=op_seq, ) ) return dict(generators), oper_list
def _parse_oper_expr(expr: str): # Remove whitespaces for easier parsing expr = expr.replace(" ", "") # Helper function to expand ranges like "1-4" into "1,2,3,4" def expand_range(r): start, end = map(int, r.split("-")) return list(map(str, range(start, end + 1))) # Helper function to compute cartesian product def cartesian_product(groups): # NOTE: Apply from right to left. # if len(groups) == 1: # return [groups[0][::-1]] return [list(r) for r in product(*groups[::-1])] # Tokenize and parse the expression tokens = expr.strip("()").split(")(") parsed_tokens = [] for token in tokens: if "-" in token: # Range detected ranges = token.split(",") expanded_ranges = [expand_range(r) if "-" in r else [r] for r in ranges] merged_ranges = [item for sublist in expanded_ranges for item in sublist] parsed_tokens.append(merged_ranges) else: # Single numbers or lists # parsed_tokens.append(list(map(int, token.split(",")))) parsed_tokens.append(token.split(",")) # Compute cartesian product if necessary return cartesian_product(parsed_tokens) ## Fetch PDB records
[docs] @functools.lru_cache(maxsize=16, typed=False) def fetch_mmcif(rcsb_id: str) -> str: """Fetch mmCIF from RCSB.""" rcsb_id = rcsb_id.lower() r = requests.get( f"https://files.rcsb.org/download/{rcsb_id}.cif.gz", headers={"Accept-Encoding": "gzip"}, ) if r.status_code == 200: data = r.content text = gzip.GzipFile(mode="r", fileobj=BytesIO(data)).read().decode() return text raise RuntimeError(f"Cannot fetch '{rcsb_id}'")
[docs] def read_mmcif( rcsb_id: str, mmcif_path: os.PathLike | None = None, ) -> str: """Fetch a mmCIF file from local directory or from the web.""" # Is valid ID? assert len(rcsb_id) == 4 # Use lowercase rcsb_id = rcsb_id.lower() dv = rcsb_id[1:3] mmcif_path = Path.cwd() if mmcif_path is None else Path(mmcif_path) mmcif_path = mmcif_path / dv / f"{rcsb_id}.cif.gz" if mmcif_path is not None and os.path.exists(mmcif_path): return read_text(mmcif_path) else: return fetch_mmcif(rcsb_id=rcsb_id)
# Parse PDB records
[docs] def get_fasta(mmcif_object: PDBxObject) -> str: entry_id = mmcif_object.header.entry_id.lower() seqres = mmcif_object.chain_to_structure label_to_auth = mmcif_object.label_to_auth modres = mmcif_object.modres fasta_str = "" for asym_id in sorted(seqres): seq = [] for seq_id in sorted(seqres[asym_id]): key = (asym_id, seq_id) if key in modres: # Apply MODRES seq.extend(modres[key].parent_comp_id) else: seq.append(seqres[asym_id][seq_id].name) for i, s in enumerate(seq): code = protein_letters_3to1.get(s, "X") seq[i] = code if len(code) == 1 else "X" seq = "".join(seq) auth_id = label_to_auth[asym_id] fasta_str += f">{entry_id}_{asym_id} | {entry_id}_{auth_id}\n" fasta_str += f"{seq}\n" return fasta_str
[docs] def get_assemblies(mmcif_object: PDBxObject): header = mmcif_object.header common = {} common["entry_id"] = header.entry_id common["release_date"] = header.deposition_date common["method"] = header.method common["resolution"] = header.resolution assemblies = {} oper_list = {k: v.to_dict() for k, v in mmcif_object.operations.items()} for aid, assembly in mmcif_object.assemblies.items(): name = f"{header.entry_id}-{aid}" asym_ids_needed = [] # Asym units need to construct a complex. op_ids_needed = [] generators = [] # Sequences of operations. for generator in assembly: g = {} g["asym_id_list"] = generator.asym_id_list g["oper_list"] = generator.oper_sequence asym_ids_needed.extend(generator.asym_id_list * len(generator.oper_sequence)) op_ids_needed.extend(sum(generator.oper_sequence, [])) generators.append(g) # Clean up operations op_ids_needed = list(set(op_ids_needed)) op_ids_needed.sort() # Oligomeric state asym_counter = collections.Counter(asym_ids_needed) entity_counter = collections.defaultdict(int) chain_to_entity = mmcif_object.chain_to_entity for chain_id, num in asym_counter.items(): entity_id = chain_to_entity[chain_id] entity_counter[entity_id] += num assemblies[name] = { **common, "assembly_id": name, "assembly_num_chains": len(asym_ids_needed), "generators": generators, "oper_list": {k: oper_list[k] for k in op_ids_needed}, "oligomeric_state": dict(sorted((k, v) for k, v in entity_counter.items())), "label_to_auth": {x: y for x, y in mmcif_object.label_to_auth.items() if x in asym_ids_needed}, } return assemblies
[docs] def get_chain_features( mmcif_object: PDBxObject, model_num: int, chain_id: str, out_chain_id: str | None = None, ) -> Tuple[Dict[str, np.ndarray], Structure]: """Get atom positions and mask from a list of Biopython Residues.""" builder = StructureBuilder() builder.set_header(f"{mmcif_object.header}") builder.init_structure(structure_id=f"{mmcif_object.header.entry_id}_{chain_id}") builder.init_model(model_id=model_num) builder.init_seg(" ") builder.init_chain(chain_id=chain_id if out_chain_id is None else out_chain_id) model_iter = mmcif_object.structure.get_models() for _ in range(model_num): model = next(model_iter) relevant_chains = [c for c in model.get_chains() if c.id == chain_id] if len(relevant_chains) != 1: raise PDBxConstructionError(f"Expected exactly one chain in structure with id {chain_id}") chain = relevant_chains[0] modres = {m.label_seq_id: m for m in mmcif_object.modres.values() if m.label_asym_id == chain_id} max_rid = len(mmcif_object.chain_to_structure[chain_id]) num_res = max_rid + sum(len(m.parent_comp_id) - 1 for m in modres.values()) aatype = np.full(num_res, rc.unk_restype_index, dtype=np.int64) residue_index = np.zeros(num_res, dtype=np.int64) seq_length = np.array(num_res, dtype=np.int64) seq_mask = np.ones(num_res, dtype=np.float32) all_atom_positions = np.zeros([num_res, rc.atom_type_num, 3]) all_atom_mask = np.zeros([num_res, rc.atom_type_num], dtype=np.int64) rid = min(mmcif_object.chain_to_structure[chain_id].keys()) # First index of chain_to_structure aid = 0 # Array index must be 0 shift = rid while rid <= max_rid: offset = 1 # Default rap = mmcif_object.chain_to_structure[chain_id][rid] def fill_features(res, array_index, builder, atom_map=None): pos = np.zeros([rc.atom_type_num, 3], dtype=np.float32) mask = np.zeros([rc.atom_type_num], dtype=np.float32) for atom in res.get_atoms(): atom_name = atom.get_name() if atom_map is not None: if atom_name in atom_map: atom_name = atom_map[atom_name] else: continue # TODO: Really? # Get only canonical atoms if atom_name in rc.atom_order.keys(): x, y, z = atom.get_coord() pos[rc.atom_order[atom_name]] = [x, y, z] mask[rc.atom_order[atom_name]] = 1.0 builder.init_atom( name=atom_name, coord=[x, y, z], b_factor=atom.get_bfactor(), occupancy=atom.get_occupancy(), altloc=atom.get_altloc(), fullname=f"{atom_name:^4s}", element=atom_name[0], ) # Fixing naming errors in arginine residues where NH2 is incorrectly assigned to be closer to CD than NH1. if res.get_resname() == "ARG": _fix_arg(pos, mask) all_atom_positions[array_index] = pos all_atom_mask[array_index] = mask residue_index[array_index] = array_index + shift if not rap.is_missing: assert rap.position is not None key = (rap.hetflag, rap.position.label_seq_id, rap.position.insertion_code) try: res = chain[key] except KeyError as e: logger.error(f"Model {model_num} does not have residue {key} in chain {chain_id}") # Missing residues aatype[aid] = rc.resname_to_idx.get(rap.name, rc.unk_restype_index) builder.init_residue(rap.name, field=" ", resseq=rid, icode=" ") seq_mask[aid] = 0.0 # Initially one. residue_index[aid] = aid + shift rid += 1 aid += offset continue resname = res.get_resname() if key[1] in modres: # Modified residues mr = modres[key[1]] offset = len(mr.parent_comp_id) for i, resname_parent in enumerate(mr.parent_comp_id): aatype[aid + i] = rc.resname_to_idx.get(resname_parent, rc.unk_restype_index) # icode = " " if offset == 1 else LATIN[i] icode = LATIN[i] builder.init_residue(resname=resname_parent, field=" ", resseq=rid, icode=icode) atom_map = DEFAULT_ATOM_MAP.get((resname, resname_parent), None) if atom_map is None: atom_map = DEFAULT_ATOM_MAP[("UNK", "UNK")] fill_features(res, aid + i, builder, atom_map=atom_map) else: # Otherwise aatype[aid] = rc.resname_to_idx.get(resname, rc.unk_restype_index) builder.init_residue(resname, field=" ", resseq=rid, icode=" ") fill_features(res, aid, builder) else: # Missing residues aatype[aid] = rc.resname_to_idx.get(rap.name, rc.unk_restype_index) builder.init_residue(rap.name, field=" ", resseq=rid, icode=" ") seq_mask[aid] = 0.0 # Initially one. residue_index[aid] = aid + shift rid += 1 aid += offset # Casting for saving storage: return { "aatype": aatype.astype(np.int8), "residue_index": residue_index.astype(np.int32), "seq_length": seq_length.astype(np.int32), "seq_mask": seq_mask.astype(np.int8), "all_atom_positions": all_atom_positions.astype(np.float32), "all_atom_mask": all_atom_mask.astype(np.int8), }, builder.get_structure()
def _fix_arg(pos, mask): cd = rc.atom_order["CD"] nh1 = rc.atom_order["NH1"] nh2 = rc.atom_order["NH2"] if all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and (np.linalg.norm(pos[nh1] - pos[cd]) > np.linalg.norm(pos[nh2] - pos[cd])): pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy() mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()