package com.turbospaces.common;

import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BiConsumer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

public class CompletableRequestReplyMapper<K, V> implements RequestReplyMapper<K, V>, InitializingBean, DisposableBean {
    private final Logger logger = LoggerFactory.getLogger( getClass() );
    private final ConcurrentMap<K, RequestReplyFuture<V>> corr = new ConcurrentHashMap<>();
    private final ThreadPoolTaskScheduler timer = new ThreadPoolTaskScheduler();

    public CompletableRequestReplyMapper() {
        super();

        timer.setDaemon( true );
        timer.setRemoveOnCancelPolicy( true );
        timer.setPoolSize( Runtime.getRuntime().availableProcessors() );
    }
    @Override
    public void afterPropertiesSet() {
        timer.afterPropertiesSet();
    }
    @Override
    public void destroy() {
        timer.destroy();
    }
    @Override
    public RequestReplyFuture<V> acquire(K key, int timeout, TimeUnit unit) {
        Map<String, String> mdc = MDC.getCopyOfContextMap(); // ~ capture MDC
        ScheduledFuture<?> timeoutTask = timer.schedule( new Runnable() {
            @Override
            public void run() {
                RequestReplyFuture<V> tmp = corr.remove( key );

                if ( tmp != null ) {
                    if ( tmp.isDone() ) {}
                    else {
                        if ( mdc != null ) {
                            MDC.setContextMap( mdc ); // ~ set MDC stuff back if necessary
                        }

                        try {
                            TimeoutException timeoutException = new RequestReplyTimeout( timeout, unit, key );
                            tmp.completeExceptionally( timeoutException );
                            logger.debug( "request-reply(m={}) removing subj due to timeout", key );
                        }
                        finally {
                            MDC.clear(); // ~ finally clear it if necessary
                        }
                    }
                }
            }
        }, Instant.now().plusSeconds( timeout ) );

        RequestReplyFuture<V> toReturn = new RequestReplyFuture<V>( timeoutTask );
        long now = System.currentTimeMillis();
        toReturn.whenComplete( new BiConsumer<V, Throwable>() {
            @Override
            public void accept(V resp, Throwable err) {
                if ( err != null ) {}
                else {
                    logger.debug( "request-reply(m={}) completed in {} 'ms'", key, ( System.currentTimeMillis() - now ) );
                }
            }
        } );

        boolean duplicate = corr.putIfAbsent( key, toReturn ) != null;
        if ( duplicate ) {
            timeoutTask.cancel( true );
            throw new RuntimeException( "duplicate key violation corrId: " + key );
        }

        return toReturn;
    }
    @Override
    public void complete(K key, V value) {
        CompletableFuture<V> subj = corr.remove( Objects.requireNonNull( key ) );
        if ( subj != null ) {
            subj.complete( value );
        }
        else {
            logger.trace( "no such correlation for key: {}", key );
        }
    }
    @Override
    public void completeExceptionally(K key, Throwable reason) {
        CompletableFuture<V> subj = corr.remove( Objects.requireNonNull( key ) );
        if ( subj != null ) {
            subj.completeExceptionally( reason );
        }
    }
    @Override
    public void completeExceptionally(Throwable reason) {
        for ( CompletableFuture<V> f : corr.values() ) {
            f.completeExceptionally( reason );
        }
    }
    @Override
    public void clear() {
        corr.clear();
    }
    @Override
    public int pendingCount() {
        return corr.size();
    }
}
