今天看代码用到了tf.repeat()
查了下资料没有人专门讲解这个,我来写一下,希望能帮到后人。
官方文档:
https://www.tensorflow.org/api_docs/python/tf/repeat?hl=ca
调用该方法:
1 | tf.repeat(input, repeats, axis=None, name=None) |
参数:
1)input: 一个tensor
2)repeats: 重复的次数
注意:len(repeats) must equal input.shape[axis] if axis is not None
3)axis:维度
这个看例子来理解
说明:An int. The axis along which to repeat values. By default (axis=None), use the flattened input array, and return a flat output array.
如果axis没有参数,则会先flatten数组,变成一维再重复
例子3:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | tf.repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1) -input第0个元素重复2次 -input第1个元素重复3次 [ [1, 1, 2, 2, 2], [3, 3, 4, 4, 4] ] tf.repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0) -input第0个元素重复2次 -input第1个元素重复3次 [ [1, 2], [1, 2], [3, 4], [3, 4], [3, 4] ] |
简单记忆的话,axis = 1横着增加,axis = 0竖着增加
例子1:
先看最简单的一维的
1 2 3 4 5 6 7 | temp = tf.constant([1,2,3]) tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[1, 2, 3]], dtype=int32) tf.repeat(input = temp, repeats=3, axis = 0) 输出: [1, 1, 1, 2, 2, 2, 3, 3, 3] 每个数重复了3次 |
当输入不是一维,但是没有赋值axis, 输入会被拍平成一维的
1 2 3 4 | temp1 = tf.constant([[1,2],[3,4]]) tf.repeat(input = temp1, repeats = 2) 输出: [1,1,2,2,3,3,4,4] |
例子2:
1 2 | temp3 = tf.const([[1],[2],[3]]) temp3.shape ---> (3,1) |
上面说到 len(repeats) = input.shape[axis]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | 所以在这个例子里必须满足: temp3.shape[axis = 0] = 4 = len(repeats) temp3.shape[axis = 1] = 1 = len(repeats) 1) 当axis = 0 tf.repeat(temp3, repeats = [1,2,3], axis = 0) 输出: [ [1], [2], [2], [3], [3], [3], ] 2) 当axis = 1 tf.repeat(temp3, repeats=2, axis=1) 输出: [ [1, 1], [2, 2], [3, 3], [4, 4] ] |