根据Spark DataFrame Scala中的列值过滤行

Filtering rows based on column values in spark dataframe scala

我有一个数据框(火花):

1
2
3
4
5
6
7
id  value
3     0
3     1
3     0
4     1
4     0
4     0

我想创建一个新的数据框:

1
2
3
3 0
3 1
4 1

需要为每个id删除1(value)之后的所有行。我尝试在spark dateframe(Scala)中使用窗口函数。但是找不到解决方案。看来我走错了方向。

我正在Scala中寻找解决方案。谢谢

使用monotonically_increasing_id

的输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
 scala> val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id","value")
data: org.apache.spark.sql.DataFrame = [id: int, value: int]

scala> val minIdx = dataWithIndex.filter($"value" === 1).groupBy($"id").agg(min($"idx")).toDF("r_id","min_idx")
minIdx: org.apache.spark.sql.DataFrame = [r_id: int, min_idx: bigint]

scala> dataWithIndex.join(minIdx,($"r_id" === $"id") && ($"idx" <= $"min_idx")).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  4|    1|
+---+-----+

如果我们在原始数据帧中进行了排序转换,则该解决方案将无法工作。那个时候monotonically_increasing_id()是基于原始DF而不是排序的DF生成的。我以前错过了这个要求。

欢迎提出所有建议。


一种方法是使用monotonically_increasing_id()和自联接:

1
2
3
4
5
6
7
8
9
10
11
12
val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id","value")
data.show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  3|    0|
|  4|    1|
|  4|    0|
|  4|    0|
+---+-----+

现在,我们生成一个名为idx的列,该列的Long递增:

1
2
val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
// dataWithIndex.cache()

现在我们为每个id获得min(idx),其中value = 1

1
2
3
4
5
val minIdx = dataWithIndex
               .filter($"value" === 1)
               .groupBy($"id")
               .agg(min($"idx"))
               .toDF("r_id","min_idx")

现在,我们将min(idx)加入到原始的DataFrame中:

1
2
3
4
5
6
7
8
9
10
11
dataWithIndex.join(
  minIdx,
  ($"r_id" === $"id") && ($"idx" <= $"min_idx")
).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
|  3|    0|
|  3|    1|
|  4|    1|
+---+-----+

注意:monotonically_increasing_id()根据行的分区生成其值。每次重新评估dataWithIndex时,此值可能都会更改。在上面的代码中,由于延迟求值,只有当我调用最终的show时,才会对monotonically_increasing_id()进行求值。

例如,如果要强制将值保持不变,则可以使用show逐步评估上述内容,请在上方取消注释以下行:

1
//  dataWithIndex.cache()


嗨,我找到了使用Window和自连接的解决方案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id","value","sorted")

data.show

scala> data.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  3|    0|     2|
|  3|    1|     3|
|  3|    0|     1|
|  4|    1|     6|
|  4|    0|     5|
|  4|    0|     4|
|  1|    0|     7|
|  1|    1|     8|
|  1|    0|     9|
|  2|    1|    10|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+




val sort_df=data.sort($"sorted")

scala> sort_df.show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  3|    0|     1|
|  3|    0|     2|
|  3|    1|     3|
|  4|    0|     4|
|  4|    0|     5|
|  4|    1|     6|
|  1|    0|     7|
|  1|    1|     8|
|  1|    0|     9|
|  2|    1|    10|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+



var window=Window.partitionBy("id").orderBy("$sorted")

 val sort_idx=sort_df.select($"*",rowNumber.over(window).as("count_index"))

val minIdx=sort_idx.filter($"value"===1).groupBy("id").agg(min("count_index")).toDF("idx","min_idx")

val result_id=sort_idx.join(minIdx,($"id"===$"idx") &&($"count_index" <= $"min_idx"))

result_id.show

+---+-----+------+-----------+---+-------+
| id|value|sorted|count_index|idx|min_idx|
+---+-----+------+-----------+---+-------+
|  1|    0|     7|          1|  1|      2|
|  1|    1|     8|          2|  1|      2|
|  2|    1|    10|          1|  2|      1|
|  3|    0|     1|          1|  3|      3|
|  3|    0|     2|          2|  3|      3|
|  3|    1|     3|          3|  3|      3|
|  4|    0|     4|          1|  4|      3|
|  4|    0|     5|          2|  4|      3|
|  4|    1|     6|          3|  4|      3|
+---+-----+------+-----------+---+-------+

仍在寻找更优化的解决方案。谢谢


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
use isin method and filter as below:

val data = Seq((3,0,2),(3,1,3),(3,0,1),(4,1,6),(4,0,5),(4,0,4),(1,0,7),(1,1,8),(1,0,9),(2,1,10),(2,0,11),(2,0,12)).toDF("id","value","sorted")
val idFilter = List(1, 2)
 data.filter($"id".isin(idFilter:_*)).show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  1|    0|     7|
|  1|    1|     8|
|  1|    0|     9|
|  2|    1|    10|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+

Ex: filter based on val
val valFilter = List(0)
data.filter($"value".isin(valFilter:_*)).show
+---+-----+------+
| id|value|sorted|
+---+-----+------+
|  3|    0|     2|
|  3|    0|     1|
|  4|    0|     5|
|  4|    0|     4|
|  1|    0|     7|
|  1|    0|     9|
|  2|    0|    11|
|  2|    0|    12|
+---+-----+------+


您可以像这样简单地使用groupBy

1
val df2 = df1.groupBy("id","value").count().select("id","value")

您的df1

1
2
3
4
5
6
7
id  value
3     0
3     1
3     0
4     1
4     0
4     0

结果数据帧为df2,这是您期望的输出,例如

1
2
3
4
5
id  value
3     0
3     1
4     1
4     0