Source code for deepfold.utils.dist_utils
import torch
[docs]
def pad_tensor(tensor: torch.Tensor, dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
if dim < 0:
pad = [0, 0] * -dim
else:
pad = [0, 0] * (tensor.ndim - dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad, mode="constant", value=0.0)
[docs]
def get_pad_size(tensor: torch.Tensor, dim: int, num_chunks: int) -> int:
seq_len = tensor.size(dim)
chunk_size = (seq_len + num_chunks - 1) // num_chunks
return num_chunks * chunk_size - seq_len