1. 简介
最近发现一个学习Pytorch的教程,有视频版和文字版deeplizard,这里面详细介绍了关于Tensor的知识,真的讲得超级好,解决了我很多关于Tensor运算的疑惑,在此记录下。
参考资料:Tensor官方文档
2. 创建Tensor
创建Tensor有四种方式
1 | #创建Tensor |
torch.Tensor()
和torch.tensor()
区别torch.Tensor()
是Tensor
类的构造函数,torch.tensor()
是factory function,该函数将传入的参数构造成一个Tensor
对象并返回。
以上4个函数中,torch.Tensor()
是构造函数,其余都是factory function。- 这4个函数的主要区别是:
torch.Tensor()
返回的Tensor
默认是float32
类型,而其他3个函数返回的Tensor
数据类型根据传入的数据而定。并且其他3个函数可以传入dtype
来指定数据的类型,但是torch.Tensor()
不能传入dtype
参数。 通过
np.array
来创建Tensor
,然后改变data的值,可以看到,前2个Tensor
的值并没有改变,后2个Tensor
的值改变。这是因为torch.Tensor()
和torch.tensor()
是copy输入数据的值,而torch.as_tensor()
和torch.from_numpy()
是share输入数据的memoryShara Data | Copy Data
-|-|-|
torch.as_tensor() | torch.tensor()|
torch.from_numpy() | torch.Tensor()|torch.as_tensor()
和torch.from_numpy()
都是factory function,且都是share data,那这2个函数有什么区别?torch.from_numpy()
仅仅接受np.array
的参数,然而torch.as_tensor()
接受array-like objects类型的参数综上所述,下面2个方法是创建Tensor的推荐方法:
torch.tensor()
torch.as_tensor()
3. Tensor的4类操作
3.1. Reshape操作
3.1.1. reshape
1 | > t = torch.tensor([ |
有2种方式获取Tensor的shape:t.size()和t.shape
3.1.2. squeeze和unsqueeze函数
torch.squeeze(input, dim=None, out=None) → Tensor
将维度中的1去掉,如果不指定dim,则去掉所有维度上的1;如果指定dim,则只去掉该维度上的1。dim是可选项
- 输入维度是(A×1×B×C×1×D),不指定dim,输出维度(A×B×CxD)
- 输入维度是(A×1×B×C×1×D),不指定dim=1,输出维度(A×B×C×1×D)
torch.unsqueeze(input, dim, out=None) → Tensor
在指定维度上增加1个维度。dim是必填项torch.unsqueeze(x, 0)
3.1.3. cat函数
torch.cat(tensors, dim=0, out=None) → Tensor
如果要拼接多个Tensor,需要将多个Tensor包装成tuple,例如:
1 | torch.cat((t1,t2,t3),dim=0) |
例如t1,t2,t3
的维度都是(2,5,7),上面在dim=0
上进行拼接,则得到为维度是(6,5,7)
cat不改变数据维度个数,例如原先是3维数据,拼接之后还是3维
3.1.4. stack函数
torch.stack(tensors, dim=0, out=None) → Tensor
1 | torch.stack((t1,t2,t3),dim=0) |
例如t1,t2,t3
的维度都是(2,5,7),上面在dim=0
上进行拼接,则得到为维度是(3,2,5,7)
stack改变数据维度个数,增加一个维度。例如原先是3维数据,拼接之后变成4维
3.1.5. cat和stack的区别
cat和stack的区别可以用一句话描述:
- cat不会改变数据维度个数,原先是3维数据,n个tensor进行cat之后还是3维数据。
- stack会增加维度个数,原先是3维,n个tensor进行stack会变成4维数据
3.2. rank和shape
- rank就是Tensor的维数
- shape就是Tensor的维度
例如
m = [
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]
]
rank = 2
shape = (3,3)
3.3. Element-wise操作
Broadcasting and Element-wise Operations with PyTorch
Broadcasting Explained
逐元素有以下4种叫法,意思都一样:
- Element-wise
- Component-wise
- Point-wise
逐元素操作有以下几种:
t1+t2
维度相同
其中t1和t2维度相同t1+2, t1-2, t1*2, t1/2
实际上是对2进行了broadcasting
,然后再和t1运算t1+t2
,rank相同,维度不同
这种情况比较复杂。首先我们先看这2个Tensor在所有维度上是否兼容。判断2个Tensor在维度上是否兼容有2个条件,只要满足其中的一个条件就兼容,否则不兼容。
- 相等
有一个值维1
例如:t1维度(1,3),t2维度(3,1),从后往前对比,我们先看第二个维度的值,分别是3和1,不相等但是满足第二个条件,即第二个维度上兼容。再看第一个维度,分别是1和3,满足第二个条件,即第一个维度上兼容。所以2个Tensor在所有维度上兼容,可以进行下一步的操作。如果不兼容,则这2个Tensor无法进行逐元素运算。
决定最终结果的输出维度。还是要看2个Tensor的维度。从后往前对比, t1维度(1,3),t2维度(3,1),先看第二维度是3和1,取最大值作为输出的第一个维度,即3,再看第一维度1和3,也是3作为输出的第二个维度。即输出的维度是(3,3)。
- 分别将t1维度(1,3),t2维度(3,1)进行广播成(3,3),然后再进行相加,得到最终的结果。
t1+t2
,rank不同- 例子1:t1的维度(2,4),t2的维度是(4,),这2个Tensor也可以进行,实际是先将低rank的t2最后一维和t1的最后一维相等,都等于4,但是t2只有一维,那就在缺失的维度上补1,变成(1,4),然后再广播成(2,4)维度,然后再和t1计算。
- 例子2:t1的维度(2,4),t2的维度是(2,),这2个Tensor不可以进行。因为t1和t2的最后一维分别是4和2,不相等也不等于1,不兼容,无法进行下一步。
- 例子3:t1维度(1,2,3),t2维度(3,3),这个Tensor就不能做逐元素操作。先看所有维度是否兼容。最后一个维度3和3,相等即兼容,再看前一个维度2和3,既不相等也不等于1,不兼容。则不能进行逐元素操作
以上2,3,4情况都涉及到了broadcasting的知识。
比较操作
比较也是逐元素操作的一种,1
2> torch.tensor([1, 2, 3]) < torch.tensor([3, 1, 2])
tensor([True, False, False])
3.4. Reduction操作
聚合操作:减少Tesnor中元素的个数。
1 | > t = torch.tensor([ |
sum()
返回的结果是scalar类型(0维的Tensor),只包含1个元素
3.4.1. 沿着某个axis聚合
1 | > t = torch.tensor([ |
3.4.2. Argmax函数介绍
当一个Tensor变量a调用argmax()
函数时,返回只包含1个元素的Tensor,该元素表示a中最大值的下标。
1 | t = torch.tensor([ |
如果argmax()
没有指定axis,则返回整个Tensor最大值的下标。如果指定axis,则返回指定轴上最大值下标。
1 | > t.max(dim=0) |
当调用max()
函数时,返回2个Tensor,第一个Tensor表示返回轴上最大的值,第2个Tensor返回最大值的下标,也就是argmax()
的返回值。
通常argmax()
通常用在分类任务的输出上,决定哪类有最高的预测值。
3.5. Access操作
1 | > t = torch.tensor([ |
如果返回的结果是scalar,只有1个元素,使用item()来获取其中的值。
如果返回的结果有多个值,可以将Tensor转换为pyhton中的list和array.
4. Tensor的4种乘法
torch.Tensor有4种常见的乘法:*, torch.mul, torch.mm, torch.matmul
4.1. 点乘 *
点乘需要满足以下3个条件的任何一个
- 2个Tensor维度完全一样
- 2个Tensor可以boardcast到相同的维度
- 1个Tensor和1个实数的操作
例如
1 | a = torch.ones(3,4) |
1 | a = torch.ones(3,4) |
1 | a = torch.ones(3,4) |
4.2. torch.mul
用法和*相同,也是element-wise的乘法,支持boardcast
乘实数
1
2
3
4
5
6
3) a = torch.randn(
a
tensor([ 0.2015, -0.4255, 2.6087])
100) torch.mul(a,
tensor([ 20.1494, -42.5491, 260.8663])乘矩阵
1
2
3
4
5
6
7
8
9
10
11
4, 1) a = torch.randn(
a
tensor([[ 1.1207],
[-0.3137],
[ 0.0700],
[ 0.8378]])
1, 4) b = torch.randn(
b
tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]])
torch.mul(a, b)
4.3. torch.mm
数学中的矩阵乘法,要求2个矩阵满足矩阵乘法的要求
1 |
|
torch.matmul
官网关于torch.matmul介绍. torch.mm的broadcast版本
当输入都是二维Tensor时,就是普通的矩阵乘法,与
tensor.mm
用法相同1
2
3
4
53,4) a = torch.ones(
4,5) b = torch.ones(
c = torch.matmul(a,b)
c.shape
torch.Size([3, 5])当输入有多维时,把多出来的一维看做是batch,其他部分做矩阵乘法
2个Tensor都是3维
将b的第0维先boardcast成2之后,然后提出来作为batch,后两维做矩阵乘法即可1
2
3
4
52,5,3) a = torch.ones(
1,3,4) b = torch.ones(
c = torch.matmul(a,b)
c.shape
torch.Size([2, 5, 4])更复杂
- 先将a的第0维的2提出来作为batch
- 剩下的a和b都可以看做3维,再将a中1进行boardcast成5,再提出来作为batch
- 剩下的(3,4)和(4,2)进行普通矩阵乘法
1
2
3
4
52,1,3,4) a = torch.ones(
5,4,2) b = torch.ones(
c = torch.matmul(a,b)
c.shape
torch.Size([2, 5, 3, 2])