佳木斯湛栽影视文化发展公司

主頁(yè) > 知識(shí)庫(kù) > 聊聊Pytorch torch.cat與torch.stack的區(qū)別

聊聊Pytorch torch.cat與torch.stack的區(qū)別

熱門(mén)標(biāo)簽:客戶(hù)服務(wù) 呼叫中心市場(chǎng)需求 硅谷的囚徒呼叫中心 Win7旗艦版 百度AI接口 語(yǔ)音系統(tǒng) 電話(huà)運(yùn)營(yíng)中心 企業(yè)做大做強(qiáng)

torch.cat()函數(shù)可以將多個(gè)張量拼接成一個(gè)張量。torch.cat()有兩個(gè)參數(shù),第一個(gè)是要拼接的張量的列表或是元組;第二個(gè)參數(shù)是拼接的維度。

torch.cat()的示例如下圖1所示

圖1 torch.cat()

torch.stack()函數(shù)同樣有張量列表和維度兩個(gè)參數(shù)。stack與cat的區(qū)別在于,torch.stack()函數(shù)要求輸入張量的大小完全相同,得到的張量的維度會(huì)比輸入的張量的大小多1,并且多出的那個(gè)維度就是拼接的維度,那個(gè)維度的大小就是輸入張量的個(gè)數(shù)。

torch.stack()的示例如下圖2所示:

圖2 torch.stack()

補(bǔ)充:torch.stack()的官方解釋?zhuān)斀庖约袄?/strong>

可以直接看最下面的【3.例子】,再回頭看前面的解釋

在pytorch中,常見(jiàn)的拼接函數(shù)主要是兩個(gè),分別是:

1、stack()

2、cat()

實(shí)際使用中,這兩個(gè)函數(shù)互相輔助:關(guān)于cat()參考torch.cat(),但是本文主要說(shuō)stack()。

函數(shù)的意義:使用stack可以保留兩個(gè)信息:[1. 序列] 和 [2. 張量矩陣] 信息,屬于【擴(kuò)張?jiān)倨唇印康暮瘮?shù)。

形象的理解:假如數(shù)據(jù)都是二維矩陣(平面),它可以把這些一個(gè)個(gè)平面(矩陣)按第三維(例如:時(shí)間序列)壓成一個(gè)三維的立方體,而立方體的長(zhǎng)度就是時(shí)間序列長(zhǎng)度。

該函數(shù)常出現(xiàn)在自然語(yǔ)言處理(NLP)和圖像卷積神經(jīng)網(wǎng)絡(luò)(CV)中。

1. stack()

官方解釋?zhuān)貉刂粋€(gè)新維度對(duì)輸入張量序列進(jìn)行連接。 序列中所有的張量都應(yīng)該為相同形狀。

淺顯說(shuō)法:把多個(gè)2維的張量湊成一個(gè)3維的張量;多個(gè)3維的湊成一個(gè)4維的張量…以此類(lèi)推,也就是在增加新的維度進(jìn)行堆疊。

outputs = torch.stack(inputs, dim=?) → Tensor

參數(shù)

inputs : 待連接的張量序列。

注:python的序列數(shù)據(jù)只有l(wèi)ist和tuple。

dim : 新的維度, 必須在0到len(outputs)之間。

注:len(outputs)是生成數(shù)據(jù)的維度大小,也就是outputs的維度值。

2. 重點(diǎn)

函數(shù)中的輸入inputs只允許是序列;且序列內(nèi)部的張量元素,必須shape相等

----舉例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必須tensor_1.shape == tensor_2.shape

dim是選擇生成的維度,必須滿(mǎn)足0=dimlen(outputs);len(outputs)是輸出后的tensor的維度大小

不懂的看例子,再回過(guò)頭看就懂了。

3. 例子

1.準(zhǔn)備2個(gè)tensor數(shù)據(jù),每個(gè)的shape都是[3,3]

# 假設(shè)是時(shí)間步T1的輸出
T1 = torch.tensor([[1, 2, 3],
          [4, 5, 6],
          [7, 8, 9]])
# 假設(shè)是時(shí)間步T2的輸出
T2 = torch.tensor([[10, 20, 30],
          [40, 50, 60],
          [70, 80, 90]])

2.測(cè)試stack函數(shù)

print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'選擇的dim>len(outputs),所以報(bào)錯(cuò)'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

可以復(fù)制代碼運(yùn)行試試:拼接后的tensor形狀,會(huì)根據(jù)不同的dim發(fā)生變化。

dim shape
0 [2, 3, 3]
1 [3, 2, 3]
2 [3, 3, 2]
3 溢出報(bào)錯(cuò)

4. 總結(jié)

1、函數(shù)作用:

函數(shù)stack()對(duì)序列數(shù)據(jù)內(nèi)部的張量進(jìn)行擴(kuò)維拼接,指定維度由程序員選擇、大小是生成后數(shù)據(jù)的維度區(qū)間。

2、存在意義:

在自然語(yǔ)言處理和卷及神經(jīng)網(wǎng)絡(luò)中, 通常為了保留–[序列(先后)信息] 和 [張量的矩陣信息] 才會(huì)使用stack。

函數(shù)存在意義?》》》

手寫(xiě)過(guò)RNN的同學(xué),知道在循環(huán)神經(jīng)網(wǎng)絡(luò)中輸出數(shù)據(jù)是:一個(gè)list,該列表插入了seq_len個(gè)形狀是[batch_size, output_size]的tensor,不利于計(jì)算,需要使用stack進(jìn)行拼接,保留–[1.seq_len這個(gè)時(shí)間步]和–[2.張量屬性[batch_size, output_size]]。

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • 淺談pytorch中stack和cat的及to_tensor的坑
  • 對(duì)PyTorch torch.stack的實(shí)例講解
  • PyTorch的torch.cat用法
  • PyTorch中torch.tensor與torch.Tensor的區(qū)別詳解

標(biāo)簽:長(zhǎng)沙 山西 濟(jì)南 山西 崇左 安康 喀什 海南

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《聊聊Pytorch torch.cat與torch.stack的區(qū)別》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
    • 微信客服
    • 微信二維碼
    • 電話(huà)咨詢(xún)

    • 400-1100-266
    韶山市| 涟源市| 安吉县| 探索| 张家界市| 庄河市| 大田县| 梧州市| 鲁山县| 永吉县| 石河子市| 团风县| 黑龙江省| 青神县| 西乡县| 宾阳县| 望谟县| 石屏县| 延长县| 徐州市| 苏尼特右旗| 六安市| 延川县| 和顺县| 徐闻县| 平罗县| 乌拉特后旗| 德化县| 商河县| 石门县| 钦州市| 昌都县| 工布江达县| 龙胜| 九台市| 宜良县| 开化县| 湛江市| 普宁市| 泗阳县| 若羌县|