简体   繁体   中英

What is the equivalent of pandas.cut() in PySpark?

pandas.cut() is used to bin values into discrete intervals. For instance,

pd.cut(
    np.array([0.2, 0.25, 0.36, 0.55, 0.67, 0.78]), 
    3, 
    include_lowest=True, 
    right=False
)

Out[9]: 
[[0.2, 0.393), [0.2, 0.393), [0.2, 0.393), [0.393, 0.587), [0.587, 0.781), [0.587, 0.781)]
Categories (3, interval[float64]): [[0.2, 0.393) < [0.393, 0.587) < [0.587, 0.781)]

How can I achieve the same in PySpark? I had a look to QuantileDiscretizer but it's definitely not the equivalent of pd.cut() as it does not return the intervals.

RDD.histogram is a similar function in Spark.

Assume that the data is contained in a dataframe with the column col1 .

+----+
|col1|
+----+
| 0.2|
|0.25|
|0.36|
|0.55|
|0.67|
|0.78|
+----+
h = df.rdd.flatMap(lambda x: x).histogram(3) #change 3 to the number of expected intervals

bins = [ (x, h[0][i+1]) for i,x in enumerate(h[0][:-1])]

def label(bin):
    return f"'{bin[0]:5.2f} - {bin[1]:5.2f}'"

e = "case "
for bin in bins[:-1]:
    e += f"when col1 >= {bin[0]} and col1 < {bin[1]} then {label(bin)} "
e += f"else {label(bins[-1])} end as bin"

df.selectExpr("col1", e).show()

Output:

+----+-------------+
|col1|          bin|
+----+-------------+
| 0.2| 0.20 -  0.39|
|0.25| 0.20 -  0.39|
|0.36| 0.20 -  0.39|
|0.55| 0.39 -  0.59|
|0.67| 0.59 -  0.78|
|0.78| 0.59 -  0.78|
+----+-------------+

bins contains the intervals as tuples:

[(0.2, 0.39333333333333337),
 (0.39333333333333337, 0.5866666666666667),
 (0.5866666666666667, 0.78)]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM