package com.example.paging3

import androidx.paging.PagingState
import androidx.room.InvalidationTracker
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.util.concurrent.atomic.AtomicInteger
import kotlin.time.measureTimedValue

class NotePagingSource(
    private val repo: NoteRepository,
    private val db: NoteDataBase,
    factory: (Invalidator) -> Unit
) : InvalidatingPagingSource<Int, AdapterItem>(factory) {
    private val allItemsCount = AtomicInteger(0)

    init {
        db.invalidationTracker.addObserver(object : InvalidationTracker.Observer(arrayOf("notes")) {
            override fun onInvalidated(tables: Set<String>) {
                invalidate()
                db.invalidationTracker.removeObserver(this)

            }
        })

    }

    override suspend fun load(params: LoadParams<Int>): LoadResult<Int, AdapterItem> =
        withContext(Dispatchers.IO) {

            if (params is LoadParams.Refresh<*>) {

                val result = measureTimedValue {
                    repo.countAll()
                }
                allItemsCount.set(result.value)

            }


            val limit = params.loadSize
            val currentPage = params.key ?: 1
            val offset = (currentPage - 1) * limit

            if (invalid) {
                LoadResult.Invalid()
            } else {
                try {
                    val notes = repo.getAll(limit = limit, offset = offset)
                    val data = if (notes.isEmpty()) {
                        listOf(AdapterItem.Content(Note(0, "Please wait while we generate sample data", System.currentTimeMillis())))
                    } else {
                        notes.map { item ->
                            AdapterItem.Content(item)
                        }
                    }

                    val nextPosToLoad = offset + data.size
                    val itemsBefore = offset
                    val itemsAfter = maxOf(0, allItemsCount.get() - nextPosToLoad)

                    val previousKey = if (currentPage == 1) null else currentPage - 1
                    val nextKey = if (data.size < limit) null else currentPage + 1

                    LoadResult.Page(
                        data = data,
                        prevKey = previousKey,
                        nextKey = nextKey,
                        itemsBefore = itemsBefore,
                        itemsAfter = itemsAfter
                    )
                } catch (e: Exception) {
                    LoadResult.Error(e)
                }
            }


        }


    override fun getRefreshKey(state: PagingState<Int, AdapterItem>): Int? {
        val key = state.anchorPosition?.let { anchorPosition ->
            val anchorPage = state.closestPageToPosition(anchorPosition)
            anchorPage?.prevKey?.plus(1) ?: anchorPage?.nextKey?.minus(1)
        }
        return key
    }

}