佳木斯湛栽影视文化发展公司

主頁(yè) > 知識(shí)庫(kù) > Pytorch反向傳播中的細(xì)節(jié)-計(jì)算梯度時(shí)的默認(rèn)累加操作

Pytorch反向傳播中的細(xì)節(jié)-計(jì)算梯度時(shí)的默認(rèn)累加操作

熱門(mén)標(biāo)簽:美圖手機(jī) 檢查注冊(cè)表項(xiàng) 智能手機(jī) 服務(wù)器配置 鐵路電話系統(tǒng) 呼叫中心市場(chǎng)需求 銀行業(yè)務(wù) 網(wǎng)站文章發(fā)布

Pytorch反向傳播計(jì)算梯度默認(rèn)累加

今天學(xué)習(xí)pytorch實(shí)現(xiàn)簡(jiǎn)單的線性回歸,發(fā)現(xiàn)了pytorch的反向傳播時(shí)計(jì)算梯度采用的累加機(jī)制, 于是百度來(lái)一下,好多博客都說(shuō)了累加機(jī)制,但是好多都沒(méi)有說(shuō)明這個(gè)累加機(jī)制到底會(huì)有啥影響, 所以我趁著自己練習(xí)的一個(gè)例子正好直觀的看一下以及如何解決:

pytorch實(shí)現(xiàn)線性回歸

先附上試驗(yàn)代碼來(lái)感受一下:

torch.manual_seed(6)
lr = 0.01   # 學(xué)習(xí)率
result = []

# 創(chuàng)建訓(xùn)練數(shù)據(jù)
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 構(gòu)建線性回歸函數(shù)
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 這里是迭代過(guò)程,為了看pytorch的反向傳播計(jì)算梯度的細(xì)節(jié),我先迭代兩次
for iteration in range(2):

    # 前向傳播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 計(jì)算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向傳播
    loss.backward()
    
    # 這里看一下反向傳播計(jì)算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新參數(shù)
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代碼比較簡(jiǎn)單,迭代了兩次, 看一下計(jì)算的梯度結(jié)果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加兩行代碼, 就是在反向傳播上面,我手動(dòng)添加梯度清零操作的代碼,再感受一下結(jié)果:

torch.manual_seed(6)
lr = 0.01
result = []
# 創(chuàng)建訓(xùn)練數(shù)據(jù)
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 構(gòu)建線性回歸函數(shù)
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向傳播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 計(jì)算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向傳播中,梯度是累加的,所以如果不想先前的梯度影響當(dāng)前梯度的計(jì)算,需要手動(dòng)清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向傳播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新參數(shù)
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

從上面可以發(fā)現(xiàn),pytorch在反向傳播的時(shí)候,確實(shí)是默認(rèn)累加上了上一次求的梯度, 如果不想讓上一次的梯度影響自己本次梯度計(jì)算的話,需要手動(dòng)的清零。

但是, 如果不進(jìn)行手動(dòng)清零的話,會(huì)有什么后果呢? 我在這次線性回歸試驗(yàn)中,遇到的后果就是loss值反復(fù)的震蕩不收斂。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 創(chuàng)建訓(xùn)練數(shù)據(jù)
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 構(gòu)建線性回歸函數(shù)
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向傳播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 計(jì)算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向傳播中,梯度是累加的,所以如果不想先前的梯度影響當(dāng)前梯度的計(jì)算,需要手動(dòng)清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向傳播
    loss.backward()
 
    # 更新參數(shù)
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy()  1:
        break
   plt.plot(result)

上面的代碼中,我沒(méi)有進(jìn)行手動(dòng)清零,迭代1000次, 把每一次的loss放到來(lái)result中, 然后畫(huà)出圖像,感受一下結(jié)果:

接下來(lái),我把手動(dòng)清零的注釋打開(kāi),進(jìn)行每次迭代之后的手動(dòng)清零操作,得到的結(jié)果:

可以看到,這個(gè)才是理想中的反向傳播求導(dǎo),然后更新參數(shù)后得到的loss值的變化。

總結(jié)

這次主要是記錄一下,pytorch在進(jìn)行反向傳播計(jì)算梯度的時(shí)候的累加機(jī)制到底是什么樣子? 至于為什么采用這種機(jī)制,我也搜了一下,大部分給出的結(jié)果是這樣子的:

但是如果不想累加的話,可以采用手動(dòng)清零的方式,只需要在每次迭代時(shí)加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索資料的時(shí)候,在一篇博客上看到兩個(gè)不錯(cuò)的線性回歸時(shí)pytorch的計(jì)算圖在這里借用一下:


以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • pytorch 梯度NAN異常值的解決方案
  • pytorch 權(quán)重weight 與 梯度grad 可視化操作
  • PyTorch 如何檢查模型梯度是否可導(dǎo)
  • 淺談pytorch中為什么要用 zero_grad() 將梯度清零
  • PyTorch梯度裁剪避免訓(xùn)練loss nan的操作
  • PyTorch 如何自動(dòng)計(jì)算梯度
  • Pytorch獲取無(wú)梯度TorchTensor中的值

標(biāo)簽:上海 沈陽(yáng) 河南 樂(lè)山 滄州 新疆 長(zhǎng)治 紅河

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch反向傳播中的細(xì)節(jié)-計(jì)算梯度時(shí)的默認(rèn)累加操作》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 收縮
    • 微信客服
    • 微信二維碼
    • 電話咨詢

    • 400-1100-266
    原平市| 大关县| 苏尼特右旗| 温宿县| 紫云| 金川县| 临猗县| 东乌| 宁波市| 万全县| 嘉禾县| 定州市| 台湾省| 阿勒泰市| 锦屏县| 长治市| 乌审旗| 淮安市| 河源市| 三亚市| 永靖县| 搜索| 合水县| 佛冈县| 中方县| 垫江县| 乐平市| 新乐市| 剑阁县| 岢岚县| 黄陵县| 新宁县| 葫芦岛市| 榆林市| 芜湖市| 马龙县| 宁海县| 靖宇县| 邹城市| 吴旗县| 昭通市|