package com.turbospaces.common;

import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledFuture;

import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.slf4j.MDC;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.ConcurrentReferenceHashMap;

import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;

import lombok.extern.slf4j.Slf4j;

@SuppressWarnings("serial")
@Slf4j
public class CompletableRequestReplyMapper<K, V> extends ThreadPoolTaskScheduler implements RequestReplyMapper<K, V> {
    private final ConcurrentMap<K, SettableFuture<V>> corr = new ConcurrentReferenceHashMap<>();

    public CompletableRequestReplyMapper() {
        super();

        setDaemon(true);
        setRemoveOnCancelPolicy(true);
        setPoolSize(Runtime.getRuntime().availableProcessors());
    }
    @Override
    public SettableFuture<V> acquire(K key, Duration duration) {
        Map<String, String> mdc = MDC.getCopyOfContextMap(); // ~ capture MDC
        StopWatch stopWatch = StopWatch.createStarted();
        SettableFuture<V> toReturn = SettableFuture.create();
        if (Objects.nonNull(corr.putIfAbsent(key, toReturn))) {
            toReturn.setException(new IllegalArgumentException("duplicate key violation for correlation id: " + key));
        } else {
            ScheduledFuture<?> timerTask = schedule(new Runnable() {
                @Override
                public void run() {
                    SettableFuture<V> tmp = corr.remove(key);
                    if (Objects.nonNull(tmp)) {
                        //
                        // ~ only complete when necessary
                        //
                        if (BooleanUtils.isFalse(tmp.isDone())) {
                            MDCUtil.propagete(mdc);
                            try {
                                tmp.setException(new RequestReplyTimeout(duration, key));
                                log.info("request-reply(m={}) removed subj due to timeout", key);
                            } catch (Exception err) {
                                log.error(err.getMessage(), err);
                            } finally {
                                MDC.clear();
                            }
                        }
                    }
                }
            }, Instant.now().plus(duration));

            toReturn.addListener(new Runnable() {
                @Override
                public void run() {
                    try {
                        log.trace("about to cancel timeout task: {} for key: {}", timerTask, key);
                        timerTask.cancel(false);

                        toReturn.get();
                        stopWatch.stop();
                        corr.remove(key);

                        log.debug("request-reply(m={}) completed in {}", key, stopWatch);
                    } catch (Exception err) {

                    }
                }
            }, MoreExecutors.directExecutor());
        }

        return toReturn;
    }
    @Override
    public boolean complete(K key, V value) {
        Objects.requireNonNull(key);
        SettableFuture<V> subj = corr.remove(key);
        if (Objects.nonNull(subj)) {
            subj.set(value);
            return true;
        }

        log.trace("no such correlation for key: {}", key);
        return false;
    }
    @Override
    public void completeExceptionally(K key, Throwable reason) {
        Objects.requireNonNull(key);
        SettableFuture<V> subj = corr.remove(key);
        if (Objects.nonNull(subj)) {
            subj.setException(reason);
        }
    }
    @Override
    public void clear() {
        corr.clear();
    }
    @Override
    public int pendingCount() {
        return corr.size();
    }
}
