Spark RDD Cheat Sheet with Scala

Spark RDD Cheat Sheet with Scala

A cheat sheet on spark RDD operations with scala

The main abstraction Spark provides is a resilient distributed dataset (RDD), which is a collection of elements partitioned across the nodes of the cluster that can be operated on in parallel. RDDs are created by starting with a file in the Hadoop file system (or any other Hadoop-supported file system), or an existing collection in the driver program, and transforming it. Users may also ask Spark to persist an RDD in memory, allowing it to be reused efficiently across parallel operations. Finally, RDDs automatically recover from node failures.

Cf: spark documentation

  • Dataset: kaggle.com/fedesoriano/heart-failure-predic..
  • A simple CSV (comma-separated-value) dataset for heart failure prediction from Kaggle
  • Tip: If you run code snippet on spark-shell, you can just do [rdd].collect() instead of [rdd].collect().foreach(println)

Dataset preview

AgeSexChestPainTypeRestingBPCholesterolFastingBSRestingECGMaxHRExerciseAnginaOldpeakST_SlopeHeartDisease
40MATA1402890Normal172N0Up0
49FNAP1601800Normal156N1Flat1
37MATA1302830ST98N0Up0
48FASY1382140Normal108Y1.5Flat1
54MNAP1501950ST122N0Up0
39MNAP1203390ST170N0Up0

Load Data as RDD

import org.apache.spark.sql.SparkSession

val spark = SparkSession
    .builder()
    .appName("Spark RDD Cheat Sheet with scala")
    .master("local")
    .getOrCreate()

val rdd = spark.sparkContext.textFile("data/heart.csv")

Map

val rdd = spark.sparkContext.textFile("data/heart.csv")

rdd
  .map(line => line)
  .collect()
  .foreach(println)

FlatMap

rdd
  .flatMap(line => line.split(","))
  .collect()
  .foreach(println)

Map Partitions

val collection = spark.sparkContext.parallelize(Array.range(1, 11), 2)

collection
  .mapPartitions(partition => partition)
  .collect()
  .foreach(println)

Map Partitions With Index

val collection = spark.sparkContext.parallelize(Array.range(1, 11), 2)

collection
  .mapPartitionsWithIndex((index, partition) =>
    if (index == 0) partition
    else Iterator()
  )
  .collect()
  .foreach(println)

For Each Partitions

val collection = spark.sparkContext.parallelize(Array.range(1, 11), 2)

collection.foreachPartition(p => {
  p.toArray.foreach(println)
  println()
})

ReduceByKey

// Count Number of M and F
rdd
  .map(line => (line.split(",")(1), 1))
  .reduceByKey((a, b) => a + b)
  .collect()
  .foreach(println)

Filter

// Females observations with Normal restingECG having chestPain = ATA 
// columns sex=1, chestPain=2 and restingECG=6

println(
  rdd
    .map(line => {
      val splitted_line = line.split(",")

      (splitted_line(1), splitted_line(2), splitted_line(6))
    })
    .filter(line =>
      line._1 == "F" && line._2 == "ATA" && line._3 == "Normal"
    )
    .count()
)

Sample

val rdd_sample = rdd.sample(false, 0.2, 42)

rdd_sample.collect().foreach(println)
println(rdd_sample.count())
println(rdd.count())

Union

val sample1 = rdd.sample(withReplacement = false, fraction = 0.3, seed = 42)
val sample2 = rdd.sample(withReplacement = false, fraction = 0.7, seed = 42)

println(f"Sample 1: ${sample1.count()}, Sample 2: ${sample2.count()}")

val union_sample = sample1.union(sample2)

println(f"Union Sample: ${union_sample.count()}")

Intersection

val sample1 = rdd.sample(withReplacement = false, fraction = 0.6, seed = 42)
val sample2 = rdd.sample(withReplacement = false, fraction = 0.6, seed = 41)

println(f"Sample 1: ${sample1.count()}, Sample 2: ${sample2.count()}")

val intersection_rdd = sample1.intersection(sample2)

println(f"Intersection Sample: ${intersection_rdd.count()}")

Distinct

val sample1 = rdd.sample(withReplacement = true, fraction = 0.8, seed = 42)
val sample2 = rdd.sample(withReplacement = true, fraction = 0.8, seed = 4)

println(s"Sample 1: ${sample1.count()}, Sample 2: ${sample2.count()}")

val union_sample = sample1.union(sample2)

println(f"Union Sample: ${union_sample.count()}")

val distinct_sample = union_sample.distinct()

println(f"Distint Sample: ${distinct_sample.count()}")

GroupBy

rdd
  .map(line => line.split(","))
  .groupBy(line => line(6))
  .map(line => (line._1, line._2.size))
  .collect()
  .foreach(line => println(f"${line._1}: ${line._2}"))

Aggregate

val N = rdd.count()
println(N)

def sumAggPartition = (partitionAccumulator: Int, currentValue: Int) =>
  partitionAccumulator + currentValue
def sumAggGlobal = (globalAccumulator: Int, currentValue: Int) =>
  (globalAccumulator + currentValue)

val mean = rdd
  .map(_.split(",")(0))
  .filter(line => !line.contains("Age"))
  .map(line => line.toInt)
  .aggregate(0)(sumAggPartition, sumAggGlobal)
  .toDouble
  ./(N)

println(mean)

val mean2 = rdd
  .map(_.split(",")(0))
  .filter(line => !line.contains("Age"))
  .map(line => line.toInt)
  .reduce((a, b) => a + b)
  .toDouble
  ./(N)

println(mean2)

Aggregate (2)

val collection = spark.sparkContext.parallelize(Array.range(1, 11), 2)

