package org.springframework.boot.autoconfigure.web;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Method;
import java.text.SimpleDateFormat;
import java.util.Collections;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.MessageSource;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.exception.AbstractException;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.util.Assert;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.util.SystemPropertyUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.FieldError;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;

@ConfigurationProperties(prefix = "error")
public class DefaultErrorAttributesCustom extends org.springframework.boot.autoconfigure.web.DefaultErrorAttributes {
  protected final Log logger = LogFactory.getLog(getClass());
  public static final Pattern PATTERN = Pattern.compile("(password[^=:\"\\s]*)([=:\"\\s]*)(\\w*)", Pattern.CASE_INSENSITIVE | Pattern.DOTALL | Pattern.MULTILINE);
  public static final String LINE_SEPARATOR = SystemPropertyUtils.resolvePlaceholders("${line.separator:\r\n}");

  private Map<String, Number> status = new LinkedCaseInsensitiveMap<Number>();

  public static final String DEFAULT_ERROR = "error";
  public static final String DEFAULT_ERROR_TIMESTAMP = "error_timestamp";
  public static final String DEFAULT_ERROR_MESSAGE = "error_message";
  public static final String DEFAULT_ERROR_DESCRIPTION = "error_description";
  public static final String DEFAULT_ERROR_TRACE = "error_trace";
  public static final String DEFAULT_ERROR_STATUS = "error_status";
  public static final String DEFAULT_ERROR_URI = "error_uri";
  public static final String DEFAULT_ERROR_THROWABLE = "error_throwable";
  public static final String DEFAULT_ERROR_EXCEPTION = "error_exception";
  public static final String DEFAULT_ERRORS = "errors";
  public static final String DEFAULT_ERRORS_COUNT = "errors_count";

  private String error = DEFAULT_ERROR;
  private String error_timestamp = DEFAULT_ERROR_TIMESTAMP;
  private String error_message = DEFAULT_ERROR_MESSAGE;
  private String error_description = DEFAULT_ERROR_DESCRIPTION;
  private String error_trace = DEFAULT_ERROR_TRACE;
  private String error_status = DEFAULT_ERROR_STATUS;
  private String error_uri = DEFAULT_ERROR_URI;
  private String error_throwable = DEFAULT_ERROR_THROWABLE;
  private String error_exception = DEFAULT_ERROR_EXCEPTION;
  private String errors = DEFAULT_ERRORS;
  private String errors_count = DEFAULT_ERRORS_COUNT;

  private MessageSource messageSource;
  // @Override
  // public ModelAndView resolveException(HttpServletRequest request, HttpServletResponse
  // response, Object handler, Exception ex) {
  // // TODO Auto-generated method stub
  // return super.resolveException(request, response, handler, ex);
  // }

  public DefaultErrorAttributesCustom(MessageSource messageSource) {
    Assert.notNull(messageSource, "'messageSource' must not be null");
    this.messageSource = messageSource;
  }

