简体   繁体   中英

How to create a custom Transformer from a UDF?

I was trying to create and save a Pipeline with custom stages. I need to add a column to my DataFrame by using a UDF . Therefore, I was wondering if it was possible to convert a UDF or a similar action into a Transformer ?

My custom UDF looks like this and I'd like to learn how to do it using an UDF as a custom Transformer .

def getFeatures(n: String) = {
    val NUMBER_FEATURES = 4  
    val name = n.split(" +")(0).toLowerCase
    ((1 to NUMBER_FEATURES)
         .filter(size => size <= name.length)
         .map(size => name.substring(name.length - size)))
} 

val tokenizeUDF = sqlContext.udf.register("tokenize", (name: String) => getFeatures(name))

It is not a fully featured solution but your can start with something like this:

import org.apache.spark.ml.{UnaryTransformer}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

class NGramTokenizer(override val uid: String)
  extends UnaryTransformer[String, Seq[String], NGramTokenizer]  {

  def this() = this(Identifiable.randomUID("ngramtokenizer"))

  override protected def createTransformFunc: String => Seq[String] = {
    getFeatures _
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType)
  }

  override protected def outputDataType: DataType = {
    new ArrayType(StringType, true)
  }
}

Quick check:

val df = Seq((1L, "abcdef"), (2L, "foobar")).toDF("k", "v")
val transformer = new NGramTokenizer().setInputCol("v").setOutputCol("vs")

transformer.transform(df).show
// +---+------+------------------+
// |  k|     v|                vs|
// +---+------+------------------+
// |  1|abcdef|[f, ef, def, cdef]|
// |  2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+

You can even try to generalize it to something like this:

import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
import scala.reflect.runtime.universe._

class UnaryUDFTransformer[T : TypeTag, U : TypeTag](
  override val uid: String,
  f: T => U
) extends UnaryTransformer[T, U, UnaryUDFTransformer[T, U]]  {

  override protected def createTransformFunc: T => U = f

  override protected def validateInputType(inputType: DataType): Unit = 
    require(inputType == schemaFor[T].dataType)

  override protected def outputDataType: DataType = schemaFor[U].dataType
}

val transformer = new UnaryUDFTransformer("featurize", getFeatures)
  .setInputCol("v")
  .setOutputCol("vs")

If you want to use UDF not the wrapped function you'll have to extend Transformer directly and override transform method. Unfortunately majority of the useful classes is private so it can be rather tricky.

Alternatively you can register UDF:

spark.udf.register("getFeatures", getFeatures _)

and use SQLTransformer

import org.apache.spark.ml.feature.SQLTransformer

val transformer = new SQLTransformer()
  .setStatement("SELECT *, getFeatures(v) AS vs FROM __THIS__")

transformer.transform(df).show
// +---+------+------------------+
// |  k|     v|                vs|
// +---+------+------------------+
// |  1|abcdef|[f, ef, def, cdef]|
// |  2|foobar|[r, ar, bar, obar]|
// +---+------+------------------+

I initially tried to extend the Transformer and UnaryTransformer abstracts but encountered trouble with my application being unable to reach DefaultParamsWriteable .As an example that may be relevant to your problem, I created a simple term normalizer as a UDF following along from this example . My goal is to match terms against patterns and sets to replace them with generic terms. For example:

"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b".r -> "emailaddr"

This is the class

import scala.util.matching.Regex

class TermNormalizer(normMap: Map[Any, String]) {
  val normalizationMap = normMap

  def normalizeTerms(terms: Seq[String]): Seq[String] = {
    var termsUpdated = terms
    for ((term, idx) <- termsUpdated.view.zipWithIndex) {
      for (normalizer <- normalizationMap.keys: Iterable[Any]) {
        normalizer match {
          case (regex: Regex) =>
            if (!regex.findFirstIn(term).isEmpty) termsUpdated = 
              termsUpdated.updated(idx, normalizationMap(regex))
          case (set: Set[String]) =>
            if (set.contains(term)) termsUpdated = 
              termsUpdated.updated(idx, normalizationMap(set))
        }
      }
    }
    termsUpdated
  }
}

I use it like this:

val testMap: Map[Any, String] = Map("hadoop".r -> "elephant",
  "spark".r -> "sparky", "cool".r -> "neat", 
  Set("123", "456") -> "set1",
  Set("789", "10") -> "set2")

val testTermNormalizer = new TermNormalizer(testMap)
val termNormalizerUdf = udf(testTermNormalizer.normalizeTerms(_: Seq[String]))

val trainingTest = sqlContext.createDataFrame(Seq(
  (0L, "spark is cool 123", 1.0),
  (1L, "adsjkfadfk akjdsfhad 456", 0.0),
  (2L, "spark rocks my socks 789 10", 1.0),
  (3L, "hadoop is cool 10", 0.0)
)).toDF("id", "text", "label")

val testTokenizer = new Tokenizer()
  .setInputCol("text")
  .setOutputCol("words")

val tokenizedTrainingTest = testTokenizer.transform(trainingTest)
println(tokenizedTrainingTest
  .select($"id", $"text", $"words", termNormalizerUdf($"words"), $"label").show(false))

Now that I read the question a little closer, it sounds like you're asking how to avoid doing it this way lol. Anyways, I'll still post it in case someone in the future is looking for an easy way to apply a transformer-ish like functionality

If you wish to make the transformer writable as well, then you can re-implement the traits such as HasInputCol in the sharedParams library in a public package of your choice and then use them with DefaultParamsWritable trait to make the transformer persistable.

This way you can also avoid having to place part of your code inside the spark core ml packages but you kind of maintain a parallel set of params in your own package. This isnt really a problem given they hardly ever change.

But do track the bug in their JIRA board here that asks for some of the common sharedParams to be made public instead of private to the ml so that people can directly use those from outside classes.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM