上海网站建设开发电话,网站接电话,生产管理网站开发,网站推广软件推荐在用 torch.nn.Upsample 给分割 label 上采样时报错#xff1a;RuntimeError: upsample_nearest2d_out_frame not implemented for Long。
参考 [1-3]#xff0c;用 [3] 给出的实现。稍微扩展一下#xff0c;支持 h、w 用不同的 scale factor#xff0c;并测试…在用 torch.nn.Upsample 给分割 label 上采样时报错RuntimeError: upsample_nearest2d_out_frame not implemented for Long。
参考 [1-3]用 [3] 给出的实现。稍微扩展一下支持 h、w 用不同的 scale factor并测试其与 PyTorch 的几个 upsample 类的异同验证 [3] 的实现用 nearest 插值。
Code
linear 要 3D 输入、trilinear 要 5D 输入故此两种插值法没比。
import torch
import torch.nn as nnclass UpsampleDeterministic(nn.Module):deterministic upsample with nearest interpolationdef __init__(self, scale_factor2):Input:scale_factor: int or (int, int), ratio to scale (along heigth width)super(UpsampleDeterministic, self).__init__()if isinstance(scale_factor, (tuple, list)):assert len(scale_factor) 2self.scale_h, self.scale_w scale_factorelse:self.scale_h self.scale_w scale_factorassert isinstance(self.scale_h, int) and isinstance(self.scale_w, int)def forward(self, x):Input:x: [n, c, h, w], torch.TensorOutput:upsampled x: [n, c, h * scale_h, w * scale_w]return x[:, :, :, None, :, None].expand(-1, -1, -1, self.scale_h, -1, self.scale_w).reshape(x.size(0), x.size(1), x.size(2) * self.scale_h, x.size(3) * self.scale_w)# 随机数据
x torch.rand(2, 3, 4, 4) # [n, c, h, w]
# [3] 的实现
us_det UpsampleDeterministic((2, 3))
# pytorch 自带的几种实现
us_list {mode: nn.Upsample(scale_factor(2, 3), modemode)for mode in (nearest, bilinear, bicubic)}
# linear: 3D
# trilinear: 5Dy_det us_det(x)
print(y_det.size())
for us_name, us in us_list.items():y us(x)print(us_name, y.size(), (y_det ! y).sum())输出
torch.Size([2, 3, 8, 12])
nearest torch.Size([2, 3, 8, 12]) tensor(0)
bilinear torch.Size([2, 3, 8, 12]) tensor(507)
bicubic torch.Size([2, 3, 8, 12]) tensor(576)可见 [3] 的实现与 nearest 结果一致。
References
请慎用torch.nn.UpsamplePyTorch中模型的可复现性Non Deterministic Behaviour even after cudnn.deterministic True and cudnn.benchmarkFalse #12207