网站开发流程框架,产品推广策略怎么写,无锡做网站的公司电话,wordpress图片上传地址本系列教程适用于没有任何pytorch的同学#xff08;简单的python语法还是要的#xff09;#xff0c;从代码的表层出发挖掘代码的深层含义#xff0c;理解具体的意思和内涵。pytorch的很多函数看着非常简单#xff0c;但是其中包含了很多内容#xff0c;不了解其中的意思… 本系列教程适用于没有任何pytorch的同学简单的python语法还是要的从代码的表层出发挖掘代码的深层含义理解具体的意思和内涵。pytorch的很多函数看着非常简单但是其中包含了很多内容不了解其中的意思就只能【看懂代码】无法【理解代码】。 目录 官方定义demoone-hot 官方定义
torch.tensor.scatter_是PyTorch中的一个函数用于将指定索引处的值替换为给定的值。
函数定义
Tensor.scatter_(dim, index, src, reduceNone) → Tensor官方解释 将张量src中的所有值写入索引张量中指定的index处的self。 对于src中的每个值它的输出索引由其在src中的索引(dimension ! dim)和在index中对应的值(dimension dim)指定。
非常难以理解十分抽象从我个人的角度来说就是
第一个参数dim表示维度即在第几维度处理数据保持其它维度不变。reduce参数是一个可选参数用于指定如何在执行散射scatter操作时对重复的索引值进行合并或聚合。index则是需要填充的列的索引即根据维度从src中取对应的值填充到tensor中去。
怎么映射的比如一个一个3维张量
self[index[i][j][k]][j][k] src[i][j][k] # if dim 0
self[i][index[i][j][k]][k] src[i][j][k] # if dim 1
self[i][j][index[i][j][k]] src[i][j][k] # if dim 2官方的文档如下TORCH.TENSOR.SCATTER_: 即使如此理解起来也是很复杂下面从例子中去理解
demo
下面是一个官方文档给出的例子
import torchsrc torch.Tensor([[-1.0276, 0.2673, -1.1752, -0.8823],[-0.6447, -0.8256, 0.1542, -0.4242]])
print(src)output torch.zeros(2, 5)
index torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])output output.scatter(1, index, src)
print(output)输出的结果 我们一步步理解代码
首先定义了一个src张量后续output即从src中取值。其次定义了output其值为二行五列的全零张量后续对output进行修改。接着定义了index即从src取值的索引。最后根据index从src取值填充到output中即完成操作。
那么具体是如何取值的呢
首先dim 1意味着从维度值为1的地方取值维度值为0的地方不变那就是
self[i][index[i][j]] src[i][j] # if dim 1具体来说
当i 0, j 0时output[0][index[0][0]] src[0][0]因为index[0][0] 3所以output[0][3] src[0][0] -1.0276这时候我们检查输出的output值确实是-1.0276。
同理
i 0, j 1: output[0][index[0][1]] output[0][1] src[0][1] 0.2673
i 0, j 2: output[0][index[0][2]] output[0][2] src[0][2] -1.1752
one-hot
作者在学习该函数时实在遇到one-hot编码时遇到的而该函数在one-hot中应用很广
index torch.tensor([[3], [2], [0], [1]])
onehot torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)