package in.sourceshift.genericmodules.securityutils;

import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;

import in.sourceshift.genericmodules.securityutils.exception.SecurityUtilsException;
import in.sourceshift.genericmodules.securityutils.hashalgorithms.BCrypt;
import in.sourceshift.genericmodules.securityutils.hashalgorithms.HashType;
import in.sourceshift.genericmodules.securityutils.hashalgorithms.PBKDF2;
import in.sourceshift.genericmodules.securityutils.hashalgorithms.SCrypt;

public class Hash {

    // These constants define the encoding and may not be changed.
    private static final String SCRYPT = "scrypt";
    private static final String BCRYPT = "bcrypt";
    private static final String PBKDF2_HMACSHA1 = "PBKDF2WithHmacSHA1";
    private static final String PBKDF2_HMACSHA256 = "PBKDF2WithHmacSHA256";
    private static final String PBKDF2_HMACSHA512 = "PBKDF2WithHmacSHA512";
    private static final int HASH_SECTIONS = 8;
    private static final int HASH_ALGORITHM_INDEX = 0;
    private static final int ITERATION_INDEX = 1;
    private static final int HASH_SIZE_INDEX = 2;
    private static final int SALT_SIZE_INDEX = 3;
    private static final int USERSALT_INDEX = 4;
    private static final int SALT_INDEX = 5;
    private static final int HASH_INDEX = 6;
    private static final int PEPPER_INDEX = 7;

    // Hash parameters with defaults
    private char[] password;
    private char[] usersalt;
    private boolean isPepper = false;
    private int hashLength = 0;
    private int saltLength = 0;
    private int factor = 0;
    private HashType algorithm = HashType.PBKDF2_SHA1;

    public static Hash password(char[] password) throws SecurityUtilsException {
        if ((password == null) || (password.length < 1)) {
            throw new SecurityUtilsException("Password cannot be null or empty.");
        }
        Hash hash = new Hash();
        hash.password = password;
        return hash;
    }

    public Hash addPepper() {
        isPepper = true;
        return this;
    }

    public Hash usersalt(char[] usersalt) {
        this.usersalt = usersalt;
        return this;
    }

    public Hash hashLength(int hashLength) {
        this.hashLength = hashLength;
        return this;
    }

    public Hash saltLength(int saltLength) {
        this.saltLength = saltLength;
        return this;
    }

    public Hash algorithm(HashType algorithm) {
        this.algorithm = algorithm;
        return this;
    }

    public Hash factor(int factor) {
        this.factor = factor;
        return this;
    }

