package app.raybritton.tokenstorage.crypto

import android.os.Build
import android.os.Process
import android.util.Base64
import android.util.Log
import app.raybritton.tokenstorage.CryptoLogging
import java.io.*
import java.security.*
import java.util.Arrays
import java.util.concurrent.atomic.AtomicBoolean
import javax.crypto.Cipher
import javax.crypto.Mac
import javax.crypto.SecretKey
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.SecretKeySpec

internal object AesCbcWithIntegrity: CryptoLogging {
    private const val CIPHER_TRANSFORMATION = "AES/CBC/PKCS5Padding"
    private const val CIPHER = "AES"
    private const val RANDOM_ALGORITHM = "SHA1PRNG"
    private const val AES_KEY_LENGTH_BITS = 128
    private const val IV_LENGTH_BYTES = 16
    private const val PBE_ITERATION_COUNT = 10000
    private const val PBE_ALGORITHM = "PBKDF2WithHmacSHA1"

    const val BASE64_FLAGS = Base64.NO_WRAP
    private val prngFixed = AtomicBoolean(false)

    private const val HMAC_ALGORITHM = "HmacSHA256"
    private const val HMAC_KEY_LENGTH_BITS = 256

    @Throws(GeneralSecurityException::class)
    private fun generateKeyFromPassword(password: String, salt: ByteArray): SecretKeys {
        debug("generateKeyFromPassword(${password.shrink()}, $salt)")
        fixPrng()
        val keySpec = PBEKeySpec(password.toCharArray(), salt,
                PBE_ITERATION_COUNT, AES_KEY_LENGTH_BITS + HMAC_KEY_LENGTH_BITS)
        val keyFactory = SecretKeyFactory
                .getInstance(PBE_ALGORITHM)
        val keyBytes = keyFactory.generateSecret(keySpec).encoded

        val confidentialityKeyBytes = copyOfRange(keyBytes, 0, AES_KEY_LENGTH_BITS / 8)
        val integrityKeyBytes = copyOfRange(keyBytes, AES_KEY_LENGTH_BITS / 8, AES_KEY_LENGTH_BITS / 8 + HMAC_KEY_LENGTH_BITS / 8)

        val confidentialityKey = SecretKeySpec(confidentialityKeyBytes, CIPHER)

        val integrityKey = SecretKeySpec(integrityKeyBytes, HMAC_ALGORITHM)

        return SecretKeys(confidentialityKey, integrityKey)
    }

    @Throws(GeneralSecurityException::class)
    fun generateKeyFromPassword(password: String, salt: String): SecretKeys {
        fine("generateKeyFromPassword(${password.shrink()}, $salt)")
        return generateKeyFromPassword(password, Base64.decode(salt, BASE64_FLAGS))
    }

    @Throws(GeneralSecurityException::class)
    private fun generateIv(): ByteArray {
        fine("generateIv()")
        return randomBytes(IV_LENGTH_BYTES)
    }

    @Throws(GeneralSecurityException::class)
    private fun randomBytes(length: Int): ByteArray {
        fine("randomBytes($length)")
        fixPrng()
        val random = SecureRandom.getInstance(RANDOM_ALGORITHM)
        val b = ByteArray(length)
        random.nextBytes(b)
        return b
    }

    @Throws(UnsupportedEncodingException::class, GeneralSecurityException::class)
    @JvmOverloads
    fun encrypt(plaintext: String, secretKeys: SecretKeys, encoding: String = "UTF-8"): CipherTextIvMac {
        return encrypt(plaintext.toByteArray(charset(encoding)), secretKeys)
    }

    @Throws(GeneralSecurityException::class)
    private fun encrypt(plaintext: ByteArray, secretKeys: SecretKeys): CipherTextIvMac {
        var iv = generateIv()
        val aesCipherForEncryption = Cipher.getInstance(CIPHER_TRANSFORMATION)
        aesCipherForEncryption.init(Cipher.ENCRYPT_MODE, secretKeys.confidentialityKey, IvParameterSpec(iv))

        iv = aesCipherForEncryption.iv
        val byteCipherText = aesCipherForEncryption.doFinal(plaintext)
        val ivCipherConcat = CipherTextIvMac.ivCipherConcat(iv, byteCipherText)

        val integrityMac = generateMac(ivCipherConcat, secretKeys.integrityKey)
        return CipherTextIvMac(byteCipherText, iv, integrityMac)
    }

