佳木斯湛栽影视文化发展公司

主頁 > 知識庫 > pytorch固定BN層參數(shù)的操作

pytorch固定BN層參數(shù)的操作

熱門標(biāo)簽:語音系統(tǒng) 百度AI接口 呼叫中心市場需求 企業(yè)做大做強 客戶服務(wù) Win7旗艦版 硅谷的囚徒呼叫中心 電話運營中心

背景:

基于PyTorch的模型,想固定主分支參數(shù),只訓(xùn)練子分支,結(jié)果發(fā)現(xiàn)在不同epoch相同的測試數(shù)據(jù)經(jīng)過主分支輸出的結(jié)果不同。

原因:

未固定主分支BN層中的running_mean和running_var。

解決方法:

將需要固定的BN層狀態(tài)設(shè)置為eval。

問題示例:

環(huán)境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假設(shè)每個epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 計算損失和參數(shù)更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

運行結(jié)果:

-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已經(jīng)將網(wǎng)絡(luò)中的各參數(shù)設(shè)置成了不需要梯度更新的狀態(tài),但是同樣的測試數(shù)據(jù)test_data在不同epoch中前向之后出現(xiàn)了不同的結(jié)果。

調(diào)用print_net_state_dict可以看到BN層中的參數(shù)running_mean和running_var并沒在可優(yōu)化參數(shù)net.parameters中

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向過程中,這兩個參數(shù)被更新了。導(dǎo)致整個網(wǎng)絡(luò)在freeze的情況下,同樣的測試數(shù)據(jù)出現(xiàn)了不同的結(jié)果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase時對BN層顯式設(shè)置eval狀態(tài):

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假設(shè)每個epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 計算損失和參數(shù)更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到結(jié)果正常了:

epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])

補充:pytorch---之BN層參數(shù)詳解及應(yīng)用(1,2,3)(1,2)?

BN層參數(shù)詳解(1,2)

一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓(xùn)練狀態(tài),訓(xùn)練狀態(tài)與否將會影響到某些層的參數(shù)是否是固定的,比如BN層(對于BN層測試的均值和方差是通過統(tǒng)計訓(xùn)練的時候所有的batch的均值和方差的平均值)或者Dropout層(對于Dropout層在測試的時候所有神經(jīng)元都是激活的)。通常用model.train()指定當(dāng)前模型model為訓(xùn)練狀態(tài),model.eval()指定當(dāng)前模型為測試狀態(tài)。

同時,BN的API中有幾個參數(shù)需要比較關(guān)心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當(dāng)前batch的統(tǒng)計特性。容易出現(xiàn)問題也正好是這三個參數(shù):trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False則γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能學(xué)習(xí)被更新。一般都會設(shè)置成affine=True。(這里是一個可學(xué)習(xí)參數(shù))

trainning和track_running_stats,track_running_stats=True表示跟蹤整個訓(xùn)練過程中的batch的統(tǒng)計特性,得到方差和均值,而不只是僅僅依賴與當(dāng)前輸入的batch的統(tǒng)計特性(意思就是說新的batch依賴于之前的batch的均值和方差這里使用momentum參數(shù),參考了指數(shù)移動平均的算法EMA)。相反的,如果track_running_stats=False那么就只是計算當(dāng)前輸入的batch的統(tǒng)計特性中的均值和方差了。當(dāng)在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那么其統(tǒng)計特性就會和全局統(tǒng)計特性有著較大偏差,可能導(dǎo)致糟糕的效果。

應(yīng)用技巧:(1,2)

通常pytorch都會用到optimizer.zero_grad() 來清空以前的batch所累加的梯度,因為pytorch中Variable計算的梯度會進(jìn)行累計,所以每一個batch都要重新清空一次梯度,原始的做法是下面這樣的:

問題:參數(shù)non_blocking,以及pytorch的整體框架??

代碼(1)

for index,data,target in enumerate(dataloader):
    data = data.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye)
    output = model(data)
    loss = criterion(output,target)
    
    #清空梯度
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

而這里為了模仿minibacth,我們每次batch不清0,累積到一定次數(shù)再清0,再更新權(quán)重:

for index, data, target in enumerate(dataloader):
    #如果不是Tensor,一般要用到torch.from_numpy()
    data = data.cuda(non_blocking = True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True)
    output = model(data)
    loss = criterion(data, target)
    loss.backward()
    if index%accumulation == 0:
        #用累積的梯度更新權(quán)重
        optimizer.step()
        #清空梯度
        optimizer.zero_grad()

雖然這里的梯度是相當(dāng)于原來的accumulation倍,但是實際在前向傳播的過程中,對于BN幾乎沒有影響,因為前向的BN還是只是一個batch的均值和方差,這個時候可以用pytorch中BN的momentum參數(shù),默認(rèn)是0.1,BN參數(shù)如下,就是指數(shù)移動平均

x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • pytorch 如何自定義卷積核權(quán)值參數(shù)
  • pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用
  • Pytorch 統(tǒng)計模型參數(shù)量的操作 param.numel()
  • pytorch 一行代碼查看網(wǎng)絡(luò)參數(shù)總量的實現(xiàn)
  • pytorch查看網(wǎng)絡(luò)參數(shù)顯存占用量等操作
  • pytorch 優(yōu)化器(optim)不同參數(shù)組,不同學(xué)習(xí)率設(shè)置的操作
  • pytorch LayerNorm參數(shù)的用法及計算過程

標(biāo)簽:崇左 喀什 濟南 長沙 安康 海南 山西 山西

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《pytorch固定BN層參數(shù)的操作》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 收縮
    • 微信客服
    • 微信二維碼
    • 電話咨詢

    • 400-1100-266
    天峻县| 阿瓦提县| 普洱| 客服| 张掖市| 扶沟县| 宜昌市| 长宁区| 沧源| 永吉县| 瑞丽市| 自治县| 泰兴市| 株洲县| 合作市| 巴彦淖尔市| 霍城县| 盖州市| 樟树市| 册亨县| 乌拉特中旗| 阳西县| 九龙城区| 元江| 专栏| 凌云县| 南澳县| 方城县| 阿图什市| 宜川县| 清丰县| 阿克苏市| 民权县| 塔城市| 奉新县| 丽江市| 台安县| 昭平县| 天柱县| 宣威市| 孟村|