简体   繁体   中英

How to get row_number is pyspark dataframe

In order to rank, i need to get the row_number is a pyspark dataframe. I saw that there is row_number function in the windows function of pyspark but this is require using HiveContext.

I tried to replace the sqlContext with HiveContext

        import pyspark
        self.sc = pyspark.SparkContext()
        #self.sqlContext = pyspark.sql.SQLContext(self.sc)
        self.sqlContext = pyspark.sql.HiveContext(self.sc)

But it now throws exception TypeError: 'JavaPackage' object is not callable Can you help in either operating the HiveContext or to get the row number in a different way?

Example of data: I want to first rank by my prediction and then calculate a loss function (ndcg) based on this ranking. In order to calculate the loss function i will nee the ranking (ie the position of the prediction in the sorting)

So the first step is to sort the data by pred but then i need a running counter of the sorted data.

+-----+--------------------+
|label|pred|
+-----+--------------------+

|  1.0|[0.25313606997906...|
|  0.0|[0.40893413256608...|
|  0.0|[0.18353492079000...|
|  0.0|[0.77719741215204...|
|  1.0|[0.62766290642569...|
|  1.0|[0.40893413256608...|
|  1.0|[0.63084085591913...|
|  0.0|[0.77719741215204...|
|  1.0|[0.36752166787523...|
|  0.0|[0.40893413256608...|
|  1.0|[0.25528507573737...|
|  1.0|[0.25313606997906...|

Thanks.

You don't need to create the HiveContext if your data is not in Hive. You can just carry on with your sqlContext .

There is no row_number for your dataframe unless you create one. pyspark.sql.functions.row_number ` is for a different purpose and it only works with a windowed partition.

What you need may be to create a new column as the row_id using monotonically_increasing_id then query it later.

from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.types import Row

data = sc.parallelize([
  Row(key=1, val='a'),
  Row(key=2, val='b'), 
  Row(key=3, val='c'), 
]).toDF()

data = data.withColumn(
  'row_id',
  monotonically_increasing_id()
)

data.collect()


Out[8]: 
[Row(key=1, val=u'a', row_id=17179869184),
 Row(key=2, val=u'b', row_id=42949672960),
 Row(key=3, val=u'c', row_id=60129542144)]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM