本文主要是介绍PyTorch|transforms.Normalize,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在训练时对图片数据进行归一化可以在梯度下降算法中更好的寻优,这是普遍认为的。那么PyTorch中的transforms.Normalize,究竟做了什么,这是应该知道的。
来看下面这个公式:x取自一组数据C, mean是这组数据的均值,而std则为标准差
x=(x-mean)/std
这也意味着,Normalize,简单来讲,就是按照此公式对输入数据进行更新,
来看这样一段代码:
import numpy as npList1=np.array([1,2,3,4])mean=np.mean(List1)std=np.std(List1)List2=(List1-mean)/std>>> List1array([1, 2, 3, 4])>>> List2array([-1.34164079, -0.4472136 , 0.4472136 , 1.34164079])
List1经过Normalize后变为List2
那么对于图片数据,Normalize具体是如何工作的呢?
假如我们有四张图片的数据,借用前面文章的数据导入方式,导入数据:
import osfrom PIL import Imageimport numpy as npfrom torchvision import transformsimport torchpath="E:\\3-10\\dogandcats\\source"IMG=[]filenames=[name for name in os.listdir(path)]for i,filename in enumerate(filenames):img=Image.open(os.path.join(path,filename))img=img.resize((28,28))#将图片像素改为28x28img=np.array(img)#将图像数据转为numpyimg=torch.tensor(img)#将numpy转换为tensor张量img=img.permute(2,0,1)#将H,W,C转换为C,H,WIMG.append(img)#得到图片列表IMGEND=torch.stack([ig for ig in IMG],dim=0)#堆叠tensor
>>> IMGEND.size()torch.Size([4, 3, 28, 28])
四张图片数据已经成功导入,并且已经转换为张量
获得r,g,b三个通道的均值
>>> mean=torch.mean(IMGEND,dim=(0,2,3),keepdim=True)>>> meantensor([[[[160.8753]],[[149.3600]],[[126.5810]]]])
获得r,g,b三个通道的标准差:
>>> std=torch.std(IMGEND,dim=(0,2,3),keepdim=True)>>> stdtensor([[[[61.7317]],[[65.0915]],[[84.2025]]]])
归一化:
process=transforms.Normalize([160.8753, 149.3600, 126.5810],[61.7317, 65.0915, 84.2025])>>> dataend1=process(IMGEND)>>> dataend1tensor([[[[-1.3587, -0.9213, -0.7269, ..., -0.3382, -0.3868, -0.4516],[-1.4397, -0.8727, -0.6135, ..., -0.1114, -0.1762, -0.2248],[-1.8771, -1.3587, -0.9375, ..., 0.1640, 0.0830, -0.1438],...,[-1.9095, -1.8285, -1.8123, ..., -2.1687, -2.2497, -2.2173],[-1.9419, -1.8609, -1.8123, ..., -2.3469, -2.4117, -2.2983],[-1.9257, -1.8447, -1.8447, ..., -2.3307, -2.3307, -2.2821]],[[-1.0502, -0.4357, -0.0055, ..., 0.4246, 0.3785, 0.3325],[-1.1424, -0.4203, 0.0406, ..., 0.5783, 0.5475, 0.5168],[-1.6340, -1.0656, -0.4664, ..., 0.7626, 0.7319, 0.5629],...,[-1.6340, -1.5572, -1.5418, ..., -1.6955, -1.7723, -1.7876],[-1.6801, -1.5879, -1.5418, ..., -1.9413, -2.0027, -1.8491],[-1.6186, -1.5726, -1.5726, ..., -1.9259, -1.8952, -1.8645]],[[-0.4938, 0.0881, 0.5988, ..., 1.0026, 0.9788, 0.9313],[-0.5888, 0.0762, 0.6107, ..., 1.0738, 1.0501, 1.0382],[-1.0758, -0.5532, 0.1000, ..., 1.1807, 1.1570, 1.0501],...,[-0.9926, -0.9332, -0.9451, ..., -1.3845, -1.4083, -1.3845],[-1.0401, -0.9689, -0.9570, ..., -1.4320, -1.4439, -1.3370],[-0.9926, -0.9451, -0.9570, ..., -1.4320, -1.4558, -1.3964]]],[[[-1.6827, -1.8609, -1.9095, ..., -0.4192, -0.4840, -0.5002],[-1.6989, -1.8285, -1.8933, ..., -0.3868, -0.4678, -0.4516],[-1.6989, -1.7961, -2.0877, ..., -0.3868, -0.4192, -0.4516],...,[ 0.7634, 0.8606, 0.8768, ..., 0.9254, 0.9092, 0.9092],[ 0.8120, 0.8930, 0.8930, ..., 0.9416, 0.8930, 0.8930],[ 0.8282, 0.9092, 0.9254, ..., 0.9254, 0.8930, 0.8930]],[[-1.9413, -2.0334, -1.9720, ..., -1.6340, -1.6340, -1.6340],[-1.9413, -2.0181, -1.9720, ..., -1.5879, -1.5572, -1.5572],[-1.9413, -1.9874, -2.0488, ..., -1.5726, -1.5265, -1.5265],...,[ 0.5936, 0.7473, 0.7473, ..., 0.8702, 0.8394, 0.8241],[ 0.6397, 0.7780, 0.7780, ..., 0.8702, 0.7780, 0.8087],[ 0.7319, 0.8241, 0.8241, ..., 0.8394, 0.8087, 0.7933]],[[-1.3608, -1.3845, -1.3370, ..., -1.2539, -1.2301, -1.2301],[-1.3608, -1.3845, -1.3252, ..., -1.2183, -1.2064, -1.2064],[-1.3608, -1.3727, -1.3964, ..., -1.2064, -1.1826, -1.1826],...,[ 0.5988, 0.7532, 0.7532, ..., 0.8719, 0.8363, 0.8363],[ 0.6700, 0.7888, 0.8007, ..., 0.8719, 0.7650, 0.8126],[ 0.7532, 0.8244, 0.8363, ..., 0.8482, 0.8126, 0.8007]]],[[[ 0.6986, 0.8282, 0.7796, ..., 0.1640, 0.0830, 0.1316],[ 0.3908, 0.5204, 0.5852, ..., 0.1964, 0.2774, 0.2126],[ 0.4070, 0.4880, 0.6014, ..., 0.0182, 0.3746, 0.2612],...,[-0.3706, -0.6135, -0.4030, ..., -0.2248, -0.2572, -0.2086],[-0.4516, -0.6783, -1.0185, ..., -0.3220, -0.3868, -0.4030],[-0.5973, -0.5973, -1.0347, ..., -0.3868, -0.4678, -0.5649]],[[ 0.6551, 0.7780, 0.6551, ..., -0.2360, -0.2513, 0.1020],[ 0.2249, 0.3478, 0.3939, ..., -0.1899, -0.1438, 0.0252],[ 0.2096, 0.2864, 0.3785, ..., -0.3282, -0.0363, -0.0055],...,[-0.1592, -0.5586, -0.6661, ..., -0.0055, -0.0363, -0.0055],[-0.2360, -0.5740, -1.1424, ..., -0.0977, -0.1284, -0.1438],[-0.3896, -0.4203, -0.9888, ..., -0.1899, -0.2206, -0.2974]],[[-0.2088, -0.0782, -0.2919, ..., -0.5770, -0.5413, 0.0050],[-0.7670, -0.6720, -0.6720, ..., -0.6363, -0.5532, -0.1257],[-0.8145, -0.7432, -0.7195, ..., -0.6720, -0.5770, -0.2919],...,[-1.4202, -1.3845, -1.0282, ..., -1.0282, -1.0045, -0.9689],[-1.4320, -1.4202, -1.2658, ..., -1.0758, -1.0401, -1.0282],[-1.4320, -1.4202, -1.4202, ..., -1.0758, -1.0995, -1.0995]]],[[[ 0.7958, 0.7958, 0.8120, ..., 0.7958, 0.7958, 0.7958],[ 0.7958, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],[ 0.8120, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],...,[ 0.7958, 0.7958, 0.7958, ..., 0.8120, 0.7958, 0.7796],[ 0.8444, 0.8444, 0.8606, ..., 0.8930, 0.8930, 0.8768],[ 0.8606, 0.8606, 0.8606, ..., 0.8930, 0.8930, 0.8930]],[[ 0.9623, 0.9623, 0.9777, ..., 0.9623, 0.9623, 0.9623],[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],...,[ 0.9623, 0.9623, 0.9623, ..., 0.9623, 0.9623, 0.9470],[ 1.0084, 1.0084, 1.0238, ..., 1.0545, 1.0545, 1.0392],[ 1.0238, 1.0238, 1.0238, ..., 1.0545, 1.0545, 1.0545]],[[ 1.2638, 1.2638, 1.2757, ..., 1.2638, 1.2638, 1.2638],[ 1.2638, 1.2757, 1.2876, ..., 1.2757, 1.2757, 1.2638],[ 1.2995, 1.2995, 1.2995, ..., 1.2757, 1.2757, 1.2638],...,[ 1.2876, 1.2876, 1.2757, ..., 1.2638, 1.2638, 1.2638],[ 1.3232, 1.3232, 1.3114, ..., 1.3351, 1.3351, 1.3232],[ 1.3351, 1.3351, 1.3114, ..., 1.3351, 1.3351, 1.3351]]]])
现在按变换公式编程进行计算:
>>> enddata=(IMGEND-mean)/std>>> enddatatensor([[[[-1.3587, -0.9213, -0.7269, ..., -0.3382, -0.3868, -0.4516],[-1.4397, -0.8727, -0.6135, ..., -0.1114, -0.1762, -0.2248],[-1.8771, -1.3587, -0.9375, ..., 0.1640, 0.0830, -0.1438],...,[-1.9095, -1.8285, -1.8123, ..., -2.1687, -2.2497, -2.2173],[-1.9419, -1.8609, -1.8123, ..., -2.3469, -2.4117, -2.2983],[-1.9257, -1.8447, -1.8447, ..., -2.3307, -2.3307, -2.2821]],[[-1.0502, -0.4357, -0.0055, ..., 0.4246, 0.3785, 0.3325],[-1.1424, -0.4203, 0.0406, ..., 0.5783, 0.5475, 0.5168],[-1.6340, -1.0656, -0.4664, ..., 0.7626, 0.7319, 0.5629],...,[-1.6340, -1.5572, -1.5418, ..., -1.6955, -1.7723, -1.7876],[-1.6801, -1.5879, -1.5418, ..., -1.9413, -2.0027, -1.8491],[-1.6186, -1.5726, -1.5726, ..., -1.9259, -1.8952, -1.8645]],[[-0.4938, 0.0881, 0.5988, ..., 1.0026, 0.9788, 0.9313],[-0.5888, 0.0762, 0.6107, ..., 1.0738, 1.0501, 1.0382],[-1.0758, -0.5532, 0.1000, ..., 1.1807, 1.1570, 1.0501],...,[-0.9926, -0.9332, -0.9451, ..., -1.3845, -1.4083, -1.3845],[-1.0401, -0.9689, -0.9570, ..., -1.4320, -1.4439, -1.3370],[-0.9926, -0.9451, -0.9570, ..., -1.4320, -1.4558, -1.3964]]],[[[-1.6827, -1.8609, -1.9095, ..., -0.4192, -0.4840, -0.5002],[-1.6989, -1.8285, -1.8933, ..., -0.3868, -0.4678, -0.4516],[-1.6989, -1.7961, -2.0877, ..., -0.3868, -0.4192, -0.4516],...,[ 0.7634, 0.8606, 0.8768, ..., 0.9254, 0.9092, 0.9092],[ 0.8120, 0.8930, 0.8930, ..., 0.9416, 0.8930, 0.8930],[ 0.8282, 0.9092, 0.9254, ..., 0.9254, 0.8930, 0.8930]],[[-1.9413, -2.0334, -1.9720, ..., -1.6340, -1.6340, -1.6340],[-1.9413, -2.0181, -1.9720, ..., -1.5879, -1.5572, -1.5572],[-1.9413, -1.9874, -2.0488, ..., -1.5726, -1.5265, -1.5265],...,[ 0.5936, 0.7473, 0.7473, ..., 0.8702, 0.8394, 0.8241],[ 0.6397, 0.7780, 0.7780, ..., 0.8702, 0.7780, 0.8087],[ 0.7319, 0.8241, 0.8241, ..., 0.8394, 0.8087, 0.7933]],[[-1.3608, -1.3845, -1.3370, ..., -1.2539, -1.2301, -1.2301],[-1.3608, -1.3845, -1.3252, ..., -1.2183, -1.2064, -1.2064],[-1.3608, -1.3727, -1.3964, ..., -1.2064, -1.1826, -1.1826],...,[ 0.5988, 0.7532, 0.7532, ..., 0.8719, 0.8363, 0.8363],[ 0.6700, 0.7888, 0.8007, ..., 0.8719, 0.7650, 0.8126],[ 0.7532, 0.8244, 0.8363, ..., 0.8482, 0.8126, 0.8007]]],[[[ 0.6986, 0.8282, 0.7796, ..., 0.1640, 0.0830, 0.1316],[ 0.3908, 0.5204, 0.5852, ..., 0.1964, 0.2774, 0.2126],[ 0.4070, 0.4880, 0.6014, ..., 0.0182, 0.3746, 0.2612],...,[-0.3706, -0.6135, -0.4030, ..., -0.2248, -0.2572, -0.2086],[-0.4516, -0.6783, -1.0185, ..., -0.3220, -0.3868, -0.4030],[-0.5973, -0.5973, -1.0347, ..., -0.3868, -0.4678, -0.5650]],[[ 0.6551, 0.7780, 0.6551, ..., -0.2360, -0.2513, 0.1020],[ 0.2249, 0.3478, 0.3939, ..., -0.1899, -0.1438, 0.0252],[ 0.2096, 0.2864, 0.3785, ..., -0.3282, -0.0363, -0.0055],...,[-0.1592, -0.5586, -0.6661, ..., -0.0055, -0.0363, -0.0055],[-0.2360, -0.5740, -1.1424, ..., -0.0977, -0.1284, -0.1438],[-0.3896, -0.4203, -0.9888, ..., -0.1899, -0.2206, -0.2974]],[[-0.2088, -0.0782, -0.2919, ..., -0.5770, -0.5413, 0.0050],[-0.7670, -0.6720, -0.6720, ..., -0.6363, -0.5532, -0.1257],[-0.8145, -0.7432, -0.7195, ..., -0.6720, -0.5770, -0.2919],...,[-1.4202, -1.3845, -1.0282, ..., -1.0282, -1.0045, -0.9689],[-1.4320, -1.4202, -1.2658, ..., -1.0758, -1.0401, -1.0282],[-1.4320, -1.4202, -1.4202, ..., -1.0758, -1.0995, -1.0995]]],[[[ 0.7958, 0.7958, 0.8120, ..., 0.7958, 0.7958, 0.7958],[ 0.7958, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],[ 0.8120, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],...,[ 0.7958, 0.7958, 0.7958, ..., 0.8120, 0.7958, 0.7796],[ 0.8444, 0.8444, 0.8606, ..., 0.8930, 0.8930, 0.8768],[ 0.8606, 0.8606, 0.8606, ..., 0.8930, 0.8930, 0.8930]],[[ 0.9623, 0.9623, 0.9777, ..., 0.9623, 0.9623, 0.9623],[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],...,[ 0.9623, 0.9623, 0.9623, ..., 0.9623, 0.9623, 0.9470],[ 1.0084, 1.0084, 1.0238, ..., 1.0545, 1.0545, 1.0392],[ 1.0238, 1.0238, 1.0238, ..., 1.0545, 1.0545, 1.0545]],[[ 1.2638, 1.2638, 1.2757, ..., 1.2638, 1.2638, 1.2638],[ 1.2638, 1.2757, 1.2876, ..., 1.2757, 1.2757, 1.2638],[ 1.2995, 1.2995, 1.2995, ..., 1.2757, 1.2757, 1.2638],...,[ 1.2876, 1.2876, 1.2757, ..., 1.2638, 1.2638, 1.2638],[ 1.3232, 1.3232, 1.3114, ..., 1.3351, 1.3351, 1.3232],[ 1.3351, 1.3351, 1.3114, ..., 1.3351, 1.3351, 1.3351]]]])
显然,两次结果一样,这也说明transforms.Normalize的实质就是使用该公式对输入数据进行变换。
同时,当transforms.Normalize接受的均值和标准差为待变换数据的均值和标准差时,按照此公式变换,得到的新的数据服从的分布一定是均值为0,标准差为1的分布。
而当transforms.Normalize接受的均值和标准差不是待变换数据的均值和标准差时,所得的新数据均值未必为0,标准差也未必为1,仅仅是按照公式变换了数据而已。
就像这样:
>>> process=transforms.Normalize([0.5, 0.6, 0.4],[0.36, 0.45, 0.45])>>> data=process(inputdata)
这里[0.5, 0.6, 0.4],[0.36, 0.45, 0.45]并不是inputdata的均值和标准差,是随意给的,仅仅是想对原数据进行变换,那么得到的新数据均值自然不一定为0,标准差也不一定为1。
当然,在我们对图片进行预处理时,往往会看到这两行代码一起出现:
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]))
这里的transforms.ToTensor()的作用就是
将输入的数据变为张量,同时shape由 W,H,C ——> C,W,H, 同时,将所有数除以255,将数据归一化到[0,1]。
根据公式:x=(x-mean)/std
得:
(0-0.5)/0.5=-1
(1-0.5)/0.5=1
可以发现:新的数据分布为[-1,1],但是新的数据均值未必为0,同时标准差也未必为0,这点需要明白。
之所以这样,是因为这里的[0.5,0.5,0.5],[0.5,0.5,0.5]并不一定就是原数据的均值和标准差。
这篇关于PyTorch|transforms.Normalize的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!