Pytorch之Tensor学习

1. 简介

最近发现一个学习Pytorch的教程,有视频版和文字版deeplizard,这里面详细介绍了关于Tensor的知识,真的讲得超级好,解决了我很多关于Tensor运算的疑惑,在此记录下。

参考资料:Tensor官方文档

2. 创建Tensor

创建Tensor有四种方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#创建Tensor
> data = np.array([1,2,3])

> o1 = torch.Tensor(data)
> o2 = torch.tensor(data)
> o3 = torch.as_tensor(data)
> o4 = torch.from_numpy(data)

> print(o1)
> print(o2)
> print(o3)
> print(o4)
tensor([1., 2., 3.])
tensor([1, 2, 3], dtype=torch.int32)
tensor([1, 2, 3], dtype=torch.int32)
tensor([1, 2, 3], dtype=torch.int32)

> print(o1.dtype)
> print(o2.dtype)
> print(o3.dtype)
> print(o4.dtype)
torch.float32
torch.int32
torch.int32
torch.int32

#内存是否共享
> print('old:', data)
old: [1 2 3]

> data[0] = 0

> print('new:', data)
new: [0 2 3]

> print(o1)
> print(o2)
> print(o3)
> print(o4)

tensor([1., 2., 3.])
tensor([1, 2, 3], dtype=torch.int32)
tensor([0, 2, 3], dtype=torch.int32)
tensor([0, 2, 3], dtype=torch.int32)
  • torch.Tensor()torch.tensor()区别
    torch.Tensor()Tensor类的构造函数,torch.tensor()是factory function,该函数将传入的参数构造成一个Tensor对象并返回。
    以上4个函数中,torch.Tensor()是构造函数,其余都是factory function。
  • 这4个函数的主要区别是:torch.Tensor()返回的Tensor默认是float32类型,而其他3个函数返回的Tensor数据类型根据传入的数据而定。并且其他3个函数可以传入dtype来指定数据的类型,但是torch.Tensor()不能传入dtype参数。
  • 通过np.array来创建Tensor,然后改变data的值,可以看到,前2个Tensor的值并没有改变,后2个Tensor的值改变。这是因为torch.Tensor()torch.tensor()是copy输入数据的值,而torch.as_tensor()torch.from_numpy()是share输入数据的memory

    Shara Data | Copy Data
    -|-|-|
    torch.as_tensor() | torch.tensor()|
    torch.from_numpy() | torch.Tensor()|

  • torch.as_tensor()torch.from_numpy()都是factory function,且都是share data,那这2个函数有什么区别?torch.from_numpy()仅仅接受np.array的参数,然而torch.as_tensor()接受array-like objects类型的参数

  • 综上所述,下面2个方法是创建Tensor的推荐方法:

    • torch.tensor()
    • torch.as_tensor()

3. Tensor的4类操作

3.1. Reshape操作

3.1.1. reshape

1
2
3
4
5
> t = torch.tensor([
[1,1,1,1],
[2,2,2,2],
[3,3,3,3]
], dtype=torch.float32)

有2种方式获取Tensor的shape:t.size()和t.shape

3.1.2. squeeze和unsqueeze函数

  1. torch.squeeze(input, dim=None, out=None) → Tensor
    将维度中的1去掉,如果不指定dim,则去掉所有维度上的1;如果指定dim,则只去掉该维度上的1。dim是可选项
  • 输入维度是(A×1×B×C×1×D),不指定dim,输出维度(A×B×CxD)
  • 输入维度是(A×1×B×C×1×D),不指定dim=1,输出维度(A×B×C×1×D)
  1. torch.unsqueeze(input, dim, out=None) → Tensor
    在指定维度上增加1个维度。dim是必填项
    torch.unsqueeze(x, 0)

3.1.3. cat函数

torch.cat(tensors, dim=0, out=None) → Tensor
如果要拼接多个Tensor,需要将多个Tensor包装成tuple,例如:

1
torch.cat((t1,t2,t3),dim=0)

例如t1,t2,t3的维度都是(2,5,7),上面在dim=0上进行拼接,则得到为维度是(6,5,7)

cat不改变数据维度个数,例如原先是3维数据,拼接之后还是3维

3.1.4. stack函数

torch.stack(tensors, dim=0, out=None) → Tensor

1
torch.stack((t1,t2,t3),dim=0)

例如t1,t2,t3的维度都是(2,5,7),上面在dim=0上进行拼接,则得到为维度是(3,2,5,7)

