LibTorch对tensor的索引/切片操作:对比PyTorch

目录

一、通过索引获取值

二、通过索引设置值


在PyTorch C++ API(libtorch)中对张量进行索引的方式与Python API的方式很相似。诸如None / ... / integer / boolean / slice / tensor的索引类型在C++ API里同样有效,这样就可以很方便的实现Python代码与C++代码的转换。主要的不同是将Python API里对张量的“[ ]”操作符转换成了以下形式:

1
2
3
torch::Tensor::index ( )    // 获取值

torch::Tensor::index_put_ ( )   // 设置值

有关官方文档请看这里。下面通过举例说明libtorch与pytorch中的向量索引/切片的方式,左边为Python方式,右边为C++方式:

一、通过索引获取值

1、tensor[Ellipsis, ...] --> tensor.index({Ellipsis, "..."})

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

a = torch.linspace(1,27,27).reshape(3, 3, 3)
print(a)
c = a[..., 2]
print(c)

#===================运行结果===============#
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]],

        [[19., 20., 21.],
         [22., 23., 24.],
         [25., 26., 27.]]])
tensor([[ 3.,  6.,  9.],
        [12., 15., 18.],
        [21., 24., 27.]])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include "iostream"
#include "torch/script.h"
int main()
{
    torch::Tensor a = torch::linspace(1, 27, 27).reshape({3, 3, 3});
    std::cout << a << std::endl;
    at::Tensor b = a.index({"...", 2});
    std::cout << b << std::endl;

    return 0;
}

/****************输出结果******************/
(1,.,.) =
  1  2  3
  4  5  6
  7  8  9

(2,.,.) =
  10  11  12
  13  14  15
  16  17  18

(3,.,.) =
  19  20  21
  22  23  24
  25  26  27
[ CPUFloatType{3,3,3} ]
  3   6   9
 12  15  18
 21  24  27
[ CPUFloatType{3,3} ]

2、tensor[1, 2] --> tensor.index({1, 2})

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

a = torch.linspace(1,27,27).reshape(3, 3, 3)
print(a)
c = a[1, 2]
print(c)

#===================运行结果=================#
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]],

        [[19., 20., 21.],
         [22., 23., 24.],
         [25., 26., 27.]]])
tensor([16., 17., 18.])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include "iostream"
#include "torch/script.h"

int main()
{
    torch::Tensor a = torch::linspace(1, 27, 27).reshape({3, 3, 3});
    std::cout << a << std::endl;
    at::Tensor b = a.index({1, 2});
    std::cout << b << std::endl;

    return 0;
}
/*****************运行结果***************/
(1,.,.) =
  1  2  3
  4  5  6
  7  8  9

(2,.,.) =
  10  11  12
  13  14  15
  16  17  18

(3,.,.) =
  19  20  21
  22  23  24
  25  26  27
[ CPUFloatType{3,3,3} ]
 16
 17
 18
[ CPUFloatType{3} ]

3、tensor[1::2] --> tensor.index({Slice(1, None, 2)})

1
2
3
4
5
6
7
8
9
import torch

a = torch.linspace(1, 6, 6)
print(a)
c = a[1::2]
print(c)
#==================运行结果==================#
tensor([1., 2., 3., 4., 5., 6.])
tensor([2., 4., 6.])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    torch::Tensor a = torch::linspace(1, 6, 6);
    std::cout << a << std::endl;
    at::Tensor b = a.index({Slice(1, None, 2)});
    std::cout << b << std::endl;

    return 0;
}
/*******************运行结果*********************/
 1
 2
 3
 4
 5
 6
[ CPUFloatType{6} ]
 2
 4
 6
[ CPUFloatType{3} ]

3.5、tensor[..., 1:] --> tensor.index({"...", Slice(1)})

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

