package app.raybritton.tokenstorage.crypto

import android.annotation.TargetApi
import android.content.Context
import android.os.Build
import android.security.KeyPairGeneratorSpec
import android.security.keystore.KeyGenParameterSpec
import android.security.keystore.KeyProperties
import android.util.Base64
import java.math.BigInteger
import java.security.KeyPairGenerator
import java.security.KeyStore
import java.security.PrivateKey
import java.security.PublicKey
import java.util.Calendar
import javax.crypto.Cipher
import javax.security.auth.x500.X500Principal

/**
 * Uses asymmetric keys (RSA/ECB/PKCS1Padding) for encyption/decryption
 *
 * @param alias The alias for the certificate
 *
 * @param userMustBeAuthenticated If true, the device must have a keyguard setup
 * NOTE: If set to true and the devices keyguard is disabled or reset (by device admin, etc) all data is lost
 *
 * @param invalidateOnBiometricChange Ignored unless userMustBeAuthenticated is true
 * If true, when the biometrics on the device change (adding or removing fingerprints, etc) all data is lost
 * Only used on Android N+
 */
class CertCrypto(private val context: Context,
                 private val alias: String,
                 private val userMustBeAuthenticated: Boolean = false,
                 private val invalidateOnBiometricChange: Boolean = false) : Crypto {
    private val KEYSTORE_TYPE = "AndroidKeyStore"
    private val KEYSTORE_ALGORITM = "RSA"
    private val ENCRYPT_ALGORITHM = "RSA/ECB/PKCS1Padding"

    private var privateKey: PrivateKey? = null
    private var publicKey: PublicKey? = null

    private fun loadCert(): Boolean {
        val keyStore = KeyStore.getInstance(KEYSTORE_TYPE)
        keyStore.load(null)
        if (keyStore.containsAlias(alias)) {
            privateKey = keyStore.getKey(alias, null) as PrivateKey
            publicKey = keyStore.getCertificate(alias).publicKey
            return true
        } else {
            return false
        }
    }

    override fun encrypt(plaintext: String): String {
        val cipher = Cipher.getInstance(ENCRYPT_ALGORITHM)
        cipher.init(Cipher.ENCRYPT_MODE, publicKey)

        val encrypted = cipher.doFinal(plaintext.toByteArray())
        return Base64.encodeToString(encrypted, Base64.NO_WRAP)
    }

    override fun decrypt(encrypted: String): String {
        val cipher = Cipher.getInstance(ENCRYPT_ALGORITHM)
        cipher.init(Cipher.DECRYPT_MODE, privateKey)

        val encryptedData = Base64.decode(encrypted, Base64.DEFAULT)
        val decodedData = cipher.doFinal(encryptedData)
        return String(decodedData)
    }

    override fun verify() {
        if (!loadCert()) {
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
                verifyMarshmallow()
            } else {
                verifyPreMarshmallow()
            }
        }
    }

    @TargetApi(Build.VERSION_CODES.M)
    private fun verifyMarshmallow() {
        val keyPairGenerator = KeyPairGenerator.getInstance(
                KeyProperties.KEY_ALGORITHM_RSA, KEYSTORE_TYPE)
        keyPairGenerator.initialize(
                KeyGenParameterSpec.Builder(
                        alias,
                        KeyProperties.PURPOSE_ENCRYPT or KeyProperties.PURPOSE_DECRYPT)
                        .setBlockModes(KeyProperties.BLOCK_MODE_ECB)
                        .setEncryptionPaddings(KeyProperties.ENCRYPTION_PADDING_RSA_PKCS1)
                        .setUserAuthenticationRequired(userMustBeAuthenticated)
                        .setupBiometrics()
                        .build())
        val keyPair = keyPairGenerator.generateKeyPair()
        privateKey = keyPair.private
        publicKey = keyPair.public
    }

    @Suppress("DEPRECATION")
    private fun verifyPreMarshmallow() {
        val keyPairGenerator = KeyPairGenerator.getInstance(
                KEYSTORE_ALGORITM, KEYSTORE_TYPE)
        val builder = KeyPairGeneratorSpec.Builder(context)
                .setAlias(alias)
                .setSerialNumber(BigInteger.valueOf(32945367343536L))
                .setSubject(X500Principal("CN=$alias Certificate, O=${context.packageName}"))
                .setStartDate(Calendar.getInstance().also { it.add(Calendar.YEAR, -1) }.time)
                .setEndDate(Calendar.getInstance().also { it.add(Calendar.YEAR, 30) }.time)

        keyPairGenerator.initialize(builder.build())
        val keyPair = keyPairGenerator.genKeyPair()
        privateKey = keyPair.private
        publicKey = keyPair.public
    }

    private fun KeyGenParameterSpec.Builder.setupBiometrics(): KeyGenParameterSpec.Builder {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
            this.setInvalidatedByBiometricEnrollment(invalidateOnBiometricChange)
        }
        return this
    }
}