pytorch(5)torch.index_select()和torch.masked_select()函数

  1. torch.index_select()是访问指定维度的下标张量。
  2. torch.masked_select()是传入两个等大小的张量一个是常规张量,一个是布尔张量。常规张量根据布尔张量保存对应位置布尔张量为True的常规常量。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import numpy as np


def tensor_indexSelect():
    arr = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                    [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
    print(arr)
    t = torch.tensor(arr)
    t1 = torch.index_select(t, dim=0, index=torch.tensor([0, 2]))
    print(t1)


def tensor_maskSelect():
    arr = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                    [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])

    print(arr)
    t = torch.normal(1.0, 1.0, (3, 3))
    print(t)
    result = t > 0.5
    t1 = torch.masked_select(t, result)
    print(t1)


if __name__ == '__main__':
    tensor_indexSelect()
    tensor_maskSelect()