Map, Flatmap and Filter in Scala
This post is part 4 of the Functional Programming in Scala series. You can view all the posts in the series here.
TLDR;
Three of the most common methods used on collections in Scala are map
, flatMap
and filter
:
map
will perform the function given in parentheses on every element in the collectionfilter
will only return elements of a collection that satisfy the expression providedflatmap
will first perform themap
method on a list and then thenflatten
method
We often chain these methods together to achieve iteration over collections of elements.
A more readable way to write these chains, is to use a for-comprehension.
Map, FlatMap and Filter in Scala
If you have done any significant programming in Scala, you will likely have used one or all of map
, flatmap
and filter
at some point. We will look at all three of them in this post.
Let’s start by creating and printing out a very simple List
:
val list = List(1,2,3)
println(list) // prints out 'List(1,2,3)'
Note that this list gets created by calling the List.apply()
method on the companion list object.
List comes with standard implementations to extract the head and tail:
println(list.head)
println(list.tail)
Map
We can call the map implementation on the list like so:
println(list.map(_ + 1))
println(list.map(_ + " is a number"))
This will iterate over every element in the list, and apply the function in the parentheses.
Another syntax of writing the map method is :
list.map { x =>
x * 2
}
Filter
We can use the filter
method if we only want to keep certain members of the list. For example, to only keep the even numbers:
println(list.filter(_ % 2 == 0))
FlatMap
The flatmap
method is similar to the map
method. The difference is the inner grouping of an item is removed and a sequence is generated. Let’s say that we have a function that turns a single element into another list:
val toPair = (x: Int) => List(x, x+1)
If we call our original list with flatmap
and this function, then we should see the concatenation of the application of toPair
on each element of the list.
println(list.flatMap(toPair)) // prints out 'List(1, 2, 2, 3, 3, 4)'
Print all Combinations of two Lists
Say that we wanted to print out all the combinations of the following two lists:
val numbers = List(1,2,3,4)
val chars = List('a', 'b', 'c', 'd')
In imperative programming, we might be tempted to do this with a couple of loops. In functional programming in Scala, we could instead do something like this:
val combinations = numbers.flatMap(n => chars.map(c => "" + c + n))
We use flatmap
, because for each element in numbers
we are going to generate a new list. And then for every elements in chars
we are going to return the string composed from the character and the number.
If you added a third element into the mix, such as colors:
val colors = List("black", "white")
Then you use two flatmap
methods, and put a map
in the most inside block:
val combinations = numbers.flatMap(n => chars.flatMap(c => colors.map(color => "" + c + n + "-" + color)))
Foreach
Next let’s look at the foreach
method. This is similar to map
, only that it receives a function returning a unit.
This will print all the numbers in the list on a separate line:
list.foreach(println)
For-Comprehensions
The flatmaps that we wrote above are quite difficult to read. In Scala, there is a more readable format that we can use. This format is called for-comprehensions. We could rewrite the combinations
method above like this:
val forCombinations = for {
n <- numbers
c <- chars
color <- colors
} yield "" + c + n + "-" + color
This is exactly equivalent of the combinations
method that we wrote previously.
Filter in For-Comprehensions
If you want to filter something out, you can put a guard in the for-comprehension. For example, to only keep even numbers:
val forCombinationsGuard = for {
n <- numbers if n % 2 == 0 c <- chars
color <- colors
} yield "" + c + n + "-" + color
This applies a filter on numbers
. If we rewrote this in full, it would look like this:
val combinationsWithColorWithFilter = numbers.filter(_ % 2 == 0).flatMap(n => chars.flatMap(c => colors.map(color => "" + c + n + "-" + color)))
If you wanted to do something with side effects in a for-comprehension, such as println
:
for {
n <- numbers
} println(n)
Source Code
As always, the source code for this post is available on Github.