    public String create() throws SecurityUtilsException {
        SecureRandom number;
        try {
            number = SecureRandom.getInstanceStrong();
        } catch (NoSuchAlgorithmException e) {
            throw new SecurityUtilsException(e);
        }
        // add usersalt if not empty
        char isUsersalted = 'n';
        String usersaltedpassword = new String(password);
        if ((usersalt != null) && (usersalt.length > 0)) {
            isUsersalted = 'y';
            usersaltedpassword = (new String(usersalt) + usersaltedpassword);
        }

        // add pepper if set
        if (isPepper) {
            usersaltedpassword = (usersaltedpassword + HashUtils.pepperarray[number.nextInt(HashUtils.pepperarray.length - 1)]);
        }

        if ((algorithm == HashType.PBKDF2_SHA1) || (algorithm == HashType.PBKDF2_SHA256) || (algorithm == HashType.PBKDF2_SHA512)) {
            PBKDF2 pbkdf2 = new PBKDF2();

            String alg = null;
            String alg2 = null;
            if (algorithm == HashType.PBKDF2_SHA1) {
                alg = Hash.PBKDF2_HMACSHA1;
                alg2 = "pbkdf2sha1";
            } else if (algorithm == HashType.PBKDF2_SHA256) {
                alg = Hash.PBKDF2_HMACSHA256;
                alg2 = "pbkdf2sha256";
            } else if (algorithm == HashType.PBKDF2_SHA512) {
                alg = Hash.PBKDF2_HMACSHA512;
                alg2 = "pbkdf2sha512";
            }

            if (hashLength <= 0) {
                // default hash length
                hashLength = pbkdf2.DEFAULT_HASH_LENGTH;
            }

            if (saltLength <= 0) {
                // default salt length
                saltLength = pbkdf2.DEFAULT_SALT_LENGTH;
            }

            if (factor <= 0) {
                // default factor
                factor = pbkdf2.DEFAULT_ITERATIONS;
            }

            // Generate a random salt
            byte[] salt = HashUtils.randomSalt(saltLength);

            // Hash the password
            byte[] hash = pbkdf2.create(usersaltedpassword.toCharArray(), salt, alg, factor, hashLength);

            // format for storage
            StringBuilder finalHash = new StringBuilder(alg2).append(":").append(factor).append(":").append(hash.length).append(":").append(salt.length).append(":")
                    .append(isUsersalted).append(":").append(BaseEncoder.encodeBase64toString(salt)).append(":").append(BaseEncoder.encodeBase64toString(hash)).append(":")
                    .append(isPepper);

            return finalHash.toString();

        } else if (algorithm == HashType.BCRYPT) {
            BCrypt bc = new BCrypt();
            if (factor <= 0) {
                // default factor
                factor = bc.DEFAULT_LOG2_ROUNDS;
            }

            if (saltLength <= 0) {
                // default salt length
                saltLength = bc.DEFAULT_SALT_LENGTH;
            }

            // Hash the password
            String hash = bc.create(usersaltedpassword, null, saltLength, factor);

            // format for storage
            StringBuilder finalHash = new StringBuilder(BCRYPT).append(":").append(factor).append(":").append(hash.length()).append(":").append(saltLength).append(":")
                    .append(isUsersalted).append("::").append(hash).append(":").append(isPepper);

            return finalHash.toString();

        } else if (algorithm == HashType.SCRYPT) {
            SCrypt sc = new SCrypt();
            if (factor <= 0) {
                // default factor
                factor = sc.COST;
            }

            if (saltLength <= 0) {
                // default salt length
                saltLength = sc.DEFAULT_SALT_LENGTH;
            }

            // Hash the password
            String hash = sc.create(usersaltedpassword, saltLength, factor);

            // format for storage
            StringBuilder finalHash = new StringBuilder(SCRYPT).append(":").append(factor).append(":").append(hash.length()).append(":").append(saltLength).append(":")
                    .append(isUsersalted).append("::").append(hash).append(":").append(isPepper);

            return finalHash.toString();

        } else {
            throw new SecurityUtilsException("Unsupported algorithm type. Expected Type.BCRYPT, Type.SCRIPT, or other Type enum.");
        }
    }

