精华内容
下载资源
问答
  • pytorch保存和加载模型
    2022-07-22 22:34:31

    模型的保存和加载

    1 只保存和加载模型参数

    torch.save(model.state_dict(), PATH)   ###将模型的参数保存到这个地址下,后缀名为pt
    
    model = model(*args, **kwargs)    ###定义模型
    model.load_state_dict(torch.load(PATH, map_location=lambda storage, loc: storage))  ##导入模型参数
    

    2 保存和加载整个模型

    torch.save(model,path)
    
    model=torch.load(path)

    这种方式可以直接保存整个模型,在应用的时候不用再重新定义模型。

    定义网络结构

    这里定义了最简单的网络结构。两层的全连接层

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layer1=nn.Linear(1,3)   ###线性层
            self.layer2=nn.Linear(3,1)
        def forward(self,x):
            x=self.layer1(x)
            x=torch.relu(x)   ###relu激活函数
            x=self.layer2(x)
            return x
    

    训练神经网络

    import torch
    import torch.nn as nn
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    epoches=2000
    
    
    # 学习率定义为0.01
    learning_rate=0.01
    
    # 创建一个模型
    model=Net()
    
    optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=1e-5)
    criterion = nn.MSELoss() #定义损失函数
    # 使用优化器来更新网络权重,lr为学习率,
    for i in range(epoch): #设定训练epoch次
    	model.train() #将模型的状态设置为train
    	for j in Sample: #对每一个样本进行遍历
    		optimizer.zero_grad() #将梯度清理,为这次的梯度计算做准备
    		output = model(j)
    		loss = criterion(output, target)
    		loss.backward()    ###这里记录的平均loss
    		optimizer.step() #更新网络权重	
    
            if (epoch+1) % 10==0:   ##每十次打印一下当前的状态
            print("Epoch {} / {},loss {:.4f}".format(epoch+1,num_epoches,loss.item()))	
    
    

    Pytorch模型保存

    ##torch.save() 可以保存字典类型的数据
    save_checkpoint({'loss': i, 'state_dict': model.state_dict()},dir)
    
    def save_checkpoint(state, dic): #state是模型的权重和状态  dic是模型保存的目录
    	if not os.path.exists(dir):
    		os.makedirs(directory)
    	fileName = directory + 'last.pth'
    	torch.save(state,fileName)#使用torch.save函数直接对训练好的模型进行保存
    

    更多相关内容
  • 各种情况下pytorch 如何保存/加载模型

    当保存和加载模型时,需要熟悉三个核心功能:

    1. torch.save :将序列化对象保存到磁盘。此函数使用 Python pickle 模块进行序列化。使
    用此函数可以保存如模型、 tensor 、字典等各种对象。
    2. torch. load :使用 pickle unpickling 功能将 pickle 对象文件反序列化到内存。此功能还可
    以有助于设备加载数据。
    3. torch.nn.Module.load_state_dict :使用反序列化函数 state_dict 来加载模型的参数字典。

    1 保存和加载推理模型

    保存

    torch.save(model.state_dict(), PATH)

    加载

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
        当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用 torch.save() 函数来保
    存模型 state_dict , 它会给模型恢复提供 最大的灵活性,这就是为什么要推荐它来保存的原
    因。
        在 PyTorch 中最常见的模型保存使 ‘.pt’ 或者是 ‘.pth’ 作为模型文件扩展名。
        请记住,在运行推理之前,务必调用 model.eval() 去设置 dropout batch normalization 层为评 估模式。如果不这么做,可能导致 模型推断结果不一致。
    注意load_state_dict() 函数只接受字典对象,而不是保存对象的路径。这就意味着在你传给
    load_state_dict() 函数之前,你必须反序列化 你保存的 state_dict 。例如,你无法通过
    model.load_state_dict(PATH) 来加载模型。

    2 保存/加载完整模型

    保存

    torch.save(model, PATH)

    加载

    new_model = torch.load(PATH)
    new_model.eval()
    #new_model 不再需要第一种方法中的建立新模型的步骤
        此部分保存 / 加载过程使用最直观的语法并涉及最少量的代码。以 Python ‘pickle’  模块的方式
    来保存模型。这种方法的缺点是序列化数据受 限于某种特殊的类而且需要确切的字典结构。
    这是因为 pickle 无法保存模型类本身。相反,它保存包含类的文件的路径,该文件在加载时使
    用。 因此,当在其他项目使用或者重构之后,您的代码可能会以各种方式中断。
        在 PyTorch 中最常见的模型保存使用 ‘.pt’ 或者是 ‘.pth’ 作为模型文件扩展名。
        请记住,在运行推理之前,务必调用 model.eval() 设置 dropout batch normalization 层为评估模式。如果不这么做,可能导致模型推断结果不一致。

    3 保存和加载 Checkpoint 用于推理/继续训练

    保存

    torch.save({
     'epoch': epoch,
     'model_state_dict': model.state_dict(),
     'optimizer_state_dict': optimizer.state_dict(),
     'loss': loss,
      ...
      }, PATH)

    加载

    model = TheModelClass(*args, **kwargs)
    optimizer = TheOptimizerClass(*args, **kwargs)
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    model.eval()
    # - or -
    model.train()

    在一个文件中保存多个模型

    保存

    torch.save({
     'modelA_state_dict': modelA.state_dict(),
     'modelB_state_dict': modelB.state_dict(),
     'optimizerA_state_dict': optimizerA.state_dict(),
     'optimizerB_state_dict': optimizerB.state_dict(),
     ...
     }, PATH)

    加载

    modelA = TheModelAClass(*args, **kwargs)
    modelB = TheModelBClass(*args, **kwargs)
    optimizerA = TheOptimizerAClass(*args, **kwargs)
    optimizerB = TheOptimizerBClass(*args, **kwargs)
    checkpoint = torch.load(PATH)
    modelA.load_state_dict(checkpoint['modelA_state_dict'])
    modelB.load_state_dict(checkpoint['modelB_state_dict'])
    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
    modelA.eval()
    modelB.eval()
    # - or -
    modelA.train()
    modelB.train()
        当保存一个模型由多个 torch.nn.Modules 组成时,例如 GAN( 对抗生成网络 ) sequence-to-
    sequence ( 序列到序列模型 ), 或者是多个模型融合 , 可以采用与保存常规检查点相同的方法。
    换句话说,保存每个模型的 state_dict 的字典和相对应的优化器。如前所述,可以通过简单地
    将它们附加到字典的方式来保存任何其他项目,这样有助于恢复训练。
        PyTorch 中常见的保存 checkpoint 是使用 .tar 文件扩展名。 要加载项目,首先需要初始化模型和优化器,然后使用 torch. load () 来加载本地字典。这里,你可以非常容易的通过简单查询字典来访问你所保存的项目。
        请记住在运行推理之前,务必调用 model.eval() 去设置 dropout batch normalization 为评估。如果不这样做,有可能得到不一致的推断结果。 如果你想要恢复训练,请调用 model.train() 以 确保这些层处于训练模式。

    使用在不同模型参数下的热启动模式

    保存

    torch.save(modelA.state_dict(), PATH)

    加载

    modelB = TheModelBClass(*args, **kwargs)
    modelB.load_state_dict(torch.load(PATH), strict=False)
        在迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见的情况。利用训练好
    的参数,有助于热启动训练过程,并希望帮助你的模型比从头开始训练能够更快地收敛。
        无论是从缺少某些键的 state_dict 加载还是从键的数目多于加载模型的 state_dict , 都可以通过在 load_state_dict() 函数中将 strict 参数设置为 False 来忽略非匹配键的函数。
        如果要将参数从一个层加载到另一个层,但是某些键不匹配,主要修改正在加载的 state_dict 中的 参数键的名称以匹配要在加载到模型中的键即可。

    通过设备保存/加载模型

    6.1 保存到 CPU、加载到 CPU

    保存

    torch.save(model.state_dict(), PATH)

    加载

    device = torch.device('cpu')
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location=device))

    6.2 保存到 GPU、加载到 GPU

    保存

    torch.save(model.state_dict(), PATH)

    加载

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.to(device) # 确保在你提供给模型的任何输入张量上调用input = input.to(device)
    当在 GPU 上训练并把模型保存在 GPU ,只需要使用 model. to (torch.device( 'cuda' )) ,将初
    始化的 model 转换为 CUDA 优化模型。另外,请 务必在所有模型输入上使用 .to ( torch.device ( ' cuda ' )) 函数来为模型准备数据。请注意,调用 my_tensor. to (device)
    会在 GPU 上返回 my_tensor 的副本。 因此,请记住手动覆盖张量: my_tensor= my_tensor. to (torch.device( 'cuda' ))

    6.3 保存到 CPU,加载到 GPU

    保存

    torch.save(model.state_dict(), PATH)

    加载

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
    model.to(device) # 确保在你提供给模型的任何输入张量上调用input = input.to(device)
        在 CPU 上训练好并保存的模型加载到 GPU 时,将 torch. load () 函数中的 map_location 参数设
    置为 cuda:device_id 。这会将模型加载到指定的 GPU 设备。接下来,请务必调用
    model. to (torch.device( 'cuda' )) 将模型的参数张量转换为 CUDA 张量。最后,确保在所有
    模型输入上使用 .to ( torch.device ( ' cuda ' )) 函数来为 CUDA 优化模型。请注意,调用
    my_tensor. to (device) 会在 GPU 上返回 my_tensor 的新副本。它不会覆盖 my_tensor 。因
    此, 请手动覆盖张量 my_tensor = my_tensor. to (torch.device( 'cuda' ))

    6.4 保存 torch.nn.DataParallel 模型

    保存

    torch.save(model.module.state_dict(), PATH)

    加载

    # 加载任何你想要的设备
    torch.nn.DataParallel 是一个模型封装,支持并行 GPU 使用。要普通保存 DataParallel 模型 ,
    请保存 model.module.state_dict() 。 这样,你就可以非常灵活地以任何方式加载模型到你
    想要的设备中。

    展开全文
  • 模型保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。 首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。 先举最简单的例子: import torch ...

    Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。
    首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。
    先举最简单的例子:

    import torch
    
    model = torch.load('my_model.pth')
    torch.save(model, 'new_model.pth')
    
    

    上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”

    如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict”。这个所谓的state_dict就是不带模型结构的模型参数了~
    咱们的加载和保存就发生了一点微妙的变化:

    import torch
    model = MyModel() # init your model class, build the graph shape
    
    state_dict = torch.load('model_state_dict.pth')
    model.load_state_dict(state_dict)
    
    torch.save(model.state_dict(), 'model_state_dict1.pth')
    
    

    比较上面两段代码,咱们可以有一下结论:

    pth文件既可能保存了模型的图结构,也有可能没保存;
    加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
    在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

    • 实验

    import torch
    import torchvision.models as models
    
    model = models.vgg16(pretrained=True)
    torch.save(model.state_dict(), 'only_weights.pth')
    
    model_state_dict = torch.load('only_weights.pth')
    model1 = models.vgg16() # describe the graph shape
    model1.load_state_dict(model_state_dict)
    model1.eval()
    
    torch.save(model1, 'whole_model.pth')
    
    model2 = torch.load('whole_model.pth')
    model2.eval()
    
    # model3 = torch.load('only_weights.pth')
    # model3.eval()    # Error
    
    
    

    model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

    • torch.load_state_dict()函数的用法

    在Pytorch中构建好一个模型后,一般需要进行预训练权重中加载。torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中,操作方式如下所示:

    sd_net = torchvision.models.resnte50(pretrained=False)
    sd_net.load_state_dict(torch.load('*.pth'), strict=True)
    

    在本博文中重点关注的是 属性 strict; 当strict=True,要求预训练权重层数的键值与新构建的模型中的权重层数名称完全吻合;如果新构建的模型在层数上进行了部分微调,则上述代码就会报错:说key对应不上。

    此时,如果我们采用strict=False 就能够完美的解决这个问题。也即,与训练权重中与新构建网络中匹配层的键值就进行使用,没有的就默认初始化

    参考博文:
    https://blog.csdn.net/ChaoMartin/article/details/118686268

    https://blog.csdn.net/leviopku/article/details/123925804

    展开全文
  • pytorch保存和加载模型的两种方式

    千次阅读 2020-06-22 21:09:01
    pytorch保存和加载模型是绑在一起的。 这里我需要注意一下不同的保存方式对应不同的读取方式,两者各有利弊。 首先说说pytorch.save()这个函数,可以参考官网:pytroch.save。 简而言之,这个函数可以保存任意的...

    pytorch中保存和加载模型是绑在一起的。
    这里我需要注意一下不同的保存方式对应不同的读取方式,两者各有利弊。

    首先说说pytorch.save()这个函数,可以参考官网:pytroch.save
    简而言之,这个函数可以保存任意的东西,比如tensor或者模型,或者仅仅是模型的参数。
    如果将保存对象局限在模型上,通常来说我们有两种方式直接保存所有的模型,只保存模型中的参数(模型结构就保存了)。以下分别说说两种不同的方式。

    为了说明,我们先建立一个简单的模型。

    import torch
    import torch.nn as nn
    
    class Generator(nn.Module):
        def __init__(self, in_c, out_c, ngf=64):
            super(Generator, self).__init__()
            model = []
            model += [
                nn.Conv2d(in_c, ngf, 3, 2, 1),
                nn.ReLU(),
                nn.BatchNorm2d(ngf),
                nn.Conv2d(ngf, out_c, 3, 2, 1)
            ]
            self.model = nn.Sequential(*model)
            
        def forward(self, x):
            return self.model(x)
    
    netG = Generator(3, 3)
    input = torch.zeros(10, 3, 256, 256)
    output = netG(input)
    

    直接保存所有模型并读取

    直接使用简单粗暴的方式保存:

    torch.save(netG, 'netG.pt')
    

    对应的,我们可以这样读取模型

    netC = torch.load('netG.pt')
    input = torch.zeros(10, 3, 256, 256)
    output = netC(input)
    

    正常情况如下(警告先忽略):在这里插入图片描述

    只保存模型中的参数并读取

    我们说模型的参数保存在网络的state_dict中,使用这个就可以读取网络的参数了。

    torch.save({'netG': netG.state_dict()}, 'model_test.pt')
    

    对应的加载模型的方式如下:

    netD = Generator(3, 3)
    state_dict = torch.load('model_test.pt')
    netD.load_state_dict(state_dict['netG'])
    input = torch.zeros(10, 3, 256, 256)
    output = netD(input)
    

    总结

    我们可以看到第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。而第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。

    总的来说,我们一般来选择第二种来保存和读取
    退一步讲,如何保存模型决定了如何读取模型

    展开全文
  • Pytorch保存和加载模型的两种方式

    千次阅读 2020-08-20 19:58:38
    与Tensorflow、Keras等框架一样,Pytorch也提供了两种保存模型的方式,这两种方式都是通过调用pickle序列化方法实现的: 只保存模型参数 保存完整模型 下面我们依次对这两种方式进行实现,以以下多层感知机模型为...
  • Pytorch保存和加载模型

    2019-11-01 14:09:47
    文章目录一、保存加载模型基本用法二、保存加载自定义模型三、跨设备保存加载模型四、CUDA 的用法 一、保存加载模型基本用法 1、保存加载整个模型(不推荐) 保存整个网络模型(网络结构+权重参数)。 torch.save...
  • 本文详解了PyTorch 模型保存加载方法。 目录 1 需要掌握3个重要的函数 2 state_dict 2.1 state_dict 介绍 2.2 保存和加载 state_dict (已经训练完,无需继续训练) 2.3 保存和加载整个模型 (已经训练完,无需...
  • # 保存和加载整个模型 torch.save(model_object, 'resnet.pth') model = torch.load('resnet.pth') 二、只保存神经网络的训练模型参数,save的对象是net.state_dict() # 将my_resnet模型储存为my_resnet.pth torc.....
  • PyTorch保存和加载模型(全面汇总)

    千次阅读 多人点赞 2019-11-22 18:41:46
    pytorch 中的 state_dict ...只有那些参数可以训练的layer才会被保存模型的state_dict中,如卷积层,线性层等等。 优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,w...
  • pytorch保存和加载模型

    2021-08-25 20:53:38
    之前想学习保存和加载模型的代码,在知乎上看到一个回答,发现两行代码就可以搞定,于是兴冲冲的加上了: torch.save(model, "model.pth.tar") model_dict=torch.load("model.pth.tar") 然后就大胆的去训练了,...
  • pytorch保存和加载模型 保存模型 有两种方式保存模型 一、保存整个网络 保存整个神经网络的的结构信息模型参数信息,save的对象是网络net。 后缀一般命名为.pkl net = Net() # 保存加载整个模型 torch.save...
  • pytorch保存和加载模型权重的方式

    千次阅读 2021-02-03 10:20:32
    (1) 保存和加载整个模型 ...(2) 仅仅保存模型参数以及分别加载模型结构参数 # 模型参数保存 torch.save(model.state_dict(), 'model_param.pkl') # 模型参数加载 model = ModelClass(...) mo
  • 另一方面,训练好的模型后需要对实际数据进行预测(Predict,或称为推理Inference),这时候就需要把模型的权重保存到硬盘中,方便后续直接调用模型进行预测。 1. 模块张量的序列化及反序列化 pytorch的一系列方法...
  • 文章目录简介一、什么是状态字典(state_dict)二... 不同设备下保存和加载模型总结 简介 本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数: torch.save :把序列化的对象保存到硬盘。它利用了 Pyt
  • Pytorch保存和加载模型参数参考博客1. 保存模型和参数2. 仅保存参数3. 加载pytorch预训练模型3.1 加载预训练模型和参数3.2 只加载...pytorch的模型和参数是分开的,可以分别保存或加载模型和参数。 pytorch有两种模...
  • 保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。 torch.load:使用pickle的...
  • PyTorch保存和加载模型CUDA

    千次阅读 2019-11-22 03:11:44
    保存了使用CUDA训练的模型后,加载时也一定得保持一致,换句话说,在定义网络的时候需要用 net.to(device) 而且在测试的时候也需要把输入标签统统转移到cuda上面,即 inputs, labels = inputs.to(device), ...
  • 保存模型: torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer....加载模型 model = model_class(num_classes=num_classes) # 定义模型 state = torch.load(datadir)
  • 欢迎关注“小白玩转Python”,发现更多 “有趣”本文的目的是展示如何保存一个模型加载它,以便在上一个 epoch 之后继续训练并进行预测。如果您正在阅读本文,我假定您熟悉深度学习...
  • 本文是一篇关于如何用Pytorch保存和加载模型的指南。 文章目录1 读写tensor1.1 单个张量1.2 张量列表张量词典2 保存和加载模型2.1 *state_dict*2.2 保存加载2.2.1 保存加载state_dict(推荐方式)2.2.2 保存...

空空如也

空空如也

1 2 3 4 5 ... 20
收藏数 12,793
精华内容 5,117
关键字:

pytorch保存和加载模型