package at.jku.isse.gradient.lang.java

import at.jku.isse.gradient.ProfilerState
import at.jku.isse.gradient.Util
import at.jku.isse.gradient.model.*
import at.jku.isse.gradient.profiledTrace
import mu.KotlinLogging
import spoon.reflect.code.*
import spoon.reflect.declaration.*
import spoon.reflect.reference.*
import spoon.reflect.visitor.filter.TypeFilter
import spoon.support.reflect.code.CtArrayReadImpl

private val logger = KotlinLogging.logger {}

class AsgRelationshipTransformer(declarationCache: DeclarationCache) : BaseTransformer(declarationCache) {

    fun transform(ctTypeReference: CtTypeReference<*>, structuralModel: StructuralModel) {

        try {

            logger.profiledTrace { "Transforming ${ctTypeReference.canonicalName()}" }

            val transformer = Transformer(ctTypeReference, structuralModel)
            transformer.scan(ctTypeReference)

            logger.profiledTrace(ProfilerState.STOP) { "Finished transforming ${ctTypeReference.canonicalName()}" }

        } catch (ex: Exception) {
            logger.error(ex) { "Error while transforming type: ${ctTypeReference.qualifiedName}" }
        }
    }

    private inner class Transformer(private var thisCtType: CtTypeReference<*>,
                                    structuralModel: StructuralModel) : SpoonVisitor() {

        private val types: MutableMap<String, Type> by structuralModel.components
        private val typeParametersMappings: MutableMap<String, TypeParameterMapping> by structuralModel.components
        private val properties: MutableMap<String, Property> by structuralModel.components
        private val executables: MutableMap<String, Executable> by structuralModel.components
        private val parameters: MutableMap<String, Parameter> by structuralModel.components

        private val invokes: MutableMap<Pair<Executable, Executable>, Invocation> = mutableMapOf()
        private val accesses: MutableMap<Pair<Executable, Property>, Access> = mutableMapOf()
        private val elementTypes: MutableMap<CtTypeReference<*>, ElementType> = mutableMapOf()

        private var thisType: Type

        private var currentExecutable: Executable? = null
        private var currentCtExecutable: CtExecutable<*>? = null
        private var currentCtExecutableReference: CtExecutableReference<*>? = null
        private var currentCtInterfaceExecutables = mutableSetOf<CtExecutableReference<*>>()

        init {
            thisType = types[thisCtType.canonicalName()]!!
        }

        private val visitedElements = mutableSetOf<String>()

        override fun scan(element: CtElement?) {
            if (element == null) return

            val elementName = element.canonicalName()
            if (!visitedElements.contains(elementName) ||
                    element is CtFieldReference<*> ||
                    element is CtInvocation<*> ||
                    element is CtFieldAccess<*>) {
                logger.trace { "Visiting: $elementName" }
                visitedElements.add(elementName)
                super.scan(element)
            }
        }

        override fun <T : Any?> visitCtTypeReference(r: CtTypeReference<T>) {
            assert(r == thisCtType)

            types[r.canonicalName()]?.let { type ->
                r.superclass?.let { types[it.canonicalName()] }
                        ?.let { type.isExtending.add(it) }

                r.safeDeclaration().ifPresent { ctType ->
                    r.superInterfaces.mapNotNull { types[it.canonicalName()] }
                            .let { type.isExtending.addAll(it) }

                    currentCtInterfaceExecutables.clear()
                    currentCtInterfaceExecutables.addAll(ctType.superInterfaces
                            .filter { it.safeDeclaration().isPresent }
                            .flatMap { it.allExecutables })

                    scan(ctType.allFields)
                    scan(ctType.allExecutables)
                    scan(ctType.formalCtTypeParameters)
                }
            }
        }

        override fun visitCtTypeParameter(t: CtTypeParameter) {

            types[t.reference.canonicalName()]?.also { parameter ->
                types[t.reference.boundingType.canonicalName()]?.let {
                    parameter.isExtending.add(it)
                }

                val declarer = t.typeParameterDeclarer
                when (declarer) {
                    is CtType<*> -> types[declarer.reference.canonicalName()]?.typeParameters?.add(parameter)
                    is CtMethod<*> -> executables[declarer.reference.canonicalName()]?.typeParameters?.add(parameter)
                }

            }
        }

        override fun <T : Any?> visitCtFieldReference(r: CtFieldReference<T>) {

            properties[r.qualifiedName]?.also { property ->
                if (r.declaringType == thisCtType) {

                    r.safeDeclaration().ifPresent { ctField ->
                        createElementType(ctField.type)?.let { property.type = it }
                        if (property.gradientType == GradientType.UNKNOWN) {
                            property.gradientType = when {
                                property.type == null -> GradientType.UNKNOWN
                                property.type!!.typeOf.isGradientModel -> GradientType.REFERENCE
                                property.type!!.typeOf.isGenerics &&
                                        property.type!!.typeParameterMappings.isNotEmpty() &&
                                        property.type!!.typeParameterMappings.first().actualType.isGradientModel -> GradientType.REFERENCE
                                else -> r.type.toGradientType()
                            }
                        }
                    }

                    thisType.properties.add(property)
                    thisType.inheritedProperties
                            .filter { it.name == property.name && !it.isStatic && !property.isStatic }
                            .let { property.shadows.addAll(it) }

                } else if (r.declaringType.canonicalName() in types && property.visibility != Visibility.PRIVATE && !property.isStatic) {

                    thisType.inheritedProperties.add(property)
                    thisType.properties
                            .find { it.name == property.name && !it.isStatic }
                            ?.shadows
                            ?.add(property)
                }
            }
        }

        override fun <T : Any?> visitCtExecutableReference(r: CtExecutableReference<T>) {

            val declarationOptional = r.safeDeclaration()

            executables[r.canonicalName()]?.let { executable ->
                if (r.declaringType == thisCtType) {

                    thisType.executables.add(executable)

                    r.safeDeclaration().ifPresent { ctExecutable ->
                        createElementType(ctExecutable.type)?.let { executable.type = it }
                        if (executable.gradientType == GradientType.UNKNOWN) {
                            executable.gradientType = when {
                                executable.type == null -> GradientType.UNKNOWN
                                executable.isConstructor -> GradientType.VOID
                                executable.type!!.typeOf.isGradientModel -> GradientType.REFERENCE
                                executable.type!!.typeOf.isGenerics &&
                                        executable.type!!.typeParameterMappings.isNotEmpty() &&
                                        executable.type!!.typeParameterMappings.first().actualType.isGradientModel -> GradientType.REFERENCE
                                else -> r.type.toGradientType()
                            }
                        }
                    }

                    declarationOptional.ifPresent { ctExecutable ->
                        currentCtExecutable = ctExecutable
                        currentCtExecutableReference = r
                        currentExecutable = executable

                        visitOverridingExecutable()
                        scan(ctExecutable.parameters)
                        scan(ctExecutable.getElements {
                            it is CtFieldAccess<*> ||
                                    it is CtConstructorCall<*> ||
                                    it is CtInvocation<*> ||
                                    it is CtReturn<*>
                        })
                    }
                } else if (r.declaringType.canonicalName() in types &&
                        executable.visibility != Visibility.PRIVATE &&
                        !executable.isStatic) {
                    thisType.inheritedExecutables.add(executable)
                }
            }

            declarationOptional.ifPresent {
                if (it is CtMethod<*> && it.formalCtTypeParameters.isNotEmpty()) {
                    scan(it.formalCtTypeParameters)
                }
            }
        }

        private fun visitOverridingExecutable() {
            currentCtInterfaceExecutables
                    .filter { currentCtExecutableReference!!.isOverriding(it) }
                    .mapNotNull { executables[it.canonicalName()] }
                    .forEach {
                        val canonicalName = Overriding.canonicalNameOf(currentExecutable!!, it, OverridingQuality.IMPLEMENTING)
                        currentExecutable!!.overrides.add(Overriding(Util.uuid(), canonicalName, currentExecutable!!, it, OverridingQuality.IMPLEMENTING))
                    }

            currentCtExecutableReference!!.overridingExecutable
                    ?.let { executables[it.canonicalName()] }
                    ?.let { superExec ->
                        if (superExec.isAbstract) {
                            val canonicalName = Overriding.canonicalNameOf(currentExecutable!!, superExec, OverridingQuality.IMPLEMENTING)
                            currentExecutable!!.overrides.add(Overriding(Util.uuid(), canonicalName, currentExecutable!!, superExec, OverridingQuality.IMPLEMENTING))
                        } else {
                            val superCalls = currentCtExecutable!!.getElements<CtInvocation<*>> {
                                it is CtInvocation<*> && it.target is CtSuperAccess<*>
                            }

                            val revertingCall = superCalls.find { it.executable.canonicalName() != superExec.canonicalName } != null
                            val overridingCall = superCalls.find { it.executable.canonicalName() == superExec.canonicalName } != null

                            val statementCount = currentCtExecutable!!.body?.statements?.size
                            val isReplacing = when {
                                statementCount == null -> OverridingQuality.REDECLARING
                                !overridingCall && statementCount == 0 -> OverridingQuality.CLEARING
                                revertingCall && statementCount == 1 -> OverridingQuality.REVERTING
                                overridingCall && statementCount == 1 -> OverridingQuality.DEFERRING
                                overridingCall && statementCount > 1 -> OverridingQuality.EXTENDING
                                else -> OverridingQuality.REPLACING
                            }
                            val canonicalName = Overriding.canonicalNameOf(currentExecutable!!, superExec, isReplacing)
                            currentExecutable!!.overrides.add(Overriding(Util.uuid(), canonicalName, currentExecutable!!, superExec, isReplacing))
                        }
                    }

        }

        override fun <T : Any?> visitCtParameter(p: CtParameter<T>) {

            val parameterId = Parameter.canonicalNameOf(
                    currentExecutable!!,
                    currentCtExecutable!!.parameters.indexOf(p)
            )

            parameters[parameterId]?.let { parameter ->

                currentExecutable!!.parameters.add(parameter)

                createElementType(p.type)?.let { parameter.type = it }
                if (parameter.gradientType == GradientType.UNKNOWN) {
                    parameter.gradientType = when {
                        parameter.type == null -> GradientType.UNKNOWN
                        parameter.type!!.typeOf.isGradientModel -> GradientType.REFERENCE
                        parameter.type!!.typeOf.isGenerics &&
                                parameter.type!!.typeParameterMappings.isNotEmpty() &&
                                parameter.type!!.typeParameterMappings.first().actualType.isGradientModel -> GradientType.REFERENCE
                        else -> p.type.toGradientType()
                    }
                }
            }
        }

        override fun <T : Any?> visitCtFieldRead(e: CtFieldRead<T>) {
            visitCtFieldAccess(e, visitDelegateExpression(e.target), AccessType.READ)
        }

        override fun <T : Any?> visitCtFieldWrite(e: CtFieldWrite<T>) {
            visitCtFieldAccess(e, visitDelegateExpression(e.target), AccessType.WRITE)
        }

        private fun visitCtFieldAccess(fieldAccess: CtFieldAccess<*>, target: CanonicalEntity?, accessType: AccessType) {
            fieldAccess.variable.safeDeclaration().ifPresent {
                properties[it.reference.qualifiedName]?.also { callee ->
                    val isIterative = fieldAccess.getParent(TypeFilter(CtLoop::class.java)) != null

                    val cardinality = if (isIterative) Cardinality.UNBOUND else Cardinality.ONE

                    var access = accesses[Pair(currentExecutable!!, callee)]

                    if (access == null) {
                        val accessName = Access.canonicalNameOf(currentExecutable!!, target, callee, cardinality, accessType)
                        access = Access(Util.uuid(), accessName, currentExecutable!!, target, callee, cardinality, accessType)

                        currentExecutable!!.accesses.add(access)

                        accesses[Pair(currentExecutable!!, callee)] = access
                        visitAccessor(access)
                    } else {

                        val newAccessType = if (access.accessType != accessType) AccessType.FULL else accessType
                        val newCardinality = when {
                            isIterative -> Cardinality.UNBOUND
                            access.cardinality < Cardinality.N -> Cardinality.N
                            else -> access.cardinality
                        }
                        val accessName = Access.canonicalNameOf(currentExecutable!!, target, callee, newCardinality, newAccessType)
                        val newAccess = Access(Util.uuid(), accessName, currentExecutable!!, target, callee, newCardinality, newAccessType)

                        currentExecutable!!.accesses.remove(access)
                        currentExecutable!!.accesses.add(newAccess)

                        accesses[Pair(currentExecutable!!, callee)] = newAccess
                    }
                }
            }
        }

        private fun visitAccessor(access: Access) {
            val executableName = access.source.name.toLowerCase()
            if ("get" + access.target.name.toLowerCase() in executableName
                    && access.accessType == AccessType.READ) {
                currentExecutable!!.accessor = Accessor.GETTER
            } else if ("set" + access.target.name.toLowerCase() in executableName
                    && access.accessType == AccessType.WRITE) {
                currentExecutable!!.accessor = Accessor.SETTER
            }
        }

        override fun <T : Any?> visitCtConstructorCall(e: CtConstructorCall<T>) {
            visitCtAbstractInvocation(e, visitDelegateExpression(e.target))
        }

        override fun <T : Any?> visitCtInvocation(e: CtInvocation<T>) {
            visitCtAbstractInvocation(e, visitDelegateExpression(e.target))
        }

        private fun <T : Any?> visitDelegateExpression(expression: CtExpression<T>?): CanonicalEntity? {
            return when {
                expression is CtFieldAccess<*> && expression.variable.declaringType != null -> {
                    properties[expression.variable.qualifiedName]
                }
                expression is CtArrayReadImpl<*> -> {
                    val target = expression.target
                    when {
                        target is CtFieldAccess<*> && target.variable.declaringType != null -> {
                            properties[target.variable.qualifiedName]
                        }
                        else -> {
                            null
                        }
                    }
                }
                expression is CtVariableAccess<*> -> {
                    val variable = expression.variable
                    if (variable is CtParameterReference<*>) {
                        val declarationOptional = variable.safeDeclaration()
                        if (declarationOptional.isPresent) {
                            val parameterName = Parameter.canonicalNameOf(currentExecutable!!,
                                    currentCtExecutable!!.parameters.indexOf(declarationOptional.get()))
                            parameters[parameterName]
                        } else {
                            null
                        }
                    } else {
                        null
                    }
                }
                expression is CtTypeAccess<*> -> {
                    types[expression.accessedType.canonicalName()]
                }
                else -> {
                    null
                }
            }
        }

        private fun <T : Any?> visitCtAbstractInvocation(ctInvocation: CtAbstractInvocation<T>, delegate: CanonicalEntity?) {
            ctInvocation.executable.safeDeclaration().ifPresent {
                executables[it.reference.canonicalName()]?.also { callee ->
                    val isIterative = ctInvocation.getParent(TypeFilter(CtLoop::class.java)) != null

                    val cardinality = if (isIterative) Cardinality.UNBOUND else Cardinality.ONE

                    var invoke = invokes[Pair(currentExecutable!!, callee)]


                    if (invoke == null) {
                        val invocationName = Invocation.canonicalNameOf(currentExecutable!!, delegate, callee, cardinality)
                        invoke = Invocation(Util.uuid(), invocationName, currentExecutable!!, delegate, callee, cardinality)

                        currentExecutable!!.invokes.add(invoke)

                        invokes[Pair(currentExecutable!!, callee)] = invoke

                    } else if (invoke.cardinality != Cardinality.UNBOUND) {

                        val newCardinality = when {
                            isIterative -> Cardinality.UNBOUND
                            invoke.cardinality < Cardinality.N -> Cardinality.N
                            else -> invoke.cardinality
                        }

                        val invocationName = Invocation.canonicalNameOf(currentExecutable!!, delegate, callee, newCardinality)
                        val newInvoke = Invocation(Util.uuid(), invocationName, currentExecutable!!, delegate, callee, newCardinality)

                        currentExecutable!!.invokes.remove(invoke)
                        currentExecutable!!.invokes.add(newInvoke)

                        invokes[Pair(currentExecutable!!, callee)] = newInvoke
                    }
                }
            }
        }

        override fun <T : Any?> visitCtReturn(returnStatement: CtReturn<T>) {
            returnStatement.returnedExpression?.type
                    ?.let { reference ->
                        createElementType(reference)
                                ?.also { currentExecutable!!.returns.add(it) }
                    }
        }

        private fun createElementType(r: CtTypeReference<*>): ElementType? {

            var elementType: ElementType? = elementTypes[r]

            if (elementType == null) {
                val declaration = r.safeDeclaration()
                elementType = if (r !is CtArrayTypeReference<*> &&
                        declaration.isPresent &&
                        !declaration.get().isGenerics &&
                        declaration.get().isSubtypeOf(r.factory.Type().ITERABLE) &&
                        r.actualTypeArguments.isNotEmpty()) {
                    createUnboundElementType(r)
                } else {
                    createBoundElementType(r)
                }
            }

            elementType?.let { elementTypes.put(r, it) }

            return elementType
        }

        private fun createUnboundElementType(r: CtTypeReference<*>): ElementType? {
            val typeName = r.actualTypeArguments.first().canonicalName()

            var elementType: ElementType? = null
            types[typeName]?.let { type ->
                val cardinality = Cardinality.UNBOUND

                val canonicalName = ElementType.canonicalNameOf(type, cardinality, emptyList())

                elementType = ElementType(Util.uuid(), canonicalName, type, emptyList(), cardinality)
            }

            return elementType
        }

        private fun createBoundElementType(r: CtTypeReference<*>): ElementType? {
            var elementType: ElementType? = null
            types[r.canonicalName()]?.let { baseType ->
                val cardinality = when (r) {
                    is CtArrayTypeReference<*> -> Cardinality.N
                    else -> Cardinality.ONE
                }

                val typeParameters = createElementTypeParameters(r)
                val canonicalName = ElementType.canonicalNameOf(baseType, cardinality, typeParameters)

                elementType = ElementType(Util.uuid(), canonicalName, baseType, typeParameters, cardinality)
            }

            return elementType
        }

        private fun createElementTypeParameters(r: CtTypeReference<*>): List<TypeParameterMapping> {
            return r.actualTypeArguments.mapNotNull { actualCtType ->
                val declarationOptional = actualCtType.safeDeclaration()
                if (declarationOptional.isPresent) {
                    types[declarationOptional.get().reference.canonicalName()]?.let { parameter ->

                        types[actualCtType.canonicalName()]?.let { actualType ->
                            val canonicalName = TypeParameterMapping.canonicalNameOf(parameter, actualType)

                            typeParametersMappings.getOrPut(canonicalName) {
                                TypeParameterMapping(
                                        id = Util.uuid(),
                                        canonicalName = canonicalName,
                                        parameter = parameter,
                                        actualType = actualType
                                )
                            }
                        }
                    }
                } else {
                    null
                }
            }
        }
    }
}

