佳木斯湛栽影视文化发展公司

主頁 > 知識(shí)庫 > pytorch 6 batch_train 批訓(xùn)練操作

pytorch 6 batch_train 批訓(xùn)練操作

熱門標(biāo)簽:呼叫中心市場需求 銀行業(yè)務(wù) 鐵路電話系統(tǒng) 智能手機(jī) 服務(wù)器配置 檢查注冊表項(xiàng) 網(wǎng)站文章發(fā)布 美圖手機(jī)

看代碼吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8個(gè)數(shù)據(jù)同時(shí)傳入網(wǎng)路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 設(shè)置不隨機(jī)打亂數(shù)據(jù) random shuffle for training
    num_workers=2,              # 使用兩個(gè)進(jìn)程提取數(shù)據(jù),subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的數(shù)據(jù)使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有數(shù)據(jù)利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

補(bǔ)充:pytorch批訓(xùn)練bug

問題描述:

在進(jìn)行pytorch神經(jīng)網(wǎng)絡(luò)批訓(xùn)練的時(shí)候,有時(shí)會(huì)出現(xiàn)報(bào)錯(cuò) 

TypeError: batch must contain tensors, numbers, dicts or lists; found class 'torch.autograd.variable.Variable'>

解決辦法:

第一步:

檢查(重點(diǎn)!!?。?!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor類,我第一次出錯(cuò)就是因?yàn)閭魅氲氖莢ariable

可以這樣將數(shù)據(jù)變?yōu)閠ensor類:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

實(shí)例化一個(gè)DataLoader對象

第三步:

    for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

這樣就可以批訓(xùn)練了

需要注意的是:train_loader輸出的是tensor,在訓(xùn)練網(wǎng)絡(luò)時(shí),需要變成Variable

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • 詳解PyTorch批訓(xùn)練及優(yōu)化器比較
  • pytorch 固定部分參數(shù)訓(xùn)練的方法
  • pytorch 準(zhǔn)備、訓(xùn)練和測試自己的圖片數(shù)據(jù)的方法
  • pytorch 在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù),修改預(yù)訓(xùn)練權(quán)重文件的方法

標(biāo)簽:新疆 長治 樂山 沈陽 滄州 河南 上海 紅河

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《pytorch 6 batch_train 批訓(xùn)練操作》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 收縮
    • 微信客服
    • 微信二維碼
    • 電話咨詢

    • 400-1100-266
    东台市| 德格县| 上思县| 永修县| 郓城县| 蓝田县| 垫江县| 田阳县| 藁城市| 盐源县| 西吉县| 汾阳市| 延长县| 子长县| 永康市| 永安市| 城步| 江川县| 云和县| 金湖县| 边坝县| 浦江县| 泸州市| 锡林浩特市| 济南市| 鄂温| 承德市| 彭水| 周至县| 天水市| 蓬溪县| 石嘴山市| 兴海县| 垫江县| 平江县| 兰西县| 吉隆县| 余姚市| 安义县| 铜鼓县| 揭西县|