promptbind.models package¶
Submodules¶
promptbind.models.att_model module¶
- class promptbind.models.att_model.ComplexGraph(args, inter_cutoff=10, intra_cutoff=8, normalize_coord=None, unnormalize_coord=None)¶
Bases:
Module
- construct_edges(X, batch_id, segment_ids, is_global)¶
Memory efficient with complexity of O(Nn) where n is the largest number of nodes in the batch
- forward(X, batch_id, segment_id, is_global)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class promptbind.models.att_model.EfficientMCAttModel(args, embed_size, hidden_size, prompt_nf, n_channel, n_edge_feats=0, n_layers=5, dropout=0.1, n_iter=5, dense=False, inter_cutoff=10, intra_cutoff=8, normalize_coord=None, unnormalize_coord=None)¶
Bases:
Module
- forward(X, H, batch_id, segment_id, mask, is_global, compound_edge_index, LAS_edge_index, batched_complex_coord_LAS, LAS_mask=None, prompt_node=None, prompt_coord=None)¶
- Parameters:
X – [n_all_node, n_channel, 3]
S – [n_all_node]
batch – [n_all_node]
- promptbind.models.att_model.sequential_and(*tensors)¶
- promptbind.models.att_model.sequential_or(*tensors)¶
promptbind.models.cross_att module¶
- class promptbind.models.cross_att.CrossAttentionModule(node_hidden_dim, pair_hidden_dim, rm_layernorm=False, keep_trig_attn=False, dist_hidden_dim=32, normalize_coord=None)¶
Bases:
Module
- forward(p_embed_batched, p_mask, c_embed_batched, c_mask, pair_embed, pair_mask, c_c_dist_embed=None, p_p_dist_embed=None)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class promptbind.models.cross_att.RowAttentionBlock(node_hidden_dim, pair_hidden_dim, attention_hidden_dim=32, no_heads=4, dropout=0.1, rm_layernorm=False)¶
Bases:
Module
- forward(node_embed_i, node_embed_j, pair_embed, pair_mask, node_mask_i)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- inf = 1000000000.0¶
- class promptbind.models.cross_att.RowTriangleAttentionBlock(pair_hidden_dim, dist_hidden_dim, attention_hidden_dim=32, no_heads=4, dropout=0.1, rm_layernorm=False)¶
Bases:
Module
- forward(pair_embed, pair_mask, dist_embed)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- inf = 1000000000.0¶
promptbind.models.egnn module¶
- Most codes are copied from https://github.com/vgsatorras/egnn, which is the official implementation of the paper:
E(n) Equivariant Graph Neural Networks Victor Garcia Satorras, Emiel Hogeboom, Max Welling
- class promptbind.models.egnn.MCAttEGNN(args, in_node_nf, hidden_nf, out_node_nf, prompt_nf, n_channel, in_edge_nf=0, act_fn=SiLU(), n_layers=4, residual=True, dropout=0.1, dense=False, normalize_coord=None, unnormalize_coord=None, geometry_reg_step_size=0.001)¶
Bases:
Module
- forward(h, x, ctx_edges, att_edges, LAS_edge_list, batched_complex_coord_LAS, segment_id=None, batch_id=None, reduced_tuple=None, pair_embed_batched=None, pair_mask=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None, mask=None, ctx_edge_attr=None, att_edge_attr=None, return_attention=False, prompt_node=None, prompt_coord=None)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class promptbind.models.egnn.MC_Att_L(args, input_nf, output_nf, hidden_nf, n_channel, edges_in_d=0, act_fn=SiLU(), dropout=0.1, coord_change_maximum=10, opm=False, normalize_coord=None)¶
Bases:
Module
Multi-Channel Attention Layer
- att_model(h, edge_index, radial, edge_attr, pair_embed=None)¶
- Parameters:
h – [bs * n_node, input_size]
edge_index – list of [n_edge], [n_edge]
radial – [n_edge, n_channel, n_channel]
edge_attr – [n_edge, edge_dim]
- coord_model(coord, edge_index, coord_diff, att_weight, v)¶
- Parameters:
coord – [bs * n_node, n_channel, d]
edge_index – list of [n_edge], [n_edge]
coord_diff – [n_edge, n_channel, d]
att_weight – [n_edge, 1], unsqueezed before passed in
v – [n_edge, hidden_size]
- forward(h, edge_index, coord, edge_attr=None, segment_id=None, batch_id=None, reduced_tuple=None, pair_embed_batched=None, pair_mask=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- node_model(h, edge_index, att_weight, v)¶
- Parameters:
h – [bs * n_node, input_size]
edge_index – list of [n_edge], [n_edge]
att_weight – [n_edge, 1], unsqueezed before passed in
v – [n_edge, hidden_size]
- trio_encoder(h, edge_index, coord, pair_embed_batched=None, pair_mask=None, batch_id=None, segment_id=None, reduced_tuple=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None)¶
- class promptbind.models.egnn.MC_E_GCL(args, input_nf, output_nf, hidden_nf, n_channel, edges_in_d=0, act_fn=SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False, dropout=0.1, coord_change_maximum=10)¶
Bases:
Module
Multi-Channel E(n) Equivariant Convolutional Layer
- coord_model(coord, edge_index, coord_diff, edge_feat)¶
coord: [bs * n_node, n_channel, d] edge_index: list of [n_edge], [n_edge] coord_diff: [n_edge, n_channel, d] edge_feat: [n_edge, hidden_size]
- edge_model(source, target, radial, edge_attr)¶
- Parameters:
source – [n_edge, input_size]
target – [n_edge, input_size]
radial – [n_edge, n_channel, n_channel]
edge_attr – [n_edge, edge_dim]
- forward(h, edge_index, coord, edge_attr=None, node_attr=None, batch_id=None)¶
h: [bs * n_node, hidden_size] edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1)) coord: [bs * n_node, n_channel, d]
- node_model(x, edge_index, edge_attr, node_attr)¶
- Parameters:
x – [bs * n_node, input_size]
edge_index – list of [n_edge], [n_edge]
edge_attr – [n_edge, hidden_size], refers to message from i to j
node_attr – [bs * n_node, node_dim]
- class promptbind.models.egnn.MC_E_GCL_Prompt(args, input_nf, output_nf, hidden_nf, prompt_nf, n_channel, edges_in_d=0, act_fn=SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False, dropout=0.1, coord_change_maximum=10)¶
Bases:
Module
Multi-Channel E(n) Equivariant Convolutional Layer
- coord_model(coord, edge_index, coord_diff, edge_feat_agg)¶
coord: [bs * n_node, n_channel, d] edge_index: list of [n_edge], [n_edge] coord_diff: [n_edge, n_channel, d] edge_feat: [n_edge, hidden_size]
- edge_model(source, target, radial, edge_attr)¶
- Parameters:
source – [n_edge, input_size]
target – [n_edge, input_size]
radial – [n_edge, n_channel, n_channel]
edge_attr – [n_edge, edge_dim]
- forward(h, edge_index, coord, edge_attr=None, node_attr=None, batch_id=None, prompt_node=None, prompt_coord=None)¶
h: [bs * n_node, hidden_size] edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1)) coord: [bs * n_node, n_channel, d]
- node_model(x, edge_index, edge_attr_agg, node_attr)¶
- Parameters:
x – [bs * n_node, input_size]
edge_index – list of [n_edge], [n_edge]
edge_attr – [n_edge, hidden_size], refers to message from i to j
node_attr – [bs * n_node, node_dim]
- prompt_generation_module(edge_feat, prompt_node, prompt_coord)¶
- prompt_interaction_module(edge_feat, prompt_node_feat, prompt_coord_feat)¶
- promptbind.models.egnn.coord2radial(edge_index, coord, rm_F_norm, batch_id=None, norm_type=None)¶
- promptbind.models.egnn.get_edges(n_nodes)¶
- promptbind.models.egnn.get_edges_batch(n_nodes, batch_size)¶
promptbind.models.model module¶
- class promptbind.models.model.IaBNet_mean_and_pocket_prediction_cls_coords_dependent(args, embedding_channels=128, pocket_pred_embedding_channels=128)¶
Bases:
Module
- forward(data, stage=1, train=False)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- inference(data)¶
- class promptbind.models.model.Transition_diff_out_dim(embedding_channels=256, out_channels=256, n=4)¶
Bases:
Module
- forward(z)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- promptbind.models.model.get_model(args, logger, device)¶
promptbind.models.model_utils module¶
- class promptbind.models.model_utils.Attention(c_q: int, c_k: int, c_v: int, c_hidden: int, no_heads: int, gating: bool = True)¶
Bases:
Module
Standard multi-head attention using AlphaFold’s default layer initialization. Allows multiple bias vectors.
- class promptbind.models.model_utils.GaussianSmearing(start=0.0, stop=5.0, num_gaussians=50)¶
Bases:
Module
- forward(dist)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class promptbind.models.model_utils.InteractionModule(node_hidden_dim, pair_hidden_dim, hidden_dim, opm=False, rm_layernorm=False)¶
Bases:
Module
- forward(p_embed, c_embed, p_mask=None, c_mask=None)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class promptbind.models.model_utils.RBFDistanceModule(rbf_stop, distance_hidden_dim, num_gaussian=32, dropout=0.1)¶
Bases:
Module
- forward(distance)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class promptbind.models.model_utils.Transition(hidden_dim=128, n=4, rm_layernorm=False)¶
Bases:
Module
- forward(x)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- promptbind.models.model_utils.flatten_final_dims(t: Tensor, no_dims: int)¶
- promptbind.models.model_utils.permute_final_dims(tensor: Tensor, inds: List[int])¶