scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中
这个函数可以从转换成onehot编码来理解。
看下面代码:
1 2 3 4 5 6 7 | index = torch.tensor([1,2,1,2,0]) torch.zeros(5,3).scatter_(1, index.unsqueeze(1), 1) # tensor([[0., 1., 0.], # [0., 0., 1.], # [0., 1., 0.], # [0., 0., 1.], # [1., 0., 0.]]) |
简要说明:这段代码的目的就是将李表[1,2,1,2,0]转成one-hot编码的形式。
因为有5个数据,然后数值范围从0~2,所以需要设置3列,所以目标矩阵应该是5x3。
比如以输出结果的第一行为例,原本
不过其实得到onehot编码可以用pandas.get_dummies
1 2 | index = [1,2,1,2,0] pd.get_dummies(index) |