# Copyright 2024 DeepFold Team
# Copyright 2022 DP Technology
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
import logging
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import List, Union
import numpy as np
import torch
from deepfold.modules.alphafold import AlphaFold
logger = logging.getLogger(__name__)
_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
[docs]
def reshape_weight(x: np.ndarray) -> np.ndarray:
len_shape = len(x.shape)
if len_shape == 2:
return x.transpose(-1, -2)
elif len_shape == 1:
return x.reshape(-1, 1)
else:
raise RuntimeError("Wrong shape")
# NOTE: `partial`` prevents `fns` from becoming methods.
[docs]
class ParamType(Enum):
LinearWeight = partial(lambda w: reshape_weight(w))
LinearWeightMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2))
LinearMHAOutputWeight = partial(lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2))
LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
LinearWeightOPM = partial(lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2))
Other = partial(lambda w: w)
def __init__(self, fn):
self.transformation = fn
[docs]
@dataclass
class Param:
param: Union[torch.Tensor, List[torch.Tensor]]
param_type: ParamType = ParamType.Other
stacked: bool = False
swap: bool = False
def _process_translations_dict(d, top_layer=True):
flat = {}
for k, v in d.items():
if type(v) == dict:
prefix = _NPZ_KEY_PREFIX if top_layer else ""
sub_flat = {(prefix + "/".join([k, k_prime])): v_prime for k_prime, v_prime in _process_translations_dict(v, top_layer=False).items()}
flat.update(sub_flat)
else:
k = "/" + k if not top_layer else k
flat[k] = v
return flat
[docs]
def stacked(param_dict_list, out=None):
"""
Args:
param_dict_list:
A list of (nested) Param dicts to stack. The structure of
each dict must be the identical (down to the ParamTypes of
"parallel" Params). There must be at least one dict
in the list.
"""
if out is None:
out = {}
template = param_dict_list[0]
for k, _ in template.items():
v = [d[k] for d in param_dict_list]
if type(v[0]) is dict:
out[k] = {}
stacked(v, out=out[k])
elif type(v[0]) is Param:
stacked_param = Param(
param=[param.param for param in v],
param_type=v[0].param_type,
stacked=True,
swap=v[0].swap,
)
out[k] = stacked_param
return out
[docs]
def assign(translation_dict, orig_weights):
for k, param in translation_dict.items():
with torch.no_grad():
weights = torch.as_tensor(orig_weights[k])
ref, param_type = param.param, param.param_type
if param.stacked:
weights = torch.unbind(weights, 0)
else:
weights = [weights]
ref = [ref]
try:
weights = list(map(param_type.transformation, weights))
for p, w in zip(ref, weights):
if param.swap:
index = p.shape[0] // 2
p[:index].copy_(w[index:])
p[index:].copy_(w[:index])
else:
p.copy_(w)
except:
logger.debug(f"{k}: Incorrect translation from {ref[0].shape} to {weights[0].shape}")
print(k)
print(ref[0].shape)
print(weights[0].shape)
raise
[docs]
def import_jax_weights_(
model: AlphaFold,
npz_path: str,
is_multimer: bool = False,
enable_ptm: bool = False,
enable_templates: bool = False,
fuse_projection_weights: bool = False,
) -> None:
"""Import AlphaFold JAX parameters.
Args:
model: AlphaFold model.
npz_path: Path to an NPZ file.
is_multimer: Whether multimer model or not.
enable_ptm: Enable predicted aligned error related modules.
enable_templates: Enable template related modules.
fuse_projection_weights: Whether triangular multiplicative layers are fused or not.
"""
# Multimer model can predict TM score.
enable_ptm |= is_multimer
data = np.load(npz_path, allow_pickle=True)
if "arr_0" in data:
data = data["arr_0"].flat[0]
global _NPZ_KEY_PREFIX
_NPZ_KEY_PREFIX = "deepfold_batch/deepfold/deepfold_iteration/"
keys = list(data.keys())
for key in keys:
for subkey in data[key]:
data[key + "//" + subkey] = data[key][subkey]
del data[key]
#######################
# Some templates
#######################
LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))
LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True))
LinearBias = lambda l: (Param(l))
LinearBiasSwap = lambda l: (Param(l, swap=True))
LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))
LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))
LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))
LinearParams = lambda l: {
"weights": LinearWeight(l.weight),
"bias": LinearBias(l.bias),
}
LinearLeftParams = lambda l, index: {
"weights": LinearWeight(l.weight[:index, :]),
"bias": LinearBias(l.bias[:index]),
}
LinearRightParams = lambda l, index: {
"weights": LinearWeight(l.weight[index:, :]),
"bias": LinearBias(l.bias[index:]),
}
LinearSwapParams = lambda l, index: {
"weights": LinearWeightSwap(l.weight),
"bias": LinearBiasSwap(l.bias),
}
LinearMHAParams = lambda l: {
"weights": LinearWeightMHA(l.weight),
"bias": LinearBiasMHA(l.bias),
}
LinearNoBiasParams = lambda l: {
"weights": LinearWeight(l.weight),
}
LayerNormParams = lambda l: {
"scale": Param(l.weight),
"offset": Param(l.bias),
}
AttentionParams = lambda att: {
"query_w": LinearWeightMHA(att.linear_q.weight),
"key_w": LinearWeightMHA(att.linear_k.weight),
"value_w": LinearWeightMHA(att.linear_v.weight),
"output_w": Param(
att.linear_o.weight,
param_type=ParamType.LinearMHAOutputWeight,
),
"output_b": LinearBias(att.linear_o.bias),
}
AttentionGatedParams = lambda att: dict(
**AttentionParams(att),
**{
"gating_w": LinearWeightMHA(att.linear_g.weight),
"gating_b": LinearBiasMHA(att.linear_g.bias),
},
)
GlobalAttentionParams = lambda att: dict(
AttentionGatedParams(att),
key_w=LinearWeight(att.linear_k.weight),
value_w=LinearWeight(att.linear_v.weight),
)
TriAttParams = lambda tri_att: {
"query_norm": LayerNormParams(tri_att.layer_norm),
"feat_2d_weights": LinearWeight(tri_att.linear.weight),
"attention": AttentionGatedParams(tri_att.mha),
}
if fuse_projection_weights:
TriMulOutParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
# "left_projection": LinearLeftParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
# "right_projection": LinearRightParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
"projection": LinearParams(tri_mul.linear_ab_p),
# "left_gate": LinearLeftParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
# "right_gate": LinearRightParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
"gate": LinearParams(tri_mul.linear_ab_g),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming iterations of
# triangle multiplication, which is confusing and not reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
# "left_projection": LinearRightParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
# "right_projection": LinearLeftParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0]//2),
"projection": LinearSwapParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0] // 2),
"gate": LinearSwapParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0] // 2),
# "left_gate": LinearRightParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
# "right_gate": LinearLeftParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0]//2),
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
else:
TriMulOutParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearLeftParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0] // 2),
"right_projection": LinearRightParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0] // 2),
"left_gate": LinearLeftParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0] // 2),
"right_gate": LinearRightParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0] // 2),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
# see commit b88f8da on the Alphafold repo
# Alphafold swaps the pseudocode's a and b between the incoming/outcoming iterations of
# triangle multiplication, which is confusing and not reproduced in our implementation.
TriMulInParams = lambda tri_mul: {
"layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
"left_projection": LinearRightParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0] // 2),
"right_projection": LinearLeftParams(tri_mul.linear_ab_p, tri_mul.linear_ab_p.weight.shape[0] // 2),
"left_gate": LinearRightParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0] // 2),
"right_gate": LinearLeftParams(tri_mul.linear_ab_g, tri_mul.linear_ab_g.weight.shape[0] // 2),
"center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
"output_projection": LinearParams(tri_mul.linear_z),
"gating_linear": LinearParams(tri_mul.linear_g),
}
PairTransitionParams = lambda pt: {
"input_layer_norm": LayerNormParams(pt.layer_norm),
"transition1": LinearParams(pt.linear_1),
"transition2": LinearParams(pt.linear_2),
}
MSAAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAColAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": AttentionGatedParams(matt.mha),
}
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt.global_attention),
}
MSAAttPairBiasParams = lambda matt: dict(
**MSAAttParams(matt),
**{
"feat_2d_norm": LayerNormParams(matt.layer_norm_z),
"feat_2d_weights": LinearWeight(matt.linear_z.weight),
},
)
IPAParams = lambda ipa: {
"q_scalar": LinearParams(ipa.linear_q),
"kv_scalar": LinearParams(ipa.linear_kv),
"q_point_local": LinearParams(ipa.linear_q_points),
"kv_point_local": LinearParams(ipa.linear_kv_points),
"trainable_point_weights": Param(param=ipa.head_weights, param_type=ParamType.Other),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
if is_multimer:
MultimerIPAParams = lambda ipa: {
"q_scalar_projection": {"weights": LinearWeightMHA(ipa.linear_q.weight)},
"k_scalar_projection": {"weights": LinearWeightMHA(ipa.linear_k.weight)},
"v_scalar_projection": {"weights": LinearWeightMHA(ipa.linear_v.weight)},
"q_point_projection": {"point_projection": LinearMHAParams(ipa.linear_q_points.linear)},
"k_point_projection": {"point_projection": LinearMHAParams(ipa.linear_k_points.linear)},
"v_point_projection": {"point_projection": LinearMHAParams(ipa.linear_v_points.linear)},
"trainable_point_weights": Param(param=ipa.head_weights, param_type=ParamType.Other),
"attention_2d": LinearParams(ipa.linear_b),
"output_projection": LinearParams(ipa.linear_out),
}
TemplatePairBlockParams = lambda b: {
"triangle_attention_starting_node": TriAttParams(b.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.tri_att_end),
"triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
"pair_transition": PairTransitionParams(b.pair_transition),
}
MSATransitionParams = lambda m: {
"input_layer_norm": LayerNormParams(m.layer_norm),
"transition1": LinearParams(m.linear_1),
"transition2": LinearParams(m.linear_2),
}
OuterProductMeanParams = lambda o: {
"layer_norm_input": LayerNormParams(o.layer_norm),
"left_projection": LinearParams(o.linear_1),
"right_projection": LinearParams(o.linear_2),
"output_w": LinearWeightOPM(o.linear_out.weight),
"output_b": LinearBias(o.linear_out.bias),
}
def EvoformerBlockParams(b, is_extra_msa=False):
if is_extra_msa:
col_att_name = "msa_column_global_attention"
msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
else:
col_att_name = "msa_column_attention"
msa_col_att_params = MSAColAttParams(b.msa_att_col)
d = {
"msa_row_attention_with_pair_bias": MSAAttPairBiasParams(b.msa_att_row),
col_att_name: msa_col_att_params,
"msa_transition": MSATransitionParams(b.msa_transition),
"outer_product_mean": OuterProductMeanParams(b.outer_product_mean),
"triangle_multiplication_outgoing": TriMulOutParams(b.pair_core.tri_mul_out),
"triangle_multiplication_incoming": TriMulInParams(b.pair_core.tri_mul_in),
"triangle_attention_starting_node": TriAttParams(b.pair_core.tri_att_start),
"triangle_attention_ending_node": TriAttParams(b.pair_core.tri_att_end),
"pair_transition": PairTransitionParams(b.pair_core.pair_transition),
}
return d
ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)
FoldIterationParams = lambda sm: {
"invariant_point_attention": IPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.linear_1),
"transition_1": LinearParams(sm.transition.linear_2),
"transition_2": LinearParams(sm.transition.linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"affine_update": LinearParams(sm.bb_update.linear),
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
if is_multimer:
MultimerFoldIterationParams = lambda sm: {
"invariant_point_attention": MultimerIPAParams(sm.ipa),
"attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
"transition": LinearParams(sm.transition.linear_1),
"transition_1": LinearParams(sm.transition.linear_2),
"transition_2": LinearParams(sm.transition.linear_3),
"transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
"quat_rigid": {"rigid": LinearParams(sm.bb_update.linear)},
"rigid_sidechain": {
"input_projection": LinearParams(sm.angle_resnet.linear_in),
"input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
"resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
"resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
"resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
"resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
"unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
},
}
############################
# translations dict overflow
############################
tps_blocks_params = None
template_pair_ln = None
template_angle_emb = None
template_angle_proj = None
# if model.template_pair_stack is not None:
if enable_templates:
tps_blocks = model.template_pair_stack.blocks
tps_blocks_params = stacked([TemplatePairBlockParams(b) for b in tps_blocks])
template_pair_ln = LayerNormParams(model.template_pair_stack.layer_norm)
template_angle_emb = LinearParams(model.template_angle_embedder.linear_1)
template_angle_proj = LinearParams(model.template_angle_embedder.linear_2)
ems_blocks = model.extra_msa_stack.blocks
ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])
evo_blocks = model.evoformer_stack.blocks
evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])
translations = {
"evoformer": {
"preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
"preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
"left_single": LinearParams(model.input_embedder.linear_tf_z_i),
"right_single": LinearParams(model.input_embedder.linear_tf_z_j),
"prev_pos_linear": LinearParams(model.recycling_embedder.linear),
"prev_msa_first_row_norm": LayerNormParams(model.recycling_embedder.layer_norm_m),
"prev_pair_norm": LayerNormParams(model.recycling_embedder.layer_norm_z),
"pair_activiations": LinearParams(model.input_embedder.linear_relpos),
"template_embedding": {
"single_template_embedding": {
"template_pair_stack": {
"__layer_stack_no_state": tps_blocks_params,
},
"output_layer_norm": template_pair_ln,
},
# "attention": AttentionParams(model.template_pointwise_att.mha),
},
"extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
"extra_msa_stack": ems_blocks_params,
"template_single_embedding": template_angle_emb,
"template_projection": template_angle_proj,
"evoformer_iteration": evo_blocks_params,
"single_activations": LinearParams(model.evoformer_stack.linear),
},
"structure_module": {
"single_layer_norm": LayerNormParams(model.structure_module.layer_norm_s),
"initial_projection": LinearParams(model.structure_module.linear_in),
"pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
"fold_iteration": (MultimerFoldIterationParams(model.structure_module) if is_multimer else FoldIterationParams(model.structure_module)),
},
"predicted_lddt_head": {
"input_layer_norm": LayerNormParams(model.auxiliary_heads.plddt.layer_norm),
"act_0": LinearParams(model.auxiliary_heads.plddt.linear_1),
"act_1": LinearParams(model.auxiliary_heads.plddt.linear_2),
"logits": LinearParams(model.auxiliary_heads.plddt.linear_3),
},
"distogram_head": {
"half_logits": LinearParams(model.auxiliary_heads.distogram.linear),
},
"experimentally_resolved_head": {
"logits": LinearParams(model.auxiliary_heads.experimentally_resolved.linear),
},
"masked_msa_head": {
"logits": LinearParams(model.auxiliary_heads.masked_msa.linear),
},
}
if not enable_templates:
evo_dict = translations["evoformer"]
keys = list(evo_dict.keys())
for k in keys:
if "template_" in k:
evo_dict.pop(k)
if enable_ptm:
translations["predicted_aligned_error_head"] = {"logits": LinearParams(model.auxiliary_heads.tm.linear)}
# fmt: off
if is_multimer:
del translations["evoformer"]["pair_activiations"]
del translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_stack"]
translations["predicted_aligned_error_head"] = {"logits": LinearParams(model.auxiliary_heads.tm.linear)}
translations["evoformer"]["~_relative_encoding"] = {}
translations["evoformer"]["~_relative_encoding"]["position_activations"] = LinearParams(model.input_embedder.linear_relpos)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["query_embedding_norm"] = LayerNormParams(model.template_pair_embedder.query_embedding_layer_norm)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_0"] = LinearParams(model.template_pair_embedder.dgram_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_1"] = LinearParams(model.template_pair_embedder.pseudo_beta_mask_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_2"] = LinearParams(model.template_pair_embedder.aatype_linear_1)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_3"] = LinearParams(model.template_pair_embedder.aatype_linear_2)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_4"] = LinearParams(model.template_pair_embedder.x_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_5"] = LinearParams(model.template_pair_embedder.y_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_6"] = LinearParams(model.template_pair_embedder.z_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_7"] = LinearParams(model.template_pair_embedder.backbone_mask_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_pair_embedding_8"] = LinearParams(model.template_pair_embedder.query_embedding_linear)
translations["evoformer"]["template_embedding"]["single_template_embedding"]["template_embedding_iteration"] = tps_blocks_params
translations["evoformer"]["template_embedding"]["output_linear"] = LinearParams(model.template_projection.linear_t)
else:
if enable_templates:
translations["evoformer"]["template_embedding"]["single_template_embedding"]["embedding2d"] = LinearParams(model.template_pair_embedder.linear)
translations["evoformer"]["template_embedding"]["attention"] = AttentionParams(model.template_pointwise_attention.mha)
# fmt: on
# Flatten keys and insert missing key prefixes
flat = _process_translations_dict(translations)
# Sanity check
keys = list(data.keys())
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
# assert len(missing) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
logger.debug("incorrect keys:", incorrect) # which with error names
logger.debug("missing keys:", missing) # which with
# Set weights
assign(flat, data)