deepfold.utils.chunk_utils.chunk_layer¶
- deepfold.utils.chunk_utils.chunk_layer(layer: Callable, inputs: Dict[str, Any], chunk_size: int, num_batch_dims: int, low_mem: bool = False, _out: Any | None = None, _add_into_out: bool = False) Any[source]¶
Implements the “chunking” procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple “pytrees,” consisting only of (arbitrarily nested) lists, tuples, and dicts with torch.Tensor leaves.
- Parameters:
layer – The layer to be applied chunk-wise
inputs – A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch dimensions.
chunk_size – The number of sub-batches per chunk. If multiple batch dimensions are specified, a “sub-batch” is defined as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product of the batch dimensions).
no_batch_dims – How many of the initial dimensions of each input tensor can be considered batch dimensions.
low_mem – Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly slower than the default setting.
- Returns:
The reassembled output of the layer on the inputs.