import argparse
import os
import json
TRUNK_PARAMS = ['n_extra_block', 'n_main_block', 'n_ref_block',\
'd_target', 'd_msa', 'd_pair', 'd_templ',\
'n_head_msa', 'n_head_pair', 'n_head_templ', 'd_hidden', 'd_hidden_templ', 'p_drop']
TRUNK_PARAMS = ['n_extra_block', 'n_main_block', 'n_ref_block',\
'd_msa', 'd_pair', 'd_templ',\
'n_head_msa', 'n_head_pair', 'n_head_templ', 'd_hidden', 'd_hidden_templ', 'p_drop']
base_SE3 = ['num_layers', 'num_channels', 'l0_in_features', 'l0_out_features', 'l1_in_features', 'l1_out_features', 'num_edge_features']
SE3_PARAMS = ['num_degrees', 'n_heads', 'div']
for se3 in base_SE3:
for suffix in ['full', 'topk', 'SC']:
SE3_PARAMS.append("%s_%s"%(se3, suffix))
[docs]
def get_args():
parser = argparse.ArgumentParser()
# JSON input file
parser.add_argument('-json_input', type=str, required=False,
help="Path to JSON file containing input data")
parser.add_argument('-output_dir', type=str, required=False,
help="Path to output directory")
parser.add_argument('-seed_num', type=int, required=False, default=1,
help="Seed number for random number generator")
# training parameters
train_group = parser.add_argument_group("training parameters")
train_group.add_argument("-model_name", default="BFF",
help="model name for saving")
train_group.add_argument('-batch_size', type=int, default=1,
help="Batch size [1]")
train_group.add_argument("-seed", type=int, default=0,
help="seed for random number, should be randomized for different training run [0]")
# data-loading parameters
data_group = parser.add_argument_group("data loading parameters")
data_group.add_argument('-maxseq', type=int, default=1024,
help="Maximum depth of subsampled MSA [1024]")
data_group.add_argument('-maxlat', type=int, default=128,
help="Maximum depth of subsampled MSA [128]")
data_group.add_argument("-crop", type=int, default=256,
help="Upper limit of crop size [256]")
data_group.add_argument('-mintplt', type=int, default=0,
help="Minimum number of templates to select [0]")
data_group.add_argument('-maxtplt', type=int, default=4,
help="maximum number of templates to select [4]")
data_group.add_argument("-rescut", type=float, default=5.0,
help="Resolution cutoff [5.0]")
data_group.add_argument("-datcut", default="2020-Apr-30",
help="PDB release date cutoff [2020-Apr-30]")
data_group.add_argument('-plddtcut', type=float, default=70.0,
help="pLDDT cutoff for distillation set [70.0]")
data_group.add_argument('-seqid', type=float, default=150.0,
help="maximum sequence identity cutoff for template selection [150.0]")
data_group.add_argument('-maxcycle', type=int, default=4,
help="maximum number of recycle [4]")
# Trunk module properties
trunk_group = parser.add_argument_group("Trunk module parameters")
trunk_group.add_argument('-n_extra_block', type=int, default=4,
help="Number of iteration blocks for extra sequences [4]")
trunk_group.add_argument('-n_main_block', type=int, default=48,
help="Number of iteration blocks for main sequences [48]")
trunk_group.add_argument('-n_ref_block', type=int, default=4,
help="Number of refinement layers [4]")
# trunk_group.add_argument('-d_target', type=int, default=768,
# help="Number of MSA features [256]")
trunk_group.add_argument('-d_msa', type=int, default=256,
help="Number of MSA features [256]")
trunk_group.add_argument('-d_pair', type=int, default=128,
help="Number of pair features [128]")
trunk_group.add_argument('-d_templ', type=int, default=64,
help="Number of templ features [64]")
trunk_group.add_argument('-n_head_msa', type=int, default=8,
help="Number of attention heads for MSA2MSA [8]")
trunk_group.add_argument('-n_head_pair', type=int, default=4,
help="Number of attention heads for Pair2Pair [4]")
trunk_group.add_argument('-n_head_templ', type=int, default=4,
help="Number of attention heads for template [4]")
trunk_group.add_argument("-d_hidden", type=int, default=32,
help="Number of hidden features [32]")
trunk_group.add_argument("-d_hidden_templ", type=int, default=32,
help="Number of hidden features for templates [32]")
trunk_group.add_argument("-p_drop", type=float, default=0.15,
help="Dropout ratio [0.15]")
# Structure module properties
str_group = parser.add_argument_group("structure module parameters")
str_group.add_argument('-num_degrees', type=int, default=2,
help="Number of degrees for SE(3) network [2]")
str_group.add_argument('-n_heads', type=int, default=4,
help="Number of attention heads for SE3-Transformer [4]")
str_group.add_argument("-div", type=int, default=4,
help="Div parameter for SE3-Transformer [4]")
str_group.add_argument('-num_layers_full', type=int, default=1,
help="Number of equivariant layers in fully-connected structure module block [1]")
str_group.add_argument('-num_channels_full', type=int, default=48,
help="Number of channels in structure module block [48]")
str_group.add_argument('-l0_in_features_full', type=int, default=32,
help="Number of type 0 input features for full-connected graph [32]")
str_group.add_argument('-l0_out_features_full', type=int, default=32,
help="Number of type 0 output features for full-connected graph [32]")
str_group.add_argument('-l1_in_features_full', type=int, default=2,
help="Number of type 1 input features [2]")
# str_group.add_argument('-l1_out_features_full', type=int, default=2,
# help="Number of type 1 output features [2]")
str_group.add_argument('-l1_out_features_full', type=int, default=2,
help="Number of type 1 output features [2]")
str_group.add_argument('-num_edge_features_full', type=int, default=32,
help="Number of edge features for full-connected graph [32]")
str_group.add_argument('-num_layers_topk', type=int, default=1,
help="Number of equivariant layers in top-k structure module block [1]")
str_group.add_argument('-num_channels_topk', type=int, default=128,
help="Number of channels in structure module block [128]")
str_group.add_argument('-l0_in_features_topk', type=int, default=64,
help="Number of type 0 input features for top-k graph [64]")
str_group.add_argument('-l0_out_features_topk', type=int, default=64,
help="Number of type 0 output features for top-k graph [64]")
str_group.add_argument('-l1_in_features_topk', type=int, default=2,
help="Number of type 1 input features [2]")
str_group.add_argument('-l1_out_features_topk', type=int, default=2,
help="Number of type 1 output features [2]")
str_group.add_argument('-num_edge_features_topk', type=int, default=64,
help="Number of edge features for top-k graph [64]")
# parse arguments
args = parser.parse_args()
# make dictionary for each parameters
trunk_param = {}
for param in TRUNK_PARAMS:
trunk_param[param] = getattr(args, param)
SE3_param_full = {}
SE3_param_topk = {}
for param in SE3_PARAMS:
if hasattr(args, param):
if "full" in param:
SE3_param_full[param[:-5]] = getattr(args, param)
elif "topk" in param:
SE3_param_topk[param[:-5]] = getattr(args, param)
else: # common arguments
SE3_param_full[param] = getattr(args, param)
SE3_param_topk[param] = getattr(args, param)
# Load JSON file if provided
if args.json_input:
if os.path.exists(args.json_input):
with open(args.json_input, 'r') as json_file:
json_data = json.load(json_file)
else:
raise FileNotFoundError(f"JSON file {args.json_input} not found!")
trunk_param['SE3_param_full'] = SE3_param_full
trunk_param['SE3_param_topk'] = SE3_param_topk
return args, trunk_param, json_data