package com.pusher.client.connection.websocket;

import com.pusher.client.connection.ConnectionEventListener;
import com.pusher.client.connection.ConnectionState;
import com.pusher.client.connection.ConnectionStateChange;
import com.pusher.client.connection.impl.InternalConnection;
import com.pusher.client.util.Factory;
import jdk.incubator.http.HttpClient;
import jdk.incubator.http.WebSocket;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import static com.pusher.client.util.PusherJsonParser.getEscapedJsonValue;
import static com.pusher.client.util.PusherJsonParser.getJsonValue;

public final class WebSocketConnection implements InternalConnection, WebSocket.Listener {

  private static final System.Logger log = System.getLogger(WebSocketConnection.class.getName());

  private static final String INTERNAL_EVENT_PREFIX = "pusher:";
  private static final String PING_EVENT_SERIALIZED = "{\"event\": \"pusher:ping\"}";

  private final HttpClient httpClient;
  private final Factory factory;
  private final ActivityTimer activityTimer;
  private final Map<ConnectionState, Set<ConnectionEventListener>> eventListeners = new ConcurrentHashMap<>();
  private final URI webSocketUri;
  private final int maxReconnectionAttempts;
  private final int maxReconnectionGap;

  private volatile ConnectionState state = ConnectionState.DISCONNECTED;
  private WebSocket underlyingConnection;
  private String socketId;
  private int reconnectAttempts = 0;

  public WebSocketConnection(
      final HttpClient httpClient,
      final String url,
      final long activityTimeout,
      final long pongTimeout,
      int maxReconnectionAttempts,
      int maxReconnectionGap,
      final Factory factory) throws URISyntaxException {
    this.httpClient = httpClient;
    this.webSocketUri = new URI(url);
    this.activityTimer = new ActivityTimer(activityTimeout, pongTimeout);
    this.maxReconnectionAttempts = maxReconnectionAttempts;
    this.maxReconnectionGap = maxReconnectionGap;
    this.factory = factory;
    for (final var state : ConnectionState.values()) {
      this.eventListeners.put(state, Collections.newSetFromMap(new ConcurrentHashMap<>()));
    }
  }

  /* Connection implementation */

  @Override
  public void connect() {
    factory.queueOnEventThread(() -> {
      if (state == ConnectionState.DISCONNECTED) {
        tryConnecting();
      }
    });
  }

  private void tryConnecting() {
    try {
      final var futureWebSocket = (httpClient == null ? HttpClient.newHttpClient() : httpClient)
          .newWebSocketBuilder()
          .buildAsync(webSocketUri, this);
      updateState(ConnectionState.CONNECTING);
      underlyingConnection = futureWebSocket.join();
    } catch (final RuntimeException e) {
      sendErrorToAllListeners("Error connecting to WebSocket server.", null, e);
    }
  }

  @Override
  public void disconnect() {
    factory.queueOnEventThread(() -> {
      if (state == ConnectionState.CONNECTED) {
        updateState(ConnectionState.DISCONNECTING);
        underlyingConnection.sendClose(WebSocket.NORMAL_CLOSURE, "");
      }
    });
  }

  @Override
  public void bind(final ConnectionState state, final ConnectionEventListener eventListener) {
    eventListeners.get(state).add(eventListener);
  }

  @Override
  public boolean unbind(final ConnectionState state, final ConnectionEventListener eventListener) {
    return eventListeners.get(state).remove(eventListener);
  }

  @Override
  public ConnectionState getState() {
    return state;
  }

  /* InternalConnection implementation detail */

  @Override
  public void sendMessage(final String message) {
    factory.queueOnEventThread(() -> {
      try {
        if (state == ConnectionState.CONNECTED) {
          underlyingConnection.sendText(message, true);
        } else {
          sendErrorToAllListeners("Cannot send a message while in " + state + " state", null, null);
        }
      } catch (final Exception e) {
        sendErrorToAllListeners("An exception occurred while sending message [" + message + "]", null, e);
      }
    });
  }

  @Override
  public String getSocketId() {
    return socketId;
  }

  /* implementation detail */

  private void updateState(final ConnectionState newState) {
    log.log(System.Logger.Level.TRACE, "State transition requested, current [" + state + "], new [" + newState + "]");
    final var change = new ConnectionStateChange(state, newState);
    state = newState;
    final Set<ConnectionEventListener> dedupe = new HashSet<>();
    eventListeners.get(ConnectionState.ALL).stream().filter(dedupe::add)
        .forEach(listener -> listener.onConnectionStateChange(change));
    eventListeners.get(newState).stream().filter(dedupe::add)
        .forEach(listener -> listener.onConnectionStateChange(change));
  }

  private void handleEvent(final String rawJson) {
    final var event = getJsonValue(rawJson, "\"event\"");
    if (event == null) {
      throw new IllegalArgumentException("Message does not contain an event field: " + rawJson);
    }
    if (event.startsWith(INTERNAL_EVENT_PREFIX)) {
      handleInternalEvent(event, rawJson);
      return;
    }
    factory.getChannelManager().onMessage(event, rawJson);
  }

  private void handleInternalEvent(final String event, final String rawJson) {
    if (event.equals("pusher:connection_established")) {
      handleConnectionMessage(rawJson);
    } else if (event.equals("pusher:error")) {
      handleError(rawJson);
    }
  }

  @SuppressWarnings("rawtypes")
  private void handleConnectionMessage(final String rawJson) {
    socketId = getEscapedJsonValue(rawJson, "\\\"socket_id\\\"");
    if (state != ConnectionState.CONNECTED) {
      updateState(ConnectionState.CONNECTED);
    }
    reconnectAttempts = 0;
  }