    private fun fixPrng() {
        debug("fixPrng()")
        if (!prngFixed.get()) {
            synchronized(PrngFixes::class.java) {
                if (!prngFixed.get()) {
                    PrngFixes.apply()
                    prngFixed.set(true)
                }
            }
        }
    }

    @Throws(GeneralSecurityException::class)
    fun decrypt(civ: CipherTextIvMac, secretKeys: SecretKeys): ByteArray {

        val ivCipherConcat = CipherTextIvMac.ivCipherConcat(civ.iv, civ.cipherText)
        val computedMac = generateMac(ivCipherConcat, secretKeys.integrityKey)
        if (constantTimeEq(computedMac, civ.mac)) {
            val aesCipherForDecryption = Cipher.getInstance(CIPHER_TRANSFORMATION)
            aesCipherForDecryption.init(Cipher.DECRYPT_MODE, secretKeys.confidentialityKey,
                    IvParameterSpec(civ.iv))
            return aesCipherForDecryption.doFinal(civ.cipherText)
        } else {
            throw GeneralSecurityException("MAC stored in civ does not match computed MAC.")
        }
    }

    @Throws(NoSuchAlgorithmException::class, InvalidKeyException::class)
    private fun generateMac(byteCipherText: ByteArray, integrityKey: SecretKey): ByteArray {
        debug("generateMac(...)")
        val sha256HMAC = Mac.getInstance(HMAC_ALGORITHM)
        sha256HMAC.init(integrityKey)
        return sha256HMAC.doFinal(byteCipherText)
    }

    class SecretKeys(val confidentialityKey: SecretKey, val integrityKey: SecretKey) {

        override fun toString(): String {
            return Base64.encodeToString(confidentialityKey.encoded, BASE64_FLAGS) + ":" + Base64.encodeToString(integrityKey.encoded, BASE64_FLAGS)
        }

        override fun hashCode(): Int {
            val prime = 31
            var result = 1
            result = prime * result + confidentialityKey.hashCode()
            result = prime * result + integrityKey.hashCode()
            return result
        }

        override fun equals(other: Any?): Boolean {
            if (this === other)
                return true
            if (other == null)
                return false
            if (javaClass != other.javaClass)
                return false
            val obj = other as SecretKeys?
            if (integrityKey != obj!!.integrityKey)
                return false
            if (confidentialityKey != obj.confidentialityKey)
                return false
            return true
        }
    }

    private fun constantTimeEq(a: ByteArray, b: ByteArray): Boolean {
        if (a.size != b.size) {
            return false
        }
        var result = 0
        for (i in a.indices) {
            result = result or (a[i].toInt() xor b[i].toInt())
        }
        return result == 0
    }

    class CipherTextIvMac {
        val cipherText: ByteArray
        val iv: ByteArray
        val mac: ByteArray

        constructor(cipherTextBytes: ByteArray, ivBytes: ByteArray, macBytes: ByteArray) {
            cipherText = ByteArray(cipherTextBytes.size)
            System.arraycopy(cipherTextBytes, 0, cipherText, 0, cipherTextBytes.size)
            iv = ByteArray(ivBytes.size)
            System.arraycopy(ivBytes, 0, iv, 0, ivBytes.size)
            mac = ByteArray(macBytes.size)
            System.arraycopy(macBytes, 0, mac, 0, macBytes.size)
        }

        constructor(base64IvAndCiphertext: String) {
            val civArray = base64IvAndCiphertext.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()
            if (civArray.size != 3) {
                throw IllegalArgumentException("Cannot parse iv:ciphertext:mac")
            } else {
                iv = Base64.decode(civArray[0], BASE64_FLAGS)
                mac = Base64.decode(civArray[1], BASE64_FLAGS)
                cipherText = Base64.decode(civArray[2], BASE64_FLAGS)
            }
        }

        override fun toString(): String {
            val ivString = Base64.encodeToString(iv, BASE64_FLAGS)
            val cipherTextString = Base64.encodeToString(cipherText, BASE64_FLAGS)
            val macString = Base64.encodeToString(mac, BASE64_FLAGS)
            return String.format("$ivString:$macString:$cipherTextString")
        }

        override fun hashCode(): Int {
            val prime = 31
            var result = 1
            result = prime * result + Arrays.hashCode(cipherText)
            result = prime * result + Arrays.hashCode(iv)
            result = prime * result + Arrays.hashCode(mac)
            return result
        }

