简体   繁体   中英

Pyspark - Create new column with the RMSE of two other columns in dataframe

I am fairly new to Pyspark. I have a dataframe, and I would like to create a 3rd column with the calculation for RMSE between col1 and col2 . I am using a user defined lambda function to make the RMSE calculation, but keep getting this error AttributeError: 'int' object has no attribute 'mean'

from pyspark.sql.functions import udf,col
from pyspark.sql.types import IntegerType
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext

spark = SparkSession.builder.config("spark.driver.memory", "30g").appName('linear_data_pipeline').getOrCreate()

sqlContext = SQLContext(sc)
old_df = sqlContext.createDataFrame(sc.parallelize(
    [(0, 1), (1, 3), (2, 5)]), ('col_1', 'col_2'))
function = udf(lambda col1, col2 : (((col1 - col2)**2).mean())**.5)
new_df = old_df.withColumn('col_n',function(col('col_1'), col('col_2')))
new_df.show()

How do I best go about fixing this issue? I would also like to find the RMSE/mean, mean absolute error, mean absolute error/mean, median absolute error, and Median Percent Error, but once I figure out how to calculate one, I should be good on the others.

I think than you are some confused. The RMSE is calculated from a succession of points, therefor you don't must calculate this for each value in two columns. I think you have to calculate RMSE using all values in each column.

This could works:

pow = udf(lambda x: x**2)
rmse = (sum(pow(old_df['col1'] - old_df['col2']))/len(old_df))**.5
print(rmse)

I don't think you need a udf in that case. I think it is possible by using only pyspark.sql.functions .

I can propose you the following untested option

import pyspark.sql.functions as psf

rmse = old_df.withColumn("squarederror",
                   psf.pow(psf.col("col1") - psf.col("col2"),
                           psf.lit(2)
                  ))
       .agg(psf.avg(psf.col("squarederror")).alias("mse"))
       .withColumn("rmse", psf.sqrt(psf.col("mse")))

rmse.collect()

Using the same logic, you can get other performance statistics

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM