Apache Spark RDD mapPartitions and mapPartitionsWithIndex

Apache Spark RDD mapPartitions and mapPartitionsWithIndex

apache_spark_logo
Apache Spark RDD mapPartitions transformation

In our previous posts we talked about map function. In this post we will learn RDD’s mapPartitions and mapPartitionsWithIndex transformation in Apache Spark.

As per Apache Spark, mapPartitions performs a map operation on an entire partition and returns a new RDD by applying the function to each partition of the RDD.

We can also say that mapPartitions is a specialized map that is called only once for each partition, where the entire content of the respective partition is available as a sequential stream of values via the input argument (Iterarator[T]).

The custom function must return yet another Iterator[U]. Spark converts the combined result iterators into a new RDD automatically. 

Now let’s see the signature of mapPartitions below

def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]

Here, the custom function f: Iterator[T] => Iterator[U] must return yet another Iterator[U]

Also,  preservesPartitioning indicates whether the input function preserves the partitioner, which should be false unless this is a pair RDD and the input function doesn't modify the keys.

Let’s understand by looking into an example. We have an RDD containing some words as shown below

// Preparing the data for Map
val days = List("Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday")

// Creating a SparkContext object.
val sparkContext = SparkSession.builder()
  .master("local[*]")
  .appName("Proedu.co examples")
  .getOrCreate()
  .sparkContext

// Converting local Scala map to RDD.
val daysRDD = sparkContext.parallelize(days)

// Let's print some data.
daysRDD.collect.foreach(println)

// Output
Sunday
Monday
Tuesday
Wednesday
Thursday
Friday
Saturday

In the next step we will convert the above RDD of String to RDD containing tuples of element and its length.

// Applying mapPartitions transformations.
val mapPartitionRDD = daysRDD.mapPartitions(iterator => {
    val list = iterator.toList
    val tuple = list.map(word => (word, word.length))
    // The custom function must return yet another Iterator.
    tuple.iterator
})

// Let's print some data.
mapPartitionRDD.collect.foreach(println)

// Output
(Sunday,6)
(Monday,6)
(Tuesday,7)
(Wednesday,9)
(Thursday,8)
(Friday,6)
(Saturday,8)

When should we use mapPartitions ?

We can choose mapPartition over map when we have some heavy initialization process like creating a database connection. For example

val dataRDD = recordsRDD.mapPartitions(partition => {

  // Creating a DB Connection per RDD partition rather than per element. 
  val connection = new DatabaseConnection
  
  val list = iterator.toList
  val tupleList = list.map(word => (word, word.length))

  // Closing the connection.
  connection.close() 
  
  // Returning a new Iterator.
  tupleList.iterator
})

As shown in above code, we are creating a database connection once per partition which is more efficient as compared to creating it once per RDD element.

Let’s assume that we have 1000 elements in one RDD partition. If we use map function, it will create a database connection 1000 times.

On the other hand , if we use mapPartitions, it will create the database connection just once for the entire partition. This is a perfect scenario where we can choose mapPartition over a map.

Complete Example in Scala

The example is also present in or Github repository https://github.com/proedu-organisation/spark-scala-examples/blob/main/src/main/scala/rdd/transformations/MapPartitionsExample.scala

import org.apache.spark.sql.SparkSession
object MapPartitionsExample extends App {
  // Preparing the data for Map
  val days = List("Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday")
  // Creating a SparkContext object.
  val sparkContext = SparkSession.builder()
    .master("local[*]")
    .appName("Proedu.co examples")
    .getOrCreate()
    .sparkContext
  val daysRDD = sparkContext.parallelize(days)
  // Let's print some data.
  daysRDD.collect.foreach(println)
  /** Output
   * Sunday
   * Monday
   * Tuesday
   * Wednesday
   * Thursday
   * Friday
   * Saturday
   */
  // Using mapPartitions transformation.
  val mapPartitionRDD = daysRDD.mapPartitions(iterator => {
    val list = iterator.toList
    val tupleList = list.map(word => (word, word.length))
    tupleList.iterator
  })
  // Let's print some data.
  mapPartitionRDD.collect.foreach(println)
  /** Output
   * (Sunday,6)
   * (Monday,6)
   * (Tuesday,7)
   * (Wednesday,9)
   * (Thursday,8)
   * (Friday,6)
   * (Saturday,8)
   */
}

mapPartitionsWithIndex

Now we will talk about a similar transformation called mapPartitionsWithIndex. It is similar to mapPartitions, but takes two parameters. The first parameter is the index of the partition and the second is an iterator through all the items within after applying whatever transformation the function encodes.

def mapPartitionsWithIndex[U: ClassTag](f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]

Let’s see the example below. This time we are calling mapPartitionsWithIndex on daysRDD.

// Using mapPartitions transformation.
val mappedRDD = daysRDD.mapPartitionsWithIndex((partition, iterator) => {
  println(s"Called in Partition $partition")
  val list = iterator.toList
  val tupleList = list.map(element => (element, element.length))
  tupleList.iterator
})

// Let's print some data.
mappedRDD.collect.foreach(println)

// Output. The third value in the tuple is the partition number.
(Sunday,6,1)
(Monday,6,2)
(Tuesday,7,3)
(Wednesday,9,4)
(Thursday,8,5)
(Friday,6,6)
(Saturday,8,7)

Complete Example in Scala

The example is also present in or Github repository https://github.com/proedu-organisation/spark-scala-examples/blob/main/src/main/scala/rdd/transformations/MapPartitionsWithIndexExample.scala

import org.apache.spark.sql.SparkSession
object MapPartitionsWithIndexExample extends App {
  // Local Scala collection.
  val days = List("Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday")
  // Creating a SparkContext object.
  val sparkContext = SparkSession.builder()
    .master("local[*]")
    .appName("Proedu.co examples")
    .getOrCreate()
    .sparkContext
  // Converting local Scala collection to RDD.
  val daysRDD = sparkContext.parallelize(days)
  // Let's print some data.
  daysRDD.collect.foreach(println)
  /** Output
   * Sunday
   * Monday
   * Tuesday
   * Wednesday
   * Thursday
   * Friday
   * Saturday
   */
  // Using mapPartitions transformation.
  val mappedRDD = daysRDD.mapPartitionsWithIndex((partition, iterator) => {
    val list = iterator.toList
    val tupleList = list.map(element => (element, element.length, partition))
    tupleList.iterator
  })
  // Let's print some data.
  mappedRDD.collect.foreach(println)
  /** Output
   * (Sunday,6,1)
   * (Monday,6,2)
   * (Tuesday,7,3)
   * (Wednesday,9,4)
   * (Thursday,8,5)
   * (Friday,6,6)
   * (Saturday,8,7)
   */
}

Happy Learning 🙂