stack改变数据维度个数,增加一个维度。例如原先是3维数据,拼接之后变成4维

3.1.5. cat和stack的区别

cat和stack的区别可以用一句话描述:

  • cat不会改变数据维度个数,原先是3维数据,n个tensor进行cat之后还是3维数据。
  • stack会增加维度个数,原先是3维,n个tensor进行stack会变成4维数据

3.2. rank和shape

  • rank就是Tensor的维数
  • shape就是Tensor的维度

例如

m = [
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]
]

rank = 2
shape = (3,3)

3.3. Element-wise操作

Broadcasting and Element-wise Operations with PyTorch
Broadcasting Explained

逐元素有以下4种叫法,意思都一样:

  • Element-wise
  • Component-wise
  • Point-wise

逐元素操作有以下几种:

  1. t1+t2维度相同
    其中t1和t2维度相同

  2. t1+2, t1-2, t1*2, t1/2
    实际上是对2进行了broadcasting,然后再和t1运算

  3. t1+t2,rank相同,维度不同
    这种情况比较复杂。

    • 首先我们先看这2个Tensor在所有维度上是否兼容。判断2个Tensor在维度上是否兼容有2个条件,只要满足其中的一个条件就兼容,否则不兼容。

      • 相等
      • 有一个值维1

        例如:t1维度(1,3),t2维度(3,1),从后往前对比,我们先看第二个维度的值,分别是3和1,不相等但是满足第二个条件,即第二个维度上兼容。再看第一个维度,分别是1和3,满足第二个条件,即第一个维度上兼容。所以2个Tensor在所有维度上兼容,可以进行下一步的操作。如果不兼容,则这2个Tensor无法进行逐元素运算。

    • 决定最终结果的输出维度。还是要看2个Tensor的维度。从后往前对比, t1维度(1,3),t2维度(3,1),先看第二维度是3和1,取最大值作为输出的第一个维度,即3,再看第一维度1和3,也是3作为输出的第二个维度。即输出的维度是(3,3)。

    • 分别将t1维度(1,3),t2维度(3,1)进行广播成(3,3),然后再进行相加,得到最终的结果。
  4. t1+t2,rank不同

    • 例子1:t1的维度(2,4),t2的维度是(4,),这2个Tensor也可以进行,实际是先将低rank的t2最后一维和t1的最后一维相等,都等于4,但是t2只有一维,那就在缺失的维度上补1,变成(1,4),然后再广播成(2,4)维度,然后再和t1计算。
    • 例子2:t1的维度(2,4),t2的维度是(2,),这2个Tensor不可以进行。因为t1和t2的最后一维分别是4和2,不相等也不等于1,不兼容,无法进行下一步。
    • 例子3:t1维度(1,2,3),t2维度(3,3),这个Tensor就不能做逐元素操作。先看所有维度是否兼容。最后一个维度3和3,相等即兼容,再看前一个维度2和3,既不相等也不等于1,不兼容。则不能进行逐元素操作

以上2,3,4情况都涉及到了broadcasting的知识。

  1. 比较操作
    比较也是逐元素操作的一种,

    1
    2
    > torch.tensor([1, 2, 3]) < torch.tensor([3, 1, 2])
    tensor([True, False, False])

3.4. Reduction操作

聚合操作:减少Tesnor中元素的个数。

1
2
3
4
5
6
7
> t = torch.tensor([
[0,1,0],
[2,0,2],
[0,3,0]
], dtype=torch.float32)
> t.sum()
tensor(8.)

sum()返回的结果是scalar类型(0维的Tensor),只包含1个元素

3.4.1. 沿着某个axis聚合

1
2
3
4
5
6
7
> t = torch.tensor([
[1,1,1,1],
[2,2,2,2],
[3,3,3,3]
], dtype=torch.float32)
> t.sum(dim=0)
tensor([6., 6., 6., 6.])

3.4.2. Argmax函数介绍

当一个Tensor变量a调用argmax()函数时,返回只包含1个元素的Tensor,该元素表示a中最大值的下标。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
t = torch.tensor([
[1,0,0,2],
[0,3,3,0],
[4,0,0,5]
], dtype=torch.float32)

> t.max()
tensor(5.)

> t.argmax()
tensor(11)

> t.flatten()
tensor([1., 0., 0., 2., 0., 3., 3., 0., 4., 0., 0., 5.])

