Apache Spark RDD’s flatMap transformation
In our previous post, we talked about the Map transformation in Spark. In this post we will learn the flatMap transformation.
As per Apache Spark documentation, flatMap(func) is similar to map, but each input item can be mapped to 0 or more output items. That means the func should return a scala.collection.Seq rather than a single item. Let’s take an example. We have input data as shown below
CAT,BAT,RAT,ELEPHANT
RAT,BAT,BAT,BAT,CAT
CAT,ELEPHANT,RAT,ELEPHANT
RAT,RAT,RAT,BAT,CAT
Let’s first use the map() function to create a word count program.
// Prepare the data. val data = Seq( "CAT,BAT,RAT,ELEPHANT", "RAT,BAT,BAT,BAT,CAT", "CAT,ELEPHANT,RAT,ELEPHANT", "RAT,RAT,RAT,BAT,CAT") // Convert the data to RDD val inputRDD = sparkContext.parallelize(data) // Using map transformation val tupleRDD = inputRDD.map(line => line.split(",")) // Printing tuple RDD tupleRDD.collect.foreach(println)
In the above code we tried to use a map function on inputRDD and splitted the data by comma. Lets see the output below
[Ljava.lang.String;@2bef09c0 [Ljava.lang.String;@62ce72ff [Ljava.lang.String;@58a63629 [Ljava.lang.String;@7de843ef
So what has happened above ?
Surprisingly we are not able to see the individual words but some Object’s toString representation. Let’s understand why ? Split function returns an Array but map function does not have the capability to flatten the array. The returned RDD can have the same or more number of elements than original RDD.
Correct Solution
In such scenarios we can use a flatMap function. flatMap knows how to flatten an Array. Let’s solve the same scenario using a flatMap function now.
// Using flatMap val tupleRDD1 = inputRDD.flatMap(line => line.split(",")) // Printing the RDD tupleRDD1.collect.foreach(println) // Output CAT BAT RAT ELEPHANT RAT BAT BAT
Complete Code
The complete example is also present in our Github repository https://github.com/proedu-organisation/spark-scala-examples/blob/main/src/main/scala/rdd/transformations/FlatMapExample.scala
import org.apache.spark.sql.SparkSession object FlatMapExample extends App { // Prepare the data. val data = Seq( "CAT,BAT,RAT,ELEPHANT", "RAT,BAT,BAT,BAT,CAT", "CAT,ELEPHANT,RAT,ELEPHANT", "RAT,RAT,RAT,BAT,CAT" ) // Creating a SparkContext object. val sparkContext = SparkSession.builder() .master("local[*]") .appName("Proedu.co examples") .getOrCreate() .sparkContext val inputRDD = sparkContext.parallelize(data) inputRDD.collect.foreach(println) val tupleRDD = inputRDD.map(line => line.split(",")) tupleRDD.collect.foreach(println) /** Output. map does not know how to flatten an Array. * [Ljava.lang.String;@2bef09c0 * [Ljava.lang.String;@62ce72ff * [Ljava.lang.String;@58a63629 * [Ljava.lang.String;@7de843ef */ val tupleRDD1 = inputRDD.flatMap(line => line.split(",")) tupleRDD1.collect.foreach(println) /** Output * CAT * BAT * RAT * ELEPHANT * RAT * BAT * BAT */ }
Happy Learning 🙂