promptbind.models package

Submodules

promptbind.models.att_model module

class promptbind.models.att_model.ComplexGraph(args, inter_cutoff=10, intra_cutoff=8, normalize_coord=None, unnormalize_coord=None)

Bases: Module

construct_edges(X, batch_id, segment_ids, is_global)

Memory efficient with complexity of O(Nn) where n is the largest number of nodes in the batch

forward(X, batch_id, segment_id, is_global)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class promptbind.models.att_model.EfficientMCAttModel(args, embed_size, hidden_size, prompt_nf, n_channel, n_edge_feats=0, n_layers=5, dropout=0.1, n_iter=5, dense=False, inter_cutoff=10, intra_cutoff=8, normalize_coord=None, unnormalize_coord=None)

Bases: Module

forward(X, H, batch_id, segment_id, mask, is_global, compound_edge_index, LAS_edge_index, batched_complex_coord_LAS, LAS_mask=None, prompt_node=None, prompt_coord=None)
Parameters:
  • X – [n_all_node, n_channel, 3]

  • S – [n_all_node]

  • batch – [n_all_node]

promptbind.models.att_model.sequential_and(*tensors)
promptbind.models.att_model.sequential_or(*tensors)

promptbind.models.cross_att module

class promptbind.models.cross_att.CrossAttentionModule(node_hidden_dim, pair_hidden_dim, rm_layernorm=False, keep_trig_attn=False, dist_hidden_dim=32, normalize_coord=None)

Bases: Module

forward(p_embed_batched, p_mask, c_embed_batched, c_mask, pair_embed, pair_mask, c_c_dist_embed=None, p_p_dist_embed=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class promptbind.models.cross_att.RowAttentionBlock(node_hidden_dim, pair_hidden_dim, attention_hidden_dim=32, no_heads=4, dropout=0.1, rm_layernorm=False)

Bases: Module

forward(node_embed_i, node_embed_j, pair_embed, pair_mask, node_mask_i)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inf = 1000000000.0
class promptbind.models.cross_att.RowTriangleAttentionBlock(pair_hidden_dim, dist_hidden_dim, attention_hidden_dim=32, no_heads=4, dropout=0.1, rm_layernorm=False)

Bases: Module

forward(pair_embed, pair_mask, dist_embed)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inf = 1000000000.0

promptbind.models.egnn module

Most codes are copied from https://github.com/vgsatorras/egnn, which is the official implementation of the paper:

E(n) Equivariant Graph Neural Networks Victor Garcia Satorras, Emiel Hogeboom, Max Welling

class promptbind.models.egnn.MCAttEGNN(args, in_node_nf, hidden_nf, out_node_nf, prompt_nf, n_channel, in_edge_nf=0, act_fn=SiLU(), n_layers=4, residual=True, dropout=0.1, dense=False, normalize_coord=None, unnormalize_coord=None, geometry_reg_step_size=0.001)

Bases: Module

forward(h, x, ctx_edges, att_edges, LAS_edge_list, batched_complex_coord_LAS, segment_id=None, batch_id=None, reduced_tuple=None, pair_embed_batched=None, pair_mask=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None, mask=None, ctx_edge_attr=None, att_edge_attr=None, return_attention=False, prompt_node=None, prompt_coord=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class promptbind.models.egnn.MC_Att_L(args, input_nf, output_nf, hidden_nf, n_channel, edges_in_d=0, act_fn=SiLU(), dropout=0.1, coord_change_maximum=10, opm=False, normalize_coord=None)

Bases: Module

Multi-Channel Attention Layer

att_model(h, edge_index, radial, edge_attr, pair_embed=None)
Parameters:
  • h – [bs * n_node, input_size]

  • edge_index – list of [n_edge], [n_edge]

  • radial – [n_edge, n_channel, n_channel]

  • edge_attr – [n_edge, edge_dim]

coord_model(coord, edge_index, coord_diff, att_weight, v)
Parameters:
  • coord – [bs * n_node, n_channel, d]

  • edge_index – list of [n_edge], [n_edge]

  • coord_diff – [n_edge, n_channel, d]

  • att_weight – [n_edge, 1], unsqueezed before passed in

  • v – [n_edge, hidden_size]

forward(h, edge_index, coord, edge_attr=None, segment_id=None, batch_id=None, reduced_tuple=None, pair_embed_batched=None, pair_mask=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

node_model(h, edge_index, att_weight, v)
Parameters:
  • h – [bs * n_node, input_size]

  • edge_index – list of [n_edge], [n_edge]

  • att_weight – [n_edge, 1], unsqueezed before passed in

  • v – [n_edge, hidden_size]

trio_encoder(h, edge_index, coord, pair_embed_batched=None, pair_mask=None, batch_id=None, segment_id=None, reduced_tuple=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None)
class promptbind.models.egnn.MC_E_GCL(args, input_nf, output_nf, hidden_nf, n_channel, edges_in_d=0, act_fn=SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False, dropout=0.1, coord_change_maximum=10)

Bases: Module

Multi-Channel E(n) Equivariant Convolutional Layer

coord_model(coord, edge_index, coord_diff, edge_feat)

coord: [bs * n_node, n_channel, d] edge_index: list of [n_edge], [n_edge] coord_diff: [n_edge, n_channel, d] edge_feat: [n_edge, hidden_size]

edge_model(source, target, radial, edge_attr)
Parameters:
  • source – [n_edge, input_size]

  • target – [n_edge, input_size]

  • radial – [n_edge, n_channel, n_channel]

  • edge_attr – [n_edge, edge_dim]

forward(h, edge_index, coord, edge_attr=None, node_attr=None, batch_id=None)

h: [bs * n_node, hidden_size] edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1)) coord: [bs * n_node, n_channel, d]

node_model(x, edge_index, edge_attr, node_attr)
Parameters:
  • x – [bs * n_node, input_size]

  • edge_index – list of [n_edge], [n_edge]

  • edge_attr – [n_edge, hidden_size], refers to message from i to j

  • node_attr – [bs * n_node, node_dim]

class promptbind.models.egnn.MC_E_GCL_Prompt(args, input_nf, output_nf, hidden_nf, prompt_nf, n_channel, edges_in_d=0, act_fn=SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False, dropout=0.1, coord_change_maximum=10)

Bases: Module

Multi-Channel E(n) Equivariant Convolutional Layer

coord_model(coord, edge_index, coord_diff, edge_feat_agg)

coord: [bs * n_node, n_channel, d] edge_index: list of [n_edge], [n_edge] coord_diff: [n_edge, n_channel, d] edge_feat: [n_edge, hidden_size]

edge_model(source, target, radial, edge_attr)
Parameters:
  • source – [n_edge, input_size]

  • target – [n_edge, input_size]

  • radial – [n_edge, n_channel, n_channel]

  • edge_attr – [n_edge, edge_dim]

forward(h, edge_index, coord, edge_attr=None, node_attr=None, batch_id=None, prompt_node=None, prompt_coord=None)

h: [bs * n_node, hidden_size] edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1)) coord: [bs * n_node, n_channel, d]