        override fun equals(other: Any?): Boolean {
            if (this === other)
                return true
            if (other == null)
                return false
            if (javaClass != other.javaClass)
                return false
            val obj = other as CipherTextIvMac?
            if (!Arrays.equals(cipherText, obj!!.cipherText))
                return false
            if (!Arrays.equals(iv, obj.iv))
                return false
            if (!Arrays.equals(mac, obj.mac))
                return false
            return true
        }

        companion object {
            fun ivCipherConcat(iv: ByteArray, cipherText: ByteArray): ByteArray {
                val combined = ByteArray(iv.size + cipherText.size)
                System.arraycopy(iv, 0, combined, 0, iv.size)
                System.arraycopy(cipherText, 0, combined, iv.size, cipherText.size)
                return combined
            }
        }
    }

    private fun copyOfRange(from: ByteArray, start: Int, end: Int): ByteArray {
        val length = end - start
        val result = ByteArray(length)
        System.arraycopy(from, start, result, 0, length)
        return result
    }

    object PrngFixes {

        private const val VERSION_CODE_JELLY_BEAN = 16
        private const val VERSION_CODE_JELLY_BEAN_MR2 = 18
        private val BUILD_FINGERPRINT_AND_DEVICE_SERIAL = buildFingerprintAndDeviceSerial

        fun apply() {
            applyOpenSSLFix()
            installLinuxPRNGSecureRandom()
        }

        @Throws(SecurityException::class)
        private fun applyOpenSSLFix() {
            if (Build.VERSION.SDK_INT < VERSION_CODE_JELLY_BEAN || Build.VERSION.SDK_INT > VERSION_CODE_JELLY_BEAN_MR2) {
                return
            }

            try {
                Class.forName("org.apache.harmony.xnet.provider.jsse.NativeCrypto")
                        .getMethod("RAND_seed", ByteArray::class.java).invoke(null, *generateSeed().toTypedArray())

                val bytesRead = Class
                        .forName("org.apache.harmony.xnet.provider.jsse.NativeCrypto")
                        .getMethod("RAND_load_file", String::class.java, Long::class.javaPrimitiveType)
                        .invoke(null, "/dev/urandom", 1024) as Int
                if (bytesRead != 1024) {
                    throw IOException("Unexpected number of bytes read from Linux PRNG: $bytesRead")
                }
            } catch (e: Exception) {
                throw SecurityException("Failed to seed OpenSSL PRNG", e)
            }

        }

        @Throws(SecurityException::class)
        private fun installLinuxPRNGSecureRandom() {
            if (Build.VERSION.SDK_INT > VERSION_CODE_JELLY_BEAN_MR2) {
                return
            }

            val secureRandomProviders = Security.getProviders("SecureRandom.SHA1PRNG")

            synchronized(Security::class.java) {
                if (!(secureRandomProviders != null && !secureRandomProviders.isEmpty() && secureRandomProviders[0].javaClass.simpleName == "LinuxPRNGSecureRandomProvider")) {
                    Security.insertProviderAt(LinuxPRNGSecureRandomProvider(), 1)
                }

                val rng1 = SecureRandom()
                if (rng1.provider.javaClass.simpleName != "LinuxPRNGSecureRandomProvider") {
                    throw SecurityException("new SecureRandom() backed by wrong Provider: " + rng1.provider.javaClass)
                }

                var rng2: SecureRandom? = null
                try {
                    rng2 = SecureRandom.getInstance("SHA1PRNG")
                } catch (e: NoSuchAlgorithmException) {
                    SecurityException("SHA1PRNG not available", e)
                }

                if (rng2!!.provider.javaClass.simpleName != "LinuxPRNGSecureRandomProvider") {
                    throw SecurityException(
                            "SecureRandom.getInstance(\"SHA1PRNG\") backed by wrong" + " Provider: "
                                    + rng2.provider.javaClass)
                }
            }
        }

        private class LinuxPRNGSecureRandomProvider : Provider("LinuxPRNG", 1.0, "A Linux-specific random number provider that uses" + " /dev/urandom") {
            init {
                put("SecureRandom.SHA1PRNG", LinuxPRNGSecureRandom::class.java.name)
                put("SecureRandom.SHA1PRNG ImplementedIn", "Software")
            }
        }

        class LinuxPRNGSecureRandom : SecureRandomSpi() {

