问题
最近在跑代码的过程中无意间碰到了一个错误,是关于Pytorch中nn.Upsampling()
的参数类型问题。现在许多人都会用这种写法,所以特意在这里分享一下解决的方法。
问题详情
注意:Python版本为Python2.7
首先看看报错的信息,如图:
从报错信息中我们可以得知,这里的问题出在了scale_factor
这个参数的类型。我们先来看看代码的写法:
1 | self.upsample = nn.Upsample(scale_factor=(1, 4, 4), mode='trillinear', align_corners=True) |
这里就涉及到了nn.Upsample()
这个函数,我们先来看看这个函数的源码。
这样一看更慌了,因为这里的说明提到scale_factor
参数是可以接受tuple
的,但是现在传入tuple
却报错了。不用着急,其实问题很简单,下面来解决。
解决办法
报错的地方主要是在使用float()
类型转换的时候的问题。既然不能够一次过转换,那么我们对tuple
中的每一个元素单独进行转换就好了。修改代码如下:
首先删除下面一行代码(Row 125):
1 | self.scale_fac = float(scale_factor) if scale_factor else None |
然后在Row 125添加如下代码:
1 | if isinstance(scale_factor, tuple): |
这样就可以完美解决问题了,希望这篇博客能够帮助你!