/*
 * Copyright 2022 Nedra Team
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package digital.nedra.commons.starter.keycloak.session.config.support;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.stereotype.Component;

@RequiredArgsConstructor
@Slf4j
@Component
public class KeycloakOauth2UserService extends OidcUserService {

  public static final String REALM_ACCESS = "realm_access";
  public static final String RESOURCE_ACCESS = "resource_access";
  private static final OAuth2Error INVALID_REQUEST =
      new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST);

  private final JwtDecoder jwtDecoder;
  private final KeycloakAuthoritiesExtractor keycloakAuthoritiesExtractor;
  @Value("${spring.security.oauth2.client.provider.sso.user-name-attribute}")
  private String nameAttribute;

  /**
   * Augments {@link OidcUserService#loadUser(OidcUserRequest)} to add authorities
   * provided by Keycloak.
   * Needed because {@link OidcUserService#loadUser(OidcUserRequest)} (currently)
   * does not provide a hook for adding custom authorities from a{@link OidcUserRequest}.
   */
  @Override
  public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
    OidcUser user = super.loadUser(userRequest);
    Collection<GrantedAuthority> keycloakAuthorities = extractKeycloakAuthorities(userRequest);
    return withAuthorities(user, keycloakAuthorities);
  }

  /**
   * Creates a new {@link OidcUser} with additional given {@code authorities}.
   */
  private OidcUser withAuthorities(OidcUser user,
                                   Collection<? extends GrantedAuthority> authorities) {
    return new DefaultOidcUser(
        authorities,
        user.getIdToken(),
        user.getUserInfo(),
        nameAttribute
    );
  }

  /**
   * Extracts {@link GrantedAuthority GrantedAuthorities} from the AccessToken in
   * the {@link OidcUserRequest}.
   */
  private Collection<GrantedAuthority> extractKeycloakAuthorities(OidcUserRequest userRequest) {
    Jwt token = parseJwt(userRequest.getAccessToken().getTokenValue());

    String clientId = userRequest.getClientRegistration().getClientId();
    if (log.isTraceEnabled()) {
      log.trace("Client name: {}", clientId);
    }

    List<GrantedAuthority> clientAuthorities =
        Optional.ofNullable(token.getClaimAsMap(RESOURCE_ACCESS))
            .map(m -> keycloakAuthoritiesExtractor.extractClientAuthorities(clientId, m))
            .orElseGet(Collections::emptyList);

    List<GrantedAuthority> realmAuthorities =
        Optional.ofNullable(token.getClaimAsMap(REALM_ACCESS))
            .map(keycloakAuthoritiesExtractor::extractRealmAuthorities)
            .orElseGet(Collections::emptyList);

    return Stream.concat(clientAuthorities.stream(), realmAuthorities.stream()).toList();
  }

  private Jwt parseJwt(String accessTokenValue) {
    try {
      // Token is already verified by spring security infrastructure.
      return jwtDecoder.decode(accessTokenValue);
    } catch (JwtException e) {
      throw new OAuth2AuthenticationException(INVALID_REQUEST, e);
    }
  }
}
