diffalign.models.epsnet package

Submodules

diffalign.models.epsnet.diffusion module

class diffalign.models.epsnet.diffusion.DiffAlign

Bases: PatchedModule

DDIM_CFG_Sampling(query_batch, reference_batch, n_steps)
Parameters:
  • quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.

  • reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.

  • n_steps (int) – A number of steps for DDIM sampling.

Returns:

Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).

Return type:

torch.tensor

DDPM_CFG_Sampling(query_batch, reference_batch)
Parameters:
  • quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.

  • reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.

Returns:

Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).

Return type:

torch.tensor

DDPM_Sampling(query_batch, reference_batch)
Parameters:
  • quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.

  • reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.

Returns:

Predicted position of query molecule shaped (num_query_nodes, 3). list[torch.tensor]: Trajectory of query molecule, contatining 999 tensors each shaped (num_query_nodes, 3).

Return type:

torch.tensor

forward(query_batch, reference_batch, time_step, condition=True)
Parameters:
  • quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.

  • reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.

  • time_step (torch.tensor) – A time step index vector for each node shaped (num_batches,).

  • condition (bool, optional) – A flag indicating whether to apply conditioning. Defaults to True.

Returns:

Predicted noise of shape (n_nodes, 3).

Return type:

torch.tensor

get_loss(query_batch, reference_batch)
Parameters:
  • quey_batch (torch_geometric.data.Batch) – A batch for query molecules containg atom_type, edge_index, edge_type, pos, and batch as attributes.

  • reference_batch (torch_geometric.data.Batch) – A batch for reference molecules with the same structure as query_batch.

Returns:

Calculated loss, a scalar tensor.

Return type:

torch.tensor

class diffalign.models.epsnet.diffusion.SinusoidalTimeEmbeddings(out_dim)

Bases: PatchedModule

Sinusoidal Time Embedder.

Parameters:

out_dim (int) – The output dimension of the embedding.

forward(time)
Parameters:

time (torch.tensor) – A tensor of shaped (num_nodes, 1) representing time values of each node.

Returns:

A tensor of shape (num_nodes, self.out_dim) containing the sinusoidal time embeddings.

Return type:

torch.tensor

diffalign.models.epsnet.diffusion.get_distance(pos, edge_index)
diffalign.models.epsnet.diffusion.merge_graphs_in_batch(batch1, batch2)

Module contents