Part 1: Deep Dive - Spark Window Functions
- Rishaab
- Nov 16, 2024
- 3 min read
Updated: Nov 19, 2024
This is part 1 of deep dive series on understanding internals of Apache Spark Window Functions.
Introduction
A window function in query processing is a function that processes a subset of data, often called window frame, over a large dataset. A window frame defines the boundaries (ie. start and end row) within which a window function is applied to a set of rows.
For eg, if we have a dataset as [1, 2, 3, 4, 5], we want to calculate SUM over a window function frame that should include only 1 element before the current element, then our sum will be calculated in this way by the SUM window function,
For the first element (1), there are no elements before it, so the sum is 1.
For the second element (2), we add the previous element (1) to it, resulting in a sum of 3.
For the third element (3), we add the previous element (2) to it, resulting in a sum of 5.
For the fourth element (4), we add the previous element (3) to it, resulting in a sum of 7.
For the fifth element (5), we add the previous element (4) to it, resulting in a sum of 9.
Hence, the resulting sum using the window function frame is [1, 3, 5, 7, 9].
In this series, we will dive deeper into the internal working of Apache Spark window functions. We will study the implementation and codebase of the Apache Spark's window function using a test code. This series presumes that you have some understanding of Apache Spark concepts like DataFrames, RDDs, Query Plan. If not, the Apache Spark documentation is a great source to learn some of these concepts.
Rank Window Function
For our study, we will take the following Spark DataFrame example in Scala, and understand how this rank() window function will be processed by the Spark.
spark.conf.set("spark.sql.codegen.wholeStage", "false")
spark.conf.set("spark.sql.codegen.factoryMode", "NO_CODEGEN")
val data = Seq(
(1, "a1"),
(2, "b1"),
(1, "a2"),
(2, "b2")).toDF("col1", "col2")
val w = Window
.partitionBy("col1")
.orderBy("col2")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val df = data.select($"col1", $"col2", rank().over(w) as "rank")
df.show()
Let's quickly understand the above code.
The below two lines disables the code generation in Apache Spark. If you are unaware, Apache Spark during query execution generates a lot of Java codes to improve the performance. However, when understanding and debugging code flows, the generated codes can be hard to track. Hence, for the learning purpose we will disable code generation by setting the these configurations. This will force Apache Spark to execute query in interpreted mode.
spark.conf.set("spark.sql.codegen.wholeStage", "false")
spark.conf.set("spark.sql.codegen.factoryMode", "NO_CODEGEN")
Next, we create a DataFrame from a simple dataset with 2 columns col1 and col2 and add some rows to it.
val data = Seq(
(1, "a1"),
(2, "b1"),
(1, "a2"),
(2, "b2")).toDF("col1", "col2")
Then, we define a window function frame specification, where the dataset is partitioned by column col1. We order rows in each partition by column col2 in ascending order. The Window.unboundedPreceding and Window.currentRow specifies the lower and upper bound for the window frame respectively, and for each partition they include rows from the beginning to the current row.
val w = Window
.partitionBy("col1")
.orderBy("col2")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
For eg., the sorted partition will look like below for our test dataset. Each such partition will be processed individually by the window function frame to calculate the rank for each record.
partition1: [(1, "a1"), (1, "a2")]
partition2: [(2, "b1"), (2, "b2")]
Below, we project columns col1 and col2, along with their rank computed by the rank() window function. The ranks are assigned based on the ordering specified within each partition, and rows with the same values (ties) in the ordering will receive the same rank. However, in case of ties, the same rank will be assigned to two rows and the next rank will result in a gap. Click here for more examples on rank functions.
val df = data.select($"col1", $"col2", rank().over(w) as "rank")
df.show()
The output for our test case will look like so,
+----+----+----+
|col1|col2|rank|
+----+----+----+
| 1| a1| 1|
| 1| a2| 2|
| 2| b1| 1|
| 2| b2| 2|
+----+----+----+
OK, with our understanding of how to use window functions, it's time to explore their inner workings. In the upcoming part of this series, we'll examine the implementation details of window functions through our rank() test example.
Comments