package com.turbospaces.ebean;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.jgroups.blocks.MethodCall;
import org.jgroups.blocks.RequestOptions;
import org.jgroups.blocks.RpcDispatcher;
import org.jgroups.util.RspList;

import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.MoreExecutors;
import com.turbospaces.boot.AbstractBootstrapAware;
import com.turbospaces.boot.Bootstrap;

import io.ebean.cache.ServerCacheConfig;
import io.ebean.cache.ServerCacheStatistics;
import io.ebean.cache.ServerCacheType;
import io.ebeaninternal.server.cache.CachedBeanData;
import io.ebeaninternal.server.cache.CachedManyIds;
import io.vavr.CheckedFunction0;

public class ReplicatedEbeanCache extends AbstractBootstrapAware implements ReplicatedCache {
    private final String cacheKey;
    private final RpcDispatcher dispatcher;
    private final LocalCache local;
    private final ServerCacheConfig config;

    public ReplicatedEbeanCache(String cacheKey, RpcDispatcher dispatcher, LocalCache local, ServerCacheConfig config) {
        this.cacheKey = Objects.requireNonNull(cacheKey);
        this.dispatcher = Objects.requireNonNull(dispatcher);
        this.local = Objects.requireNonNull(local);
        this.config = Objects.requireNonNull(config);
    }
    @Override
    public int size() {
        return local.size();
    }
    @Override
    public Object get(Object id) {
        return local.get(id);
    }
    @Override
    public int hitRatio() {
        return local.hitRatio();
    }
    @Override
    public ServerCacheStatistics statistics(boolean reset) {
        return local.statistics(reset);
    }
    @Override
    public void setBootstrap(Bootstrap bootstrap) {
        this.bootstrap = Objects.requireNonNull(bootstrap);
    }
    @Override
    @SuppressWarnings("serial")
    public void put(Object id, Object value) {
        local.put(id, value);

        byte[] keyAsBytes = SerializationUtils.serialize((Serializable) id);
        ByteArrayOutputStream out = new ByteArrayOutputStream();

        switch (config.getType()) {
            case NATURAL_KEY: {
                try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
                    oos.writeObject(value);
                    oos.flush();
                } catch (IOException err) {
                    ExceptionUtils.wrapAndThrow(err);
                }

                break;
            }
            case BEAN: {
                try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
                    CachedBeanData data = (CachedBeanData) value;
                    data.writeExternal(oos);
                    oos.flush();
                } catch (IOException err) {
                    ExceptionUtils.wrapAndThrow(err);
                }

                break;
            }
            case COLLECTION_IDS: {
                try (ObjectOutputStream oos = new ObjectOutputStream(out)) {
                    CachedManyIds data = (CachedManyIds) value;
                    data.writeExternal(oos);
                    oos.flush();
                } catch (IOException err) {
                    ExceptionUtils.wrapAndThrow(err);
                }

                break;
            }
            default: {
                throw new IllegalArgumentException("Unexpected cache type: " + config.getType());
            }
        }

        //
        // ~ we should not block main thread as such
        //
        StopWatch stopWatch = StopWatch.createStarted();
        FluentFuture.from(bootstrap.globalPlatform().submit(new CheckedFunction0<CompletableFuture<RspList<Object>>>() {
            @Override
            public CompletableFuture<RspList<Object>> apply() throws Throwable {
                MethodCall call = new MethodCall(JGroupsCacheManager.METHOD_ON_CACHE_PUT, cacheKey, keyAsBytes, out.toByteArray());
                CompletableFuture<RspList<Object>> future = dispatcher.callRemoteMethodsWithFuture(null, call, RequestOptions.ASYNC());
                stopWatch.stop();
                return future;
            }
        })).addCallback(new FutureCallback<CompletableFuture<RspList<Object>>>() {
            @Override
            public void onSuccess(CompletableFuture<RspList<Object>> result) {
                logger.trace("putted {} entry on remote nodes by key: {} value: {} in: {}", cacheKey, id, value, stopWatch);
                long time = stopWatch.getTime(TimeUnit.SECONDS);
                if (time > 0) {
                    logger.error("put operation took too long: {}", stopWatch);
                }
            }
            @Override
            public void onFailure(Throwable t) {
                logger.error(t.getMessage(), t);
            }
        }, MoreExecutors.directExecutor());
    }
    @Override
    @SuppressWarnings("serial")
    public void remove(Object id) {
        local.remove(id);

        //
        // ~ we should not block main thread
        //
        StopWatch stopWatch = StopWatch.createStarted();
        FluentFuture.from(bootstrap.globalPlatform().submit(new CheckedFunction0<CompletableFuture<RspList<Object>>>() {
            @Override
            public CompletableFuture<RspList<Object>> apply() throws Throwable {
                byte[] keyAsBytes = SerializationUtils.serialize((Serializable) id);
                MethodCall call = new MethodCall(JGroupsCacheManager.METHOD_ON_CHANGE_REMOVE, cacheKey, keyAsBytes);
                CompletableFuture<RspList<Object>> future = dispatcher.callRemoteMethodsWithFuture(null, call, RequestOptions.ASYNC());
                stopWatch.stop();
                return future;
            }
        })).addCallback(new FutureCallback<CompletableFuture<RspList<Object>>>() {
            @Override
            public void onSuccess(CompletableFuture<RspList<Object>> result) {
                logger.debug("removed {} entry on remote nodes by key: {} in: {}", cacheKey, id, stopWatch);
                long time = stopWatch.getTime(TimeUnit.SECONDS);
                if (time > 0) {
                    logger.error("remove operation took too long: {}", stopWatch);
                }
            }
            @Override
            public void onFailure(Throwable t) {
                logger.error(t.getMessage(), t);
            }
        }, MoreExecutors.directExecutor());
    }
    @Override
    @SuppressWarnings("serial")
    public void clear() {
        local.clear();

        //
        // ~ we should not block main thread as such
        //
        StopWatch stopWatch = StopWatch.createStarted();
        FluentFuture.from(bootstrap.globalPlatform().submit(new CheckedFunction0<CompletableFuture<RspList<Object>>>() {
            @Override
            public CompletableFuture<RspList<Object>> apply() throws Throwable {
                MethodCall call = new MethodCall(JGroupsCacheManager.METHOD_ON_CACHE_CLEAR, cacheKey);
                CompletableFuture<RspList<Object>> future = dispatcher.callRemoteMethodsWithFuture(null, call, RequestOptions.ASYNC());
                stopWatch.stop();
                return future;
            }
        })).addCallback(new FutureCallback<CompletableFuture<RspList<Object>>>() {
            @Override
            public void onSuccess(CompletableFuture<RspList<Object>> result) {
                logger.info("cleared {} on remote nodes in: {}", cacheKey, stopWatch);
                long time = stopWatch.getTime(TimeUnit.SECONDS);
                if (time > 0) {
                    logger.error("clear operation took too long: {}", stopWatch);
                }
            }
            @Override
            public void onFailure(Throwable t) {
                logger.error(t.getMessage(), t);
            }
        }, MoreExecutors.directExecutor());
    }
    @Override
    public void onPut(byte[] keyData, byte[] valueData) {
        Object key = SerializationUtils.deserialize(keyData);
        ByteArrayInputStream is = new ByteArrayInputStream(valueData);

        switch (config.getType()) {
            case NATURAL_KEY: {
                try (ObjectInputStream ois = new ObjectInputStream(is)) {
                    Object read = ois.readObject();

                    logger.trace("onPut {} by key: {} value: {}", cacheKey, key, read);
                    local.put(key, read);
                } catch (ClassNotFoundException | IOException err) {
                    ExceptionUtils.wrapAndThrow(err);
                }

                break;
            }
            case BEAN: {
                try (ObjectInputStream ois = new ObjectInputStream(is)) {
                    CachedBeanData read = new CachedBeanData();
                    read.readExternal(ois);

                    logger.trace("onPut {} by key: {} bean value: {}", cacheKey, key, read);
                    local.put(key, read);
                } catch (ClassNotFoundException | IOException err) {
                    ExceptionUtils.wrapAndThrow(err);
                }

                break;
            }
            case COLLECTION_IDS: {
                try (ObjectInputStream ois = new ObjectInputStream(is)) {
                    CachedManyIds read = new CachedManyIds();
                    read.readExternal(ois);

                    logger.trace("onPut {} by key: {} ids value: {}", cacheKey, key, read);
                    local.put(key, read);
                } catch (ClassNotFoundException | IOException err) {
                    ExceptionUtils.wrapAndThrow(err);
                }

                break;
            }
            default: {
                throw new IllegalArgumentException("unexpected cache type: " + config.getType());
            }
        }
    }
    @Override
    public void onRemove(byte[] data) {
        Object key = SerializationUtils.deserialize(data);

        logger.trace("onRemove {} by key: {}", cacheKey, key);
        local.remove(key);
    }
    @Override
    public int onClear() {
        logger.info("onClear {}", cacheKey);
        int size = local.size();
        local.clear();
        return size;
    }
    @Override
    public void cleanUp() {
        local.cleanUp();
    }
    @Override
    public ServerCacheType type() {
        return config.getType();
    }
}
