Improving Coroutines stateIn operator

State management in the Android world is and always was a hot topic. What architecture to use, how to model UI state, how to make everything lifecycle aware.
But thanks to the latest efforts from Google and JetBrains we have a set of recommendations and technologies that make it easier than ever. But there are still some limitations and shortcomings that we can try to improve.

Let's imagine a modern tech stack. Jetpack Compose to define the UI, Jetpack ViewModel to control it and Kotlin Coroutines to exchange the data between those two layers.

This is how the UI could be defined:

@Composable
fun MyScreen() {
    val viewModel: MyViewModel = viewModel()
    val state: MyState by viewModel.state.collectAsStateWithLifecycle()
    ...
}

The recommendation is to use collectAsStateWithLifecycle() so that the Composable will stop the collection based on the Lifecycle state to not waste the resources. By default, it stops and starts the collection when the view is in at least STARTED state. You can read more about it in a great article by Manuel Vivo.

We can do the same optimization in the ViewModel layer. Make it somewhat Lifecycle aware by using stateIn operator with a stop timeout on any flows we expose to the UI.

class MyViewModel(
    repo: MyRepository
) : ViewModel() {

    val state: StateFlow<MyState> = repo.data1()
        .map { /* Map to MyState */ }
        .stateIn(
            scope = viewModelScope,
            started = SharingStarted.WhileSubscribed(stopTimeoutMillis = 5_000),
            initialValue = MyState()
        )
    ...
}

Thanks to the stateIn and the timeout behavior with WhileSubscribed we can differentiate between situations like changing the orientation of the screen and putting the app into the background. If the StateFlow does not have any subscribers for the specified amount of time, it will cancel upstream data1 to not waste resources.
We can configure the stateIn operator with three different strategies. It can start eagerly, lazily or while subscribed with a specified timeout.

Combining collectAsStateWithLifecycle and stateIn gives you great control on how the resources are used by your application.

The problem - complex state management

Current design of the stateIn encourages exposing separate streams from ViewModel for each data stream or other values you want to cache. But the more complexity in the UI the harder it becomes to manage in such a way. As you can see below, even with a simple example, it already is quite boilerplate'y.

class MyViewModel(
    repo: MyRepository
) : ViewModel() {

    val data1: StateFlow<Data1> = repo.data1()
        .stateIn(
            scope = viewModelScope,
            started = SharingStarted.WhileSubscribed(stopTimeoutMillis = 5_000),
            initialValue = emptyList()
        )
        
    val data2: StateFlow<Data2> = repo.data2()
        .stateIn(
            scope = viewModelScope,
            started = SharingStarted.WhileSubscribed(stopTimeoutMillis = 5_000),
            initialValue = emptyList()
        )

    val showDialog: StateFlow<Boolean>
        get() = _showDialog
    private val _showDialog: MutableStateFlow<Boolean> = MutableStateFlow(false)
    
    val navigateToHome: StateFlow<Boolean>
        get() = _navigateToHome
    private val _navigateToHome: MutableStateFlow<Boolean> = MutableStateFlow(false)
    ...
}


A great way to mitigate this is to use a single UI state object that represents the whole UI.
Having a single data class that describes the screen is something that can improve readability, ease of state management and testing for the UI.

With this convention, the above's code becomes this:

class MyViewModel(
    repo: MyRepo
) : ViewModel() {

    val state: StateFlow<MyState> 
        get() = _state
    private val _state: MutableStateFlow<MyState> = MutableStateFlow(MyState())

    init {
        repo.data1()
            .onEach { _state.update { it.copy(data1 = it) } }
            .launchIn(viewModelScope)
        repo.data2()
            .onEach { _state.update { it.copy(data2 = it) } }
            .launchIn(viewModelScope)
    }
    ...
}

data class MyState(
    val data1: List<Data1> = emptyList(),
    val data2: List<Data2> = emptyList(),
    val showDialog: Boolean = false,
    val navigateToHome: Boolean = false,
)

