佳木斯湛栽影视文化发展公司

主頁 > 知識(shí)庫 > 解決Pytorch修改預(yù)訓(xùn)練模型時(shí)遇到key不匹配的情況

解決Pytorch修改預(yù)訓(xùn)練模型時(shí)遇到key不匹配的情況

熱門標(biāo)簽:銀行業(yè)務(wù) 呼叫中心市場需求 智能手機(jī) 鐵路電話系統(tǒng) 服務(wù)器配置 網(wǎng)站文章發(fā)布 美圖手機(jī) 檢查注冊(cè)表項(xiàng)

一、Pytorch修改預(yù)訓(xùn)練模型時(shí)遇到key不匹配

最近想著修改網(wǎng)絡(luò)的預(yù)訓(xùn)練模型vgg.pth,但是發(fā)現(xiàn)當(dāng)我加載預(yù)訓(xùn)練模型權(quán)重到新建的模型并保存之后。

在我使用新賦值的網(wǎng)絡(luò)模型時(shí)出現(xiàn)了key不匹配的問題

#加載后保存(未修改網(wǎng)絡(luò))
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 將新保存的網(wǎng)絡(luò)代替之前的預(yù)訓(xùn)練模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet為ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights) 

此時(shí)會(huì)如下出錯(cuò)誤:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

說明之前的預(yù)訓(xùn)練模型 key參數(shù)為"0.weight", “0.bias”,但是經(jīng)過加載保存之后變?yōu)榱?vgg.0.weight", “vgg.0.bias”

我認(rèn)為是因?yàn)楸旧淼哪P投x文件里self.vgg = nn.ModuleList(base)這一句。

現(xiàn)在的問題是因?yàn)樽约憾x保存的模型key參數(shù)多了一個(gè)前綴。

可以通過如下語句進(jìn)行修改,并加載

from collections import OrderedDict   #導(dǎo)入此模塊
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面幾位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict) 

此時(shí)就不會(huì)再出錯(cuò)了。

參考了這個(gè)篇。修改一下就可以應(yīng)用到自己的模型啦。

//www.jb51.net/article/214214.htm

二、pytorch加載預(yù)訓(xùn)練模型遇到的問題:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加載resnet預(yù)訓(xùn)練模型時(shí),遇到的一個(gè)問題,在此記錄一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其實(shí)是使用的版本的問題,pytorch0.4.1之后在BN層加入了track_running_stats這個(gè)參數(shù),

這個(gè)參數(shù)的作用如下:

訓(xùn)練時(shí)用來統(tǒng)計(jì)訓(xùn)練時(shí)的forward過的min-batch數(shù)目,每經(jīng)過一個(gè)min-batch, track_running_stats+=1

如果沒有指定momentum, 則使用1/num_batches_tracked 作為因數(shù)來計(jì)算均值和方差(running mean and variance).

其實(shí),這個(gè)參數(shù)沒啥用.但因?yàn)楣俜教峁┑念A(yù)訓(xùn)練模型是pytorch0.3版本訓(xùn)練出來的,因此沒有這個(gè)參數(shù).

所以,只要過濾一下預(yù)訓(xùn)練權(quán)重字典中的關(guān)鍵字即可,‘num_batches_tracked'.代碼例子,如下.

有問題的代碼:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

對(duì)'num_batches_tracked進(jìn)行過濾:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • Pytorch通過保存為ONNX模型轉(zhuǎn)TensorRT5的實(shí)現(xiàn)
  • pytorch_pretrained_bert如何將tensorflow模型轉(zhuǎn)化為pytorch模型
  • pytorch模型的保存和加載、checkpoint操作
  • PyTorch 如何檢查模型梯度是否可導(dǎo)
  • pytorch 預(yù)訓(xùn)練模型讀取修改相關(guān)參數(shù)的填坑問題
  • PyTorch模型轉(zhuǎn)TensorRT是怎么實(shí)現(xiàn)的?

標(biāo)簽:新疆 滄州 上海 紅河 河南 樂山 沈陽 長治

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《解決Pytorch修改預(yù)訓(xùn)練模型時(shí)遇到key不匹配的情況》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 收縮
    • 微信客服
    • 微信二維碼
    • 電話咨詢

    • 400-1100-266
    孟州市| 十堰市| 日土县| 柳林县| 阿克| 白银市| 青铜峡市| 洛隆县| 清流县| 双辽市| 黔南| 长顺县| 界首市| 中宁县| 安徽省| 大关县| 太和县| 竹溪县| 乡城县| 贺兰县| 饶阳县| 达州市| 阆中市| 射洪县| 合水县| 武川县| 台北市| 玉龙| 北海市| 沅江市| 临武县| 铜山县| 宁河县| 衡阳县| 珠海市| 仪征市| 华池县| 肥东县| 汝南县| 海口市| 松江区|