对loss进行mask

简介

在计算loss和评价指标时,对一些不关注的值进行mask。下面介绍mask的使用。

对loss进行mask

在NLP中的Seq2Seq中经常会对loss进行mask,因为一个batch中句子的长度通常不一样,一个batch中不足长度的位置用0填充,最后生成句子计算loss时需要忽略掉原先那些padding的值,即只保留mask中值为1的位置,忽略值为0的位置。在计算loss时,将那些本不应该计算的mask掉,使其loss为0,这样就不会反向传播了。

1
2
3
masked_predicts = torch.masked_select(predicts, mask)
masked_targets = torch.masked_select(targets, mask)
loss = my_criterion(masked_predicts, masked_targets)
1
2
3
diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
loss = torch.sum(diff2) / torch.sum(mask)
out.backward()

有时候mask是舍弃一些不想关注的值,比如预测车流量时,真实车流量小于5的值则舍弃,即不关注那些车流量小的值预测结果,只关注大约5的值的预测结果。

1
2
3
def masked_mean_squared_error(y_true, y_pred):
idx = (y_true > 5).nonzero()
return K.mean(K.square(y_pred[idx] - y_true[idx]))

Pytorch的mask_select函数

torch.masked_select(input, mask, out=None) → Tensor
返回1-D的Tensor

1
2
3
4
5
6
7
8
9
10
11
12
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, True, True],
[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])

【参考资料】

浅谈mask矩阵
pytorch-DCRNN

打赏
0%