package lighttunnel.server.http

import io.netty.buffer.ByteBufUtil
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelFutureListener
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.FullHttpRequest
import lighttunnel.logger.loggerDelegate
import lighttunnel.proto.ProtoMessage
import lighttunnel.proto.ProtoMessageType
import lighttunnel.proto.RemoteInfo
import lighttunnel.server.util.AK_HTTP_HOST
import lighttunnel.server.util.AK_SESSION_ID
import lighttunnel.util.HttpUtil
import lighttunnel.util.LongUtil

internal class HttpServerChannelHandler(
    private val registry: HttpRegistry,
    private val interceptor: HttpRequestInterceptor,
    private val httpPlugin: HttpPlugin? = null
) : SimpleChannelInboundHandler<FullHttpRequest>() {
    private val logger by loggerDelegate()

    @Throws(Exception::class)
    override fun channelActive(ctx: ChannelHandlerContext?) {
        logger.trace("channelActive: {}", ctx)
        super.channelActive(ctx)
    }

    @Throws(Exception::class)
    override fun channelInactive(ctx: ChannelHandlerContext?) {
        logger.trace("channelInactive: {}", ctx)
        if (ctx == null) {
            super.channelInactive(ctx)
            return
        }
        val httpHost = ctx.channel().attr(AK_HTTP_HOST).get()
        val sessionId = ctx.channel().attr(AK_SESSION_ID).get()
        if (httpHost != null && sessionId != null) {
            val httpFd = registry.getHttpFd(httpHost)
            if (httpFd != null) {
                val head = LongUtil.toBytes(httpFd.tunnelId, sessionId)
                httpFd.tunnelChannel.writeAndFlush(ProtoMessage(ProtoMessageType.REMOTE_DISCONNECT, head, RemoteInfo(ctx.channel().remoteAddress()).toBytes()))
            }
            ctx.channel().attr(AK_HTTP_HOST).set(null)
            ctx.channel().attr(AK_SESSION_ID).set(null)
        }
        super.channelInactive(ctx)
    }

    @Throws(Exception::class)
    override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) {
        logger.trace("exceptionCaught: {}", ctx, cause)
        ctx ?: return
        ctx.channel().writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE)
    }

    override fun channelRead0(ctx: ChannelHandlerContext?, msg: FullHttpRequest?) {
        logger.trace("channelRead0: {}", ctx)
        ctx ?: return
        msg ?: return
        val httpPluginResponse = httpPlugin?.doHandle(msg)
        if (httpPluginResponse != null) {
            ctx.channel().writeAndFlush(HttpUtil.toByteBuf(httpPluginResponse)).addListener(ChannelFutureListener.CLOSE)
        }
        val httpHost = HttpUtil.getHostWithoutPort(msg)
        if (httpHost == null) {
            ctx.channel().writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE)
            return
        }
        ctx.channel().attr(AK_HTTP_HOST).set(httpHost)
        val httpFd = registry.getHttpFd(httpHost)
        if (httpFd == null) {
            ctx.channel().writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE)
            return
        }
        val httpInterceptorResponse = interceptor.handleHttpRequest(ctx, httpFd.tunnelRequest, msg)
        if (httpInterceptorResponse != null) {
            ctx.channel().writeAndFlush(HttpUtil.toByteBuf(httpInterceptorResponse))
            return
        }
        val sessionId = httpFd.putChannel(ctx.channel())
        ctx.channel().attr(AK_SESSION_ID).set(sessionId)
        val head = LongUtil.toBytes(httpFd.tunnelId, sessionId)
        httpFd.tunnelChannel.writeAndFlush(ProtoMessage(ProtoMessageType.REMOTE_CONNECTED, head, RemoteInfo(ctx.channel().remoteAddress()).toBytes()))
        val data = ByteBufUtil.getBytes(HttpUtil.toByteBuf(msg))
        httpFd.tunnelChannel.writeAndFlush(ProtoMessage(ProtoMessageType.TRANSFER, head, data))
    }

}