简体   繁体   English

如何在 PySpark 中创建自定义 Estimator

[英]How to create a custom Estimator in PySpark

I am trying to build a simple custom Estimator in PySpark MLlib.我正在尝试在 PySpark MLlib 中构建一个简单的自定义Estimator I have here that it is possible to write a custom Transformer but I am not sure how to do it on an Estimator .我在这里说可以编写自定义 Transformer 但我不确定如何在Estimator上执行此操作。 I also don't understand what @keyword_only does and why do I need so many setters and getters.我也不明白@keyword_only作用以及为什么我需要这么多的 setter 和 getter。 Scikit-learn seem to have a proper document for custom models ( see here ) but PySpark doesn't. Scikit-learn 似乎有自定义模型的适当文档( 参见此处),但 PySpark 没有。

Pseudo code of an example model:示例模型的伪代码:

class NormalDeviation():
    def __init__(self, threshold = 3):
    def fit(x, y=None):
       self.model = {'mean': x.mean(), 'std': x.std()]
    def predict(x):
       return ((x-self.model['mean']) > self.threshold * self.model['std'])
    def decision_function(x): # does ml-lib support this?

Generally speaking there is no documentation because as for Spark 1.6 / 2.0 most of the related API is not intended to be public.一般来说,没有文档,因为对于 Spark 1.6 / 2.0,大多数相关 API 并不打算公开。 It should change in Spark 2.1.0 (see SPARK-7146 ).它应该在 Spark 2.1.0 中改变(参见SPARK-7146 )。

API is relatively complex because it has to follow specific conventions in order to make given Transformer or Estimator compatible with Pipeline API. API 相对复杂,因为它必须遵循特定的约定才能使给定的TransformerEstimatorPipeline API 兼容。 Some of these methods may be required for features like reading and writing or grid search.其中一些方法可能需要阅读和写作或网格搜索等功能。 Other, like keyword_only are just a simple helpers and not strictly required.其他的,比如keyword_only只是一个简单的助手,并不严格要求。

Assuming you have defined following mix-ins for mean parameter:假设您已为均值参数定义了以下混合:

from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp


class HasMean(Params):

    mean = Param(Params._dummy(), "mean", "mean", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasMean, self).__init__()

    def setMean(self, value):
        return self._set(mean=value)

    def getMean(self):
        return self.getOrDefault(self.mean)

standard deviation parameter:标准偏差参数:

class HasStandardDeviation(Params):

    standardDeviation = Param(Params._dummy(),
        "standardDeviation", "standardDeviation", 
        typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasStandardDeviation, self).__init__()

    def setStddev(self, value):
        return self._set(standardDeviation=value)

    def getStddev(self):
        return self.getOrDefault(self.standardDeviation)

and threshold:和阈值:

class HasCenteredThreshold(Params):

    centeredThreshold = Param(Params._dummy(),
            "centeredThreshold", "centeredThreshold",
            typeConverter=TypeConverters.toFloat)

    def __init__(self):
        super(HasCenteredThreshold, self).__init__()

    def setCenteredThreshold(self, value):
        return self._set(centeredThreshold=value)

    def getCenteredThreshold(self):
        return self.getOrDefault(self.centeredThreshold)

you could create basic Estimator as follows:您可以创建基本的Estimator ,如下所示:

from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 
from pyspark import keyword_only  

class NormalDeviation(Estimator, HasInputCol, 
        HasPredictionCol, HasCenteredThreshold,
        DefaultParamsReadable, DefaultParamsWritable):

    @keyword_only
    def __init__(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):
        super(NormalDeviation, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    # Required in Spark >= 3.0
    def setInputCol(self, value):
        """
        Sets the value of :py:attr:`inputCol`.
        """
        return self._set(inputCol=value)

    # Required in Spark >= 3.0
    def setPredictionCol(self, value):
        """
        Sets the value of :py:attr:`predictionCol`.
        """
        return self._set(predictionCol=value)

    @keyword_only
    def setParams(self, inputCol=None, predictionCol=None, centeredThreshold=1.0):
        kwargs = self._input_kwargs
        return self._set(**kwargs)        
        
    def _fit(self, dataset):
        c = self.getInputCol()
        mu, sigma = dataset.agg(avg(c), stddev_samp(c)).first()
        return NormalDeviationModel(
            inputCol=c, mean=mu, standardDeviation=sigma, 
            centeredThreshold=self.getCenteredThreshold(),
            predictionCol=self.getPredictionCol())


class NormalDeviationModel(Model, HasInputCol, HasPredictionCol,
        HasMean, HasStandardDeviation, HasCenteredThreshold,
        DefaultParamsReadable, DefaultParamsWritable):

    @keyword_only
    def __init__(self, inputCol=None, predictionCol=None,
                mean=None, standardDeviation=None,
                centeredThreshold=None):
        super(NormalDeviationModel, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)  

    @keyword_only
    def setParams(self, inputCol=None, predictionCol=None,
                mean=None, standardDeviation=None,
                centeredThreshold=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)           

    def _transform(self, dataset):
        x = self.getInputCol()
        y = self.getPredictionCol()
        threshold = self.getCenteredThreshold()
        mu = self.getMean()
        sigma = self.getStddev()

        return dataset.withColumn(y, (dataset[x] - mu) > threshold * sigma)    

Credits to Benjamin-Manns for the use of DefaultParamsReadable, DefaultParamsWritable available in PySpark >= 2.3.0感谢Benjamin-Manns在 PySpark >= 2.3.0 中使用 DefaultParamsReadable、DefaultParamsWritable

Finally it could be used as follows:最后它可以如下使用:

df = sc.parallelize([(1, 2.0), (2, 3.0), (3, 0.0), (4, 99.0)]).toDF(["id", "x"])

normal_deviation = NormalDeviation().setInputCol("x").setCenteredThreshold(1.0)
model  = Pipeline(stages=[normal_deviation]).fit(df)

model.transform(df).show()
## +---+----+----------+
## | id|   x|prediction|
## +---+----+----------+
## |  1| 2.0|     false|
## |  2| 3.0|     false|
## |  3| 0.0|     false|
## |  4|99.0|      true|
## +---+----+----------+

I disagree with @Shteingarts Solution, as he creates members on class level and even mixes them with instance ones.我不同意@Shteingarts 解决方案,因为他在类级别创建成员,甚至将它们与实例混合。 Will lead to issues if you create several HasMean instances.如果您创建多个 HasMean 实例,则会导致问题。 Why not use the imho correct approach with instance variables?为什么不对实例变量使用 imho 正确的方法? Same holds for the other code samples.其他代码示例也是如此。

from pyspark.ml.pipeline import Estimator, Model, Pipeline
from pyspark.ml.param.shared import *
from pyspark.sql.functions import avg, stddev_samp


class HasMean(Params):
    def __init__(self):
        super(HasMean, self).__init__()
        self.mean = Param(self, "mean", "mean", typeConverter=TypeConverters.toFloat)

    def setMean(self, value):
        return self.set(self.mean, value)

    def getMean(self):
        return self.getOrDefault(self.mean)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM