package pw.ian.sangria_scalapb

import scala.collection.mutable.Map
import sangria.schema._
import scalapb.descriptors._
import com.trueaccord.scalapb._
import cats.implicits._
import org.log4s._

/**
  * Creates GraphQL types from ScalaPB enums and message types.
  */
class TypeRegistry(
  transformer: SchemaTransformer,
) {
  private[this] val logger = getLogger

  val messages = Map[Descriptor, ObjectType[Unit, FieldContext]]()
  val messageNames = Map[String, String]()
  val enums = Map[EnumDescriptor, EnumType[PEnum]]()

  /**
    * Generates a type for an enum.
    */
  def fetchEnumType[T <: GeneratedEnum](
    implicit comp: GeneratedEnumCompanion[T],
  ): EnumType[PEnum] = {
    fetchPEnumType(comp.scalaDescriptor)
  }

  /**
    * Fetches a penum.
    */
  def fetchPEnumType(
    desc: EnumDescriptor,
    ): EnumType[PEnum] = {
    enums.get(desc) match {
      case Some(t) => t
      case None => {
        val t = generatePEnumType(desc)
        enums.put(desc, t)
        t
      }
    }
  }

  /**
    * Generates an enum.
    */
  private def generatePEnumType(
    desc: EnumDescriptor,
    ): EnumType[PEnum] = {
    val name = transformer.renames.get(desc.fullName) match {
      case Some(x) => x
      case None => desc.name
    }
    EnumType(
      name,
      None,
      desc.values.toList.map { v =>
        EnumValue(
          v.name,
          value = PEnum(v),
        )
      },
    )
  }

  /**
    * Generates a type for a message.
    */
  def fetchMessageType[T <: GeneratedMessage with Message[T]](
    implicit comp: GeneratedMessageCompanion[T],
  ): ObjectType[Unit, FieldContext] = {
    fetchPMessageType(comp.scalaDescriptor)
  }

  /**
    * Fetches a message.
    */
  def fetchPMessageType(
    desc: Descriptor,
  ): ObjectType[Unit, FieldContext] = {
    messages.get(desc) match {
      case Some(t) => t
      case None => {
        val t = generatePMessageType(desc)
        messages.put(desc, t)
        t
      }
    }
  }

  private def makeField(
    fd: FieldDescriptor,
  ): Field[Unit, FieldContext] = {
    if (fd.isRepeated) {
      makeFieldRepeated(fd)
    } else {
      makeFieldSingular(fd)
    }
  }

  private def makeFieldRepeated(
    fd: FieldDescriptor,
  ): Field[Unit, FieldContext] = {
    val name = if (fd.name === "id") {
      "raw_id"
    } else {
      fd.name
    }
    fd.scalaType match {
      case ScalaType.Boolean =>
        Field(
          name,
          ListType(BooleanType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PBoolean].value),
        )
      case ScalaType.ByteString =>
        // TODO(igm): evaluate what graphql blobs should be. not this.
        Field(
          name,
          ListType(StringType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PByteString].value.toString),
        )
      case ScalaType.String =>
        Field(
          name,
          ListType(StringType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PString].value),
        )
      case ScalaType.Double =>
        Field(
          name,
          ListType(FloatType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PDouble].value),
        )
      case ScalaType.Float =>
        Field(
          name,
          ListType(FloatType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PFloat].value.toDouble),
        )
      case ScalaType.Int =>
        Field(
          name,
          ListType(IntType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PInt].value),
        )
      case ScalaType.Long =>
        Field(
          name,
          ListType(FloatType),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PLong].value.toDouble),
        )
      case ScalaType.Message(desc) =>
        Field(
          name,
          ListType(fetchPMessageType(desc)),
          resolve = ctx => {
            val fc = ctx.value
            fc.pmessage.value(fd).asInstanceOf[PRepeated].value.map { pm =>
              fc.transform(fd, pm.asInstanceOf[PMessage])
            },
          }
        )
      case ScalaType.Enum(desc) =>
        Field(
          name,
          ListType(fetchPEnumType(desc)),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PRepeated].value.map(_.asInstanceOf[PEnum]),
        )
    }
  }

  private def makeFieldSingular(
    fd: FieldDescriptor,
  ): Field[Unit, FieldContext] = {
    val name = if (fd.name === "id") {
      "raw_id"
    } else {
      fd.name
    }
    fd.scalaType match {
      case ScalaType.Boolean =>
        Field(
          name,
          BooleanType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PBoolean].value,
        )
      case ScalaType.ByteString =>
        // TODO(igm): evaluate what graphql blobs should be. not this.
        Field(
          name,
          StringType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PByteString].value.toString,
        )
      case ScalaType.String =>
        Field(
          name,
          StringType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PString].value,
        )
      case ScalaType.Double =>
        Field(
          name,
          FloatType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PDouble].value,
        )
      case ScalaType.Float =>
        Field(
          name,
          FloatType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PFloat].value.toDouble,
        )
      case ScalaType.Int =>
        Field(
          name,
          IntType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PInt].value,
        )
      case ScalaType.Long =>
        Field(
          name,
          FloatType,
          resolve = _.value.pmessage.value(fd).asInstanceOf[PLong].value.toDouble,
        )
      case ScalaType.Message(desc) =>
        Field(
          name,
          OptionType(fetchPMessageType(desc)),
          resolve = ctx => {
            val fc = ctx.value
            fc.pmessage.value(fd) match {
              case v: PMessage => Some(fc.transform(fd, v))
              case _ => None
            }
          }
        )
      case ScalaType.Enum(desc) =>
        Field(
          name,
          fetchPEnumType(desc),
          resolve = _.value.pmessage.value(fd).asInstanceOf[PEnum],
        )
    }
  }

  /**
    * Generates a message.
    */
  private def generatePMessageType(
    desc: Descriptor,
  ): ObjectType[Unit, FieldContext] = {
    val name = transformer.renames.get(desc.fullName) match {
      case Some(x) => x
      case None => desc.name
    }
    val existing = messageNames.get(name)
    existing.foreach { currentName =>
      logger.warn(s"Duplicate name encountered for type ${desc.fullName}: ${currentName}")
    }

    messageNames.put(name, desc.fullName)
    ObjectType[Unit, FieldContext](
      name,
      () => desc.fields.toList.map(makeField)

        // Foreign keys
        ++ desc.fields.toList.map { fd =>
          transformer.buildForeignKeys(this, fd)
        }.flatten

        // Normal fields
        ++ transformer.buildMessageFields(this, desc),
    )
  }

}
