0%

SourceChangeWarning:验证集上准确率很高,但是测试集上很低

按照国(Ge)际(Ren)惯例,这种问题的解决办法直接写在最前面。答案是:训练环境中的 pytorch 版本与测试环境中的版本不一致。

今天训练一个 DST 模型。奇怪的是训练集上拥有 70%+ 的 joint acc,验证集上也有 57%+,但在测试集上仅为 2%。

经过一系列排查之后,虽然修复了一些 bug,但结果依旧没有变化。那就没辙了。。。做深度学习的人都知道,神经网络模型压根就是一个黑盒子,根本无法像写 web 应用一样调试程序。

私以为代码不可能出错,但是它在测试的时候就是会出问题。后来,在经过一下午的人肉排查之后,突然发现一个一直被我无视的警告,这是因为之前一直感觉它没什么用。错误提示如下:

1
2
3
SourceChangeWarning: source code of class 'torch.nn.modules.rnn.LSTM' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.
SourceChangeWarning: source code of class 'torch.nn.modules.loss.NLLLoss' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.

它表明源码已发生改变,但是我压根就没改过源码。经过一番排查后,我突然发现这个警告不是提醒我,我的代码发生改变,而是 pytorch 的源码发生了变化。

想到这一点,就全可以解释通了。

由于现在是暑期,我一直用远程连接,连接着实验室的服务器。而服务器的 pytorch 代码的版本与我笔记本上的不一样!

服务器上的是 pytorch=1.5.0, torchvision=0.6.0 cuda 版本,而本地的是 torch=1.4.0, torchvision=0.5.0 cpu 版本。

后来我直接在服务器上测试我的模型,发现结果终于合理了。