package at.jku.isse.gradient.service.grpc

import at.jku.isse.gradient.GradientConfig
import at.jku.isse.gradient.GradientEvents
import at.jku.isse.gradient.Util
import at.jku.isse.gradient.bytes
import at.jku.isse.gradient.dal.ServerManager
import at.jku.isse.gradient.message.Common
import at.jku.isse.gradient.message.Event
import at.jku.isse.gradient.message.EventServiceGrpc
import at.jku.isse.gradient.model.GradientType
import at.jku.isse.gradient.model.StructuralCache
import at.jku.isse.gradient.runtime.__Gradient_Observable__
import at.jku.isse.gradient.service.EventService
import com.google.common.eventbus.EventBus
import com.google.common.eventbus.Subscribe
import com.google.common.util.concurrent.AtomicDoubleArray
import com.google.inject.Inject
import com.google.inject.Provider
import com.google.inject.name.Named
import com.google.protobuf.Empty
import io.grpc.stub.StreamObserver
import mu.KotlinLogging
import java.io.File
import java.net.URL
import java.nio.file.Path
import java.time.Instant
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicIntegerArray
import java.util.concurrent.atomic.AtomicLongArray
import java.util.logging.Level
import java.util.regex.Pattern

private val logger = KotlinLogging.logger {}


private class GrpcFactory {
    companion object {
        val emptyUUID = toUUID(UUID(0, 0))

        fun toUUID(uuid: UUID): Common.UUID {
            return Common.UUID.newBuilder().setBytes(uuid.bytes()).build()
        }

        fun uuid(): Common.UUID {
            return Common.UUID.newBuilder().setBytes(Util.uuid().bytes()).build()
        }

        fun eventKey(runtimeId: Common.UUID, processId: Common.UUID, objectId: Common.UUID, eventId: Common.UUID, frameId: Common.UUID): Event.EventKey {
            return Event.EventKey.newBuilder()
                    .setRuntimeId(runtimeId)
                    .setProcessId(processId)
                    .setObjectId(objectId)
                    .setFrameId(frameId)
                    .setEventId(eventId)
                    .build()
        }

        fun monitoringEvent(id: Event.EventKey, type: Event.MonitoringEventType, previousEvent: Common.UUID? = null, elementId: Common.UUID? = null,
                            timestamp: Long = Instant.now().epochSecond, valueType: Common.GradientType = Common.GradientType.VOID, value: Any? = null): Event.MonitoringEvent {
            assert(value != null || valueType == Common.GradientType.VOID) { "The value is null while still having a gradient type that is not void." }

            val eventBuilder = Event.MonitoringEvent.newBuilder()
                    .setId(id)
                    .setType(type)
                    .setTimestamp(timestamp)
                    .setValueType(valueType)

            previousEvent?.let { eventBuilder.previousEvent = it }
            elementId?.let { eventBuilder.element = it }

            if (value != null) {
                assert(value is String || value is Double || value is UUID) { "Value did not have the proper type: ${value::class}" }
                assert(value != GradientType.NUMBER || value is Double) { "Number value is not a double: $elementId" }
                when (valueType) {
                    Common.GradientType.TEXT -> eventBuilder.textValue = value as String
                    Common.GradientType.NUMBER -> eventBuilder.numberValue = value as Double
                    Common.GradientType.REFERENCE -> eventBuilder.referenceValue = toUUID(value as UUID)
                    Common.GradientType.UNKNOWN -> eventBuilder.textValue = value as String
                    Common.GradientType.VOID -> {
                    }
                    else -> {
                        throw IllegalArgumentException("Unknown gradient type: $valueType")
                    }
                }
            }
            return eventBuilder.build()
        }

        fun markerEvent(id: Common.UUID = uuid(), streamId: Common.UUID, type: Event.MetaEventType, trigger: Event.Trigger, previousMarker: Common.UUID? = null): Event.MarkerEvent {

            val event = Event.MarkerEvent.newBuilder()
                    .setId(id)
                    .setStreamId(streamId)
                    .setType(type)
                    .setTrigger(trigger)
            previousMarker?.let { event.previousMarkerEvent = previousMarker }

            return event.build()
        }

        fun frameEvent(id: Common.UUID = uuid(), type: Event.MetaEventType, trigger: Event.Trigger, previousFrame: Common.UUID? = null,
                       parentFrame: Common.UUID? = null, frameId: Common.UUID = uuid(), associatedEvent: Common.UUID): Event.FrameEvent {
            val event = Event.FrameEvent.newBuilder()
                    .setId(id)
                    .setType(type)
                    .setTrigger(trigger)
                    .setFrameId(frameId)
                    .setAssociatedEvent(associatedEvent)
            previousFrame?.let { event.previousFrame = previousFrame }
            if (parentFrame != null && parentFrame != emptyUUID) {
                event.parentFrame = parentFrame
            }

            return event.build()
        }

        fun eventMessage(event: Event.MonitoringEvent): Event.EventMessage {
            return Event.EventMessage.newBuilder()
                    .setMonitoringEvent(event)
                    .build()
        }

        fun eventMessage(event: Event.FrameEvent): Event.EventMessage {
            return Event.EventMessage.newBuilder()
                    .setFrameEvent(event)
                    .build()
        }

        fun eventMessage(event: Event.MarkerEvent): Event.EventMessage {
            return Event.EventMessage.newBuilder()
                    .setMarkerEvent(event)
                    .build()
        }
    }
}


