佳木斯湛栽影视文化发展公司

主頁 > 知識庫 > pytorch中的model=model.to(device)使用說明

pytorch中的model=model.to(device)使用說明

熱門標(biāo)簽:語音系統(tǒng) 電話運(yùn)營中心 百度AI接口 硅谷的囚徒呼叫中心 呼叫中心市場需求 企業(yè)做大做強(qiáng) 客戶服務(wù) Win7旗艦版

這代表將模型加載到指定設(shè)備上。

其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")則代表的使用GPU。

當(dāng)我們指定了設(shè)備之后,就需要將模型加載到相應(yīng)設(shè)備中,此時需要使用model=model.to(device),將模型加載到相應(yīng)的設(shè)備中。

將由GPU保存的模型加載到CPU上。

將torch.load()函數(shù)中的map_location參數(shù)設(shè)置為torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

將由GPU保存的模型加載到GPU上。確保對輸入的tensors調(diào)用input = input.to(device)方法。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

將由CPU保存的模型加載到GPU上。

確保對輸入的tensors調(diào)用input = input.to(device)方法。map_location是將模型加載到GPU上,model.to(torch.device('cuda'))是將模型參數(shù)加載為CUDA的tensor。

最后保證使用.to(torch.device('cuda'))方法將需要使用的參數(shù)放入CUDA。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

補(bǔ)充:pytorch中model.to(device)和map_location=device的區(qū)別

一、簡介

在已訓(xùn)練并保存在CPU上的GPU上加載模型時,加載模型時經(jīng)常由于訓(xùn)練和保存模型時設(shè)備不同出現(xiàn)讀取模型時出現(xiàn)錯誤,在對跨設(shè)備的模型讀取時候涉及到兩個參數(shù)的使用,分別是model.to(device)和map_location=devicel兩個參數(shù),簡介一下兩者的不同。

將map_location函數(shù)中的參數(shù)設(shè)置 torch.load()為 cuda:device_id。這會將模型加載到給定的GPU設(shè)備。

調(diào)用model.to(torch.device('cuda'))將模型的參數(shù)張量轉(zhuǎn)換為CUDA張量,無論在cpu上訓(xùn)練還是gpu上訓(xùn)練,保存的模型參數(shù)都是參數(shù)張量不是cuda張量,因此,cpu設(shè)備上不需要使用torch.to(torch.device("cpu"))。

二、實(shí)例

了解了兩者代表的意義,以下介紹兩者的使用。

1、保存在GPU上,在CPU上加載

保存:

torch.save(model.state_dict(), PATH)

加載:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

解釋:

在使用GPU訓(xùn)練的CPU上加載模型時,請傳遞 torch.device('cpu')給map_location函數(shù)中的 torch.load()參數(shù),使用map_location參數(shù)將張量下面的存儲器動態(tài)地重新映射到CPU設(shè)備 。

2、保存在GPU上,在GPU上加載

保存:

torch.save(model.state_dict(), PATH)

加載:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解釋:

在GPU上訓(xùn)練并保存在GPU上的模型時,只需將初始化model模型轉(zhuǎn)換為CUDA優(yōu)化模型即可model.to(torch.device('cuda'))。

此外,請務(wù)必.to(torch.device('cuda'))在所有模型輸入上使用該 功能來準(zhǔn)備模型的數(shù)據(jù)。

請注意,調(diào)用my_tensor.to(device) 返回my_tensorGPU上的新副本。

它不會覆蓋 my_tensor。

因此,請記住手動覆蓋張量: my_tensor = my_tensor.to(torch.device('cuda'))

3、保存在CPU,在GPU上加載

保存:

torch.save(model.state_dict(), PATH)

加載:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

解釋:

在已訓(xùn)練并保存在CPU上的GPU上加載模型時,請將map_location函數(shù)中的參數(shù)設(shè)置 torch.load()為 cuda:device_id。

這會將模型加載到給定的GPU設(shè)備。

接下來,請務(wù)必調(diào)用model.to(torch.device('cuda'))將模型的參數(shù)張量轉(zhuǎn)換為CUDA張量。

最后,確保.to(torch.device('cuda'))在所有模型輸入上使用該 函數(shù)來為CUDA優(yōu)化模型準(zhǔn)備數(shù)據(jù)。

請注意,調(diào)用 my_tensor.to(device)返回my_tensorGPU上的新副本。

它不會覆蓋my_tensor。

因此,請記住手動覆蓋張量:my_tensor = my_tensor.to(torch.device('cuda'))

以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • 聊聊pytorch測試的時候?yàn)楹我由蟤odel.eval()
  • pytorch中的model.eval()和BN層的使用
  • 解決Pytorch中的神坑:關(guān)于model.eval的問題
  • Pytorch BertModel的使用說明
  • PyTorch中model.zero_grad()和optimizer.zero_grad()用法
  • pytorch掉坑記錄:model.eval的作用說明
  • pytorch:model.train和model.eval用法及區(qū)別詳解
  • pytorch 修改預(yù)訓(xùn)練model實(shí)例
  • pytorch查看torch.Tensor和model是否在CUDA上的實(shí)例

標(biāo)簽:長沙 海南 山西 崇左 山西 喀什 濟(jì)南 安康

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《pytorch中的model=model.to(device)使用說明》,本文關(guān)鍵詞  ;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 收縮
    • 微信客服
    • 微信二維碼
    • 電話咨詢

    • 400-1100-266
    西平县| 玉溪市| 莱芜市| 东安县| 连州市| 金堂县| 长泰县| 准格尔旗| 视频| 龙川县| 阿城市| 隆安县| 凤山市| 濉溪县| 斗六市| 新密市| 渝北区| 阿克陶县| 巴彦淖尔市| 凤山县| 雷州市| 阳新县| 黄石市| 株洲市| 巩留县| 古浪县| 长寿区| 临沭县| 延吉市| 锡林郭勒盟| 东宁县| 太仆寺旗| 轮台县| 定南县| 肃南| 安达市| 田东县| 黔西| 杨浦区| 十堰市| 龙川县|