Error:output with shape [1, 224, 224] doesn‘t match the broadcast shape [3, 224, 224]
生活随笔
收集整理的這篇文章主要介紹了
Error:output with shape [1, 224, 224] doesn‘t match the broadcast shape [3, 224, 224]
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Error:output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
原模型輸入的圖片為RGB三通道,輸入的為單通道灰度圖片。
解決如下:
from torch import nn from torchvision import datasets from torchvision import transforms as T from torch.utils.data import DataLoader from torchvision.utils import make_grid, save_image import numpy as np import matplotlib.pyplot as plttransform = T.Compose([T.ToTensor(), #這會將介于0到255之間的numpy數組轉換為介于0到1之間的浮點張量T.Normalize((0.5, ), (0.5, )), #在normalize()方法中, 我們指定了用來標準化張量圖像所有通道的均值, 并且還指定了中心偏差。 ]) dataset = datasets.MNIST('data/', download=True, train=False, transform=transform) dataloader = DataLoader(dataset, shuffle=True, batch_size=100)print(type(dataset[0][0]),dataset[0][0].size()) # print(dataset[0][0]) # 要繪制張量圖像, 我們必須將其更改回numpy array。 # 我們將在函數def im_convert()中完成此工作, 該函數包含一個將成為張量圖像的參數。 def im_convert(tensor):image=tensor.clone().detach().numpy()# 使用torch.clone()獲得的新tensor和原來的數據不再共享內存,但仍保留在計算圖中,# clone操作在不共享數據內存的同時支持梯度梯度傳遞與疊加,所以常用在神經網絡中某個單元需要重復使用的場景下。# 通常如果原tensor的requires_grad=True,則:# clone()操作后的tensor requires_grad=True# detach()操作后的tensor requires_grad=False。image=image.transpose(1, 2, 0)# 將轉換為numpy數組的張量具有第一, 第二和第三維的形狀。第一維表示顏色通道, 第二維和第三維表示圖像和像素的寬度和高度。# 我們知道MNIST數據集中的每個圖像都是對應于單個彩色通道的灰度, 其寬度和高度為28 * 28像素。因此, 形狀將為(1、28、28)。# 為了繪制圖像, 要求圖像的形狀為(28, 28, 1)。因此, 通過將軸零, 一和二交換print(image.shape)image=image*(np.array((0.5, 0.5, 0.5))+np.array((0.5, 0.5, 0.5)))print(image.shape)# 我們對圖像進行歸一化, 而之前我們必須對其進行歸一化。通過減去平均值并除以標準偏差來完成歸一化。# 我們將乘以標準偏差, 然后將平均值相加image=image.clip(0, 1)print(image.shape,type(image))return image# 為了確保介于0和1之間的范圍, 我們使用了clip()# 函數并傳遞了零和一作為參數。我們將clip函數應用到最小值0和最大值1并返回圖像。# 它將創建一個對象, 該對象使我們可以一次通過一個可變的訓練加載器。 # 我們通過在dataiter上調用next來一次訪問一個元素。 # next()函數將獲取我們的第一批訓練數據, 并且該訓練數據將被分為以下圖像和標簽 dataiter=iter(dataloader) images, labels=dataiter.next()fig=plt.figure(figsize=(25, 6)) #fig=plt.figure(figsize=(25, 4)) #圖片輸出寬度較上面小 for idx in np.arange(20):ax=fig.add_subplot(2, 10, idx+1)plt.imshow(im_convert(images[idx]))ax.set_title([labels[idx].item()]) plt.show()最終結果如下:
總結
以上是生活随笔為你收集整理的Error:output with shape [1, 224, 224] doesn‘t match the broadcast shape [3, 224, 224]的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Netgear Readyshare:U
- 下一篇: web中网络编程详解