  @Override
  public Map<String, Object> getErrorAttributes(RequestAttributes requestAttributes, boolean includeStackTrace) {
    Map<String, Object> errorAttributes = super.getErrorAttributes(requestAttributes, includeStackTrace);
    errorAttributes.put(this.error_timestamp, new Date());

    Throwable exception = getError(requestAttributes);
    int status = getStatus(requestAttributes);
    if (exception == null) {
      errorAttributes.put(this.error_description, messageSource.getMessage(String.valueOf(status), null, "", getLocale()));
    }
    else {
      print(exception, requestAttributes, includeStackTrace);
      boolean abstractException = exception instanceof AbstractException;

      if (abstractException) {
        status = ((AbstractException) exception).getStatus();
      }
      ResponseStatus responseStatus = AnnotationUtils.findAnnotation(exception.getClass(), ResponseStatus.class);
      if (responseStatus != null) {
        status = responseStatus.value().value();
      }

      String exceptionCode = exception.getClass().getName();
      status = this.status.getOrDefault(exceptionCode, status).intValue();
      // status
      if (requestAttributes instanceof ServletRequestAttributes) {
        ((ServletRequestAttributes) requestAttributes).getRequest().setAttribute("javax.servlet.error.status_code", status);
      }
      // errorAttributes.put(error_throwable, exception);
      errorAttributes.put(error_exception, exception.getClass().getName());
      errorAttributes.put(error_status, status);
      errorAttributes.put(error_message, exception.getMessage());
      String code = getCode(exception);
      errorAttributes.put(this.error, getCode(exception));
      errorAttributes.put(error_uri, getPath(requestAttributes));
      if (includeStackTrace) {
        errorAttributes.put(error_trace, getStackTrace(exception));
      }
      String localizedMessage = exception.getLocalizedMessage();
      if (StringUtils.hasText(localizedMessage) && localizedMessage.contains("; nested exception is ")) {
        localizedMessage = localizedMessage.substring(0, localizedMessage.indexOf("; nested exception is "));
      }
      Object[] args = abstractException ? ((AbstractException) exception).getArgs() : null;

      BindingResult result = extractBindingResult(exception);
      if (result != null) {
        errorAttributes.put(this.errors, result);
        errorAttributes.put(this.errors_count, result.getErrorCount());

        Map<String, String> errors = new LinkedHashMap<String, String>();
        for (FieldError fieldError : result.getFieldErrors()) {
          String key = fieldError.getObjectName() + '.' + fieldError.getField();
          errors.put(key, messageSource.getMessage(key, fieldError.getArguments(), fieldError.getDefaultMessage(), getLocale()));
        }
        errorAttributes.put(this.errors, errors);
      }
      errorAttributes.put(this.error_description, messageSource.getMessage(code, args, localizedMessage, getLocale()));
    }
    return errorAttributes;
  }

  protected Integer getStatus(RequestAttributes requestAttributes) {
    if (requestAttributes instanceof ServletRequestAttributes) {
      Integer statusCode = (Integer) ((ServletRequestAttributes) requestAttributes).getRequest().getAttribute("javax.servlet.error.status_code");
      if (statusCode != null) {
        return statusCode;
      }
    }
    return HttpStatus.INTERNAL_SERVER_ERROR.value();
  }

  protected void print(Throwable exception, RequestAttributes requestAttributes, boolean includeStackTrace) {
    if (requestAttributes instanceof ServletRequestAttributes) {
      String message = getMessage(exception, ((ServletRequestAttributes) requestAttributes).getRequest(), includeStackTrace);
      if (logger.isTraceEnabled()) {
        logger.trace(message, exception);
      }
      else if (logger.isErrorEnabled()) {
        logger.error(message);
      }
    }
  }