stateIn doesn't play well with that idea. But we still have some options here with vanilla Coroutines. We could use a combine function, but depending on the number of streams, handling it can be awkward. If you have more than 4 streams plus your MutableStateFlow, you'd be forced to unpack and cast the values from an array of Any.

private val _state: MutableStateFlow<MyState> = MutableStateFlow(MyState())
val state: StateFlow<MyState> = combine(
    _state,
    repo.fetchData1(),
    ...
    repo.fetchData5(),
) { values ->
    (values[0] as MyState).copy(
        data1 = values[1] as Data1,
        ...
        data5 = values[5] as Data5,
    )
}.stateIn(
    scope = viewModelScope,
    started = SharingStarted.WhileSubscribed(timeout),
    initialValue = MyState()
)

Adding additional overloads would work, but you'd need a lot of them for maintenance purposes. Alternatively, we can do it with a single custom extension that will work in every condition.

Solution - extension functions to the rescue!

Thanks to the design of the Kotlin language, we can solve this problem quite elegantly.
We could do our own version of the stateIn operator as an extension on the existing MutableStateFlow that holds our single UI state object and can merge any streams we want. Let's call it stateInMerge. We can craft the API similarly to the one from the Coroutines library.

Here's an example from above with our new extension.

class MyViewModel(
    repo: MyRepo
) : ViewModel() {

    val state: StateFlow<MyState> 
        get() = _state
    private val _state: MutableStateFlow<MyState> = MutableStateFlow(MyState())
        .stateInMerge(
            scope = viewModelScope,
            launched = Launched.WhileSubscribed(stopTimeoutMillis = 5_000),
            { repo.data1().onEach { data -> state.update { it.copy(data1 = data) } } }, 
            { repo.data2().onEach { data -> state.update { it.copy(data2 = data) } } }, 
            ...
        )
    ...
}

With this approach, we can use our private _state field in the rest of the ViewModel code as we normally would to update any other values in the UI state.

class MyViewModel
    ...
    
    fun dialogClicked() {
        _state.update { it.copy(showDialog = true) } 
    }
    
    ...
}

Additionally, we can create a shorthand for onEach with state update, so we won't have to nest the lambdas.

class MyViewModel(
    repo: MyRepo
) : ViewModel() {

    val state: StateFlow<MyState> 
        get() = _state
    private val _state: MutableStateFlow<MyState> = MutableStateFlow(MyState())
        .stateInMerge(
            scope = viewModelScope,
            launched = Launched.WhileSubscribed(stopTimeoutMillis = 5_000),
            { repo.data1().onEachToState { state, data -> state.copy(data1 = data) } }, 
            { repo.data2().onEachToState { state, data -> state.copy(data2 = data) } }, 
        )
    ...
}

Also, if you want to use different strategies for different streams, you actually can chain the stateInMerge call's on the same MutableStateFlow object.

private val _state: MutableStateFlow<MyState> = MutableStateFlow(MyState())
    .stateInMerge(
        scope = viewModelScope,
        launched = Launched.Lazily,
        { repo.data1().onEachToState { state, data -> state.copy(data1 = data) } },
        ...
    )
    .stateInMerge(
        scope = viewModelScope,
        launched = Launched.WhileSubscribed(stopTimeoutMillis = 5_000),
        { repo.data2().onEachToState { state, data -> state.copy(data2 = data) } }, 
        ...
    )

Implementation

In short, we have to do our own implementation of stateIn behavior that'll also merge the streams to our MutableStateFlow. Let's look at the signature of our extension.

fun <T> MutableStateFlow<T>.stateInMerge(
    scope: CoroutineScope,
    launched: Launched,
    vararg flows: StateInMergeContext<T>.() -> Flow<*>,
): MutableStateFlow<T> = MutableStateFlowWithStateInMerge(
    scope = scope,
    state = this,
    launched = launched,
    lambdas = flows,
)

