Apache Spark RDD’s flatMap transformation

  • Post category:Spark
  • Reading time:5 mins read

Apache Spark RDD’s flatMap transformation

apache_spark_logo
Apache Spark

In our previous post, we talked about the Map transformation in Spark. In this post we will learn the flatMap transformation.

As per Apache Spark documentation, flatMap(func) is similar to map, but each input item can be mapped to 0 or more output items. That means the func should return a scala.collection.Seq rather than a single item. Let’s take an example. We have input data as shown below

CAT,BAT,RAT,ELEPHANT
RAT,BAT,BAT,BAT,CAT
CAT,ELEPHANT,RAT,ELEPHANT
RAT,RAT,RAT,BAT,CAT

Let’s first use the map() function to create a word count program.

// Prepare the data.
val data = Seq(
      "CAT,BAT,RAT,ELEPHANT",
      "RAT,BAT,BAT,BAT,CAT",
      "CAT,ELEPHANT,RAT,ELEPHANT",
      "RAT,RAT,RAT,BAT,CAT")

// Convert the data to RDD
val inputRDD = sparkContext.parallelize(data)

// Using map transformation
val tupleRDD = inputRDD.map(line => line.split(","))

// Printing tuple RDD 
tupleRDD.collect.foreach(println)

In the above code we tried to use a map function on inputRDD and splitted the data by comma. Lets see the output below

[Ljava.lang.String;@2bef09c0
[Ljava.lang.String;@62ce72ff
[Ljava.lang.String;@58a63629
[Ljava.lang.String;@7de843ef

So what has happened above ?

Surprisingly we are not able to see the individual words but some Object’s toString representation. Let’s understand why ? Split function returns an Array but map function does not have the capability to flatten the array. The returned RDD can have the same or more number of elements than original RDD.

Correct Solution

In such scenarios we can use a flatMap function. flatMap knows how to flatten an Array. Let’s solve the same scenario using a flatMap function now.

// Using flatMap 
val tupleRDD1 = inputRDD.flatMap(line => line.split(","))

// Printing the RDD
tupleRDD1.collect.foreach(println)

// Output
CAT
BAT
RAT
ELEPHANT
RAT
BAT
BAT

Complete Code

The complete example is also present in our Github repository https://github.com/proedu-organisation/spark-scala-examples/blob/main/src/main/scala/rdd/transformations/FlatMapExample.scala

import org.apache.spark.sql.SparkSession
object FlatMapExample extends App {
  // Prepare the data.
  val data = Seq(
    "CAT,BAT,RAT,ELEPHANT",
    "RAT,BAT,BAT,BAT,CAT",
    "CAT,ELEPHANT,RAT,ELEPHANT",
    "RAT,RAT,RAT,BAT,CAT"
  )
  // Creating a SparkContext object.
  val sparkContext = SparkSession.builder()
    .master("local[*]")
    .appName("Proedu.co examples")
    .getOrCreate()
    .sparkContext
  val inputRDD = sparkContext.parallelize(data)
  inputRDD.collect.foreach(println)
  val tupleRDD = inputRDD.map(line => line.split(","))
  tupleRDD.collect.foreach(println)
  /** Output. map does not know how to flatten an Array.
   * [Ljava.lang.String;@2bef09c0
   * [Ljava.lang.String;@62ce72ff
   * [Ljava.lang.String;@58a63629
   * [Ljava.lang.String;@7de843ef
   */
  val tupleRDD1 = inputRDD.flatMap(line => line.split(","))
  tupleRDD1.collect.foreach(println)
  /** Output
   * CAT
   * BAT
   * RAT
   * ELEPHANT
   * RAT
   * BAT
   * BAT
   */
}

Happy Learning 🙂