Halo

A magic place for coding

0%

pytorch的Upsampling函数参数问题

问题

  最近在跑代码的过程中无意间碰到了一个错误,是关于Pytorch中nn.Upsampling()的参数类型问题。现在许多人都会用这种写法,所以特意在这里分享一下解决的方法。

问题详情

注意:Python版本为Python2.7
  首先看看报错的信息,如图:
Pytorch upsampling问题
  从报错信息中我们可以得知,这里的问题出在了scale_factor这个参数的类型。我们先来看看代码的写法:

1
self.upsample = nn.Upsample(scale_factor=(1, 4, 4), mode='trillinear', align_corners=True)

  这里就涉及到了nn.Upsample()这个函数,我们先来看看这个函数的源码。
Upsampling源码
  这样一看更慌了,因为这里的说明提到scale_factor参数是可以接受tuple的,但是现在传入tuple却报错了。不用着急,其实问题很简单,下面来解决。

解决办法

  报错的地方主要是在使用float()类型转换的时候的问题。既然不能够一次过转换,那么我们对tuple中的每一个元素单独进行转换就好了。修改代码如下:
首先删除下面一行代码(Row 125):

1
self.scale_fac = float(scale_factor) if scale_factor else None

然后在Row 125添加如下代码:

1
2
3
4
5
6
if isinstance(scale_factor, tuple):
self.scale_factor = ()
for factor in scale_factor:
self.scale_factor += (float(factor),)
else:
self.scale_factor = float(scale_factor) if scale_factor else None

  这样就可以完美解决问题了,希望这篇博客能够帮助你!

Welcome to my other publishing channels