  @SuppressWarnings("rawtypes")
  private void handleError(final String rawJson) {
    final var message = getJsonValue(rawJson, "\"message\"");
    final var code = getJsonValue(rawJson, "\"code\"");
    sendErrorToAllListeners(message, code, null);
  }

  private void sendErrorToAllListeners(final String message, final String code, final Throwable e) {
    final Set<ConnectionEventListener> dedupe = new HashSet<>();
    eventListeners.values().stream().flatMap(Set::stream).filter(dedupe::add)
        .forEach(listener -> factory.queueOnEventThread(() -> listener.onError(message, code, e)));
  }

  /* WebSocketListener implementation */

  @Override
  public void onOpen(final WebSocket webSocket) {
    webSocket.request(Long.MAX_VALUE);
    // TODO: log the handshake data
  }

  private StringBuilder partialBuilder;

  @Override
  @SuppressWarnings("unchecked")
  public CompletionStage<?> onText(final WebSocket webSocket,
                                   final CharSequence message,
                                   final WebSocket.MessagePart part) {
    activityTimer.activity();
    switch (part) {
      case WHOLE:
        final var msgCopy = message.toString();
        factory.queueOnEventThread(() -> handleEvent(msgCopy));
        return null;
      case FIRST:
        if (partialBuilder == null) {
          partialBuilder = new StringBuilder(message);
          return null;
        }
        partialBuilder.setLength(0);
        partialBuilder.append(message);
        return null;
      case PART:
        partialBuilder.append(message);
        return null;
      case LAST:
        final var fullMessage = partialBuilder.append(message).toString();
        factory.queueOnEventThread(() -> handleEvent(fullMessage));
        return null;
      default:
        throw new IllegalStateException(part + " WebSocket messages are not currently supported. Please open an issue with replication details.");
    }
  }

  @Override
  public CompletionStage<?> onClose(final WebSocket webSocket,
                                    final int statusCode,
                                    final String reason) {
    if (state == ConnectionState.DISCONNECTED || state == ConnectionState.RECONNECTING) {
      log.log(System.Logger.Level.WARNING, "Received close from underlying socket when already disconnected.  Close code ["
          + statusCode + "], Reason [" + reason + "]");
      partialBuilder = null;
      return null;
    }
    //Reconnection logic
    if (state == ConnectionState.CONNECTED || state == ConnectionState.CONNECTING) {
      if (reconnectAttempts < maxReconnectionAttempts) {
        tryReconnecting();
      } else {
        updateState(ConnectionState.DISCONNECTING);
        cancelTimeoutsAndTransitonToDisconnected();
      }
    } else if (state == ConnectionState.DISCONNECTING) {
      cancelTimeoutsAndTransitonToDisconnected();
    }
    partialBuilder = null;
    return null;
  }

  private void tryReconnecting() {
    reconnectAttempts++;
    updateState(ConnectionState.RECONNECTING);
    final long reconnectInterval = Math.min(maxReconnectionGap, reconnectAttempts * reconnectAttempts);
    factory.getTimers().schedule(this::tryConnecting, reconnectInterval, TimeUnit.SECONDS);
  }

  private void cancelTimeoutsAndTransitonToDisconnected() {
    activityTimer.cancelTimeouts();
    factory.queueOnEventThread(() -> {
      updateState(ConnectionState.DISCONNECTED);
      factory.shutdownThreads();
    });
  }

  @Override
  public void onError(final WebSocket webSocket, final Throwable error) {
    factory.queueOnEventThread(() -> {
      // Do not change connection state as Java_WebSocket will also call onClose.
      // See:
      // https://github.com/leggetter/pusher-java-client/issues/8#issuecomment-16128590
      // updateState(ConnectionState.DISCONNECTED);
      sendErrorToAllListeners("An exception was thrown by the WebSocket", null, error);
    });
  }

  private class ActivityTimer {

    private final long activityTimeout;
    private final long pongTimeout;

    private Future<?> pingTimer;
    private Future<?> pongTimer;

    ActivityTimer(final long activityTimeout, final long pongTimeout) {
      this.activityTimeout = activityTimeout;
      this.pongTimeout = pongTimeout;
    }

    /**
     * On any activity from the server - Cancel pong timeout - Cancel
     * currently ping timeout and re-schedule
     */
    synchronized void activity() {
      if (pongTimer != null) {
        pongTimer.cancel(true);
      }
      if (pingTimer != null) {
        pingTimer.cancel(false);
      }
      pingTimer = factory.getTimers().schedule(() -> {
        log.log(System.Logger.Level.TRACE, "Sending ping");
        sendMessage(PING_EVENT_SERIALIZED);
        schedulePongCheck();
      }, activityTimeout, TimeUnit.MILLISECONDS);
    }

    /**
     * Cancel any pending timeouts, for example because we are disconnected.
     */
    synchronized void cancelTimeouts() {
      if (pingTimer != null) {
        pingTimer.cancel(false);
      }
      if (pongTimer != null) {
        pongTimer.cancel(false);
      }
    }

    /**
     * Called when a ping is sent to await the response - Cancel any
     * existing timeout - Schedule new one
     */
    private synchronized void schedulePongCheck() {
      if (pongTimer != null) {
        pongTimer.cancel(false);
      }
      pongTimer = factory.getTimers().schedule(() -> {
        log.log(System.Logger.Level.TRACE, "Timed out awaiting pong from server - disconnecting");
        disconnect();
        // Proceed immediately to handle the close
        // The WebSocketClient will attempt a graceful WebSocket shutdown by exchanging the close frames
        // but may not succeed if this disconnect was called due to pong timeout...
        onClose(underlyingConnection, -1, "Pong timeout");
      }, pongTimeout, TimeUnit.MILLISECONDS);
    }
  }
}
