简介
在计算loss和评价指标时,对一些不关注的值进行mask。下面介绍mask的使用。
对loss进行mask
在NLP中的Seq2Seq中经常会对loss进行mask,因为一个batch中句子的长度通常不一样,一个batch中不足长度的位置用0填充,最后生成句子计算loss时需要忽略掉原先那些padding的值,即只保留mask中值为1的位置,忽略值为0的位置。在计算loss时,将那些本不应该计算的mask掉,使其loss为0,这样就不会反向传播了。
1  | masked_predicts = torch.masked_select(predicts, mask)  | 
1  | diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)  | 
有时候mask是舍弃一些不想关注的值,比如预测车流量时,真实车流量小于5的值则舍弃,即不关注那些车流量小的值预测结果,只关注大约5的值的预测结果。
1  | def masked_mean_squared_error(y_true, y_pred):  | 
Pytorch的mask_select函数
torch.masked_select(input, mask, out=None) → Tensor
返回1-D的Tensor
1  | x = torch.randn(3, 4)  | 
【参考资料】