tf.repeat(), Tensorflow2.1.0以上

今天看代码用到了tf.repeat()

查了下资料没有人专门讲解这个,我来写一下,希望能帮到后人。


官方文档:

https://www.tensorflow.org/api_docs/python/tf/repeat?hl=ca


调用该方法:

1
tf.repeat(input, repeats, axis=None, name=None)

参数:
1)input: 一个tensor

2)repeats: 重复的次数
注意:len(repeats) must equal input.shape[axis] if axis is not None

3)axis:维度
这个看例子来理解
说明:An int. The axis along which to repeat values. By default (axis=None), use the flattened input array, and return a flat output array.
如果axis没有参数,则会先flatten数组,变成一维再重复

例子3:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
tf.repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
                                    -input第0个元素重复2次
                                    -input第1个元素重复3次
[
 [1, 1, 2, 2, 2],
 [3, 3, 4, 4, 4]
]

tf.repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
                                    -input第0个元素重复2次
                                    -input第1个元素重复3次
[
 [1, 2],
 [1, 2],
 [3, 4],
 [3, 4],
 [3, 4]
]

简单记忆的话,axis = 1横着增加,axis = 0竖着增加


例子1:
先看最简单的一维的

1
2
3
4
5
6
7
temp = tf.constant([1,2,3])
tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[1, 2, 3]], dtype=int32)

tf.repeat(input = temp, repeats=3, axis = 0)
输出:
[1, 1, 1, 2, 2, 2, 3, 3, 3]
每个数重复了3次

当输入不是一维,但是没有赋值axis, 输入会被拍平成一维的

1
2
3
4
temp1 = tf.constant([[1,2],[3,4]])
tf.repeat(input = temp1, repeats = 2)
输出:
[1,1,2,2,3,3,4,4]

例子2:

1
2
temp3 = tf.const([[1],[2],[3]])
temp3.shape ---> (3,1)

上面说到 len(repeats) = input.shape[axis]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
所以在这个例子里必须满足:
temp3.shape[axis = 0] = 4 = len(repeats)
temp3.shape[axis = 1] = 1 = len(repeats)

1) 当axis = 0
tf.repeat(temp3, repeats = [1,2,3], axis = 0)
输出:
[
[1],
[2],
[2],
[3],
[3],
[3],
]

2) 当axis = 1
tf.repeat(temp3, repeats=2, axis=1)
输出:
[
 [1, 1],
 [2, 2],
 [3, 3],
 [4, 4]
]