[英]Iterate over elements of columns Scala
我有一个由两个双打阵列组成的数据帧。 我想创建一个新列,它是将欧几里德距离函数应用于前两列的结果,即如果我有:
A B
(1,2) (1,3)
(2,3) (3,4)
创建:
A B C
(1,2) (1,3) 1
(2,3) (3,4) 1.4
我的数据架构是:
df.schema.foreach(println)
StructField(col1,ArrayType(DoubleType,false),false)
StructField(col2,ArrayType(DoubleType,false),true)
每当我调用这个距离函数时:
def distance(xs: Array[Double], ys: Array[Double]) = {
sqrt((xs zip ys).map { case (x,y) => pow(y - x, 2) }.sum)
}
我收到类型错误:
df.withColumn("distances" , distance($"col1",$"col2"))
<console>:68: error: type mismatch;
found : org.apache.spark.sql.ColumnName
required: Array[Double]
ids_with_predictions_centroids3.withColumn("distances" , distance($"col1",$"col2"))
我知道我必须遍历每列的元素,但我无法找到如何在任何地方执行此操作的说明。 我是Scala编程的新手。
要在数据帧上使用自定义函数,需要将其定义为UDF
。 例如,这可以完成如下:
val distance = udf((xs: WrappedArray[Double], ys: WrappedArray[Double]) => {
math.sqrt((xs zip ys).map { case (x,y) => math.pow(y - x, 2) }.sum)
})
df.withColumn("C", distance($"A", $"B")).show()
请注意, WrappedArray
需要使用WrappedArray
(或Seq
)。
结果数据帧:
+----------+----------+------------------+
| A| B| C|
+----------+----------+------------------+
|[1.0, 2.0]|[1.0, 3.0]| 1.0|
|[2.0, 3.0]|[3.0, 4.0]|1.4142135623730951|
+----------+----------+------------------+
Spark函数基于列工作 , 你唯一的错误就是你在函数中混合了列和基元
并且错误消息足够清楚,表示您正在传递距离函数中的列,即$"col1"
和$"col2"
是列,但距离函数定义为distance(xs: Array[Double], ys: Array[Double])
采用原始类型 。
解决方案是使距离函数完全基于列
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions._
def distance(xs: Column, ys: Column) = {
sqrt(pow(ys(0)-xs(0), 2) + pow(ys(1)-xs(1), 2))
}
df.withColumn("distances" , distance($"col1",$"col2")).show(false)
这应该给你正确的结果没有错误
+------+------+------------------+
|col1 |col2 |distances |
+------+------+------------------+
|[1, 2]|[1, 3]|1.0 |
|[2, 3]|[3, 4]|1.4142135623730951|
+------+------+------------------+
我希望答案是有帮助的
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.