  private String getMessage(Throwable exception, HttpServletRequest request, boolean includeStackTrace) {
    StringBuilder stringBuilder = new StringBuilder(LINE_SEPARATOR);
    stringBuilder.append("TIME: ");
    stringBuilder.append(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'Z", Locale.getDefault()).format(new Date()));
    stringBuilder.append(LINE_SEPARATOR);
    stringBuilder.append(LINE_SEPARATOR);

    stringBuilder.append("SERVER: ");
    stringBuilder.append(request.getScheme());
    stringBuilder.append("://");
    stringBuilder.append(request.getServerName());
    stringBuilder.append(':');
    stringBuilder.append(request.getServerPort());

    String requestURI = getRequestUri(request);
    stringBuilder.append(requestURI);

    String queryString = getQueryString(request);
    stringBuilder.append(queryString);

    stringBuilder.append(LINE_SEPARATOR);
    stringBuilder.append(LINE_SEPARATOR);

    stringBuilder.append("REMOTE: ");
    stringBuilder.append(request.getRemoteAddr());
    stringBuilder.append(':');
    stringBuilder.append(request.getRemotePort());
    stringBuilder.append(LINE_SEPARATOR);
    stringBuilder.append(LINE_SEPARATOR);

    stringBuilder.append("EXCEPTION: ");
    stringBuilder.append(exception.getMessage());
    stringBuilder.append(LINE_SEPARATOR);
    stringBuilder.append(LINE_SEPARATOR);

    if (includeStackTrace) {
      stringBuilder.append("TRACE: ");
      StringWriter stackTrace = new StringWriter();
      exception.printStackTrace(new PrintWriter(stackTrace));
      stackTrace.flush();
      stringBuilder.append(stackTrace);
      stringBuilder.append(LINE_SEPARATOR);
      stringBuilder.append(LINE_SEPARATOR);
    }

    stringBuilder.append("RAW: ");
    stringBuilder.append(LINE_SEPARATOR);
    stringBuilder.append(request.getMethod());
    stringBuilder.append(' ');
    stringBuilder.append(request.getScheme());
    stringBuilder.append("://");
    stringBuilder.append(request.getServerName());
    stringBuilder.append(':');
    stringBuilder.append(request.getServerPort());
    stringBuilder.append(requestURI);
    stringBuilder.append(queryString);
    stringBuilder.append(' ');
    stringBuilder.append(request.getProtocol());
    stringBuilder.append(LINE_SEPARATOR);

    for (String headerName : Collections.list(request.getHeaderNames())) {
      stringBuilder.append(headerName);
      stringBuilder.append(": ");
      if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265
        stringBuilder.append(StringUtils.collectionToDelimitedString(Collections.list(request.getHeaders(headerName)), "; "));
      }
      else {
        stringBuilder.append(StringUtils.collectionToCommaDelimitedString(Collections.list(request.getHeaders(headerName))));
      }
      stringBuilder.append(LINE_SEPARATOR);
    }

    stringBuilder.append(LINE_SEPARATOR);

    StringBuilder parameter = new StringBuilder();
    ContentCachingRequestWrapper contentCachingRequestWrapper = WebUtils.getNativeRequest(request, ContentCachingRequestWrapper.class);
    if (contentCachingRequestWrapper == null) {
      for (String parameterName : Collections.list(request.getParameterNames())) {
        if (StringUtils.hasText(parameter)) {
          parameter.append('&');
        }
        parameter.append(parameterName);
        parameter.append('=');
        parameter.append(StringUtils.arrayToCommaDelimitedString(request.getParameterValues(parameterName)));
      }
    }
    else {
      byte[] buf = contentCachingRequestWrapper.getContentAsByteArray();
      if (buf.length > 0) {
        try {
          String body = new String(buf, 0, buf.length, contentCachingRequestWrapper.getCharacterEncoding());
          Matcher matcher = PATTERN.matcher(body);
          while (matcher.find()) {
            body = body.replaceFirst(matcher.group(), matcher.group(1) + matcher.group(2) + "<masked>");
          }
          stringBuilder.append(body);
        }
        catch (UnsupportedEncodingException e) {
          // ignore
        }
      }
    }

    stringBuilder.append(new String(parameter).replaceAll("password=[^&]*", "password=<masked>"));

    ContentCachingResponseWrapper contentCachingResponseWrapper = WebUtils.getNativeRequest(request, ContentCachingResponseWrapper.class);
    if (contentCachingResponseWrapper != null) {
      byte[] buf = contentCachingResponseWrapper.getContentAsByteArray();
      if (buf.length > 0) {
        try {
          stringBuilder.append(new String(buf, 0, buf.length, contentCachingResponseWrapper.getCharacterEncoding()));
          contentCachingResponseWrapper.copyBodyToResponse();
        }
        catch (IOException e) {
          // ignore
        }
      }
    }
    return new String(stringBuilder);
  }

  private Locale getLocale() {
    Locale locale = LocaleContextHolder.getLocale();
    if (locale == null) {
      locale = Locale.getDefault();
    }
    return locale;
  }

  /**
   * @see org.springframework.web.util.WebUtils#ERROR_EXCEPTION_ATTRIBUTE
   * @see org.springframework.security.web.WebAttributes#ACCESS_DENIED_403
   * @see org.springframework.security.web.WebAttributes#AUTHENTICATION_EXCEPTION
   */
  @Override
  public Throwable getError(RequestAttributes requestAttributes) {
    Throwable exception = super.getError(requestAttributes);
    if (exception == null) {
      exception = getAttribute(requestAttributes, DispatcherServlet.EXCEPTION_ATTRIBUTE);
    }
    if (exception == null) {
      exception = getAttribute(requestAttributes, "SPRING_SECURITY_403_EXCEPTION");
    }
    if (exception == null) {
      exception = getSessionAttribute(requestAttributes, "SPRING_SECURITY_403_EXCEPTION");
    }
    if (exception == null) {
      exception = getAttribute(requestAttributes, "SPRING_SECURITY_LAST_EXCEPTION");
    }
    if (exception == null) {
      exception = getSessionAttribute(requestAttributes, "SPRING_SECURITY_LAST_EXCEPTION");
    }
    while (exception instanceof ServletException && exception.getCause() != null) {
      exception = ((ServletException) exception).getCause();
    }
    return exception;
  }

