Pull to refresh

User-defined aggregation functions in Spark

Level of difficultyMedium
Reading time6 min
Views656

Below, we will discuss user-defined aggregation functions (UDAF) using org.apache.spark.sql.expressions.Aggregator, which can be used for aggregating groups of elements in a DataSet into a single value in any user-defined way.

Let’s start by examining an example from the official documentation that implements a simple aggregation

case class Data(i: Int)

val customSummer =  new Aggregator[Data, Int, Int] {
     def zero: Int = 0
     def reduce(b: Int, a: Data): Int = b + a.i
     def merge(b1: Int, b2: Int): Int = b1 + b2
     def finish(r: Int): Int = r
     def bufferEncoder: Encoder[Int] = Encoders.scalaInt
     def outputEncoder: Encoder[Int] = Encoders.scalaInt
   }.toColumn()

val ds: Dataset[Data] = ...
val aggregated = ds.select(customSummer)

The case class ‘Data’ is needed because UDAFs are used for DataSets.

To create your own aggregator, you need to define 6 functions with predefined names:

  • zero — also known as the initial value; should satisfy the requirement: “something” + zero = “something”

  • reduce — the reduce function that performs our aggregation into a buffer

  • merge — the function for merging buffers

  • finish — the final processing to obtain the target value; in the specific example, it is essentially absent, but it is often useful (as we’ll see in the examples below)

  • two encoders for the buffer and the output value

Why so many? Let’s dive a little bit deeper. Spark is a distributed data processing framework and processes aggregation in the following steps:

  • First step: On the executors, we perform a pre-aggregation using a reduce function to minimize the volume of data shuffling.

  • Second step: Spark shuffles the data — moving the data with the same aggregation key to the same executor.

  • Third step: Finalizes aggregation on executors using a merge function.

  • Fourth step: Final processing using a finish function.

Therefore, even if you don’t need a certain step, we still have to define these functions. On the other hand, you are unlikely to use UDAF for simple summation, and for complex aggregations, such staging can be very useful. Also, it is important to note that the aggregation is essentially performed twice — first on the executors, and then the final aggregation is done with data shuffling.

Let’s consider an educational ‘word count’ example implemented using aggregators. Suppose we have a simple CSV file with user tweets:

userId,tweet
f6e8252f,cat dog frog cat dog frog frog dog dog bird
f6e8252f,cat dog frog dog dog bird
f6e8252f,cat dog
29d89ee4,frog frog dog bird
29d89ee4,frog cat dog frog frog dog dog bird
29d89ee4,frog bird

Let’s determine the most frequently used word for each user. Naturally, there are many ways to solve this task, but for educational purposes, we will demonstrate how to solve it using a UDAF.

val df = spark.read.format("csv").option("header", "true").load(path)

type FrenquencyMapType = Map[String, Int]

class MyCustomAggregator(columnName: String) extends Aggregator [Row, FrenquencyMapType, String] {

  private def combineMaps(map1: FrenquencyMapType, map2: FrenquencyMapType): FrenquencyMapType = {
    val combinedKeys = map1.keySet ++ map2.keySet
    combinedKeys.map { key => key -> (map1.getOrElse(key, 0) + map2.getOrElse(key, 0))
    }.toMap
  }

  def zero: FrenquencyMapType = Map.empty[String, Int]

  def reduce(buffer: FrenquencyMapType, row: Row): FrenquencyMapType = {
    val wordsMap = row.getAs[Any](columnName).toString.split(" ").groupBy(identity).mapValues(_.length)
    combineMaps(buffer, wordsMap)

  }

  def merge(map1: FrenquencyMapType, map2: FrenquencyMapType): FrenquencyMapType =
    combineMaps(map1, map2)

  def finish(buffer: FrenquencyMapType): String =
    if (buffer.isEmpty) "nUll" else buffer.maxBy(_._2)._1
  def bufferEncoder: Encoder[FrenquencyMapType] = Encoders.kryo[Map[String, Int]]

  def outputEncoder: Encoder[String] = Encoders.STRING

}

val mostPopularWordAggregator = new MyCustomAggregator("tweet").toColumn

df.groupBy(col("userID"))
  .agg(mostPopularWordAggregator.name("favoriteWord"))
  .withColumnRenamed("value", "userId")
  .show()

The result:

+--------+---------------+
|userId  |favoriteWord   |
+--------+---------------+
|29d89ee4|frog           |
|f6e8252f|dog            |
+--------+---------------+

Note the specification of types <IN, BUF, OUT>:

  • The input type for our aggregator is Row.

  • For the buffer, we use the type Map[String, Int](alias FrenquencyMapType)

  • Finally, the output type is simply String since the most popular word is returned.

Astute readers may have noticed that, in essence, we adapted the ‘most common value’ aggregator for our needs. Could we achieve the same result without UDAF? Of course! One of the naives variants is:

  1. Get an ArrayType column using org.apache.spark.sql.functions.split

  2. Create a new row for each element in the array using org.apache.spark.sql.functions.explode

  3. Perform standard aggregation over userId + word.

  4. Take the userId + word with max value of count-aggregation.

