Java Stream flatMap()

Filed Under: Java

Java Stream flatMap() is a very useful function to flatten the Stream after applying the given function to all the elements of the stream.

Java Stream flatMap

Let’s look at the syntax of Stream flatMap() function.


<R> Stream<R> flatMap(Function<T, Stream<R>> mapper);

In simple words, flatMap() is used when we have a stream of collections and we want to flatten it rather than using map() function and getting the nested Stream.

Let’s look at an example to better understand this scenario. Suppose we have few List of Strings:


List<String> l1 = Arrays.asList("a","b");
List<String> l2 = Arrays.asList("c","d");

Now we want to merge these lists and get a new list of Strings and change the letters to uppercase. Since we have multiple lists, we will have to first merge them to a single list and then apply map() function. Something like below code:


List<String> l = new ArrayList<>();
l.addAll(l1);
l.addAll(l2);

List<String> letters = l.stream()
			.map(String::toUpperCase)
			.collect(Collectors.toList());

Obviously this is a lot of rework and we have to manually merge the lists to get a single list of elements and then apply map() function. This is where flatMap() is very useful. Let’s see how we can use flatMap() to perform the same operation.


List<String> betterLetters = Stream.of(l1, l2)
				.flatMap(List::stream)
				.map(String::toUpperCase)
				.collect(Collectors.toList());

Now it’s clear that we used flatMap() function to flatten the Stream of Lists to Stream of elements.

Java Stream flatMap() Real Life Example

Let’s look into a real life example where flatMap() function will be really helpful. Suppose we have a State class that contains list of cities. Now we have a list of States and we want to get the list of all the cities. Here flatMap() will be very helpful as we won’t have to write nested for loops and iterate over the lists manually. Below is a complete example to show this scenario.


package com.journaldev.streams;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class JavaStreamFlatMapAggregateExample {
	public static void main(String[] args) {
		State karnataka = new State();
		karnataka.addCity("Bangalore");
		karnataka.addCity("Mysore");

		State punjab = new State();
		punjab.addCity("Chandigarh");
		punjab.addCity("Ludhiana");

		List<State> allStates = Arrays.asList(karnataka, punjab);

		//Java Stream flatMap way
		List<String> allCities = allStates.stream().flatMap(e -> e.getCities().stream()).collect(Collectors.toList());
		
		System.out.println(allCities);
		
		//legacy way
		allCities = new ArrayList<String>();
		for(State state : allStates) {
			for(String city : state.getCities())
				allCities.add(city);
		}
		System.out.println(allCities);
	}

}

class State {
	private List<String> cities = new ArrayList<>();

	public void addCity(String city) {
		cities.add(city);
	}

	public List<String> getCities() {
		return this.cities;
	}
}

It’s very clear that flatMap() is very useful when we have to work with List of lists.

You can download the example code from my GitHub Repository.

Reference: API Doc

Leave a Reply

Your email address will not be published. Required fields are marked *

close
Generic selectors
Exact matches only
Search in title
Search in content
Search in posts
Search in pages