package be.bagofwords.util;

import org.codehaus.jackson.map.ObjectMapper;
import org.codehaus.jackson.map.SerializationConfig;
import org.codehaus.jackson.type.JavaType;
import org.xerial.snappy.Snappy;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;

public class SerializationUtils {

    public static final long LONG_NULL = Long.MAX_VALUE - 3;
    public static final double DOUBLE_NULL = Double.MAX_VALUE;
    public static final int INT_NULL = Integer.MAX_VALUE;
    public static final float FLOAT_NULL = Float.MAX_VALUE;
    public static final String STRING_NULL = "xyNUlLxy";
    private static final String ENCODING = "UTF-8";

    private static final ObjectMapper prettyPrintObjectMapper = new ObjectMapper();
    private static final ObjectMapper defaultObjectMapper = new ObjectMapper();

    static {
        prettyPrintObjectMapper.enable(SerializationConfig.Feature.INDENT_OUTPUT);
    }


    public static String serializeObject(Object object) {
        return serializeObject(object, false);
    }

    public static String serializeObject(Object object, boolean prettyPrint) {
        try {
            if (object instanceof Compactable) {
                ((Compactable) object).compact();
            }
            ObjectMapper objectMapper = prettyPrint ? prettyPrintObjectMapper : defaultObjectMapper;
            return objectMapper.writeValueAsString(object);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T> T deserializeObject(String object, Class<T> objectClass, Class... genericParams) {
        try {
            if (genericParams.length > 0) {
                JavaType type = defaultObjectMapper.getTypeFactory().constructParametricType(objectClass, genericParams);
                return defaultObjectMapper.readValue(object, type);
            } else {
                return defaultObjectMapper.readValue(object, objectClass);
            }
        } catch (IOException e) {
            String objectForMessage = object;
            if (!StringUtils.isEmpty(objectForMessage) && objectForMessage.length() > 200) {
                objectForMessage = objectForMessage.substring(0, 200) + "...";
            }
            throw new RuntimeException("Failed to read " + objectForMessage, e);
        }
    }

    public static String bytesToString(byte[] key) {
        try {
            return new String(key, ENCODING);
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }
    }

    public static byte[] stringToBytes(String key) {
        try {
            return key.getBytes(ENCODING);
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        } catch (OutOfMemoryError outOfMemoryError) {
            throw new RuntimeException("OOM while trying to convert " + key.substring(0, Math.min(100, key.length())) + " of length " + key.length(), outOfMemoryError);
        }
    }

    public static <T> T bytesToObject(byte[] bytes, Class<T> objectClass) {
        if (bytes == null) {
            return null;
        } else {
            if (objectClass == Long.class) {
                return (T) new Long(bytesToLong(bytes));
            } else if (objectClass == Double.class) {
                return (T) new Double(Double.longBitsToDouble(bytesToLong(bytes)));
            } else if (objectClass == Integer.class) {
                return (T) new Integer(bytesToInt(bytes));
            } else if (objectClass == Float.class) {
                return (T) new Float(Float.intBitsToFloat(bytesToInt(bytes)));
            } else if (objectClass == String.class) {
                return (T) bytesToString(bytes);
            } else {
                String objectAsString = bytesToString(bytes);
                return SerializationUtils.deserializeObject(objectAsString, objectClass);
            }
        }
    }

    public static <T> byte[] objectToBytesCheckForNull(T value, Class<T> objectClass) {
        if (objectClass == Long.class) {
            if (value == null) {
                return longToBytes(LONG_NULL);
            } else if (value.equals(LONG_NULL)) {
                throw new RuntimeException("Sorry " + value + " is a reserved value to indicate null.");
            } else {
                return longToBytes((Long) value);
            }
        } else if (objectClass == Double.class) {
            long valueAsLong;
            if (value == null) {
                valueAsLong = Double.doubleToLongBits(DOUBLE_NULL);
            } else if (value.equals(DOUBLE_NULL)) {
                throw new RuntimeException("Sorry " + value + " is a reserved value to indicate null");
            } else {
                valueAsLong = Double.doubleToLongBits((Double) value);
            }
            return longToBytes(valueAsLong);
        }
        if (objectClass == Integer.class) {
            if (value == null) {
                return intToBytes(INT_NULL);
            } else if (value.equals(INT_NULL)) {
                throw new RuntimeException("Sorry " + value + " is a reserved value to indicate null.");
            } else {
                return intToBytes((Integer) value);
            }
        } else if (objectClass == Float.class) {
            int valueAsInt;
            if (value == null) {
                valueAsInt = Float.floatToIntBits(FLOAT_NULL);
            } else if (value.equals(FLOAT_NULL)) {
                throw new RuntimeException("Sorry " + value + " is a reserved value to indicate null");
            } else {
                valueAsInt = Float.floatToIntBits((Float) value);
            }
            return intToBytes(valueAsInt);
        } else if (objectClass == String.class) {
            if (value == null) {
                return stringToBytes(STRING_NULL);
            } else if (value.equals(STRING_NULL)) {
                throw new RuntimeException("Sorry " + value + " is a reserved value to indicate null");
            } else {
                return stringToBytes((String) value);
            }
        } else {
            return stringToBytes(SerializationUtils.serializeObject(value, false));
        }
    }

    public static <T> T bytesToObjectCheckForNull(byte[] value, Class<T> objectClass) {
        if (objectClass == Long.class) {
            long response = bytesToLong(value);
            if (response != LONG_NULL) {
                return (T) new Long(response);
            } else {
                return null;
            }
        } else if (objectClass == Double.class) {
            double response = Double.longBitsToDouble(bytesToLong(value));
            if (response != DOUBLE_NULL) {
                return (T) new Double(response);
            } else {
                return null;
            }
        } else if (objectClass == Integer.class) {
            int response = bytesToInt(value);
            if (response != INT_NULL) {
                return (T) new Integer(response);
            } else {
                return null;
            }
        } else if (objectClass == Float.class) {
            float response = Float.intBitsToFloat(bytesToInt(value));
            if (response != FLOAT_NULL) {
                return (T) new Float(response);
            } else {
                return null;
            }
        } else {
            String objectAsString = bytesToString(value);
            if (STRING_NULL.equals(objectAsString)) {
                return null;
            } else {
                if (objectClass == String.class) {
                    return (T) objectAsString;
                } else {
                    return SerializationUtils.deserializeObject(objectAsString, objectClass);
                }
            }
        }
    }

    public static <T> byte[] objectToBytes(T value, Class<T> objectClass) {
        if (objectClass == Long.class) {
            return longToBytes((Long) value);
        } else if (objectClass == Double.class) {
            return longToBytes(Double.doubleToLongBits((Double) value));
        } else if (objectClass == Integer.class) {
            return intToBytes((Integer) value);
        } else if (objectClass == Float.class) {
            return intToBytes(Float.floatToIntBits((Float) value));
        } else if (objectClass == String.class) {
            return stringToBytes((String) value);
        } else {
            return stringToBytes(SerializationUtils.serializeObject(value));
        }
    }

    public static <T> byte[] objectToCompressedBytes(T value, Class<T> objectClass) {
        try {
            return Snappy.compress(objectToBytes(value, objectClass));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T> T compressedBytesToObject(byte[] bytes, Class<T> objectClass) {
        try {
            return bytesToObject(Snappy.uncompress(bytes), objectClass);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static byte[] longToBytes(long value) {
        byte[] bytes = new byte[8];
        bytes[0] = (byte) (value >>> 56);
        bytes[1] = (byte) (value >>> 48);
        bytes[2] = (byte) (value >>> 40);
        bytes[3] = (byte) (value >>> 32);
        bytes[4] = (byte) (value >>> 24);
        bytes[5] = (byte) (value >>> 16);
        bytes[6] = (byte) (value >>> 8);
        bytes[7] = (byte) (value);
        return bytes;
    }

    public static long bytesToLong(byte[] bytes) {
        if (bytes.length < 8) {
            bytes = Arrays.copyOf(bytes, 8);
        }
        long result = (((long) bytes[0] << 56) +
                ((long) (bytes[1] & 255) << 48) +
                ((long) (bytes[2] & 255) << 40) +
                ((long) (bytes[3] & 255) << 32) +
                ((long) (bytes[4] & 255) << 24) +
                ((bytes[5] & 255) << 16) +
                ((bytes[6] & 255) << 8) +
                (bytes[7] & 255));
        return result;
    }

    public static byte[] intToBytes(int value) {
        byte[] bytes = new byte[4];
        bytes[0] = (byte) ((value >>> 24));
        bytes[1] = (byte) ((value >>> 16));
        bytes[2] = (byte) ((value >>> 8));
        bytes[3] = (byte) ((value));
        return bytes;
    }

    public static int bytesToInt(byte[] bytes) {
        int result = (bytes[0] << 24) +
                ((bytes[1] & 255) << 16) +
                ((bytes[2] & 255) << 8) +
                (bytes[3] & 255);
        return result;
    }

    /**
     * Careful! Not compatible with above method to convert objects to byte arrays!
     */

    public static void writeObject(Object object, OutputStream outputStream) {
        try {
            if (object instanceof Compactable) {
                ((Compactable) object).compact();
            }
            defaultObjectMapper.writeValue(outputStream, object);
        } catch (IOException exp) {
            throw new RuntimeException("Failed to write object to outputstream", exp);
        }
    }

    /**
     * Careful! Not compatible with above method to convert objects to byte arrays!
     */

    public static <T> T readObject(Class<T> _class, InputStream inputStream) {
        try {
            return defaultObjectMapper.readValue(inputStream, _class);
        } catch (IOException exp) {
            throw new RuntimeException("Failed to read object from inputstream", exp);
        }
    }

    public static <T> int getWidth(Class<T> objectClass) {
        if (objectClass == Long.class || objectClass == Double.class) {
            return 8;
        } else if (objectClass == Integer.class || objectClass == Float.class) {
            return 4;
        } else {
            return -1;
        }
    }
}