node_model(x, edge_index, edge_attr_agg, node_attr)
Parameters:
  • x – [bs * n_node, input_size]

  • edge_index – list of [n_edge], [n_edge]

  • edge_attr – [n_edge, hidden_size], refers to message from i to j

  • node_attr – [bs * n_node, node_dim]

prompt_generation_module(edge_feat, prompt_node, prompt_coord)
prompt_interaction_module(edge_feat, prompt_node_feat, prompt_coord_feat)
promptbind.models.egnn.coord2radial(edge_index, coord, rm_F_norm, batch_id=None, norm_type=None)
promptbind.models.egnn.get_edges(n_nodes)
promptbind.models.egnn.get_edges_batch(n_nodes, batch_size)
promptbind.models.egnn.unsorted_segment_mean(data, segment_ids, num_segments)
Parameters:
  • data – [n_edge, *dimensions]

  • segment_ids – [n_edge]

  • num_segments – [bs * n_node]

promptbind.models.egnn.unsorted_segment_sum(data, segment_ids, num_segments)
Parameters:
  • data – [n_edge, *dimensions]

  • segment_ids – [n_edge]

  • num_segments – [bs * n_node]

promptbind.models.model module

class promptbind.models.model.IaBNet_mean_and_pocket_prediction_cls_coords_dependent(args, embedding_channels=128, pocket_pred_embedding_channels=128)

Bases: Module

forward(data, stage=1, train=False)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inference(data)
class promptbind.models.model.Transition_diff_out_dim(embedding_channels=256, out_channels=256, n=4)

Bases: Module

forward(z)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

promptbind.models.model.get_model(args, logger, device)

promptbind.models.model_utils module

class promptbind.models.model_utils.Attention(c_q: int, c_k: int, c_v: int, c_hidden: int, no_heads: int, gating: bool = True)

Bases: Module

Standard multi-head attention using AlphaFold’s default layer initialization. Allows multiple bias vectors.

forward(q_x: Tensor, kv_x: Tensor, biases: List[Tensor] | None = None) Tensor
Parameters:
  • q_x – [*, Q, C_q] query data

  • kv_x – [*, K, C_k] key data

  • biases – List of biases that broadcast to [*, H, Q, K]

Returns

[*, Q, C_q] attention update

class promptbind.models.model_utils.GaussianSmearing(start=0.0, stop=5.0, num_gaussians=50)

Bases: Module

forward(dist)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class promptbind.models.model_utils.InteractionModule(node_hidden_dim, pair_hidden_dim, hidden_dim, opm=False, rm_layernorm=False)

Bases: Module

forward(p_embed, c_embed, p_mask=None, c_mask=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class promptbind.models.model_utils.RBFDistanceModule(rbf_stop, distance_hidden_dim, num_gaussian=32, dropout=0.1)

Bases: Module

forward(distance)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class promptbind.models.model_utils.Transition(hidden_dim=128, n=4, rm_layernorm=False)

Bases: Module

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

promptbind.models.model_utils.flatten_final_dims(t: Tensor, no_dims: int)
promptbind.models.model_utils.permute_final_dims(tensor: Tensor, inds: List[int])

Module contents