如果argmax()没有指定axis,则返回整个Tensor最大值的下标。如果指定axis,则返回指定轴上最大值下标。

1
2
3
4
5
6
7
8
9
10
11
> t.max(dim=0)
(tensor([4., 3., 3., 5.]), tensor([2, 1, 1, 2]))

> t.argmax(dim=0)
tensor([2, 1, 1, 2])

> t.max(dim=1)
(tensor([2., 3., 5.]), tensor([3, 1, 3]))

> t.argmax(dim=1)
tensor([3, 1, 3])

当调用max()函数时,返回2个Tensor,第一个Tensor表示返回轴上最大的值,第2个Tensor返回最大值的下标,也就是argmax()的返回值。
通常argmax()通常用在分类任务的输出上,决定哪类有最高的预测值。

3.5. Access操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
> t = torch.tensor([
[1,2,3],
[4,5,6],
[7,8,9]
], dtype=torch.float32)

> t.mean()
tensor(5.)

> t.mean().item()
5.0

> t.mean(dim=0).tolist()
[4.0, 5.0, 6.0]

> t.mean(dim=0).numpy()
array([4., 5., 6.], dtype=float32)

如果返回的结果是scalar,只有1个元素,使用item()来获取其中的值。
如果返回的结果有多个值,可以将Tensor转换为pyhton中的list和array.

4. Tensor的4种乘法

参考资料

torch.Tensor有4种常见的乘法:*, torch.mul, torch.mm, torch.matmul

4.1. 点乘 *

点乘需要满足以下3个条件的任何一个

  • 2个Tensor维度完全一样
  • 2个Tensor可以boardcast到相同的维度
  • 1个Tensor和1个实数的操作

例如

1
2
a = torch.ones(3,4)
a * 2
1
2
3
4
a = torch.ones(3,4)
b = torch.Tensor([1,2,3,4])
#a的维度(3,4),b的维度(4,),可以将b进行boardcast成(3,4)的维度,再和a相乘
a * b
1
2
3
a = torch.ones(3,4)
b = torch.ones(3,4)
a * b

4.2. torch.mul

官网关于torch.mul介绍

用法和*相同,也是element-wise的乘法,支持boardcast

  1. 乘实数

    1
    2
    3
    4
    5
    6

    >>> a = torch.randn(3)
    >>> a
    tensor([ 0.2015, -0.4255, 2.6087])
    >>> torch.mul(a, 100)
    tensor([ 20.1494, -42.5491, 260.8663])
  2. 乘矩阵

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11

    >>> a = torch.randn(4, 1)
    >>> a
    tensor([[ 1.1207],
    [-0.3137],
    [ 0.0700],
    [ 0.8378]])
    >>> b = torch.randn(1, 4)
    >>> b
    tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]])
    >>> torch.mul(a, b)

4.3. torch.mm

官网关于torch.mm介绍

数学中的矩阵乘法,要求2个矩阵满足矩阵乘法的要求

1
2
3
4
5
6
7

>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
[4., 4.],
[4., 4.]])

torch.matmul

官网关于torch.matmul介绍. torch.mm的broadcast版本

  1. 当输入都是二维Tensor时,就是普通的矩阵乘法,与tensor.mm用法相同

    1
    2
    3
    4
    5
    >>> a = torch.ones(3,4)
    >>> b = torch.ones(4,5)
    >>> c = torch.matmul(a,b)
    >>> c.shape
    torch.Size([3, 5])
  2. 当输入有多维时,把多出来的一维看做是batch,其他部分做矩阵乘法

  3. 2个Tensor都是3维
    将b的第0维先boardcast成2之后,然后提出来作为batch,后两维做矩阵乘法即可

    1
    2
    3
    4
    5
    >>> a = torch.ones(2,5,3)
    >>> b = torch.ones(1,3,4)
    >>> c = torch.matmul(a,b)
    >>> c.shape
    torch.Size([2, 5, 4])
  4. 更复杂

    • 先将a的第0维的2提出来作为batch
    • 剩下的a和b都可以看做3维,再将a中1进行boardcast成5,再提出来作为batch
    • 剩下的(3,4)和(4,2)进行普通矩阵乘法
    1
    2
    3
    4
    5
    >>> a = torch.ones(2,1,3,4)
    >>> b = torch.ones(5,4,2)
    >>> c = torch.matmul(a,b)
    >>> c.shape
    torch.Size([2, 5, 3, 2])
打赏
0%