/*
 * Decompiled with CFR 0.152.
 */
package io.airlift.security.jwks;

import com.google.common.io.Closer;
import io.airlift.concurrent.Threads;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.http.client.ResponseHandler;
import io.airlift.http.client.StringResponseHandler;
import io.airlift.log.Logger;
import io.airlift.security.jwks.JwksDecoder;
import io.airlift.units.Duration;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.security.PublicKey;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;

public final class JwksService {
    private static final Logger log = Logger.get(JwksService.class);
    private final URI address;
    private final HttpClient httpClient;
    private final Duration refreshDelay;
    private final AtomicReference<Map<String, PublicKey>> keys;
    private Closer closer;

    public JwksService(URI address, HttpClient httpClient, Duration refreshDelay) {
        this.address = Objects.requireNonNull(address, "address is null");
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
        this.refreshDelay = Objects.requireNonNull(refreshDelay, "refreshDelay is null");
        this.keys = new AtomicReference<Map<String, PublicKey>>(this.fetchKeys());
    }

    @PostConstruct
    public synchronized void start() {
        if (this.closer != null) {
            return;
        }
        this.closer = Closer.create();
        ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor(Threads.daemonThreadsNamed((String)"JWKS loader"));
        this.closer.register(executorService::shutdownNow);
        ScheduledFuture<?> refreshJob = executorService.scheduleWithFixedDelay(() -> {
            try {
                this.refreshKeys();
            }
            catch (Throwable e) {
                log.error(e, "Error fetching JWKS keys");
            }
        }, this.refreshDelay.toMillis(), this.refreshDelay.toMillis(), TimeUnit.MILLISECONDS);
        this.closer.register(() -> refreshJob.cancel(true));
    }

    @PreDestroy
    public synchronized void stop() {
        if (this.closer == null) {
            return;
        }
        try {
            this.closer.close();
        }
        catch (IOException e) {
            throw new UncheckedIOException("Error stopping JWKS service", e);
        }
        finally {
            this.closer = null;
        }
    }

    public Map<String, PublicKey> getKeys() {
        return this.keys.get();
    }

    public Optional<PublicKey> getKey(String keyId) {
        return Optional.ofNullable(this.keys.get().get(keyId));
    }

    public void refreshKeys() {
        this.keys.set(this.fetchKeys());
    }

    private Map<String, PublicKey> fetchKeys() {
        StringResponseHandler.StringResponse response;
        Request request = Request.Builder.prepareGet().setUri(this.address).build();
        try {
            response = (StringResponseHandler.StringResponse)this.httpClient.execute(request, (ResponseHandler)StringResponseHandler.createStringResponseHandler());
        }
        catch (RuntimeException e) {
            throw new RuntimeException("Error reading JWKS keys from " + this.address, e);
        }
        if (response.getStatusCode() != 200) {
            throw new RuntimeException("Unexpected response code " + response.getStatusCode() + " from JWKS service at " + this.address);
        }
        try {
            return JwksDecoder.decodeKeys(response.getBody());
        }
        catch (RuntimeException e) {
            throw new RuntimeException("Unable to decode JWKS response from " + this.address, e);
        }
    }
}

