package com.turbospaces.executor;

import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.slf4j.MDC;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.turbospaces.boot.AbstractBootstrapAware;
import com.turbospaces.boot.BootstrapAware;

import io.netty.util.AsciiString;
import io.vavr.CheckedRunnable;

public class ThreadPoolContextWorker extends AbstractBootstrapAware implements ContextWorker {
    private final LoadingCache<AsciiString, SerialContextWorker> executors;
    private final ExecutorService executor;

    public ThreadPoolContextWorker(ExecutorService executor) {
        this.executor = Objects.requireNonNull(executor);
        this.executors = CacheBuilder.newBuilder()
                .expireAfterAccess(1, TimeUnit.HOURS)
                .build(new CacheLoader<AsciiString, SerialContextWorker>() {
                    @Override
                    public SerialContextWorker load(AsciiString key) {
                        SerialContextWorker serial = new SerialContextWorker(key, executor);
                        serial.setBootstrap(bootstrap);
                        return serial;
                    }
                });
    }
    @Override
    public ExecutorService executor() {
        return executor;
    }
    @Override
    public ContextWorker forKey(WorkUnit unit) {
        AsciiString partitionKey = new AsciiString(unit.key());
        return executors.getUnchecked(partitionKey);
    }
    @Override
    public void schedule(CheckedRunnable command) {
        if (command instanceof BootstrapAware) {
            ((BootstrapAware) command).setBootstrap(bootstrap);
        }

        Map<String, String> mdc = MDC.getCopyOfContextMap(); // ~ capture MDC

        if (mdc != null) {
            executor.execute(new Runnable() {
                @Override
                public void run() {
                    if (mdc != null) {
                        for (Entry<String, String> it : mdc.entrySet()) {
                            MDC.put(it.getKey(), it.getValue());
                        }
                    }

                    try {
                        command.run();
                    } catch (Throwable err) {
                        logger.error(err.getMessage(), err);
                        ExceptionUtils.wrapAndThrow(err);
                    } finally {
                        if (mdc != null) {
                            for (String it : mdc.keySet()) {
                                MDC.remove(it);
                            }
                        }
                    }
                }
            });
        } else {
            executor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        logger.debug("before apply: {}", command);
                        command.run();
                    } catch (Throwable err) {
                        logger.error(err.getMessage(), err);
                        ExceptionUtils.wrapAndThrow(err);
                    }
                }
            });
        }
    }
}
