package systems.comodal.shamir;

import java.math.BigInteger;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import java.util.stream.IntStream;

public final class Shamir {

  private Shamir() {
  }

  public static ShamirSharesBuilder buildShares() {
    return new ShamirSharesBuilder();
  }

  public static BigInteger[] createSecrets(final Random secureRandom, final BigInteger prime, final int requiredShares) {
    final var secrets = new BigInteger[requiredShares];
    createSecrets(secureRandom, prime, secrets);
    return secrets;
  }

  public static void createSecrets(final Random secureRandom, final BigInteger prime, final BigInteger[] secrets) {
    for (int i = 0; i < secrets.length; i++) {
      secrets[i] = createSecret(secureRandom, prime);
    }
  }

  public static BigInteger createSecret(final Random secureRandom, final BigInteger prime) {
    for (BigInteger secret; ; ) {
      secret = new BigInteger(prime.bitLength(), secureRandom);
      if (secret.compareTo(BigInteger.ZERO) > 0 && secret.compareTo(prime) < 0) {
        return secret;
      }
    }
  }

  public static BigInteger[] createShares(final Random secureRandom,
                                          final BigInteger prime,
                                          final BigInteger secret,
                                          final int requiredShares,
                                          final int numShares) {
    final var secrets = new BigInteger[requiredShares];
    secrets[0] = secret;
    for (int i = 1; i < requiredShares; i++) {
      secrets[i] = createSecret(secureRandom, prime);
    }
    return createShares(prime, secrets, numShares);
  }

  public static BigInteger[] createShares(final BigInteger prime, final BigInteger[] secrets, final int numShares) {
    final var shares = new BigInteger[numShares];
    for (int shareIndex = 0; shareIndex < numShares; shareIndex++) {
      var result = secrets[0];
      final var sharePosition = BigInteger.valueOf(shareIndex + 1);
      for (int exp = 1; exp < secrets.length; exp++) {
        result = result.add(secrets[exp]
            .multiply(sharePosition.pow(exp).mod(prime)))
            .mod(prime);
      }
      shares[shareIndex] = result;
    }
    return shares;
  }

  public static BigInteger reconstructSecret(final Map<BigInteger, BigInteger> coordinates, final BigInteger prime) {
    final var coordinateEntries = coordinates.entrySet();
    var freeCoefficient = BigInteger.ZERO;

    for (final var referencePoint : coordinateEntries) {
      var numerator = BigInteger.ONE;
      var denominator = BigInteger.ONE;

      final var referencePosition = referencePoint.getKey();
      for (final var point : coordinateEntries) {
        final var position = point.getKey();
        if (referencePosition.equals(position)) {
          continue;
        }
        numerator = numerator.multiply(position.negate()).mod(prime);
        denominator = denominator.multiply(referencePosition.subtract(position)).mod(prime);
      }
      final var share = referencePoint.getValue();
      freeCoefficient = prime.add(freeCoefficient)
          .add(share.multiply(numerator).multiply(denominator.modInverse(prime)))
          .mod(prime);
    }
    return freeCoefficient;
  }

  private static BigInteger reconstructSecret(final Map.Entry<BigInteger, BigInteger>[] coordinates, final BigInteger prime) {
    var freeCoefficient = BigInteger.ZERO;

    for (final var referencePoint : coordinates) {
      var numerator = BigInteger.ONE;
      var denominator = BigInteger.ONE;

      final var referencePosition = referencePoint.getKey();
      for (final var point : coordinates) {
        final var position = point.getKey();
        if (referencePosition.equals(position)) {
          continue;
        }
        numerator = numerator.multiply(position.negate()).mod(prime);
        denominator = denominator.multiply(referencePosition.subtract(position)).mod(prime);
      }
      final var share = referencePoint.getValue();
      freeCoefficient = prime.add(freeCoefficient)
          .add(share.multiply(numerator).multiply(denominator.modInverse(prime)))
          .mod(prime);
    }
    return freeCoefficient;
  }

  @SuppressWarnings("unchecked")
  public static int validateShareCombinations(final BigInteger expectedSecret,
                                              final BigInteger prime,
                                              final int numRequiredShares,
                                              final BigInteger[] shares) {
    final var coordinates = IntStream.range(0, shares.length)
        .mapToObj(i -> Map.entry(BigInteger.valueOf(i + 1), shares[i]))
        .toArray(Map.Entry[]::new);
    return Shamir.shareCombinations(coordinates, 0, numRequiredShares, new Map.Entry[numRequiredShares], expectedSecret, prime);
  }

  private static int shareCombinations(final Map.Entry<BigInteger, BigInteger>[] coordinates,
                                       final int startPos,
                                       final int len,
                                       final Map.Entry<BigInteger, BigInteger>[] result,
                                       final BigInteger expectedSecret,
                                       final BigInteger prime) {
    if (len == 0) {
      validateReconstruction(expectedSecret, prime, result);
      return 1;
    }
    int numSubSets = 0;
    for (int i = startPos; i <= coordinates.length - len; i++) {
      result[result.length - len] = coordinates[i];
      numSubSets += shareCombinations(coordinates, i + 1, len - 1, result, expectedSecret, prime);
    }
    return numSubSets;
  }

  private static void validateReconstruction(final BigInteger expectedSecret,
                                             final BigInteger prime,
                                             final Map.Entry<BigInteger, BigInteger>[] coordinates) {
    final var reconstructedSecret = reconstructSecret(coordinates, prime);
    if (!expectedSecret.equals(reconstructedSecret)) {
      throw new IllegalStateException(String.format("Reconstructed secret does not equal expected secret. %nReconstructed: '%s' %nExpected: '%s' %nWith %d shares: %n%s",
          reconstructedSecret, expectedSecret, coordinates.length, Arrays.toString(coordinates)));
    }
  }
}
