diffalign.models.epsnet package¶
Submodules¶
diffalign.models.epsnet.diffusion module¶
- class diffalign.models.epsnet.diffusion.DiffAlign[source]¶
Bases:
Module
- DDIM_CFG_Sampling(query_batch, reference_batch, n_steps)[source]¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
n_steps (int) – A number of steps for DDIM sampling.
- Returns:
Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).
- Return type:
torch.tensor
- DDPM_CFG_Sampling(query_batch, reference_batch)[source]¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
- Returns:
Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).
- Return type:
torch.tensor
- DDPM_Sampling(query_batch, reference_batch)[source]¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
- Returns:
Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).
- Return type:
torch.tensor
- forward(query_batch, reference_batch, time_step, condition=True)[source]¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
time_step (torch.tensor) – A time step index vector for each node shaped (num_batches,).
condition (bool, optional) – A flag indicating whether to apply conditioning. Defaults to True.
- Returns:
Predicted noise of shape (n_nodes, 3).
- Return type:
torch.tensor
- get_loss(query_batch, reference_batch)[source]¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
- Returns:
Calculated loss, a scalar tensor.
- Return type:
torch.tensor