package com.turbospaces.dispatch;

import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.slf4j.MDC;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.Any;
import com.google.protobuf.Message;
import com.turbospaces.api.facade.RequestWrapperFacade;
import com.turbospaces.boot.AbstractBootstrapAware;
import com.turbospaces.common.NonBlockingCallOnly;
import com.turbospaces.common.PlatformUtil;
import com.turbospaces.dispatch.TransactionalRequestOutcome.TransactionalRequestOutcomeBuilder;
import com.turbospaces.executor.WorkUnit;
import com.turbospaces.mdc.MdcTags;
import com.turbospaces.metrics.MetricTags;

import api.v1.ApiFactory;
import api.v1.ApplicationException;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Timer;
import io.opentracing.Scope;
import io.opentracing.Span;
import io.vavr.CheckedRunnable;
import lombok.Getter;

public class SafeRequestHandler<REQ extends Message, RESP extends Message.Builder>
        extends AbstractBootstrapAware
        implements CheckedRunnable, Supplier<WorkerCompletableTask> {
    public static final String METRIC_NAME = "dispatcher";

    private final Object mutex = new Object();
    private final SettableFuture<TransactionalRequestOutcome> future = SettableFuture.create();
    private final WorkerCompletableTask toReturn;
    private final TransactionalRequestHandler<REQ, RESP> action;
    @Getter
    private final TransactionalRequest<REQ, RESP> transaction;
    private final RequestWrapperFacade reqw;
    private final ApiFactory apiFactory;
    private final Span span;
    private final StopWatch stopWatch;
    private final String operationName;
    private final Set<Tag> tags;

    public SafeRequestHandler(
            RequestWrapperFacade reqw,
            ApiFactory apiFactory,
            Span span,
            WorkUnit record,
            TransactionalRequest<REQ, RESP> transaction,
            TransactionalRequestHandler<REQ, RESP> action) {
        this.span = Objects.requireNonNull(span);
        this.reqw = Objects.requireNonNull(reqw);
        this.apiFactory = Objects.requireNonNull(apiFactory);
        this.transaction = Objects.requireNonNull(transaction);
        this.toReturn = new WorkerCompletableTask(future, record, reqw.headers().getMessageId());
        this.action = Objects.requireNonNull(action);

        this.stopWatch = StopWatch.createStarted();
        this.operationName = PlatformUtil.toLowerUnderscore(reqw.body().getTypeUrl());
        this.tags = new LinkedHashSet<Tag>();
    }
    @Override
    public WorkerCompletableTask get() {
        return toReturn;
    }
    @Override
    public void run() {
        var currentThread = Thread.currentThread();
        var oldName = currentThread.getName();
        var newName = oldName + "|" + operationName;
        currentThread.setName(newName);

        //
        // ~ add tags
        //
        tags.add(Tag.of(MetricTags.OPERATION, operationName));
        reqw.headers().tags().forEach(new BiConsumer<String, String>() {
            @Override
            public void accept(String key, String value) {
                tags.add(Tag.of(key, value));
            }
        });

        try (Scope activate = bootstrap.tracer().activateSpan(span)) {
            Objects.requireNonNull(bootstrap.tracer().activeSpan());

            setGuards();
            action.apply(transaction);

            ListenableFuture<?> replyCondition = transaction.replyCondition();
            replyCondition.addListener(new Runnable() {
                @Override
                public void run() {
                    TransactionalRequestOutcomeBuilder outcome = null;

                    try {
                        Object get = replyCondition.get();
                        if (Objects.nonNull(get)) {
                            logger.trace("completed with reply: {}", get);
                        }

                        //
                        // ~ store as MDC field
                        //
                        long took = System.currentTimeMillis() - stopWatch.getStartTime();
                        MDC.put(MdcTags.MDC_TOOK, String.valueOf(took));

                        var reply = transaction.reply().build();
                        var respw = reqw.toReply(reply);

                        outcome = TransactionalRequestOutcome.builder();
                        outcome.key(transaction.routingKey());
                        outcome.reply(respw);

                        registerNotifications(outcome);
                        registerEventStream(outcome);
                    } catch (ExecutionException err) {
                        captureFailure(err.getCause());
                    } catch (Exception err) {
                        captureFailure(err);
                    } finally {
                        reportMetrics(); // ~ ensure we report metrics
                        //
                        // ~ only then complete and return control to high level dispatcher
                        //
                        if (Objects.nonNull(outcome)) {
                            future.set(outcome.build());
                        }
                        logger.debug("finished {} : {}", operationName, stopWatch);
                        removeMDCKeys(); // ~ and finally remove MDC tags
                    }
                }
            }, MoreExecutors.directExecutor());
        } catch (Throwable err) {
            captureFailure(err);
            reportMetrics();
        } finally {
            try {
                if (transaction.isAckOrNack()) {

                } else {
                    transaction.ack(); // ~ we mark offset for commit immediately
                }
            } finally {
                currentThread.setName(oldName);
                removeMDCKeys();
            }
        }
    }

    protected void removeMDCKeys() {
        synchronized (mutex) {
            MDC.remove(MdcTags.MDC_TOOK);
            MDC.remove(MdcTags.MDC_ERROR_CODE);
            MDC.remove(MdcTags.MDC_ERROR_CLASS);
        }
    }
    protected void addTagsFromContext() {
        synchronized (mutex) {
            try {
                transaction.tags().forEach((k, v) -> tags.add(Tag.of(k, v)));
            } catch (Exception err) {
                logger.error("Error when adding tags", err);
            }
        }
    }
    protected void reportMetrics() {
        synchronized (mutex) {
            addTagsFromContext();
            boolean isError = tags.stream().anyMatch(t -> t.getKey().equals(MetricTags.ERROR));
            if (BooleanUtils.isFalse(isError)) {
                tags.add(Tag.of(MetricTags.ERROR, "none"));
            }
            var replyTo = reqw.headers().getReplyTo();
            if (StringUtils.isNotEmpty(replyTo)) {
                tags.add(Tag.of(MetricTags.REPLY_TO, replyTo));
            }
            stopWatch.stop();

            Timer timer = bootstrap.meterRegistry().timer(METRIC_NAME, tags);
            timer.record(stopWatch.getTime(), TimeUnit.MILLISECONDS);
        }
    }
    protected void captureFailure(Throwable err) {
        synchronized (mutex) {
            if (BooleanUtils.negate(toReturn.isDone())) {
                Throwable cause = ExceptionUtils.getRootCause(err);
                if (Objects.isNull(cause)) {
                    cause = err;
                }

                tags.add(Tag.of(MetricTags.ERROR, cause.getClass().getSimpleName()));

                if (cause instanceof ApplicationException) {
                    ApplicationException app = (ApplicationException) cause;
                    tags.add(Tag.of(MetricTags.ERROR_CODE, String.valueOf(app.getCode().toString())));

                    MDC.put(MdcTags.MDC_ERROR_CLASS, app.getClass().getSimpleName());
                    MDC.put(MdcTags.MDC_ERROR_CODE, String.valueOf(app.getCode().toString()));
                }

                Message.Builder reply = transaction.reply();

                //
                // ~ we want to reset object to default instance (clear dirty state)
                //
                if (BooleanUtils.isFalse(transaction.isPreserveReply())) {
                    reply.clear();
                }

                var respw = reqw.toExceptionalReply(reply.build(), cause);

                //
                // ~ we don't want to bombard sentry in unnecessary cases
                //
                boolean logAsFailure = false;
                if (respw.status().isSystem()) {
                    logAsFailure = true;
                } else if (respw.status().isTimeout()) {
                    logAsFailure = true;
                }

                //
                // ~ so we post as error only system/timeout by default, subject for later changes
                //
                if (cause instanceof ApplicationException) {
                    if (logAsFailure) {
                        logger.error(err.getMessage(), err);
                    } else {
                        logger.warn(err.getMessage(), err);
                    }
                } else {
                    logger.error(err.getMessage(), err);
                }

                var outcome = TransactionalRequestOutcome.builder();
                outcome.key(transaction.routingKey());
                outcome.reply(respw);

                //
                // ~ in some cases we would like to preserve notifications
                //
                if (transaction.isPreserveNotifications()) {
                    registerNotifications(outcome);
                }
                if (transaction.isPreserveEventStream()) {
                    registerEventStream(outcome);
                }

                future.set(outcome.build());
            }
        }
    }
    @Override
    public String toString() {
        return reqw.body().getTypeUrl();
    }
    private void setGuards() {
        if (BooleanUtils.isFalse(transaction.isAckOrNack())) {
            NonBlockingCallOnly.MARKER.set(Boolean.TRUE);
        }
    }
    protected void registerNotifications(TransactionalRequestOutcomeBuilder outcome) {
        synchronized (mutex) {
            outcome.notifications(transaction
                    .notifications()
                    .stream()
                    .map(it -> apiFactory.notificationMapper().pack(it))
                    .collect(Collectors.toUnmodifiableList()));
        }
    }
    protected void registerEventStream(TransactionalRequestOutcomeBuilder outcome) {
        synchronized (mutex) {
            outcome.eventStream(transaction
                    .eventStreaming()
                    .stream()
                    .map(it -> Any.pack(it.build()))
                    .collect(Collectors.toUnmodifiableList()));
        }
    }
}
