Below, we will discuss user-defined aggregation functions (UDAF) using org.apache.spark.sql.expressions.Aggregator
, which can be used for aggregating groups of elements in a DataSet into a single value in any user-defined way.
Let’s start by examining an example from the official documentation that implements a simple aggregation
case class Data(i: Int)
val customSummer = new Aggregator[Data, Int, Int] {
def zero: Int = 0
def reduce(b: Int, a: Data): Int = b + a.i
def merge(b1: Int, b2: Int): Int = b1 + b2
def finish(r: Int): Int = r
def bufferEncoder: Encoder[Int] = Encoders.scalaInt
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}.toColumn()
val ds: Dataset[Data] = ...
val aggregated = ds.select(customSummer)
The case class ‘Data’ is needed because UDAFs are used for DataSets.
To create your own aggregator, you need to define 6 functions with predefined names:
zero — also known as the initial value; should satisfy the requirement: “something” + zero = “something”
reduce — the reduce function that performs our aggregation into a buffer
merge — the function for merging buffers
finish — the final processing to obtain the target value; in the specific example, it is essentially absent, but it is often useful (as we’ll see in the examples below)
two encoders for the buffer and the output value
Why so many? Let’s dive a little bit deeper. Spark is a distributed data processing framework and processes aggregation in the following steps:
First step: On the executors, we perform a pre-aggregation using a reduce function to minimize the volume of data shuffling.
Second step: Spark shuffles the data — moving the data with the same aggregation key to the same executor.
Third step: Finalizes aggregation on executors using a merge function.
Fourth step: Final processing using a finish function.
Therefore, even if you don’t need a certain step, we still have to define these functions. On the other hand, you are unlikely to use UDAF for simple summation, and for complex aggregations, such staging can be very useful. Also, it is important to note that the aggregation is essentially performed twice — first on the executors, and then the final aggregation is done with data shuffling.
Let’s consider an educational ‘word count’ example implemented using aggregators. Suppose we have a simple CSV file with user tweets:
userId,tweet
f6e8252f,cat dog frog cat dog frog frog dog dog bird
f6e8252f,cat dog frog dog dog bird
f6e8252f,cat dog
29d89ee4,frog frog dog bird
29d89ee4,frog cat dog frog frog dog dog bird
29d89ee4,frog bird
Let’s determine the most frequently used word for each user. Naturally, there are many ways to solve this task, but for educational purposes, we will demonstrate how to solve it using a UDAF.
val df = spark.read.format("csv").option("header", "true").load(path)
type FrenquencyMapType = Map[String, Int]
class MyCustomAggregator(columnName: String) extends Aggregator [Row, FrenquencyMapType, String] {
private def combineMaps(map1: FrenquencyMapType, map2: FrenquencyMapType): FrenquencyMapType = {
val combinedKeys = map1.keySet ++ map2.keySet
combinedKeys.map { key => key -> (map1.getOrElse(key, 0) + map2.getOrElse(key, 0))
}.toMap
}
def zero: FrenquencyMapType = Map.empty[String, Int]
def reduce(buffer: FrenquencyMapType, row: Row): FrenquencyMapType = {
val wordsMap = row.getAs[Any](columnName).toString.split(" ").groupBy(identity).mapValues(_.length)
combineMaps(buffer, wordsMap)
}
def merge(map1: FrenquencyMapType, map2: FrenquencyMapType): FrenquencyMapType =
combineMaps(map1, map2)
def finish(buffer: FrenquencyMapType): String =
if (buffer.isEmpty) "nUll" else buffer.maxBy(_._2)._1
def bufferEncoder: Encoder[FrenquencyMapType] = Encoders.kryo[Map[String, Int]]
def outputEncoder: Encoder[String] = Encoders.STRING
}
val mostPopularWordAggregator = new MyCustomAggregator("tweet").toColumn
df.groupBy(col("userID"))
.agg(mostPopularWordAggregator.name("favoriteWord"))
.withColumnRenamed("value", "userId")
.show()
The result:
+--------+---------------+
|userId |favoriteWord |
+--------+---------------+
|29d89ee4|frog |
|f6e8252f|dog |
+--------+---------------+
Note the specification of types <IN, BUF, OUT>:
The input type for our aggregator is Row.
For the buffer, we use the type
Map[String, Int]
(alias FrenquencyMapType)Finally, the output type is simply String since the most popular word is returned.
Astute readers may have noticed that, in essence, we adapted the ‘most common value’ aggregator for our needs. Could we achieve the same result without UDAF? Of course! One of the naives variants is:
Get an ArrayType column using
org.apache.spark.sql.functions.split
Create a new row for each element in the array using
org.apache.spark.sql.functions.explode
Perform standard aggregation over userId + word.
Take the userId + word with max value of count-aggregation.
In the code it could look like this:
val window = Window.partitionBy("userId")
df.withColumn("arrayOfWords", split(col("tweet"), " "))
.withColumn("word", explode(col("arrayOfWords")))
.groupBy("userId", "word")
.agg(count("word").as("count"))
.withColumn("maxCount", max("count").over(window))
.filter(col("count") === col("maxCount"))
.select("userId", "word")
.show()
And now let’s compare the plans. With UDAF it’s very simple:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- ObjectHashAggregate(keys=[userID#16], functions=[mycustomaggregator($line15.$read$$iw$$iw$$iw$$iw$MyCustomAggregator@69cf253c, Some(createexternalrow(userId#16.toString, tweet#17.toString, StructField(userId,StringType,true), StructField(tweet,StringType,true))), Some(interface org.apache.spark.sql.Row), Some(StructType(StructField(userId,StringType,true), StructField(tweet,StringType,true))), encodeusingserializer(input[0, java.lang.Object, true], true), decodeusingserializer(input[0, binary, true], scala.collection.immutable.Map, true), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false), StringType, true, 0, 0)])
+- Exchange hashpartitioning(userID#16, 200), ENSURE_REQUIREMENTS, [id=#30]
+- ObjectHashAggregate(keys=[userID#16], functions=[partial_mycustomaggregator($line15.$read$$iw$$iw$$iw$$iw$MyCustomAggregator@69cf253c, Some(createexternalrow(userId#16.toString, tweet#17.toString, StructField(userId,StringType,true), StructField(tweet,StringType,true))), Some(interface org.apache.spark.sql.Row), Some(StructType(StructField(userId,StringType,true), StructField(tweet,StringType,true))), encodeusingserializer(input[0, java.lang.Object, true], true), decodeusingserializer(input[0, binary, true], scala.collection.immutable.Map, true), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false), StringType, true, 0, 0)])
+- FileScan csv [userId#16,tweet#17] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/a616287/Repos/SBX/files/tweets.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<userId:string,tweet:string>
And with the usage of only standart org.apache.spark.sql.functions._
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [userId#16, word#41]
+- Filter (isnotnull(maxCount#57L) AND (count#52L = maxCount#57L))
+- Window [max(count#52L) windowspecdefinition(userId#16, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS maxCount#57L], [userId#16]
+- Sort [userId#16 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(userId#16, 200), ENSURE_REQUIREMENTS, [id=#66]
+- HashAggregate(keys=[userId#16, word#41], functions=[count(word#41)])
+- Exchange hashpartitioning(userId#16, word#41, 200), ENSURE_REQUIREMENTS, [id=#63]
+- HashAggregate(keys=[userId#16, word#41], functions=[partial_count(word#41)])
+- Generate explode(arrayOfWords#36), [userId#16], false, [word#41]
+- Project [userId#16, split(tweet#17, , -1) AS arrayOfWords#36]
+- Filter ((size(split(tweet#17, , -1), true) > 0) AND isnotnull(split(tweet#17, , -1)))
+- FileScan csv [userId#16,tweet#17] Batched: false, DataFilters: [(size(split(tweet#17, , -1), true) > 0), isnotnull(split(tweet#17, , -1))], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/Users/a616287/Repos/SBX/files/tweets.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<userId:string,tweet:string>ious.
I guess the conclusion regarding these two plans is obvious.
So, org.apache.spark.sql.expressions.Aggregator
is a powerful abstraction in Spark SQL used for creating custom aggregation functions that can be used in DataFrame API or SQL queries. You should consider using Aggregator
when:
Custom Aggregation Logic: You need to perform custom aggregation operations that are not provided by built-in Spark SQL functions.
Aggregator
allows you to define your own aggregation logic .Performance Optimization: Using
Aggregator
can provide performance benefits over attempting to perform aggregations using a combination of standard functions fromorg.apache.spark.sql.functions._
. This is particularly advantageous compared to relying solely on UDFs or combinations of standard functions, asAggregator
leverages Spark's Catalyst optimizer to generate optimized physical execution plansCode Reusability: By encapsulating custom aggregation logic within an
Aggregator
, you can reuse the same logic across different parts of your codebase, making your code more modular and easier to maintain.Integration with DataFrame API and SQL:
Aggregator
seamlessly integrates with the DataFrame API and can be used directly in SQL queries, allowing you to leverage the expressive power of SQL while benefiting from the performance optimizations provided byAggregator
. (Not covered in this artical because I prefer DataFrame API :-))
Overall, Aggregator
is a powerful tool for building custom aggregation functions in Spark SQL, especially when you need to perform complex aggregations efficiently.