package com.turbospaces.common;

import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiConsumer;

import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.slf4j.MDC;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

import lombok.extern.slf4j.Slf4j;

@SuppressWarnings("serial")
@Slf4j
public class CompletableRequestReplyMapper<K, V> extends ThreadPoolTaskScheduler implements RequestReplyMapper<K, V> {
    private final ConcurrentMap<K, CompletableFuture<V>> corr = new ConcurrentHashMap<>();

    public CompletableRequestReplyMapper() {
        super();

        setDaemon(true);
        setRemoveOnCancelPolicy(true);
        setPoolSize(Runtime.getRuntime().availableProcessors());
    }
    @Override
    public CompletableFuture<V> acquire(K key, int timeout, TimeUnit unit) {
        Map<String, String> mdc = MDC.getCopyOfContextMap(); // ~ capture MDC
        ScheduledFuture<?> timeoutTask = schedule(new Runnable() {
            @Override
            public void run() {
                CompletableFuture<V> tmp = corr.remove(key);

                if (tmp != null) {
                    if (BooleanUtils.isFalse(tmp.isDone())) {
                        if (mdc != null) {
                            MDC.setContextMap(mdc); // ~ set MDC stuff back if necessary
                        }

                        try {
                            TimeoutException timeoutException = new RequestReplyTimeout(timeout, unit, key);
                            tmp.completeExceptionally(timeoutException);
                            log.debug("request-reply(m={}) removing subj due to timeout", key);
                        } catch (Exception err) {
                            log.error(err.getMessage(), err);
                        } finally {
                            MDC.clear(); // ~ finally clear it if necessary
                        }
                    }
                }
            }
        }, Instant.now().plusSeconds(timeout));

        StopWatch stopWatch = StopWatch.createStarted();
        CompletableFuture<V> toReturn = new RequestReplyFuture<V>(key, timeoutTask).whenComplete(new BiConsumer<V, Throwable>() {
            @Override
            public void accept(V resp, Throwable err) {
                if (Objects.isNull(err)) {
                    stopWatch.stop();
                    log.debug("request-reply(m={}) completed in {}", key, stopWatch);
                }
            }
        });

        boolean duplicate = corr.putIfAbsent(key, toReturn) != null;
        if (duplicate) {
            timeoutTask.cancel(true);
            throw new RuntimeException("duplicate key violation corrId: " + key);
        }

        return toReturn;
    }
    @Override
    public boolean complete(K key, V value) {
        CompletableFuture<V> subj = corr.remove(Objects.requireNonNull(key));
        if (subj != null) {
            subj.complete(value);
            return true;
        }

        log.trace("no such correlation for key: {}", key);
        return false;
    }
    @Override
    public void completeExceptionally(K key, Throwable reason) {
        CompletableFuture<V> subj = corr.remove(Objects.requireNonNull(key));
        if (subj != null) {
            subj.completeExceptionally(reason);
        }
    }
    @Override
    public void clear() {
        corr.clear();
    }
    @Override
    public int pendingCount() {
        return corr.size();
    }
}
