模型保存与加载
当保存或者加载你的模型时,主要需要和这些函数打交道:
函数 | 作用 |
---|---|
torch.save() | 将一系列对象保存到硬盘。其使用python的pickle功能进行序列化,Models, tensors, 以及各种各样的对象都会以字典形式被存储(类似于json格式) |
torch.load() | 加载对象。使用pickle的unpickling功能将pickled objects反序列化加载到内存 |
torch.nn.Module.load_state_dict() | 加载一个模型的参数字典,得到state_dict |
state_dict
是一个将每一个layer映射到其参数张量的字典对象,不过只有含可学习参数的layer和registered buffers(batchnorm's running_mean) 会在state_dict
中有词条。
每一个Module
和Optimizer
都有一个.state_dict()
的方法:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
详细过程:
Save/Load Model for Inference:PyTorch习惯将模型存储为.pt/.pth的格式
save/load state_dict(推荐)
注意load_state_dict()函数以字典对象为其参数,而不是一个直接的PATH
torch.save(model.state_dict(), PATH) model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
save/load entire model
torch.save(model, PATH) # Model class must be defined somewhere model = torch.load(PATH) model.eval()
save/load a General Checkpoint for Inference and/or Resuming Training:习惯将这些checkpoints存储为.tar文件格式
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()
Saving Multiple Models in One File:操作过程和存储general gechkpoint的过程一样,习惯将这些echkpoints存储为.tar文件格式
torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
Warmstarting (热启动) Model Using Parameters from a Different Model:无论一个模型要加载的某state_dict比自己真实的要多或少,都可以设置strict=False来忽略那些不匹配的keys
torch.save(modelA.state_dict(), PATH) modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)
Save/load Model Across Devices
Save on GPU, Load on CPU
torch.save(model.state_dict(), PATH) device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
Save on GPU, Load on GPU
torch.save(model.state_dict(), PATH) device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
Save on CPU, Load on GPU
torch.save(model.state_dict(), PATH) device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
Saving torch.nn.DataParallel Models
torch.save(model.module.state_dict(), PATH) # Load to whatever device you want