println(
  collection.aggregate(0)(
    (acc, value) => {
      println(f"from-seqOp: $acc, $value")
      acc + value
    },
    (endAcc, endValue) => {
      println(f"from-combOp: $endAcc, $endValue")
      endAcc + endValue
    }
  )
)

Sort By

rdd
  .map(_.split(",")(0))
  .filter(line => !line.contains("Age"))
  .map(line => line.toInt)
  .sortBy(line => line)
  .collect()
  .foreach(println)

Save As Text File

// Add new column "Id"
var i = 0
var newLine = ""
rdd
  .map(line => {
    newLine = f"${if (i == 0) "Id" else i},$line"
    i += 1
    newLine
  })
  .saveAsTextFile("data/heart2")

Join

val rdd2 = spark.sparkContext.textFile("data/heart2.csv")

val part1 = rdd2
  .map(_.split(","))
  .map(line =>
    (
      line(0),
      Patient(
        id = line(0),
        sex = line(2),
        chestPainType = line(3),
        restingECG = line(7)
      )
    )
  )

val part2 = rdd2
  .map(_.split(","))
  .map(line =>
    (
      line(0),
      Patient(
        id = line(0),
        age = line(1),
        restingBP = line(4),
        cholesterol = line(5),
        fastingBS = line(6)
      )
    )
  )

part1.take(3).foreach(println)
println()

part2.take(3).foreach(println)
println()

val joinned_part = part1.join(part2)
joinned_part.take(3).foreach(println)
println()

CoGroup VS Join VS Cartesian

val part1 =
  spark.sparkContext.parallelize(
    Seq(("A", "Diaf-From-1"), ("A", "Diaf-From-1"), ("B", "Yeah-From-1"))
  )
val part2 =
  spark.sparkContext.parallelize(
    Seq(("A", "Bro-From-2"), ("A", "Walabook-From-2"))
  )

val cogroupped_part = part1.cogroup(part2)
cogroupped_part.collect().foreach(println)
println()

val joinned_part = part1.join(part2)
joinned_part.collect().foreach(println)
println()

val cartesian_part = part1.cartesian(part2)
cartesian_part.collect().foreach(println)
println()

Pipe

val collection =
  spark.sparkContext.parallelize(
    Seq(("A", "Diaf-From-1"), ("A", "Diaf-From-1"), ("B", "Yeah-From-1"))
  )

collection.pipe("head -n 5 data/heart2.csv").collect().foreach(println)

Glom

// Show all partitions in a single array
val collection = spark.sparkContext.parallelize(Array.range(1, 11), 2)
collection
  .glom()
  .collect()
  .foreach(p => {
    println
    p.foreach(i => print(f"$i "))
    println
  })

Coalesce

val collectionWithThreePartitions = spark.sparkContext.parallelize(Array.range(1, 1001), 3)

val collectionWithOnePartition = collectionWith3Partitions.coalesce(1)

println(f"Before Coalesce: ${collectionWithThreePartitions.getNumPartitions}")
println(f"After Coalesce: ${collectionWithOnePartition.getNumPartitions}")

Repartition

val collectionWith3Partitions = spark.sparkContext.parallelize(Array.range(1, 1001), 3)

val collectionWith2Partitions = collectionWith3Partitions.repartition(2)

println(f"Before Repartition: ${collectionWith3Partitions.getNumPartitions}")
println(f"After Repartition: ${collectionWith2Partitions.getNumPartitions}")

Repartition And Sort Within Partitions

// Collection with keys from 1 to 15
val collectionBeforeRepartition = spark.sparkContext.parallelize(
  Seq(
    (7, 7),(15, 15),(14, 14),(2, 2),(9, 9),
    (1, 1),(10, 10),(3, 3),(6, 6),(8, 8),
    (4, 4),(12, 12),(13, 13),(11, 11),(5, 5)
  )
)

collectionBeforeRepartition.foreachPartition(p => {
  println()
  print(f"Before Partition: ")
  p.toArray.foreach(i => print(f"$i "))
  println()
})

val collectionAfterRepartition =
  collectionBeforeRepartition.repartitionAndSortWithinPartitions(
    new RangePartitioner(2, collectionBeforeRepartition)
  )

collectionAfterRepartition.foreachPartition(p => {
  println()
  print(f"After Partition: ")
  p.toArray.foreach(i => print(f"$i "))
  println()
})

Advanced examples

Create rdd schema from case class

// Case Class
case class Patient(
  id: String = "",
  age: String = "",
  sex: String = "",
  chestPainType: String = "",
  restingBP: String = "",
  cholesterol: String = "",
  fastingBS: String = "",
  restingECG: String = "",
  maxHR: String = "",
  exerciseAngina: String = "",
  oldpeak: String = ""
)

// Filter using case class fields and pattern matching
print(
  rdd
    .map(line => line.split(","))
    .map(line =>
      Patient(
        sex = Option(line(1)),
        chestPainType = Option(line(2)),
        restingECG = Option(line(6))
      )
    )
    .filter(patient =>
      (
        patient.chestPainType match {
          case Some("ATA") => true
          case _           => false
        }
      ) &&
        (
          patient.restingECG match {
            case Some("Normal") => true
            case _              => false
          }
        ) &&
        (
          patient.sex match {
            case Some("F") => true
            case _         => false
          }
        )
    )
    .count()
)

// A simpler example using filter with case case

println(
  rdd
    .map(_.split(","))
    .map(line =>
      Patient(sex = line(1), chestPainType = line(2), restingECG = line(6))
    )
    .filter(patient =>
      patient.sex == "F" && patient.chestPainType == "ATA" && patient.restingECG == "Normal"
    )
    .count()
)