In the code it could look like this:

val window = Window.partitionBy("userId")

df.withColumn("arrayOfWords", split(col("tweet"), " "))
  .withColumn("word", explode(col("arrayOfWords")))
  .groupBy("userId", "word")
  .agg(count("word").as("count"))
  .withColumn("maxCount", max("count").over(window))
  .filter(col("count") === col("maxCount"))
  .select("userId", "word")
  .show()

And now let’s compare the plans. With UDAF it’s very simple:

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- ObjectHashAggregate(keys=[userID#16], functions=[mycustomaggregator($line15.$read$$iw$$iw$$iw$$iw$MyCustomAggregator@69cf253c, Some(createexternalrow(userId#16.toString, tweet#17.toString, StructField(userId,StringType,true), StructField(tweet,StringType,true))), Some(interface org.apache.spark.sql.Row), Some(StructType(StructField(userId,StringType,true), StructField(tweet,StringType,true))), encodeusingserializer(input[0, java.lang.Object, true], true), decodeusingserializer(input[0, binary, true], scala.collection.immutable.Map, true), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false), StringType, true, 0, 0)])
   +- Exchange hashpartitioning(userID#16, 200), ENSURE_REQUIREMENTS, [id=#30]
      +- ObjectHashAggregate(keys=[userID#16], functions=[partial_mycustomaggregator($line15.$read$$iw$$iw$$iw$$iw$MyCustomAggregator@69cf253c, Some(createexternalrow(userId#16.toString, tweet#17.toString, StructField(userId,StringType,true), StructField(tweet,StringType,true))), Some(interface org.apache.spark.sql.Row), Some(StructType(StructField(userId,StringType,true), StructField(tweet,StringType,true))), encodeusingserializer(input[0, java.lang.Object, true], true), decodeusingserializer(input[0, binary, true], scala.collection.immutable.Map, true), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false), StringType, true, 0, 0)])
         +- FileScan csv [userId#16,tweet#17] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/a616287/Repos/SBX/files/tweets.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<userId:string,tweet:string>

And with the usage of only standart org.apache.spark.sql.functions._

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [userId#16, word#41]
   +- Filter (isnotnull(maxCount#57L) AND (count#52L = maxCount#57L))
      +- Window [max(count#52L) windowspecdefinition(userId#16, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS maxCount#57L], [userId#16]
         +- Sort [userId#16 ASC NULLS FIRST], false, 0
            +- Exchange hashpartitioning(userId#16, 200), ENSURE_REQUIREMENTS, [id=#66]
               +- HashAggregate(keys=[userId#16, word#41], functions=[count(word#41)])
                  +- Exchange hashpartitioning(userId#16, word#41, 200), ENSURE_REQUIREMENTS, [id=#63]
                     +- HashAggregate(keys=[userId#16, word#41], functions=[partial_count(word#41)])
                        +- Generate explode(arrayOfWords#36), [userId#16], false, [word#41]
                           +- Project [userId#16, split(tweet#17,  , -1) AS arrayOfWords#36]
                              +- Filter ((size(split(tweet#17,  , -1), true) > 0) AND isnotnull(split(tweet#17,  , -1)))
                                 +- FileScan csv [userId#16,tweet#17] Batched: false, DataFilters: [(size(split(tweet#17,  , -1), true) > 0), isnotnull(split(tweet#17,  , -1))], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/a616287/Repos/SBX/files/tweets.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<userId:string,tweet:string>ious.

I guess the conclusion regarding these two plans is obvious.

So, org.apache.spark.sql.expressions.Aggregator is a powerful abstraction in Spark SQL used for creating custom aggregation functions that can be used in DataFrame API or SQL queries. You should consider using Aggregator when:

  1. Custom Aggregation Logic: You need to perform custom aggregation operations that are not provided by built-in Spark SQL functions. Aggregator allows you to define your own aggregation logic .

  2. Performance Optimization: Using Aggregator can provide performance benefits over attempting to perform aggregations using a combination of standard functions from org.apache.spark.sql.functions._. This is particularly advantageous compared to relying solely on UDFs or combinations of standard functions, as Aggregator leverages Spark's Catalyst optimizer to generate optimized physical execution plans

  3. Code Reusability: By encapsulating custom aggregation logic within an Aggregator, you can reuse the same logic across different parts of your codebase, making your code more modular and easier to maintain.

  4. Integration with DataFrame API and SQL: Aggregator seamlessly integrates with the DataFrame API and can be used directly in SQL queries, allowing you to leverage the expressive power of SQL while benefiting from the performance optimizations provided by Aggregator. (Not covered in this artical because I prefer DataFrame API :-))

Overall, Aggregator is a powerful tool for building custom aggregation functions in Spark SQL, especially when you need to perform complex aggregations efficiently.

Tags:
Hubs:
Total votes 1: ↑1 and ↓0+1
Comments0

Articles