private data class LocalStreamState(val threadId: Long = Thread.currentThread().id,
                                    val processId: Common.UUID = GrpcFactory.uuid(),
                                    val parentFrames: Stack<Common.UUID> = Stack(),
                                    var previousEvent: Common.UUID? = null,
                                    var previousFrame: Common.UUID? = null) {
    init {
        parentFrames.push(GrpcFactory.emptyUUID)
    }
}

private data class GlobalStreamState(val projectContext: Common.ProjectContext,
                                     val runtimeId: Common.UUID,
                                     val streamId: Common.UUID,
                                     val outStream: StreamObserver<Event.EventBatch>,
                                     var previousMarkerId: Common.UUID? = null,
                                     val serverLatch: CountDownLatch = CountDownLatch(1))

/**
 * State-full but thread-safe.
 */
class GrpcEventService
@Inject internal constructor(@Named("runtimeId") runtimeId: UUID,
                             structuralCacheProvider: Provider<StructuralCache?>,
                             private val gradientConfig: GradientConfig,
                             eventBus: EventBus,
                             serverManager: ServerManager) : EventService {

    private val grpcEventService: EventServiceGrpc.EventServiceStub = serverManager.eventService()

    private val elementIds: MutableMap<String, UUID>?
    private val globalState: GlobalStreamState?
    private val localState = ThreadLocal.withInitial { LocalStreamState() }
    private val eventBuffer: Vector<Event.EventMessage> = Vector(gradientConfig.observationCacheSize())

    init {
        val structuralCache = structuralCacheProvider.get()
        if (structuralCache != null) {
            elementIds = mutableMapOf()
            elementIds.putAll(structuralCache.elements)

            val projectContext = Common.ProjectContext.newBuilder()
                    .setProjectName(structuralCache.projectName)
                    .setVersionId(GrpcFactory.toUUID(structuralCache.versionId))
                    .build()

            eventBus.register(this)
            logger.debug { "Reporting to for $structuralCache.projectName @ $structuralCache.versionId" }

            val eventReport = Event.EventReport.newBuilder()
                    .setProjectContext(projectContext)
                    .setRuntimeId(GrpcFactory.toUUID(runtimeId))
                    .build()

            var latch = CountDownLatch(1)
            var streamId: Common.UUID? = null
            grpcEventService.register(eventReport, object : StreamObserver<Common.UUID> {
                override fun onNext(value: Common.UUID) {
                    logger.debug { "Successfully registered monitoring session: $value" }
                    streamId = value
                }

                override fun onError(t: Throwable?) {
                    logger.error(t) { "Could not register monitoring stream on server." }
                    latch.countDown()
                }

                override fun onCompleted() {
                    logger.debug { "Monitoring stream registration complete." }
                    if (streamId == null){
                        logger.error { "Did not receive a streaming id. Will not report the events." }
                    }
                    latch.countDown()
                }
            })
            latch.await()

            globalState = streamId?.let {
                latch = CountDownLatch(1)
                val outStream = grpcEventService.report(object : StreamObserver<Empty> {
                    override fun onNext(value: Empty?) {
                    }

                    override fun onError(t: Throwable?) {
                        logger.error(t) { "Could not report event." }
                        latch.countDown()
                    }

                    override fun onCompleted() {
                        logger.debug { "Server closing stream connection." }
                        latch.countDown()
                    }
                })
                val marker = GrpcFactory.markerEvent(streamId = it, type = Event.MetaEventType.OPEN, trigger = Event.Trigger.BEFORE)
                eventBuffer.add(GrpcFactory.eventMessage(marker))

                GlobalStreamState(projectContext, GrpcFactory.toUUID(runtimeId), it, outStream, serverLatch = latch)
            }

        } else {
            elementIds = null
            globalState = null
            logger.warn { "Structural cache is not available, no events will be reported." }
        }
    }

    @Subscribe
    fun cleanupSignal(@Suppress("UNUSED_PARAMETER") cleanup: GradientEvents.Cleanup) {
        globalState?.let {
            val event = GrpcFactory.markerEvent(streamId = it.streamId, type = Event.MetaEventType.CLOSE, trigger = Event.Trigger.AFTER)
            eventBuffer.add(GrpcFactory.eventMessage(event))
            flush()
            it.outStream.onCompleted()
            it.serverLatch.await()
        }

    }

    override fun reportPropertyRead(elementName: String, objectId: UUID, obj: Any?) {
        globalState?.let {
            val state = localState.get()
            assert(state.parentFrames.isNotEmpty())

            val frame = state.parentFrames.peek()
            val events = toValue(obj).map { (valueType, value) ->

                val elementId = if (elementName in elementIds!!) GrpcFactory.toUUID(elementIds[elementName]!!) else null
                val grpcObjectId = GrpcFactory.toUUID(objectId)

                val id = GrpcFactory.eventKey(it.runtimeId, state.processId, grpcObjectId, GrpcFactory.uuid(), frame)
                val event = GrpcFactory.monitoringEvent(id, Event.MonitoringEventType.READ, state.previousEvent, elementId,
                        valueType = Common.GradientType.forNumber(valueType.ordinal), value = value)

                state.previousEvent = id.eventId

                GrpcFactory.eventMessage(event)
            }
            report(events)
        }
    }

    override fun reportPropertyWrite(elementName: String, objectId: UUID, obj: Any?) {
        globalState?.let {
            val state = localState.get()
            assert(state.parentFrames.isNotEmpty())

            val frame = state.parentFrames.peek()
            val events = toValue(obj).map { (valueType, value) ->

                val elementId = if (elementName in elementIds!!) GrpcFactory.toUUID(elementIds[elementName]!!) else null
                val grpcObjectId = GrpcFactory.toUUID(objectId)

                val id = GrpcFactory.eventKey(it.runtimeId, state.processId, grpcObjectId, GrpcFactory.uuid(), frame)
                val event = GrpcFactory.monitoringEvent(id, Event.MonitoringEventType.WRITE, state.previousEvent, elementId,
                        valueType = Common.GradientType.forNumber(valueType.ordinal), value = value)

                state.previousEvent = id.eventId

                GrpcFactory.eventMessage(event)
            }
            report(events)
        }
    }

    override fun reportExecutableCall(elementName: String, objectId: UUID) {
        globalState?.let {
            val state = localState.get()
            assert(state.parentFrames.isNotEmpty())

            val elementId = if (elementName in elementIds!!) GrpcFactory.toUUID(elementIds[elementName]!!) else null

            val eventId = GrpcFactory.uuid()
            val frameEvent = GrpcFactory.frameEvent(type = Event.MetaEventType.OPEN, trigger = Event.Trigger.BEFORE,
                    previousFrame = state.previousFrame, parentFrame = state.parentFrames.peek(), associatedEvent = eventId)
            val id = GrpcFactory.eventKey(it.runtimeId, state.processId, GrpcFactory.toUUID(objectId), eventId, frameEvent.frameId)
            val event = GrpcFactory.monitoringEvent(id, Event.MonitoringEventType.CALL, state.previousEvent, elementId)

            state.previousEvent = id.eventId
            state.previousFrame = frameEvent.id
            state.parentFrames.push(frameEvent.frameId)

            report(GrpcFactory.eventMessage(frameEvent))
            report(GrpcFactory.eventMessage(event))
        }
    }

    override fun reportExecutableParameter(elementName: String, objectId: UUID, obj: Any?) {
        globalState?.let {
            val state = localState.get()
            assert(state.parentFrames.isNotEmpty())

            val frame = state.parentFrames.peek()
            val events = toValue(obj).map { (valueType, value) ->

                val elementId = if (elementName in elementIds!!) GrpcFactory.toUUID(elementIds[elementName]!!) else null
                val grpcObjectId = GrpcFactory.toUUID(objectId)

                val id = GrpcFactory.eventKey(it.runtimeId, state.processId, grpcObjectId, GrpcFactory.uuid(), frame)
                val event = GrpcFactory.monitoringEvent(id, Event.MonitoringEventType.RECEIVE, state.previousEvent, elementId,
                        valueType = Common.GradientType.forNumber(valueType.ordinal), value = value)

                state.previousEvent = id.eventId

                GrpcFactory.eventMessage(event)
            }
            report(events)
        }
    }

    override fun reportExecutableReturn(elementName: String, objectId: UUID, obj: Any?) {
        globalState?.let { it ->
            val state = localState.get()
            assert(state.parentFrames.isNotEmpty())

            val elementId = if (elementName in elementIds!!) GrpcFactory.toUUID(elementIds[elementName]!!) else null
            val grpcObjectId = GrpcFactory.toUUID(objectId)

            val frame = state.parentFrames.pop()
            val events = toValue(obj).map { (valueType, value) ->
                val id = GrpcFactory.eventKey(it.runtimeId, state.processId, grpcObjectId, GrpcFactory.uuid(), frame)
                val event = GrpcFactory.monitoringEvent(id, Event.MonitoringEventType.RETURN, state.previousEvent, elementId,
                        valueType = Common.GradientType.forNumber(valueType.ordinal), value = value)

                state.previousEvent = id.eventId

                GrpcFactory.eventMessage(event)
            }

            val frameEvent = GrpcFactory.frameEvent(type = Event.MetaEventType.CLOSE, trigger = Event.Trigger.AFTER, previousFrame = state.previousFrame, parentFrame = state.parentFrames.peek(),
                    frameId = frame, associatedEvent = state.previousEvent!!)

            state.previousFrame = frameEvent.id

            report(events)
            report(GrpcFactory.eventMessage(frameEvent))
        }
    }

    override fun reportExecutableException(elementName: String, objectId: UUID, exception: Throwable) {
        globalState?.let {
            val state = localState.get()
            assert(state.parentFrames.isNotEmpty())

            val elementId = if (elementName in elementIds!!) GrpcFactory.toUUID(elementIds[elementName]!!) else null
            val grpcObjectId = GrpcFactory.toUUID(objectId)

            val frame = state.parentFrames.pop()
            val id = GrpcFactory.eventKey(it.runtimeId, state.processId, grpcObjectId, GrpcFactory.uuid(), frame)
            val event = GrpcFactory.monitoringEvent(id, Event.MonitoringEventType.EXCEPT, state.previousEvent, elementId,
                    valueType = Common.GradientType.forNumber(GradientType.TEXT.ordinal), value = exception::class.qualifiedName)

            state.previousEvent = id.eventId

            val frameEvent = GrpcFactory.frameEvent(type = Event.MetaEventType.CLOSE, trigger = Event.Trigger.AFTER, previousFrame = state.previousFrame, parentFrame = state.parentFrames.peek(),
                    frameId = frame, associatedEvent = id.eventId)

            state.previousFrame = frameEvent.id

            report(GrpcFactory.eventMessage(event))
            report(GrpcFactory.eventMessage(frameEvent))
        }

    }

    private fun report(event: Event.EventMessage) {
        if (globalState == null) return

        eventBuffer.add(event)

        if (eventBuffer.size >= gradientConfig.observationCacheSize()) flush()
    }

    private fun report(events: List<Event.EventMessage>) {
        if (globalState == null) return

        eventBuffer.addAll(events)

        if (eventBuffer.size >= gradientConfig.observationCacheSize()) flush()
    }

    private fun flush() {
        if (globalState == null) return

        synchronized(eventBuffer) {
            val batch = Event.EventBatch.newBuilder()
                    .addAllMessages(eventBuffer)
                    .build()

            globalState.outStream.onNext(batch)
            eventBuffer.clear()
        }
    }

    private fun toValue(obj: Any?, unwindDepth: Int = gradientConfig.iterableUnwindDepth()): List<Pair<GradientType, Any?>> {
        return when (obj) {
            is __Gradient_Observable__ -> listOf(GradientType.REFERENCE to obj.__gradient_id__())
            is Boolean -> listOf(GradientType.NUMBER to (if (obj) 1 else 0).toDouble())
            is AtomicBoolean -> listOf(GradientType.NUMBER to (if (obj.get()) 1 else 0).toDouble())
            is Number -> listOf(GradientType.NUMBER to obj.toDouble())
            is Date -> listOf(GradientType.NUMBER to obj.time.toDouble())
            is AtomicIntegerArray -> iterableToValues((0..obj.length()).map { obj.get(it) }, unwindDepth)
            is AtomicLongArray -> iterableToValues((0..obj.length()).map { obj.get(it) }, unwindDepth)
            is AtomicDoubleArray -> iterableToValues((0..obj.length()).map { obj.get(it) }, unwindDepth)
            is Char,
            is CharSequence -> listOf(GradientType.TEXT to obj.toString())
            is File -> listOf(GradientType.TEXT to obj.absolutePath)
            is Path -> listOf(GradientType.TEXT to obj.toAbsolutePath().toString())
            is URL -> listOf(GradientType.TEXT to obj.toString())
            is Pattern -> listOf(GradientType.TEXT to obj.pattern())
            is Level -> listOf(GradientType.TEXT to obj.name)
            is Class<*> -> listOf(GradientType.TEXT to obj.canonicalName)
            is Enum<*> -> listOf(GradientType.TEXT to obj.name)
            is CharArray -> listOf(GradientType.TEXT to obj.joinToString(""))
            is ByteArray -> iterableToValues(obj.asIterable(), unwindDepth)
            is ShortArray -> iterableToValues(obj.asIterable(), unwindDepth)
            is IntArray -> iterableToValues(obj.asIterable(), unwindDepth)
            is LongArray -> iterableToValues(obj.asIterable(), unwindDepth)
            is FloatArray -> iterableToValues(obj.asIterable(), unwindDepth)
            is DoubleArray -> iterableToValues(obj.asIterable(), unwindDepth)
            is Array<*> -> iterableToValues(obj.asIterable(), unwindDepth)
            is Map.Entry<*, *> -> toValue(obj.value)
            is Map<*, *> -> iterableToValues(obj.values, unwindDepth)
            is Collection<*> -> iterableToValues(obj.asIterable(), unwindDepth)
            is Iterator<*> -> iterableToValues(Iterable { obj.iterator() }, unwindDepth)
            else -> if (obj == null) listOf(GradientType.VOID to null) else listOf(Pair(GradientType.UNKNOWN, obj::class.qualifiedName))
        }
    }

    private fun iterableToValues(iterable: Iterable<Any?>, unwindDepth: Int): List<Pair<GradientType, Any?>> {
        val result = mutableListOf<Pair<GradientType, Any?>>()
        for (obj in iterable.asIterable()) {
            if (result.size < unwindDepth) {
                val values = toValue(obj, unwindDepth - result.size)
                result.addAll(values)
            }
        }

        return result
    }
}