    public boolean verify(String correctHash) throws SecurityUtilsException {
        // check hash
        if ((correctHash == null) || correctHash.isEmpty()) {
            throw new SecurityUtilsException("Correct hash cannot be null or empty.");
        }

        // Decode the hash into its parameters
        String[] params = correctHash.split(":");
        if (params.length != HASH_SECTIONS) {
            throw new SecurityUtilsException("Fields are missing from the correct hash. Double-check JHash vesrion and hash format.");
        }

        // validate each part
        int iterations = 0;
        try {
            iterations = Integer.parseInt(params[ITERATION_INDEX]);
        } catch (NumberFormatException ex) {
            throw new SecurityUtilsException("Could not parse the iteration count as an integer.", ex);
        }

        if (iterations < 1) {
            throw new SecurityUtilsException("Invalid number of iterations. Must be >= 1.");
        }

        String usersaltedpassword = new String(password);
        try {
            if ('y' == params[USERSALT_INDEX].charAt(0)) {
                usersaltedpassword = (new String(usersalt) + usersaltedpassword);
            }
        } catch (IllegalArgumentException ex) {
            throw new SecurityUtilsException("Could not parse the usersalt flag.", ex);
        }

        byte[] salt = null;
        try {
            salt = BaseEncoder.decodeBase64(params[SALT_INDEX]);
        } catch (IllegalArgumentException ex) {
            throw new SecurityUtilsException("Base64 decoding of salt failed.", ex);
        }

        int storedHashSize = 0;
        try {
            storedHashSize = Integer.parseInt(params[HASH_SIZE_INDEX]);
        } catch (NumberFormatException ex) {
            throw new SecurityUtilsException("Could not parse the hash size as an integer.", ex);
        }

        int storedSaltSize = 0;
        try {
            storedSaltSize = Integer.parseInt(params[SALT_SIZE_INDEX]);
        } catch (NumberFormatException ex) {
            throw new SecurityUtilsException("Could not parse the salt size as an integer.", ex);
        }

        // verify algorithm
        String algorithm = params[HASH_ALGORITHM_INDEX];
        if (algorithm.toLowerCase().startsWith("pbkdf2")) {
            PBKDF2 pbkdf2 = new PBKDF2();

            if ("pbkdf2sha1".equals(algorithm)) {
                algorithm = PBKDF2_HMACSHA1;
            } else if ("pbkdf2sha256".equals(algorithm)) {
                algorithm = PBKDF2_HMACSHA256;
            } else if ("pbkdf2sha512".equals(algorithm)) {
                algorithm = PBKDF2_HMACSHA512;
            }

            byte[] hash = null;
            try {
                hash = BaseEncoder.decodeBase64(params[HASH_INDEX]);
            } catch (IllegalArgumentException ex) {
                throw new SecurityUtilsException("Base64 decoding of hash failed.", ex);
            }

            if (storedHashSize != hash.length) {
                throw new SecurityUtilsException("Hash length doesn't match stored hash length.");
            }

            if (Boolean.valueOf(params[PEPPER_INDEX])) {
                for (char element : HashUtils.pepperarray) {

                    // Compute the hash of the provided string,
                    // using the same salt,
                    // iteration count, and hash length
                    byte[] testHash = pbkdf2.create((usersaltedpassword + element).toCharArray(), salt, algorithm, iterations, hash.length);
                    // Compare the hashes in constant time.
                    if (HashUtils.slowEquals(hash, testHash)) {
                        return true;
                    }
                }

            } else {
                // Compute the hash of the provided string, using the same salt,
                // iteration count, and hash length
                byte[] testHash = pbkdf2.create(usersaltedpassword.toCharArray(), salt, algorithm, iterations, hash.length);

                // Compare the hashes in constant time.
                return HashUtils.slowEquals(hash, testHash);
            }

        } else if (algorithm.equals(BCRYPT)) {
            BCrypt bc = new BCrypt();

            byte[] hash = null;
            try {
                hash = params[HASH_INDEX].getBytes(StandardCharsets.UTF_8);
            } catch (Exception ex) {
                throw new SecurityUtilsException("Parsing of hash failed.", ex);
            }

            if (storedHashSize != hash.length) {
                throw new SecurityUtilsException("Hash length doesn't match stored hash length.");
            }

            if (Boolean.valueOf(params[PEPPER_INDEX])) {
                for (char element : HashUtils.pepperarray) {
                    byte[] testHash = bc.create(usersaltedpassword + element, new String(hash), storedSaltSize, iterations).getBytes(StandardCharsets.UTF_8);
                    if (HashUtils.slowEquals(hash, testHash)) {
                        return true;
                    }
                }
            } else {
                byte[] testHash = bc.create(usersaltedpassword, new String(hash), storedSaltSize, iterations).getBytes(StandardCharsets.UTF_8);

                return HashUtils.slowEquals(hash, testHash);
            }

        } else if (algorithm.equals(SCRYPT)) {
            SCrypt sc = new SCrypt();
            byte[] hash = null;
            try {
                hash = params[HASH_INDEX].getBytes(StandardCharsets.UTF_8);
            } catch (Exception ex) {
                throw new SecurityUtilsException("Parsing of hash failed.", ex);
            }

            if (storedHashSize != hash.length) {
                throw new SecurityUtilsException("Hash length doesn't match stored hash length.");
            }
            if (Boolean.valueOf(params[PEPPER_INDEX])) {
                for (char element : HashUtils.pepperarray) {
                    if (sc.verify(usersaltedpassword + element, new String(hash))) {
                        return true;
                    }
                }
            } else {
                return sc.verify(usersaltedpassword, new String(hash));
            }

        } else {
            // unrecognized algorithm
            throw new SecurityUtilsException("Unsupported algorithm type: " + algorithm);
        }
        throw new SecurityUtilsException("Unsupported algorithm type: " + algorithm);
    }

}