            private var mSeeded: Boolean = false

            override fun engineSetSeed(bytes: ByteArray) {
                try {
                    var out: OutputStream? = null
                    synchronized(sLock) {
                        out = urandomOutputStream
                    }
                    out!!.write(bytes)
                    out!!.flush()
                } catch (e: IOException) {
                    Log.w(PrngFixes::class.java.simpleName, "Failed to mix seed into $URANDOM_FILE")
                } finally {
                    mSeeded = true
                }
            }

            override fun engineNextBytes(bytes: ByteArray) {
                if (!mSeeded) {
                    engineSetSeed(generateSeed())
                }

                try {
                    var dis: DataInputStream? = null
                    synchronized(sLock) {
                        dis = urandomInputStream
                    }
                    synchronized(dis!!) {
                        dis!!.readFully(bytes)
                    }
                } catch (e: IOException) {
                    throw SecurityException("Failed to read from $URANDOM_FILE", e)
                }

            }

            override fun engineGenerateSeed(size: Int): ByteArray {
                val seed = ByteArray(size)
                engineNextBytes(seed)
                return seed
            }

            private val urandomInputStream: DataInputStream
                get() = synchronized(sLock) {
                    if (sUrandomIn == null) {
                        try {
                            sUrandomIn = DataInputStream(FileInputStream(URANDOM_FILE))
                        } catch (e: IOException) {
                            throw SecurityException("Failed to open " + URANDOM_FILE
                                    + " for reading", e)
                        }

                    }
                    return sUrandomIn as DataInputStream
                }

            private val urandomOutputStream: OutputStream
                @Throws(IOException::class)
                get() = synchronized(sLock) {
                    if (sUrandomOut == null) {
                        sUrandomOut = FileOutputStream(URANDOM_FILE)
                    }
                    return sUrandomOut as OutputStream
                }

            companion object {

                /*
             * IMPLEMENTATION NOTE: Requests to generate bytes and to mix in a
             * seed are passed through to the Linux PRNG (/dev/urandom).
             * Instances of this class seed themselves by mixing in the current
             * time, PID, UID, build fingerprint, and hardware serial number
             * (where available) into Linux PRNG.
             *
             * Concurrency: Read requests to the underlying Linux PRNG are
             * serialized (on sLock) to ensure that multiple threads do not get
             * duplicated PRNG output.
             */

                private val URANDOM_FILE = File("/dev/urandom")

                private val sLock = Any()

                /**
                 * Input stream for reading from Linux PRNG or `null` if not
                 * yet opened.

                 * @GuardedBy("sLock")
                 */
                private var sUrandomIn: DataInputStream? = null

                /**
                 * Output stream for writing to Linux PRNG or `null` if not
                 * yet opened.

                 * @GuardedBy("sLock")
                 */
                private var sUrandomOut: OutputStream? = null
            }
        }

        /**
         * Generates a device- and invocation-specific seed to be mixed into the
         * Linux PRNG.
         */
        private fun generateSeed(): ByteArray {
            try {
                val seedBuffer = ByteArrayOutputStream()
                val seedBufferOut = DataOutputStream(seedBuffer)
                seedBufferOut.writeLong(System.currentTimeMillis())
                seedBufferOut.writeLong(System.nanoTime())
                seedBufferOut.writeInt(Process.myPid())
                seedBufferOut.writeInt(Process.myUid())
                seedBufferOut.write(BUILD_FINGERPRINT_AND_DEVICE_SERIAL)
                seedBufferOut.close()
                return seedBuffer.toByteArray()
            } catch (e: IOException) {
                throw SecurityException("Failed to generate seed", e)
            }

        }

        private val deviceSerialNumber: String?
            get() {
                try {
                    return Build::class.java.getField("SERIAL").get(null) as String
                } catch (ignored: Exception) {
                    return null
                }

            }

        private val buildFingerprintAndDeviceSerial: ByteArray
            get() {
                val result = StringBuilder()
                val fingerprint = Build.FINGERPRINT
                if (fingerprint != null) {
                    result.append(fingerprint)
                }
                val serial = deviceSerialNumber
                if (serial != null) {
                    result.append(serial)
                }
                try {
                    return result.toString().toByteArray(charset("UTF-8"))
                } catch (e: UnsupportedEncodingException) {
                    throw RuntimeException("UTF-8 encoding not supported")
                }

            }
    }
}