繁体   English   中英

Pyspark udf 在 Python 函数工作时不起作用

[英]Pyspark udf doesn't work while Python function works

我有一个 Python 函数:

def get_log_probability(string, transition_log_probabilities):
    string = ngrams(string, 2)
    terms = [transition_log_probabilities[bigram]
                       for bigram in string]
    log_probability = sum(terms)/len(terms) if len(terms) > 0 else sum(terms)
    return log_probability

我想将此函数用于 Pyspark DataFrame 列,并将transition_log_probabilities作为常量,如下所示:

transition_log_probabilities = {('a', 'a'): -3.688879454113936,
('a', 'b'): -3.688879454113936,
('a', 'c'): -3.688879454113936,
('b', 'a'): -3.688879454113936,
('b', 'b'): -3.688879454113936,
('b', 'c'): -3.688879454113936,
('c', 'a'): -3.688879454113936,
('c', 'b'): -3.688879454113936,
('c', 'c'): -3.688879454113936}

所以我把它改成 Pyspark UDF:

def get_log_prob_udf(dictionary):
    return udf(lambda string: get_log_probability(string, dictionary), FloatType())

即使get_log_probability("abc", transition_log_probabilities)工作并给出-3.688879454113936的结果,当我将其 UDF 应用到 Pyspark 时,如下所示:

df = df \
.withColumn("string_log_probability", get_log_prob_udf(transition_log_probabilities)(col('string')))

它不起作用并抛出错误

An error occurred while calling o3463.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 
182.0 failed 1 times, most recent failure: Lost task 0.0 in stage 182.0 (TID 774) 
(kubernetes.docker.internal executor driver): net.razorvine.pickle.PickleException: 
expected zero arguments for construction of ClassDict (for numpy.dtype)

有谁知道如何解决它? 非常感谢。

希望这是您正在寻找的结果。

df = spark.createDataFrame( [ (1, "bc"), (2, "aa"), (3, "ca") ], ["id", "string"] )
                           
from pyspark.sql import functions as F, types as T
from nltk import ngrams
                           
transition_log_probabilities = {('a', 'a'): -3.688879454113936,
        ('a', 'b'): -3.688879454113936,
        ('a', 'c'): -3.688879454113936,
        ('b', 'a'): -3.688879454113936,
        ('b', 'b'): -3.688879454113936,
        ('b', 'c'): -3.688879454113936,
        ('c', 'a'): -3.688879454113936,
        ('c', 'b'): -3.688879454113936,
        ('c', 'c'): -3.688879454113936}
    
def get_log_probability(string):
    
    string = ngrams(string, 2)
    terms = [transition_log_probabilities[bigram]
                       for bigram in string]
    log_probability = sum(terms)/len(terms) if len(terms) > 0 else sum(terms)
    return log_probability


get_log_prob_udf = udf(get_log_probability, T.FloatType())

df.withColumn('string_log_probability', get_log_prob_udf(F.col('string'))).show()
+---+------+----------------------+
| id|string|string_log_probability|
+---+------+----------------------+
|  1|    bc|            -3.6888795|
|  2|    aa|            -3.6888795|
|  3|    ca|            -3.6888795|
+---+------+----------------------+

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM