佳木斯湛栽影视文化发展公司

主頁 > 知識庫 > pytorch 實現(xiàn)多個Dataloader同時訓練

pytorch 實現(xiàn)多個Dataloader同時訓練

熱門標簽:智能手機 檢查注冊表項 鐵路電話系統(tǒng) 銀行業(yè)務(wù) 呼叫中心市場需求 網(wǎng)站文章發(fā)布 美圖手機 服務(wù)器配置

看代碼吧~

如果兩個dataloader的長度不一樣,那就加個:

from itertools import cycle

僅使用zip,迭代器將在長度等于最小數(shù)據(jù)集的長度時耗盡。 但是,使用cycle時,我們將再次重復最小的數(shù)據(jù)集,除非迭代器查看最大數(shù)據(jù)集中的所有樣本。

補充:pytorch技巧:自定義數(shù)據(jù)集 torch.utils.data.DataLoader 及Dataset的使用

本博客中有可直接運行的例子,便于直觀的理解,在torch環(huán)境中運行即可。

1. 數(shù)據(jù)傳遞機制

在 pytorch 中數(shù)據(jù)傳遞按一下順序:

1、創(chuàng)建 datasets ,也就是所需要讀取的數(shù)據(jù)集。

2、把 datasets 傳入DataLoader。

3、DataLoader迭代產(chǎn)生訓練數(shù)據(jù)提供給模型。

2. torch.utils.data.Dataset

Pytorch提供兩種數(shù)據(jù)集:

Map式數(shù)據(jù)集 Iterable式數(shù)據(jù)集。其中Map式數(shù)據(jù)集繼承torch.utils.data.Dataset,Iterable式數(shù)據(jù)集繼承torch.utils.data.IterableDataset。

本文只介紹 Map式數(shù)據(jù)集。

一個Map式的數(shù)據(jù)集必須要重寫 __getitem__(self, index)、 __len__(self) 兩個方法,用來表示從索引到樣本的映射(Map)。 __getitem__(self, index)按索引映射到對應(yīng)的數(shù)據(jù), __len__(self)則會返回這個數(shù)據(jù)集的長度。

基本格式如下:

 import torch.utils.data as data
class VOCDetection(data.Dataset):
    '''
    必須繼承data.Dataset類
    '''
    def __init__(self):
        '''
        在這里進行初始化,一般是初始化文件路徑或文件列表
        '''
        pass
    def __getitem__(self, index):
        '''
        1. 按照index,讀取文件中對應(yīng)的數(shù)據(jù)  (讀取一個數(shù)據(jù)!?。。∥覀兂Wx取的數(shù)據(jù)是圖片,一般我們送入模型的數(shù)據(jù)成批的,但在這里只是讀取一張圖片,成批后面會說到)
        2. 對讀取到的數(shù)據(jù)進行數(shù)據(jù)增強 (數(shù)據(jù)增強是深度學習中經(jīng)常用到的,可以提高模型的泛化能力)
        3. 返回數(shù)據(jù)對 (一般我們要返回 圖片,對應(yīng)的標簽) 在這里因為我沒有寫完整的代碼,返回值用 0 代替
        '''
        return 0
    def __len__(self):
        '''
        返回數(shù)據(jù)集的長度
        '''
        return 0

可直接運行的例子:

