Source code for bsitep.utils.model_utils
import torch
from collections import OrderedDict
from bsitep.seresnet import SEResNet
[docs]
def load_model(model_path):
model = SEResNet().cuda()
state_dict = torch.load(model_path, map_location=torch.device('cuda'), weights_only=True)
new_dict = OrderedDict((key[7:], value) for key, value in state_dict.items())
model.load_state_dict(new_dict)
model.eval().to('cuda')
return model