diffalign.models.epsnet package¶
Submodules¶
diffalign.models.epsnet.diffusion module¶
- class diffalign.models.epsnet.diffusion.DiffAlign¶
Bases:
PatchedModule
- DDIM_CFG_Sampling(query_batch, reference_batch, n_steps)¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
n_steps (int) – A number of steps for DDIM sampling.
- Returns:
Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).
- Return type:
torch.tensor
- DDPM_CFG_Sampling(query_batch, reference_batch)¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
- Returns:
Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).
- Return type:
torch.tensor
- DDPM_Sampling(query_batch, reference_batch)¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
- Returns:
Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).
- Return type:
torch.tensor
- forward(query_batch, reference_batch, time_step, condition=True)¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
time_step (torch.tensor) – A time step index vector for each node shaped (num_batches,).
condition (bool, optional) – A flag indicating whether to apply conditioning. Defaults to True.
- Returns:
Predicted noise of shape (n_nodes, 3).
- Return type:
torch.tensor
- get_loss(query_batch, reference_batch)¶
- Parameters:
quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.
reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.
- Returns:
Calculated loss, a scalar tensor.
- Return type:
torch.tensor
- class diffalign.models.epsnet.diffusion.SinusoidalTimeEmbeddings(out_dim)¶
Bases:
PatchedModule
Sinusoidal Time Embedder.
- Parameters:
out_dim (int) – The output dimension of the embedding.
- forward(time)¶
- Parameters:
time (torch.tensor) – A tensor of shaped (num_nodes, 1) representing time values of each node.
- Returns:
A tensor of shape (num_nodes, self.out_dim) containing the sinusoidal time embeddings.
- Return type:
torch.tensor
- diffalign.models.epsnet.diffusion.get_distance(pos, edge_index)¶
- diffalign.models.epsnet.diffusion.merge_graphs_in_batch(batch1, batch2)¶