In order to rank, i need to get the row_number is a pyspark dataframe. I saw that there is row_number function in the windows function of pyspark but this is require using HiveContext.
I tried to replace the sqlContext with HiveContext
import pyspark
self.sc = pyspark.SparkContext()
#self.sqlContext = pyspark.sql.SQLContext(self.sc)
self.sqlContext = pyspark.sql.HiveContext(self.sc)
But it now throws exception TypeError: 'JavaPackage' object is not callable Can you help in either operating the HiveContext or to get the row number in a different way?
Example of data: I want to first rank by my prediction and then calculate a loss function (ndcg) based on this ranking. In order to calculate the loss function i will nee the ranking (ie the position of the prediction in the sorting)
So the first step is to sort the data by pred but then i need a running counter of the sorted data.
+-----+--------------------+
|label|pred|
+-----+--------------------+
| 1.0|[0.25313606997906...|
| 0.0|[0.40893413256608...|
| 0.0|[0.18353492079000...|
| 0.0|[0.77719741215204...|
| 1.0|[0.62766290642569...|
| 1.0|[0.40893413256608...|
| 1.0|[0.63084085591913...|
| 0.0|[0.77719741215204...|
| 1.0|[0.36752166787523...|
| 0.0|[0.40893413256608...|
| 1.0|[0.25528507573737...|
| 1.0|[0.25313606997906...|
Thanks.
You don't need to create the HiveContext
if your data is not in Hive. You can just carry on with your sqlContext
.
There is no row_number
for your dataframe unless you create one. pyspark.sql.functions.row_number
` is for a different purpose and it only works with a windowed partition.
What you need may be to create a new column as the row_id
using monotonically_increasing_id
then query it later.
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.types import Row
data = sc.parallelize([
Row(key=1, val='a'),
Row(key=2, val='b'),
Row(key=3, val='c'),
]).toDF()
data = data.withColumn(
'row_id',
monotonically_increasing_id()
)
data.collect()
Out[8]:
[Row(key=1, val=u'a', row_id=17179869184),
Row(key=2, val=u'b', row_id=42949672960),
Row(key=3, val=u'c', row_id=60129542144)]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.