We take three parameters. Coroutine scope, launch strategy (eager, lazily, while subscribed) and vararg, list of lambdas returning Flow with any type. We don't have to care about the type of the flow. What is important here that before returning from the lambda we want to perform our operation of updating the UI state. We can see that in our first example:

.stateInMerge(
    ...
    { repo.data1().onEach { data -> state.update { it.copy(data1 = data) } } }, 
    ...
)

But for this to happen, we need access to our MutableStateFlow object. That's why we need the StateInMergeContext<T> as this receiver for the lambdas. It's an interface that contains the state object and the shorthand for onEachToState.

interface StateInMergeContext<T> {
    val state: MutableStateFlow<T>
    fun <R> Flow<R>.onEachToState(mapper: (T, R) -> T): Flow<R>
}

Here's the full implementation. Because there's not that much code that goes into that solution, if you want to use it, you can just copy this to your project and enjoy your new extension.

https://gist.github.com/tomczyn/d8f23c5e313d40c45fef87935c9c14cc

fun <T> MutableStateFlow<T>.stateInMerge(
    scope: CoroutineScope,
    launched: Launched,
    vararg flows: StateInMergeContext<T>.() -> Flow<*>,
): MutableStateFlow<T> = MutableStateFlowWithStateInMerge(
    scope = scope,
    state = this,
    launched = launched,
    lambdas = flows,
)

interface StateInMergeContext<T> {
    val state: MutableStateFlow<T>
    fun <R> Flow<R>.onEachToState(mapper: (T, R) -> T): Flow<R>
}

sealed interface Launched {
    data object Eagerly : Launched
    data class WhileSubscribed(val stopTimeoutMillis: Long = 0L) : Launched
    data object Lazily : Launched
}

private class MutableStateFlowWithStateInMerge<T>(
    private val scope: CoroutineScope,
    launched: Launched,
    private val state: MutableStateFlow<T>,
    lambdas: Array<out StateInMergeContext<T>.() -> Flow<*>>,
) : MutableStateFlow<T> by state {

    private val context: StateInMergeContext<T> = object : StateInMergeContext<T> {
        override val state: MutableStateFlow<T>
            get() = this@MutableStateFlowWithStateInMerge

        override fun <R> Flow<R>.onEachToState(mapper: (T, R) -> T): Flow<R> =
            onEach { value -> state.update { state -> mapper(state, value) } }
    }

    private val flows: List<Flow<*>> = lambdas
        .map { produceFlow -> produceFlow(context) }

    init {
        when (launched) {
            Launched.Eagerly -> launchAll()
            Launched.Lazily -> scope.launch {
                waitForFirstSubscriber()
                launchAll()
            }

            is Launched.WhileSubscribed -> {
                var jobs: Array<Job> = emptyArray()
                state.subscriptionCount
                    .map { it > 0 }
                    .distinctUntilChanged()
                    .flatMapLatest { subscribed ->
                        flow<Unit> {
                            when {
                                subscribed && jobs.isEmpty() -> jobs = launchAll()
                                subscribed -> launchInactive(jobs)
                                !subscribed && jobs.isNotEmpty() -> {
                                    delay(launched.stopTimeoutMillis)
                                    jobs.cancelActive()
                                }
                            }
                        }
                    }
                    .launchIn(scope)
            }
        }
    }

    private suspend fun waitForFirstSubscriber() {
        state.subscriptionCount.first { it > 0 }
    }

    private fun launchAll(): Array<Job> = flows
        .map { flow -> flow.launchIn(scope) }
        .toTypedArray()

    private fun launchInactive(jobs: Array<Job>) {
        check(jobs.size == flows.size)
        jobs.forEachIndexed { index, job ->
            if (job.isCancelled) jobs[index] = flows[index].launchIn(scope)
        }
    }

    private suspend fun Array<Job>.cancelActive() {
        forEach { job -> if (job.isActive) job.cancelAndJoin() }
    }
}