  @SuppressWarnings("unchecked")
  private <T> T getAttribute(RequestAttributes requestAttributes, String name) {
    return (T) requestAttributes.getAttribute(name, RequestAttributes.SCOPE_REQUEST);
  }

  @SuppressWarnings("unchecked")
  private <T> T getSessionAttribute(RequestAttributes requestAttributes, String name) {
    return (T) requestAttributes.getAttribute(name, RequestAttributes.SCOPE_SESSION);
  }

  /**
   * <code>
    StringBuilder stringBuilder = new StringBuilder();
    for (char c : Conventions.getVariableName(exception).toCharArray()) {
      if (Character.isUpperCase(c)) {
        stringBuilder.append(".");
        stringBuilder.append(Character.toLowerCase(c));
      }
      else {
        stringBuilder.append(Character.toLowerCase(c));
      }
    }
    return new String(stringBuilder);
   * </code>
   * @see org.springframework.core.Constants#propertyToConstantNamePrefix(String)
   */
  private String getCode(Throwable exception) {
    Method method = ReflectionUtils.findMethod(exception.getClass(), "getCode");
    if (method != null) {
      Object invokeMethod = ReflectionUtils.invokeMethod(method, exception);
      if (invokeMethod instanceof String) {
        return (String) invokeMethod;
      }
    }
    return exception.getClass().getName();
  }

  /**
   * @see org.springframework.web.bind.MethodArgumentNotValidException#getBindingResult()
   */
  private BindingResult extractBindingResult(Throwable exception) {
    if (exception instanceof BindingResult) {
      return (BindingResult) exception;
    }
    Method method = ReflectionUtils.findMethod(exception.getClass(), "getBindingResult");
    if (method != null) {
      Object invokeMethod = ReflectionUtils.invokeMethod(method, exception);
      if (invokeMethod instanceof BindingResult) {
        return (BindingResult) invokeMethod;
      }
    }
    return null;
  }

  private String getStackTrace(Throwable exception) {
    StringWriter stackTrace = new StringWriter();
    exception.printStackTrace(new PrintWriter(stackTrace));
    stackTrace.flush();
    return stackTrace.toString();
  }

  private String getPath(RequestAttributes requestAttributes) {
    String path = getAttribute(requestAttributes, WebUtils.ERROR_REQUEST_URI_ATTRIBUTE);
    if (path == null) {
      path = getAttribute(requestAttributes, WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE);
    }
    if (path == null && requestAttributes instanceof ServletRequestAttributes) {
      path = ((ServletRequestAttributes) requestAttributes).getRequest().getRequestURI();
    }
    return path;
  }

  /**
   * @see org.springframework.web.util.WebUtils#INCLUDE_REQUEST_URI_ATTRIBUTE
   */
  public static String getRequestUri(HttpServletRequest request) {
    Object requestURI = request.getAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE);
    if (requestURI == null) {
      requestURI = request.getAttribute(WebUtils.FORWARD_REQUEST_URI_ATTRIBUTE);
    }
    if (requestURI == null) {
      requestURI = request.getAttribute(WebUtils.INCLUDE_REQUEST_URI_ATTRIBUTE);
    }
    if (requestURI == null) {
      requestURI = request.getRequestURI();
    }
    if (requestURI == null) {
      return "";
    }
    return String.valueOf(requestURI).replaceAll("\\r\\n|\\r|\\n", " ");
  }

  public static String getQueryString(HttpServletRequest request) {
    Object queryString = request.getAttribute(WebUtils.FORWARD_QUERY_STRING_ATTRIBUTE);
    if (queryString == null) {
      queryString = request.getAttribute(WebUtils.INCLUDE_QUERY_STRING_ATTRIBUTE);
    }
    if (queryString == null) {
      queryString = request.getQueryString();
    }
    if (queryString == null) {
      return "";
    }
    return '?' + String.valueOf(queryString).replaceAll("\\r\\n|\\r|\\n", " ");
  }

  public Map<String, Number> getStatus() {
    return status;
  }

  public void setStatus(Map<String, Number> status) {
    this.status = status;
  }
}
