- torch.index_select()是访问指定维度的下标张量。
- torch.masked_select()是传入两个等大小的张量一个是常规张量,一个是布尔张量。常规张量根据布尔张量保存对应位置布尔张量为True的常规常量。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | import torch import numpy as np def tensor_indexSelect(): arr = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) print(arr) t = torch.tensor(arr) t1 = torch.index_select(t, dim=0, index=torch.tensor([0, 2])) print(t1) def tensor_maskSelect(): arr = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) print(arr) t = torch.normal(1.0, 1.0, (3, 3)) print(t) result = t > 0.5 t1 = torch.masked_select(t, result) print(t1) if __name__ == '__main__': tensor_indexSelect() tensor_maskSelect() |