简介
在计算loss和评价指标时,对一些不关注的值进行mask。下面介绍mask的使用。
对loss进行mask
在NLP中的Seq2Seq中经常会对loss进行mask,因为一个batch中句子的长度通常不一样,一个batch中不足长度的位置用0填充,最后生成句子计算loss时需要忽略掉原先那些padding的值,即只保留mask中值为1的位置,忽略值为0的位置。在计算loss时,将那些本不应该计算的mask掉,使其loss为0,这样就不会反向传播了。
1 | masked_predicts = torch.masked_select(predicts, mask) |
1 | diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0 * torch.flatten(mask) |
有时候mask是舍弃一些不想关注的值,比如预测车流量时,真实车流量小于5的值则舍弃,即不关注那些车流量小的值预测结果,只关注大约5的值的预测结果。
1 | def masked_mean_squared_error(y_true, y_pred): |
Pytorch的mask_select函数
torch.masked_select(input, mask, out=None) → Tensor
返回1-D的Tensor
1 | 3, 4) x = torch.randn( |
【参考资料】