package org.zowe.apiml.gateway.loadbalancer;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Clock;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.MalformedJwtException;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.RequestDataContext;
import org.springframework.cloud.client.loadbalancer.reactive.ReactiveLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SameInstancePreferenceServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.web.server.ResponseStatusException;
import org.zowe.apiml.gateway.caching.LoadBalancerCache;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/zowe/apiml/gateway/loadbalancer/DeterministicLoadBalancer.class */
public class DeterministicLoadBalancer extends SameInstancePreferenceServiceInstanceListSupplier {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(DeterministicLoadBalancer.class);
    private static final String HEADER_NONE_SIGNATURE = Base64.getEncoder().encodeToString("{\"typ\":\"JWT\",\"alg\":\"none\"}".getBytes(StandardCharsets.UTF_8));
    private final LoadBalancerCache cache;
    private final Clock clock;
    private final int expirationTime;

    public DeterministicLoadBalancer(ServiceInstanceListSupplier serviceInstanceListSupplier, ReactiveLoadBalancer.Factory<ServiceInstance> factory, LoadBalancerCache loadBalancerCache, Clock clock, int i) {
        super(serviceInstanceListSupplier, factory);
        this.cache = loadBalancerCache;
        this.clock = clock;
        this.expirationTime = i;
        log.debug("StickySessionLoadBalancer instantiated");
    }

    public Flux<List<ServiceInstance>> get(Request request) {
        String serviceId = getServiceId();
        if (serviceId == null) {
            return Flux.empty();
        }
        AtomicReference atomicReference = new AtomicReference();
        return this.delegate.get(request).flatMap(list -> {
            return getSub(request.getContext()).switchIfEmpty(Mono.just("")).flatMap(str -> {
                if (str == null || str.isEmpty()) {
                    log.debug("No authentication present on request, not filtering the service: {}", serviceId);
                    return Mono.empty();
                }
                atomicReference.set(str);
                return this.cache.retrieve(str, serviceId).onErrorResume(th -> {
                    return Mono.empty();
                });
            }).switchIfEmpty(Mono.just(LoadBalancerCache.LoadBalancerCacheRecord.NONE)).flatMapMany(loadBalancerCacheRecord -> {
                return filterInstances((String) atomicReference.get(), serviceId, loadBalancerCacheRecord, list, request.getContext());
            });
        }).doOnError(th -> {
            log.debug("Error in determining service instances", th);
        });
    }

    private boolean isTooOld(LocalDateTime localDateTime) {
        return LocalDateTime.now().minusHours(this.expirationTime).isAfter(localDateTime);
    }

    private Mono<String> getSub(Object obj) {
        return obj instanceof RequestDataContext ? Mono.just(extractSubFromToken((String) Optional.ofNullable((List) ((RequestDataContext) obj).getClientRequest().getCookies().get("apimlAuthenticationToken")).map(list -> {
            return (String) list.get(0);
        }).orElse(""))) : Mono.just("");
    }

    private Flux<List<ServiceInstance>> filterInstances(String str, String str2, LoadBalancerCache.LoadBalancerCacheRecord loadBalancerCacheRecord, List<ServiceInstance> list, Object obj) {
        if (!shouldIgnore(list, str)) {
            return (StringUtils.isNotBlank(loadBalancerCacheRecord.getInstanceId()) && isTooOld(loadBalancerCacheRecord.getCreationTime())) ? this.cache.delete(str, str2).thenMany(chooseOne(str, list)) : StringUtils.isNotBlank(loadBalancerCacheRecord.getInstanceId()) ? chooseOne(loadBalancerCacheRecord.getInstanceId(), str, list) : chooseOne(str, list);
        }
        try {
            return Flux.just(checkInstanceIdHeader(getInstanceId(obj), list));
        } catch (ResponseStatusException e) {
            return Flux.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Service instance not found for the provided instance ID"));
        }
    }

    private String getInstanceId(Object obj) {
        if (obj instanceof RequestDataContext) {
            return getInstanceFromHeader((RequestDataContext) obj);
        }
        return null;
    }

    private String getInstanceFromHeader(RequestDataContext requestDataContext) {
        HttpHeaders headers;
        if (requestDataContext == null || requestDataContext.getClientRequest() == null || (headers = requestDataContext.getClientRequest().getHeaders()) == null) {
            return null;
        }
        return headers.getFirst("X-InstanceId");
    }

    private List<ServiceInstance> checkInstanceIdHeader(String str, List<ServiceInstance> list) {
        if (str == null) {
            return list;
        }
        List<ServiceInstance> list2 = list.stream().filter(serviceInstance -> {
            return str.equals(serviceInstance.getInstanceId());
        }).toList();
        if (list2.isEmpty()) {
            throw new ResponseStatusException(HttpStatus.NOT_FOUND, "Service instance not found for the provided instance ID");
        }
        return list2;
    }

    private Flux<List<ServiceInstance>> chooseOne(String str, String str2, List<ServiceInstance> list) {
        Stream<ServiceInstance> stream = list.stream();
        if (str != null) {
            stream = stream.filter(serviceInstance -> {
                return str.equals(serviceInstance.getInstanceId());
            });
        }
        ServiceInstance orElse = stream.findAny().orElse(list.get(0));
        return this.cache.store(str2, orElse.getServiceId(), new LoadBalancerCache.LoadBalancerCacheRecord(orElse.getInstanceId())).thenMany(Flux.just(Collections.singletonList(orElse)));
    }

    private Flux<List<ServiceInstance>> chooseOne(String str, List<ServiceInstance> list) {
        return chooseOne(null, str, list);
    }

    boolean shouldIgnore(List<ServiceInstance> list, String str) {
        return StringUtils.isEmpty(str) || list.isEmpty() || !lbTypeIsAuthentication(list.get(0));
    }

    private boolean lbTypeIsAuthentication(ServiceInstance serviceInstance) {
        String str;
        Map metadata = serviceInstance.getMetadata();
        return (metadata == null || (str = (String) metadata.get("apiml.lb.type")) == null || !str.equals("authentication")) ? false : true;
    }

    private String removeJwtSign(String str) {
        if (str == null) {
            return null;
        }
        int indexOf = str.indexOf(46);
        int lastIndexOf = str.lastIndexOf(46);
        if (indexOf < 0 || indexOf >= lastIndexOf) {
            throw new MalformedJwtException("Invalid JWT format");
        }
        return HEADER_NONE_SIGNATURE + str.substring(indexOf, lastIndexOf + 1);
    }

    private Claims getJwtClaims(String str) {
        try {
            return (Claims) Jwts.parser().unsecured().clock(this.clock).build().parseUnsecuredClaims(removeJwtSign(str)).getPayload();
        } catch (RuntimeException e) {
            log.debug("Exception when trying to parse the JWT token {}", str);
            return null;
        }
    }

    private String extractSubFromToken(String str) {
        Claims jwtClaims;
        return (str.isEmpty() || (jwtClaims = getJwtClaims(str)) == null) ? "" : jwtClaims.getSubject();
    }
}
