Source code for deepfold.eval.plot

from typing import Dict, List, Optional, Tuple

import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable

import deepfold.common.residue_constants as rc
from deepfold.eval.distogram import compute_distogram, compute_predicted_distogram
from deepfold.eval.msa import compute_neff_v2 as compute_neff


def _set_size(
    w: float,
    h: float,
    ax: plt.Axes | None = None,
):
    """w, h: width, height in inches"""
    if not ax:
        ax = plt.gca()
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w) / (r - l)
    figh = float(h) / (t - b)
    ax.figure.set_size_inches(figw, figh)


[docs] def find_cluster_boundaries(a: np.ndarray) -> List[Tuple[int, int, int]]: a = np.asarray(a) assert a.ndim == 1 boundaries = [] start = 0 current = a[0] for i in range(1, len(a)): if a[i] != current: boundaries.append((start, i - 1, current)) start = i current = a[i] boundaries.append((start, len(a) - 1, current)) return boundaries
[docs] def plot_distogram( outputs: dict, asym_id: Optional[np.ndarray] = None, ncols: int = 5, sort: bool = False, fig_kwargs: dict = dict(), ) -> plt.Figure: num_models = len(outputs) nrows = int((num_models + ncols + 1) / ncols) * 2 fig_kwargs.update( { "figsize": (5 * ncols, 9 * nrows), "dpi": 150.0, } ) fig = plt.figure(**fig_kwargs) if sort: sorted_outputs = {k: outputs[k] for k in sorted(outputs.keys())} else: sorted_outputs = outputs for n, (model_name, value) in enumerate(sorted_outputs.items(), start=1): if asym_id is not None: boundaries = find_cluster_boundaries(asym_id) # From the distogram head m = n - 1 ax = fig.add_subplot(2 * nrows, ncols, 2 * (m // ncols) * ncols + m % ncols + 1) ax.set_title(f"{model_name} from the Distogram Head") distogram = compute_predicted_distogram(value["distogram_logits"]) im1 = ax.imshow(distogram, cmap="viridis_r", vmin=0, vmax=22) # Draw chain breaks if asym_id is not None: for i, _, _ in boundaries[1:]: z = i - 0.5 ax.axhline(y=z, color="k", linestyle="-", alpha=0.6) ax.axvline(x=z, color="k", linestyle="-", alpha=0.6) # Pseudo beta positions all_atom_positions = value["final_atom_positions"] all_atom_mask = value["final_atom_mask"] ca_idx = rc.atom_order["CA"] is_gly = all_atom_mask[:, rc.restype_order["G"]] < 0.5 cb_idx = rc.atom_order["CB"] pseudo_beta = np.where( np.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :], ) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im1, cax=cax, orientation="vertical", ticks=[]) # From the final atom positions ax = fig.add_subplot(2 * nrows, ncols, 2 * (m // ncols) * ncols + ncols + m % ncols + 1) ax.set_title(f"{model_name} from the Structure") distogram = compute_distogram(pseudo_beta) im2 = ax.imshow(np.clip(distogram, a_min=0, a_max=22), cmap="viridis_r", vmin=0, vmax=22) # Draw chain breaks if asym_id is not None: for i, _, _ in boundaries[1:]: z = i - 0.5 ax.axhline(y=z, color="k", linestyle="-", alpha=0.6) ax.axvline(x=z, color="k", linestyle="-", alpha=0.6) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im2, cax=cax, orientation="vertical", ticks=[]) return fig
[docs] def plot_predicted_alignment_error( outputs: dict, asym_id: Optional[np.ndarray] = None, fig_kwargs: dict = dict(), ) -> plt.Figure: num_models = len(outputs) fig_kwargs.update( { "figsize": (5 * num_models, 4), "dpi": 150.0, } ) fig = plt.figure(**fig_kwargs) for n, (model_name, value) in enumerate(outputs.items(), start=1): # Draw PAE ax = fig.add_subplot(1, num_models, n) ax.set_title(model_name) im = ax.imshow(value["predicted_aligned_error"], label=model_name, cmap="Greens", vmin=0, vmax=30) # Draw chain breaks if asym_id is not None: boundaries = find_cluster_boundaries(asym_id) for i, _, _ in boundaries[1:]: z = i - 0.5 ax.axhline(y=z, color="k", linestyle="-", alpha=0.6) ax.axvline(x=z, color="k", linestyle="-", alpha=0.6) # fig.colorbar(im) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im, cax=cax, orientation="vertical") return fig
PLDDT_COLORS = [ (0.0, "#ff7d45"), (0.5, "#ffdb13"), (0.7, "#65cbf3"), (0.9, "#0053d6"), (1.0, "#0053d6"), ] plddt_cmap = LinearSegmentedColormap.from_list(name="plddt", colors=PLDDT_COLORS)
[docs] def plot_plddt( outputs: dict, asym_id: Optional[np.ndarray] = None, scale_with_len: bool = False, fig_kwargs: dict = dict(), ) -> plt.Figure: # Scale with the length scale = 1 if scale_with_len: max_len = 0 for v in outputs.values(): max_len = max(max_len, v["plddt"].shape[0]) scale = max(scale, 1 + max_len // 200) fig_kwargs.update( { "figsize": (12 * scale, 5), "dpi": 150.0, } ) fig = plt.figure(**fig_kwargs) ax = fig.add_subplot(1, 1, 1) ax.set_title(rf"Predicted C$\alpha$-lDDT") ranking = [(v["plddt"].mean(), k) for k, v in outputs.items()] ranking.sort(key=lambda x: x[0], reverse=True) for n, (_, key) in enumerate(ranking, start=1): model_name = key value = outputs[key] # Draw plDDT x = np.arange(len(value["plddt"])) + 1 y = value["plddt"] ax.scatter( x=x, y=y, c=(y / 100), cmap=plddt_cmap, marker=".", zorder=2, ) ax.plot(x, y, "-", label=f"Rank {n} ({model_name})", zorder=1) # Draw chain breaks if asym_id is not None: boundaries = find_cluster_boundaries(asym_id) for i, _, _ in boundaries[1:]: z = i - 0.5 ax.axvline(x=z, color="k", linestyle="-.", alpha=0.6) ax.legend(ncols=4, fontsize="smaller") ax.set_ylim(0, 100) ax.set_xlabel("Positions") ax.set_ylabel("plDDT") fig.colorbar(ScalarMappable(norm=Normalize(vmin=0.0, vmax=1.0), cmap=plddt_cmap), ax=ax) return fig
[docs] def plot_msa( feature_dict: Dict[str, np.ndarray], sort_lines: bool = True, dpi: float = 150.0, scale_with_len: bool = False, ) -> plt.Figure: scale = 1 seq = feature_dict["msa"][0] if scale_with_len: scale = max(scale, 1 + seq // 200) if "asym_id" in feature_dict: # Multimer ls = [0] k = feature_dict["asym_id"][0] for i in feature_dict["asym_id"]: if i == k: ls[-1] += 1 else: ls.append(1) k = i else: ls = [len(seq)] ln = np.cumsum([0] + ls) try: n = feature_dict["num_alignments"][0] except: n = feature_dict["num_alignments"] msa = feature_dict["msa"][:n] gap = msa != 21 qid = msa == seq gapid = np.stack([gap[:, ln[i] : ln[i + 1]].max(-1) for i in range(len(ls))], -1) lines = [] nn = [] for g in np.unique(gapid, axis=0): i = np.where((gapid == g).all(axis=-1)) qid_ = qid[i] gap_ = gap[i] seqid = np.stack([qid_[:, ln[i] : ln[i + 1]].mean(-1) for i in range(len(ls))], -1).sum(-1) / (g.sum(-1) + 1e-8) non_gaps = gap_.astype(float) non_gaps[non_gaps == 0] = np.nan if sort_lines: lines_ = non_gaps[seqid.argsort()] * seqid[seqid.argsort(), None] else: lines_ = non_gaps[::-1] * seqid[::-1, None] nn.append(len(lines_)) lines.append(lines_) nn = np.cumsum(np.append(0, nn)) lines = np.concatenate(lines, 0) fig = plt.figure(figsize=(12 * scale, 5), dpi=dpi) ax = fig.add_subplot(1, 1, 1) im = ax.imshow( lines, cmap="rainbow_r", interpolation="nearest", aspect="auto", vmax=1.0, vmin=0.0, origin="lower", extent=(0, lines.shape[1], 0, lines.shape[0]), ) for i in ln[1:-1]: ax.plot([i, i], [0, lines.shape[0]], color="black", alpha=0.6) for j in nn[1:-1]: ax.plot([0, lines.shape[1]], [j, j], color="black", alpha=0.6) neff = compute_neff(msa) title = "MSA Depth" title += r" ($N_{eff}=$" title += f"{neff:6.3f})" ax.set_title(title) ax.plot((np.isnan(lines) == False).sum(0), color="black") ax.set_xlim(0, lines.shape[1]) ax.set_ylim(0, max(lines.shape[0], 100)) fig.colorbar(im) ax.set_xlabel("Positions") ax.set_ylabel("Sequences") return fig