Source code for diffalign.utils.misc

import os
import time
import random
import logging
import torch
import numpy as np
from glob import glob
from logging import Logger
from tqdm.auto import tqdm
from torch_geometric.data import Batch


[docs] class BlackHole(object): def __setattr__(self, name, value): pass def __call__(self, *args, **kwargs): return self def __getattr__(self, name): return self
[docs] def get_logger(name, log_dir=None, log_fn='log.txt'): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.DEBUG) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) if log_dir is not None: file_handler = logging.FileHandler(os.path.join(log_dir, log_fn)) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger
[docs] def get_new_log_dir(root='./logs', prefix='', tag=''): fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) if prefix != '': fn = prefix + '_' + fn if tag != '': fn = fn + '_' + tag log_dir = os.path.join(root, fn) os.makedirs(log_dir) return log_dir
[docs] def seed_all(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed)
[docs] def inf_iterator(iterable): iterator = iterable.__iter__() while True: try: yield iterator.__next__() except StopIteration: iterator = iterable.__iter__()
[docs] def log_hyperparams(writer, args): from torch.utils.tensorboard.summary import hparams vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} exp, ssi, sei = hparams(vars_args, {}) writer.file_writer.add_summary(exp) writer.file_writer.add_summary(ssi) writer.file_writer.add_summary(sei)
[docs] def int_tuple(argstr): return tuple(map(int, argstr.split(',')))
[docs] def str_tuple(argstr): return tuple(argstr.split(','))
[docs] def repeat_data(data, num_repeat): datas = [data.clone() for i in range(num_repeat)] return Batch.from_data_list(datas)
[docs] def repeat_batch(batch, num_repeat): datas = batch.to_data_list() new_data = [] for i in range(num_repeat): new_data += datas.clone() return Batch.from_data_list(new_data)
[docs] def get_checkpoint_path(folder, it=None): if it is not None: return os.path.join(folder, '%d.pt' % it), it all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) all_iters.sort() return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1]