Or if you prefer, you can add this extension as a library: https://github.com/tomczyn/state-in-merge

dependencies {
    implementation("com.tomczyn:state-in-merge:1.2.0")
}

Deep dive - Optional

So how does this actually work. It creates a custom implementation of MutableStateFlow interface and delegates the functionality to the original object created in the ViewModel. We're using MutableStateFlowWithStateInMerge class for that. Then we add the logic of launching the flows in the init block. And that's about it.
This is a great example of why the Kotlin extension functions are such a powerful tool to craft awesome API's.

Strategies implementation is also somewhat bare-bone. Eagerly, just launches all the flows immediately.

Launched.Eagerly -> launchAll()

...

private fun launchAll(): Array<Job> = flows
    .map { flow -> flow.launchIn(scope) }
    .toTypedArray()

Lazily waits for the first subscriber and then launches everything.

Launched.Lazily -> scope.launch {
    waitForFirstSubscriber()
    launchAll()
}

...

private suspend fun waitForFirstSubscriber() {
    state.subscriptionCount.first { it > 0 }
}

The only interesting implementation is WhileSubscribed strategy.

is Launched.WhileSubscribed -> {
    var jobs: Array<Job> = emptyArray()
    state.subscriptionCount
        .map { it > 0 }
        .distinctUntilChanged()
        .flatMapLatest { subscribed ->
            flow<Unit> {
                when {
                    subscribed && jobs.isEmpty() -> jobs = launchAll()
                    subscribed -> launchInactive(jobs)
                    !subscribed && jobs.isNotEmpty() -> {
                        delay(launched.stopTimeoutMillis)
                        jobs.cancelActive()
                    }
                }
            }
        }
        .launchIn(scope)
}

...

private fun launchAll(): Array<Job> = flows
    .map { flow -> flow.launchIn(scope) }
    .toTypedArray()

private fun launchInactive(jobs: Array<Job>) {
    check(jobs.size == flows.size)
    jobs.forEachIndexed { index, job ->
        if (!job.isActive) jobs[index] = flows[index].launchIn(scope)
    }
}

private suspend fun Array<Job>.cancelActive() {
    forEach { job -> if (job.isActive) job.cancelAndJoin() }
}

Here, we launch a subscriptionCount flow on our MutableStateFlow object. We have to map the count value to a Boolean that represents whether there are any subscribers or not. Note here that we use distinctUntilChanged to avoid rerunning the flatpMapLatest with the same value.

Next, we use flatMapLatest to handle launching, timeout, and cancellation. flatMapLatest switches to a new flow with each new emission and stops the old one. This is what we need because it cancels the timeout if we get a new subscription before the timeout is up.

Inside the flatMapLatest flow, we have three separate scenarios:

  • Subscribed for the first time - launches all flows and caches jobs.
subscribed && jobs.isEmpty() -> jobs = launchAll()

...

private fun launchAll(): Array<Job> = flows
    .map { flow -> flow.launchIn(scope) }
    .toTypedArray()
  • Resubscribed after losing all subscriptions - check for cancelled jobs and launch them again.
subscribed -> launchInactive(jobs)

...

private fun launchInactive(jobs: Array<Job>) {
    check(jobs.size == flows.size)
    jobs.forEachIndexed { index, job ->
        if (job.isCancelled) jobs[index] = flows[index].launchIn(scope)
    }
}
  • Lost all subscriptions - starts the delay and then cancels all active jobs.
!subscribed && jobs.isNotEmpty() -> {
    delay(launched.stopTimeoutMillis)
    jobs.cancelActive()
}

...

private suspend fun Array<Job>.cancelActive() {
    forEach { job -> if (job.isActive) job.cancelAndJoin() }
}

And that's all. As you can see, the implementation is actually quite straightforward.