package io.digital.patterns.keycloak.impersonate;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.extern.java.Log;
import org.springframework.cache.CacheManager;
import org.springframework.cache.concurrent.ConcurrentMapCacheManager;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.jwt.JwtDecoders;
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.logging.Level;

import static java.lang.String.format;

@Log
public class KeycloakImpersonateService {

    private final RestTemplate restTemplate;
    private final KeycloakConfiguration keycloakConfiguration;
    private final NimbusJwtDecoder jwtDecoder;
    private final CacheManager cacheManager;
    private JwtTimestampValidator jwtTimestampValidator;
    private static final String CACHE_NAME = UUID.randomUUID().toString() + "-" + "token-cache";

    public KeycloakImpersonateService(KeycloakConfiguration keycloakConfiguration) {
        this.restTemplate = new RestTemplate();
        this.keycloakConfiguration = keycloakConfiguration;

        String issuerUrl = issuerUrl(keycloakConfiguration);

        this.jwtDecoder = (NimbusJwtDecoder) JwtDecoders.fromOidcIssuerLocation(issuerUrl);
        this.jwtTimestampValidator = new JwtTimestampValidator();
        this.jwtDecoder.setJwtValidator(jwtTimestampValidator);
        this.cacheManager = new ConcurrentMapCacheManager(CACHE_NAME);
    }

    private String issuerUrl(KeycloakConfiguration keycloakConfiguration) {
        if (keycloakConfiguration.getTokenExchangeUrl() == null) {
            return format("%s/realms/%s",
                    keycloakConfiguration.getAuthUrl(),
                    keycloakConfiguration.getAuthRealm());
        }
        return keycloakConfiguration.getTokenExchangeUrl().substring(0,
                keycloakConfiguration.getTokenExchangeUrl().indexOf("/protocol/openid-connect/token"));
    }

    public String generateAccessToken(final String user) {
        log.log(Level.FINE, "Initiating access token request");
        Assert.notNull(user, "User cannot be null");
        TokenResult tokenResult = Objects.requireNonNull(this.cacheManager.getCache(CACHE_NAME))
                .get("accessToken",
                        () -> getNewAccessToken(user));
        if (tokenResult == null) {
            throw new IllegalStateException("Access token is null");
        }
        boolean invalidToken = false;
        try {
            jwtDecoder.decode(tokenResult.accessToken);
        } catch (Exception e) {
            log.log(Level.WARNING, "Token invalid " +  e.getMessage());
            invalidToken = true;
        }
        if (invalidToken) {
            log.log(Level.FINE, "Token has expired. Requesting a new token using refresh token");
            TokenResult updated = refreshToken(tokenResult);
            Objects.requireNonNull(cacheManager.getCache(CACHE_NAME))
                    .put("accessToken", updated);
            log.log(Level.FINE,"Token refreshed and stored in cache for later use");
            return updated.accessToken;

        }
        return tokenResult.accessToken;
    }

    private TokenResult refreshToken(TokenResult tokenResult) {
        String refreshToken = tokenResult.refreshToken;
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
        MultiValueMap<String, String> map = new LinkedMultiValueMap<>();
        map.add("refresh_token", refreshToken);
        map.add("grant_type", "refresh_token");
        map.add("client_id", keycloakConfiguration.getClientId());
        map.add("client_secret", keycloakConfiguration.getClientSecret());

        HttpEntity<MultiValueMap<String, String>> entity = new HttpEntity<>(map, headers);

        TokenResult updated = restTemplate.exchange(
                format("%s/realms/%s/protocol/openid-connect/token",
                        this.keycloakConfiguration.getAuthUrl(),
                        this.keycloakConfiguration.getAuthRealm()),
                HttpMethod.POST,
                entity,
                TokenResult.class
        ).getBody();

        if (updated == null) {
            throw new IllegalStateException("Updated token is null");
        }
        return updated;
    }

    private String tokenExchangeUrl() {
        if (this.keycloakConfiguration.getTokenExchangeUrl() != null
         && !this.keycloakConfiguration.getTokenExchangeUrl().equalsIgnoreCase("")) {
            return keycloakConfiguration.getTokenExchangeUrl();
        }

        return format("%s/realms/%s/protocol/openid-connect/token",
                this.keycloakConfiguration.getAuthUrl(),
                this.keycloakConfiguration.getAuthRealm());
    }

    private TokenResult getNewAccessToken(String user) {
        log.info("Getting new access token on initial start...");
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
        MultiValueMap<String, String> map = new LinkedMultiValueMap<>();
        map.add("grant_type", "password");
        map.add("username", keycloakConfiguration.getUsername());
        map.add("password", keycloakConfiguration.getPassword());
        String clientId = keycloakConfiguration.getClientId();
        map.add("client_id", clientId);
        String clientSecret = keycloakConfiguration.getClientSecret();
        map.add("client_secret", clientSecret);
        HttpEntity<MultiValueMap<String, String>> entity = new HttpEntity<>(map, headers);

        String authUrl = tokenExchangeUrl();

        TokenResult token = restTemplate.exchange(authUrl,
                HttpMethod.POST, entity, TokenResult.class).getBody();
        if (token == null) {
            throw new IllegalStateException(format("Unable to get token for user %s", user));
        }
        map = new LinkedMultiValueMap<>();
        map.add("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange");
        map.add("client_id", clientId);

        map.add("client_secret", clientSecret);
        map.add("subject_token", token.getAccessToken());
        map.add("requested_subject", user);
        map.add("subject_token_type", "urn:ietf:params:oauth:token-type:access_token");
        entity = new HttpEntity<>(map, headers);

        token = restTemplate.exchange(authUrl,
                HttpMethod.POST, entity, TokenResult.class).getBody();
        return Optional.ofNullable(token)
                .orElseThrow(() ->
                        new IllegalStateException(format("Unable to get token for user %s", user)));
    }

    @Data
    public static class TokenResult {
        @JsonProperty("access_token")
        private String accessToken;
        @JsonProperty("refresh_token")
        private String refreshToken;
    }

    JwtTimestampValidator jwtTimestampValidator() {
        return this.jwtTimestampValidator;
    }
}