a = torch.linspace(1,27,27).reshape(3, 3, 3)
b = a[..., 1:]
print(b)
#===============运行结果===================#
tensor([[[ 2.,  3.],
         [ 5.,  6.],
         [ 8.,  9.]],

        [[11., 12.],
         [14., 15.],
         [17., 18.]],

        [[20., 21.],
         [23., 24.],
         [26., 27.]]])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    torch::Tensor a = torch::linspace(1, 27, 27).reshape({3, 3, 3});
    torch::Tensor b = a.index({"...", Slice(1)});
    std::cout << b << std::endl;

    return 0;
}
/******************运行结果**********************/
(1,.,.) =
  2  3
  5  6
  8  9

(2,.,.) =
  11  12
  14  15
  17  18

(3,.,.) =
  20  21
  23  24
  26  27
[ CPUFloatType{3,3,2} ]

4、tensor[torch.tensor([1, 2])] --> tensor.index({torch::tensor({1, 2})})

1
2
3
4
5
6
7
8
9
10
import torch

a = torch.linspace(1,4,4)
b = torch.tensor([0, 1, 3, 2])
c = a[b]
print(a)
print(c)
#===============运行结果===============#
tensor([1., 2., 3., 4.])
tensor([1., 2., 4., 3.])

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    torch::Tensor a = torch::linspace(1, 4, 4);
    torch::Tensor b = torch::tensor({0, 1, 3, 2});
    torch::Tensor c = a.index({b});
    std::cout << a << std::endl;
    std::cout << b << std::endl;

    return 0;
}
/*******************运行结果********************/
 1
 2
 3
 4
[ CPUFloatType{4} ]
 0
 1
 3
 2
[ CPULongType{4} ]

二、通过索引设置值

1、tensor[1, 2] = 1 --> tensor.index_put_({1, 2}, 1)

1
2
3
4
5
6
7
8
9
10
11
import torch

a = torch.linspace(1,4,4).reshape(2, 2)
print(a)
a[1, 1] = 100
print(a)
#==================运行结果=====================#
tensor([[1., 2.],
        [3., 4.]])
tensor([[  1.,   2.],
        [  3., 100.]])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    torch::Tensor a = torch::linspace(1, 4, 4).reshape({2, 2});
    std::cout << a << std::endl;
    a.index_put_({1, 1}, 100);
    std::cout << a << std::endl;

    return 0;
}
/***************运行结果****************/
 1  2
 3  4
[ CPUFloatType{2,2} ]
   1    2
   3  100
[ CPUFloatType{2,2} ]

2、tensor[Ellipsis, ...] = 1 --> tensor.index_put_({Ellipsis, "..."}, 1)

1
2
3
4
5
6
7
8
9
10
11
import torch

a = torch.linspace(1,4,4).reshape(2, 2)
print(a)
a[..., 1] = 100
print(a)
#====================运行结果=====================#
tensor([[1., 2.],
        [3., 4.]])
tensor([[  1., 100.],
        [  3., 100.]])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    torch::Tensor a = torch::linspace(1, 4, 4).reshape({2, 2});
    std::cout << a << std::endl;
    a.index_put_({"...", 1}, 100);
    std::cout << a << std::endl;

    return 0;
}
/***************运行结果****************/
 1  2
 3  4
[ CPUFloatType{2,2} ]
   1  100
   3  100
[ CPUFloatType{2,2} ]

3、tensor[torch.tensor([1, 2])] = 1 --> tensor.index_put_({torch::tensor({1, 2})}, 1)

1
2
3
4
5
6
7
8
9
10
import torch

a = torch.linspace(1,4,4)
b = torch.tensor([0, 2])
print(a)
a[b] = 100
print(a)
#===============运行结果==================#
tensor([1., 2., 3., 4.])
tensor([100.,   2., 100.,   4.])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    torch::Tensor a = torch::linspace(1, 4, 4);
    torch::Tensor b = torch::tensor({0, 2});
    std::cout << a << std::endl;
    a.index_put_({b}, 100);
    std::cout << a << std::endl;

    return 0;
}
/*****************运行结果*****************/
 1
 2
 3
 4
[ CPUFloatType{4} ]
 100
   2
 100
   4
[ CPUFloatType{4} ]