Group and average NumPy matrix
说我有一个如下所示的任意numpy矩阵:
1 2 3 4 5 6 7 8 9 10 | arr = [[ 6.0 12.0 1.0] [ 7.0 9.0 1.0] [ 8.0 7.0 1.0] [ 4.0 3.0 2.0] [ 6.0 1.0 2.0] [ 2.0 5.0 2.0] [ 9.0 4.0 3.0] [ 2.0 1.0 4.0] [ 8.0 4.0 4.0] [ 3.0 5.0 4.0]] |
对按第三列号分组的行进行平均的有效方法是什么?
预期输出为:
1 2 3 4 | result = [[ 7.0 9.33 1.0] [ 4.0 3.0 2.0] [ 9.0 4.0 3.0] [ 4.33 3.33 4.0]] |
一种紧凑的解决方案是使用numpy_indexed(免责声明:我是它的作者),该实现了完全矢量化的解决方案:
1 2 | import numpy_indexed as npi npi.group_by(arr[:, 2]).mean(arr) |
您可以执行以下操作:
1 2 3 4 | for x in sorted(np.unique(arr[...,2])): results.append([np.average(arr[np.where(arr[...,2]==x)][...,0]), np.average(arr[np.where(arr[...,2]==x)][...,1]), x]) |
测试:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | >>> arr array([[ 6., 12., 1.], [ 7., 9., 1.], [ 8., 7., 1.], [ 4., 3., 2.], [ 6., 1., 2.], [ 2., 5., 2.], [ 9., 4., 3.], [ 2., 1., 4.], [ 8., 4., 4.], [ 3., 5., 4.]]) >>> results=[] >>> for x in sorted(np.unique(arr[...,2])): ... results.append([np.average(arr[np.where(arr[...,2]==x)][...,0]), ... np.average(arr[np.where(arr[...,2]==x)][...,1]), ... x]) ... >>> results [[7.0, 9.3333333333333339, 1.0], [4.0, 3.0, 2.0], [9.0, 4.0, 3.0], [4.333333333333333, 3.3333333333333335, 4.0]] |
不需要对数组
1 2 3 4 5 6 7 8 9 10 11 12 | arr = np.array( [[ 6.0, 12.0, 1.0], [ 7.0, 9.0, 1.0], [ 8.0, 7.0, 1.0], [ 4.0, 3.0, 2.0], [ 6.0, 1.0, 2.0], [ 2.0, 5.0, 2.0], [ 9.0, 4.0, 3.0], [ 2.0, 1.0, 4.0], [ 8.0, 4.0, 4.0], [ 3.0, 5.0, 4.0]]) np.array([a.mean(0) for a in np.split(arr, np.argwhere(np.diff(arr[:, 2])) + 1)]) |
解决方案
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | from itertools import groupby from operator import itemgetter arr = [[6.0, 12.0, 1.0], [7.0, 9.0, 1.0], [8.0, 7.0, 1.0], [4.0, 3.0, 2.0], [6.0, 1.0, 2.0], [2.0, 5.0, 2.0], [9.0, 4.0, 3.0], [2.0, 1.0, 4.0], [8.0, 4.0, 4.0], [3.0, 5.0, 4.0]] result = [] for groupByID, rows in groupby(arr, key=itemgetter(2)): position1, position2, counter = 0, 0, 0 for row in rows: position1+=row[0] position2+=row[1] counter+=1 result.append([position1/counter, position2/counter, groupByID]) print(result) |
将输出:
1 2 3 4 | [[7.0, 9.333333333333334, 1.0]] [[4.0, 3.0, 2.0]] [[9.0, 4.0, 3.0]] [[4.333333333333333, 3.3333333333333335, 4.0]] |