import torch.utils.data as data
import numpy as np
x = np.array(range(80)).reshape(8, 10) # 模擬輸入, 8個樣本,每個樣本長度為10
y = np.array(range(8))  # 模擬對應(yīng)樣本的標簽, 8個標簽 
class Mydataset(data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.idx = list()
        for item in x:
            self.idx.append(item)
        pass
    def __getitem__(self, index):
        input_data = self.idx[index] #可繼續(xù)進行數(shù)據(jù)增強,這里沒有進行數(shù)據(jù)增強操作
        target = self.y[index]
        return input_data, target
    def __len__(self):
        return len(self.idx)
datasets = Mydataset(x, y)  # 初始化
print(datasets.__len__())  # 調(diào)用__len__() 返回數(shù)據(jù)的長度
for i in range(len(y)):
    input_data, target = datasets.__getitem__(i)  # 調(diào)用__getitem__(index) 返回讀取的數(shù)據(jù)對
    print('input_data%d =' % i, input_data)
    print('target%d = ' % i, target)

結(jié)果如下:

3. torch.utils.data.DataLoader

PyTorch中數(shù)據(jù)讀取的一個重要接口是 torch.utils.data.DataLoader。

該接口主要用來將自定義的數(shù)據(jù)讀取接口的輸出或者PyTorch已有的數(shù)據(jù)讀取接口的輸入按照batch_size封裝成Tensor,后續(xù)只需要再包裝成Variable即可作為模型的輸入。

torch.utils.data.DataLoader(onject)的可用參數(shù)如下:

1.dataset(Dataset): 數(shù)據(jù)讀取接口,該輸出是torch.utils.data.Dataset類的對象(或者繼承自該類的自定義類的對象)。

2.batch_size (int, optional): 批訓練數(shù)據(jù)量的大小,根據(jù)具體情況設(shè)置即可。一般為2的N次方(默認:1)

3.shuffle (bool, optional):是否打亂數(shù)據(jù),一般在訓練數(shù)據(jù)中會采用。(默認:False)

4.sampler (Sampler, optional):從數(shù)據(jù)集中提取樣本的策略。如果指定,“shuffle”必須為false。我沒有用過,不太了解。

5.batch_sampler (Sampler, optional):和batch_size、shuffle等參數(shù)互斥,一般用默認。

6.num_workers:這個參數(shù)必須大于等于0,為0時默認使用主線程讀取數(shù)據(jù),其他大于0的數(shù)表示通過多個進程來讀取數(shù)據(jù),可以加快數(shù)據(jù)讀取速度,一般設(shè)置為2的N次方,且小于batch_size(默認:0)

7.collate_fn (callable, optional): 合并樣本清單以形成小批量。用來處理不同情況下的輸入dataset的封裝。

8.pin_memory (bool, optional):如果設(shè)置為True,那么data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內(nèi)存中.

9.drop_last (bool, optional): 如果數(shù)據(jù)集大小不能被批大小整除,則設(shè)置為“true”以除去最后一個未完成的批。如果“false”那么最后一批將更小。(默認:false)

10.timeout(numeric, optional):設(shè)置數(shù)據(jù)讀取時間限制,超過這個時間還沒讀取到數(shù)據(jù)的話就會報錯。(默認:0)

11.worker_init_fn (callable, optional): 每個worker初始化函數(shù)(默認:None)

可直接運行的例子:

import torch.utils.data as data
import numpy as np
x = np.array(range(80)).reshape(8, 10) # 模擬輸入, 8個樣本,每個樣本長度為10
y = np.array(range(8))  # 模擬對應(yīng)樣本的標簽, 8個標簽
class Mydataset(data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.idx = list()
        for item in x:
            self.idx.append(item)
        pass
    def __getitem__(self, index):
        input_data = self.idx[index]
        target = self.y[index]
        return input_data, target
    def __len__(self):
        return len(self.idx)
if __name__ ==('__main__'):
    datasets = Mydataset(x, y)  # 初始化
    dataloader = data.DataLoader(datasets, batch_size=4, num_workers=2) 
    for i, (input_data, target) in enumerate(dataloader):
        print('input_data%d' % i, input_data)
        print('target%d' % i, target)

結(jié)果如下:(注意看類別,DataLoader把數(shù)據(jù)封裝為Tensor)

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • pytorch鎖死在dataloader(訓練時卡死)
  • pytorch Dataset,DataLoader產(chǎn)生自定義的訓練數(shù)據(jù)案例
  • 解決Pytorch dataloader時報錯每個tensor維度不一樣的問題
  • pytorch中DataLoader()過程中遇到的一些問題
  • Pytorch dataloader在加載最后一個batch時卡死的解決
  • Pytorch 如何加速Dataloader提升數(shù)據(jù)讀取速度
  • pytorch DataLoader的num_workers參數(shù)與設(shè)置大小詳解

標簽:新疆 長治 沈陽 滄州 樂山 紅河 上海 河南

巨人網(wǎng)絡(luò)通訊聲明:本文標題《pytorch 實現(xiàn)多個Dataloader同時訓練》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 收縮
    • 微信客服
    • 微信二維碼
    • 電話咨詢

    • 400-1100-266
    舟曲县| 余姚市| 手游| 汝南县| 南皮县| 佛山市| 龙陵县| 泌阳县| 苏尼特右旗| 昔阳县| 江陵县| 南平市| 京山县| 丽水市| 玉田县| 塔河县| 无极县| 林州市| 新和县| 青岛市| 呼伦贝尔市| 汉川市| 加查县| 固始县| 德格县| 弥勒县| 河东区| 新乐市| 怀仁县| 临澧县| 五大连池市| 尚义县| 五峰| 海丰县| 车险| 清新县| 南汇区| 景东| 遂宁市| 无为县| 牟定县|