简体   繁体   中英

How to test livedata with coroutine in unit test

I am using mockito, junit5 and coroutine to fetch data in Repository. But the no method got invoked in the test cases. I tried to use the normal suspend function without any Dispatchers and emit() functions and it works. Therefore, I guess the cause may be due to the livedata coroutine

GitReposRepository.kt

fun loadReposSuspend(owner: String) = liveData(Dispatchers.IO) {
    emit(Result.Loading)
    val response = githubService.getReposNormal(owner)
    val repos = response.body()!!
    if (repos.isEmpty()) {
        emit(Result.Success(repos))
        repoDao.insert(*repos.toTypedArray())
    } else {
        emitSource(repoDao.loadRepositories(owner)
                           .map { Result.Success(it) })
    }
}

GitReposRepositoryTest.kt

internal class GitRepoRepositoryTest {

    private lateinit var appExecutors:AppExecutors
    private lateinit var repoDao: RepoDao
    private lateinit var githubService: GithubService
    private lateinit var gitRepoRepository: GitRepoRepository

    @BeforeEach
    internal fun setUp() {
        appExecutors = mock(AppExecutors::class.java)
        repoDao = mock(RepoDao::class.java)
        githubService = mock(GithubService::class.java)
        gitRepoRepository = GitRepoRepository(appExecutors,
                                              repoDao,
                                              githubService)
    }

    @Test
    internal fun `should call network to fetch result and insert to db`() = runBlocking {
        //given
        val owner = "Testing"
        val response = Response.success(listOf(Repo(),Repo()))
        `when`(githubService.getReposNormal(ArgumentMatchers.anyString())).thenReturn(response)
        //when
        gitRepoRepository.loadReposSuspend(owner)
        //then
        verify(githubService).getReposNormal(owner)
        verify(repoDao).insertRepos(ArgumentMatchers.anyList())
    }
}

After few days searching on the internet. I find out how to do the unit test with coroutine in livedata and come up with the following ideas. It might not be the best idea but hope it can bring some insight to the people who have similar problems.

There are few necessary parts for coroutine unit test with livedata:

  1. Need to add 2 rules for the unit test ( Coroutine Rule, InstantExecutor Rule ). If you use Junit5 like me, you should use extensions instead. Coroutine Rule provide the function for you to use the testCoroutine dispatcher in Java UnitTest . InstantExecutor Rule provide the function for you to monitor the livedata emit value in Java UnitTest . And be careful coroutine.dispatcher is the most important part for testing coroutine in Java UnitTest . It is suggested to watch the video about Coroutine testing in Kotlin https://youtu.be/KMb0Fs8rCRs

  2. Need to set the CoroutineDispatcher to be injected in Constructor

    You should ALWAYS inject Dispatchers ( https://youtu.be/KMb0Fs8rCRs?t=850 )

  3. Some livedata extension for livedata to help you verify the values of emitted values from live data.

Here is my repository ( I follow the recommended app architecture in android official)

GitRepoRepository.kt (This idea comes from 2 sources , LegoThemeRepository , NetworkBoundResource

@Singleton
class GitRepoRepository @Inject constructor(private val appExecutors: AppExecutors,
                                            private val repoDao: RepoDao,
                                            private val githubService: GithubService,
                                            private val dispatcher: CoroutineDispatcher = Dispatchers.IO,
                                            private val repoListRateLimit: RateLimiter<String> = RateLimiter(
                                                    10,
                                                    TimeUnit.MINUTES)
) {

    fun loadRepo(owner: String
    ): LiveData<Result<List<Repo>>> = repositoryLiveData(
            localResult = { repoDao.loadRepositories(owner) },
            remoteResult = {
                transformResult { githubService.getRepo(owner) }.apply {
                    if (this is Result.Error) {
                        repoListRateLimit.reset(owner)
                    }
                }
            },
            shouldFetch = { repoListRateLimit.shouldFetch(owner) },
            saveFetchResult = { repoDao.insertRepos(it) },
            dispatcher = this.dispatcher
    )
    ...
}

GitRepoRepositoryTest.kt

@ExperimentalCoroutinesApi
@ExtendWith(InstantExecutorExtension::class)
class GitRepoRepositoryTest {

    // Set the main coroutines dispatcher for unit testing
    companion object {
        @JvmField
        @RegisterExtension
        var coroutinesRule = CoroutinesTestExtension()
    }

    private lateinit var appExecutors: AppExecutors
    private lateinit var repoDao: RepoDao
    private lateinit var githubService: GithubService
    private lateinit var gitRepoRepository: GitRepoRepository
    private lateinit var rateLimiter: RateLimiter<String>

    @BeforeEach
    fun setUp() {
        appExecutors = mock(AppExecutors::class.java)
        repoDao = mock(RepoDao::class.java)
        githubService = mock(GithubService::class.java)
        rateLimiter = mock(RateLimiter::class.java) as RateLimiter<String>
        gitRepoRepository = GitRepoRepository(appExecutors,
                                              repoDao,
                                              githubService,
                                              coroutinesRule.dispatcher,
                                              rateLimiter)
    }

    @Test
    fun `should not call network to fetch result if the process in rate limiter is not valid`() = coroutinesRule.runBlocking {
        //given
        val owner = "Tom"
        val response = Response.success(listOf(Repo(), Repo()))
        `when`(githubService.getRepo(anyString())).thenReturn(
                response)
        `when`(rateLimiter.shouldFetch(anyString())).thenReturn(false)
        //when
        gitRepoRepository.loadRepo(owner).getOrAwaitValue()
        //then
        verify(githubService, never()).getRepo(owner)
        verify(repoDao, never()).insertRepos(anyList())
    }

    @Test
    fun `should reset ratelimiter if the network response contains error`() = coroutinesRule.runBlocking {
        //given
        val owner = "Tom"
        val response = Response.error<List<Repo>>(500,
                                                  "Test Server Error".toResponseBody(
                                                          "text/plain".toMediaTypeOrNull()))
        `when`(githubService.getRepo(anyString())).thenReturn(
                response)
        `when`(rateLimiter.shouldFetch(anyString())).thenReturn(true)
        //when
        gitRepoRepository.loadRepo(owner).getOrAwaitValue()
        //then
        verify(rateLimiter, times(1)).reset(owner)
    }
}

CoroutineUtil.kt (Idea also came from here , Here should be the custom implementation if you want to log some information, and the following test cases provide some insights for you how to test it in coroutine

sealed class Result<out R> {
    data class Success<out T>(val data: T) : Result<T>()
    object Loading : Result<Nothing>()
    data class Error<T>(val message: String) : Result<T>()
    object Finish : Result<Nothing>()
}

fun <T, A> repositoryLiveData(localResult: (() -> LiveData<T>) = { MutableLiveData() },
                              remoteResult: (suspend () -> Result<A>)? = null,
                              saveFetchResult: suspend (A) -> Unit = { Unit },
                              dispatcher: CoroutineDispatcher = Dispatchers.IO,
                              shouldFetch: () -> Boolean = { true }
): LiveData<Result<T>> =
        liveData(dispatcher) {
            emit(Result.Loading)
            val source: LiveData<Result<T>> = localResult.invoke()
                    .map { Result.Success(it) }
            emitSource(source)
            try {
                remoteResult?.let {
                    if (shouldFetch.invoke()) {
                        when (val response = it.invoke()) {
                            is Result.Success -> {
                                saveFetchResult(response.data)
                            }
                            is Result.Error -> {
                                emit(Result.Error<T>(response.message))
                                emitSource(source)
                            }
                            else -> {
                            }
                        }
                    }
                }
            } catch (e: Exception) {
                emit(Result.Error<T>(e.message.toString()))
                emitSource(source)
            } finally {
                emit(Result.Finish)
            }
        }

suspend fun <T> transformResult(call: suspend () -> Response<T>): Result<T> {
    try {
        val response = call()
        if (response.isSuccessful) {
            val body = response.body()
            if (body != null) return Result.Success(body)
        }
        return error(" ${response.code()} ${response.message()}")
    } catch (e: Exception) {
        return error(e.message ?: e.toString())
    }
}

fun <T> error(message: String): Result<T> {
    return Result.Error("Network call has failed for a following reason: $message")
}

CoroutineUtilKtTest.kt

interface Delegation {
    suspend fun remoteResult(): Result<String>
    suspend fun saveResult(s: String)
    fun localResult(): MutableLiveData<String>
    fun shouldFetch(): Boolean
}

fun <T> givenSuspended(block: suspend () -> T) = BDDMockito.given(runBlocking { block() })

@ExperimentalCoroutinesApi
@ExtendWith(InstantExecutorExtension::class)
class CoroutineUtilKtTest {
    // Set the main coroutines dispatcher for unit testing
    companion object {
        @JvmField
        @RegisterExtension
        var coroutinesRule = CoroutinesTestExtension()
    }

    val delegation: Delegation = mock()
    private val LOCAL_RESULT = "Local Result Fetch"
    private val REMOTE_RESULT = "Remote Result Fetch"
    private val REMOTE_CRASH = "Remote Result Crash"

    @BeforeEach
    fun setUp() {
        given { delegation.shouldFetch() }
                .willReturn(true)
        given { delegation.localResult() }
                .willReturn(MutableLiveData(LOCAL_RESULT))
        givenSuspended { delegation.remoteResult() }
                .willReturn(Result.Success(REMOTE_RESULT))
    }

    @Test
    fun `should call local result only if the remote result should not fetch`() = coroutinesRule.runBlocking {
        //given
        given { delegation.shouldFetch() }.willReturn(false)

        //when
        repositoryLiveData<String, String>(
                localResult = { delegation.localResult() },
                remoteResult = { delegation.remoteResult() },
                shouldFetch = { delegation.shouldFetch() },
                dispatcher = coroutinesRule.dispatcher
        ).getOrAwaitValue()
        //then
        verify(delegation, times(1)).localResult()
        verify(delegation, never()).remoteResult()
    }


    @Test
    fun `should call remote result and then save result`() = coroutinesRule.runBlocking {
        //when
        repositoryLiveData<String, String>(
                shouldFetch = { delegation.shouldFetch() },
                remoteResult = { delegation.remoteResult() },
                saveFetchResult = { s -> delegation.saveResult(s) },
                dispatcher = coroutinesRule.dispatcher
        ).getOrAwaitValue()
        //then
        verify(delegation, times(1)).remoteResult()
        verify(delegation,
               times(1)).saveResult(REMOTE_RESULT)
    }

    @Test
    fun `should emit Loading, Success, Finish Status when we fetch local and then remote`() = coroutinesRule.runBlocking {
        //when
        val ld = repositoryLiveData<String, String>(
                localResult = { delegation.localResult() },
                shouldFetch = { delegation.shouldFetch() },
                remoteResult = { delegation.remoteResult() },
                saveFetchResult = { delegation.shouldFetch() },
                dispatcher = coroutinesRule.dispatcher
        )
        //then
        ld.captureValues {
            assertEquals(arrayListOf(Result.Loading,
                                     Result.Success(LOCAL_RESULT),
                                     Result.Finish), values)
        }
    }

    @Test
    fun `should emit Loading,Success, Error, Success, Finish Status when we fetch remote but fail`() = coroutinesRule.runBlocking {
        givenSuspended { delegation.remoteResult() }
                .willThrow(RuntimeException(REMOTE_CRASH))
        //when
        val ld = repositoryLiveData<String, String>(
                localResult = { delegation.localResult() },
                shouldFetch = { delegation.shouldFetch() },
                remoteResult = { delegation.remoteResult() },
                saveFetchResult = { delegation.shouldFetch() },
                dispatcher = coroutinesRule.dispatcher
        )
        //then
        ld.captureValues {
            assertEquals(arrayListOf(Result.Loading,
                                     Result.Success(LOCAL_RESULT),
                                     Result.Error(REMOTE_CRASH),
                                     Result.Success(LOCAL_RESULT),
                                     Result.Finish
            ), values)
        }
    }


}

LiveDataTestUtil.kt (This idea comes from aac sample , kotlin-coroutine )

fun <T> LiveData<T>.getOrAwaitValue(
        time: Long = 2,
        timeUnit: TimeUnit = TimeUnit.SECONDS,
        afterObserve: () -> Unit = {}
): T {
    var data: T? = null
    val latch = CountDownLatch(1)
    val observer = object : Observer<T> {
        override fun onChanged(o: T?) {
            data = o
            latch.countDown()
            this@getOrAwaitValue.removeObserver(this)
        }
    }
    this.observeForever(observer)

    afterObserve.invoke()

    // Don't wait indefinitely if the LiveData is not set.
    if (!latch.await(time, timeUnit)) {
        this.removeObserver(observer)
        throw TimeoutException("LiveData value was never set.")
    }

    @Suppress("UNCHECKED_CAST")
    return data as T
}

class LiveDataValueCapture<T> {

    val lock = Any()

    private val _values = mutableListOf<T?>()
    val values: List<T?>
        get() = synchronized(lock) {
            _values.toList() // copy to avoid returning reference to mutable list
        }

    fun addValue(value: T?) = synchronized(lock) {
        _values += value
    }
}

inline fun <T> LiveData<T>.captureValues(block: LiveDataValueCapture<T>.() -> Unit) {
    val capture = LiveDataValueCapture<T>()
    val observer = Observer<T> {
        capture.addValue(it)
    }
    observeForever(observer)
    try {
        capture.block()
    } finally {
        removeObserver(observer)
    }
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM