package com.pusher.client.connection.websocket;

import com.pusher.client.channel.impl.ChannelManager;
import com.pusher.client.connection.ConnectionEventListener;
import com.pusher.client.connection.ConnectionState;
import com.pusher.client.connection.ConnectionStateChange;
import com.pusher.client.connection.impl.InternalConnectionManager;
import jdk.incubator.http.HttpClient;
import jdk.incubator.http.WebSocket;

import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

import static com.pusher.client.util.PusherJsonParser.getEscapedJsonValue;
import static com.pusher.client.util.PusherJsonParser.getJsonValue;
import static java.nio.charset.StandardCharsets.UTF_8;

public final class WebSocketConnectionManager implements InternalConnectionManager, WebSocket.Listener {

  private static final System.Logger log = System.getLogger(WebSocketConnectionManager.class.getName());

  private static final String INTERNAL_EVENT_PREFIX = "pusher:";
  private static final byte[] PING_MSG = "ping".getBytes(UTF_8);

  private final HttpClient httpClient;
  private final Executor listenerExecutor;
  private final ScheduledExecutorService timers;
  private final ChannelManager channelManager;
  private final Map<ConnectionState, Set<ConnectionEventListener>> eventListeners;
  private final URI webSocketUri;
  private final int maxReconnectionGap;
  private final long pongTimeout;

  private volatile ConnectionState state = ConnectionState.DISCONNECTED;
  private volatile CompletableFuture<WebSocket> webSocket;
  private volatile StringBuilder partialMsgBuilder;
  private volatile String socketId;
  private volatile int reconnectAttempts;
  private volatile long activityTimeout;
  private volatile ScheduledFuture<?> connectFuture;
  private volatile long lastActivity;
  private volatile long lastPing;
  private volatile ScheduledFuture<?> activityMonitorFuture;

  public WebSocketConnectionManager(
      final String name,
      final URI webSocketUri,
      final HttpClient httpClient,
      final ChannelManager channelManager,
      final long activityTimeout,
      final long pongTimeout,
      final int maxReconnectionGap,
      final Executor listenerExecutor) {
    this.webSocketUri = webSocketUri;
    this.httpClient = httpClient == null ? HttpClient.newHttpClient() : httpClient;
    this.channelManager = channelManager;
    this.activityTimeout = activityTimeout;
    this.pongTimeout = pongTimeout;
    this.eventListeners = new ConcurrentHashMap<>();
    this.maxReconnectionGap = maxReconnectionGap;
    this.listenerExecutor = listenerExecutor;
    this.timers = Executors.newScheduledThreadPool(1, new PusherThreadFactory(name));
  }

  private static final class PusherThreadFactory implements ThreadFactory {

    private final String namePrefix;
    private final AtomicInteger nextId = new AtomicInteger(1);

    private PusherThreadFactory(final String name) {
      this.namePrefix = name + "-pusher-java-client-connection-monitor-";
    }

    @Override
    public Thread newThread(final Runnable r) {
      final var t = new Thread(null, r, namePrefix + nextId.getAndIncrement(), 0, false);
      t.setDaemon(true);
      return t;
    }
  }

  /* ConnectionManager implementation */

  @Override
  public void connect() {
    if (state == ConnectionState.DISCONNECTED) {
      tryConnecting();
    }
  }

  private void tryConnecting() {
    synchronized (timers) {
      try {
        switch (state) {
          case RECONNECTING:
          case DISCONNECTED:
            break;
          case CONNECTING:
          case CONNECTED:
          case DISCONNECTING:
          default:
            return;
        }
        updateState(ConnectionState.CONNECTING);
        this.webSocket = httpClient.newWebSocketBuilder()
            .connectTimeout(Duration.ofMillis(pongTimeout))
            .buildAsync(webSocketUri, this);
      } catch (final RuntimeException e) {
        sendErrorToAllListeners("Error connecting to WebSocket server.", null, e);
      }
    }
  }

  @Override
  public void disconnect() {
    synchronized (timers) {
      if (state == ConnectionState.CONNECTED) {
        webSocket = webSocket.join().sendClose(WebSocket.NORMAL_CLOSURE, "");
        updateState(ConnectionState.DISCONNECTING);
      }
    }
  }

  @Override
  public void bind(final ConnectionState state, final ConnectionEventListener eventListener) {
    eventListeners.computeIfAbsent(state, _state -> Collections.newSetFromMap(new ConcurrentHashMap<>())).add(eventListener);
  }

  @Override
  public boolean unbind(final ConnectionState state, final ConnectionEventListener eventListener) {
    final var listeners = eventListeners.get(state);
    return listeners != null && listeners.remove(eventListener);
  }

  @Override
  public ConnectionState getState() {
    return state;
  }

  /* InternalConnectionManager implementation detail */

  @Override
  public void sendMessage(final String message) {
    synchronized (timers) {
      try {
        webSocket = webSocket.join().sendText(message, true);
      } catch (final Exception e) {
        sendErrorToAllListeners("An exception occurred while sending message [" + message + "]", null, e);
      }
    }
  }

  @Override
  public String getSocketId() {
    return socketId;
  }

  /* implementation detail */

  private void updateState(final ConnectionState newState) {
    log.log(System.Logger.Level.TRACE, "State transition requested, current [" + state + "], new [" + newState + "]");
    final var change = new ConnectionStateChange(state, newState);
    state = newState;
    final Set<ConnectionEventListener> dedupe = new HashSet<>();
    final var allStateListeners = eventListeners.get(ConnectionState.ALL);
    if (allStateListeners != null) {
      allStateListeners.stream().filter(dedupe::add)
          .forEach(listener -> listenerExecutor.execute(() -> listener.onConnectionStateChange(change)));
    }
    final var newStateListeners = eventListeners.get(newState);
    if (newStateListeners != null) {
      newStateListeners.stream().filter(dedupe::add)
          .forEach(listener -> listenerExecutor.execute(() -> listener.onConnectionStateChange(change)));
    }
  }

  private void handleEvent(final String rawJson) {
    final var event = getJsonValue(rawJson, "\"event\":");
    if (event == null) {
      throw new IllegalArgumentException("Message does not contain an event field: " + rawJson);
    }
    if (event.startsWith(INTERNAL_EVENT_PREFIX)) {
      handleInternalEvent(event, rawJson);
      return;
    }
    channelManager.onMessage(event, rawJson);
  }

  private void handleInternalEvent(final String event, final String rawJson) {
    if (event.equals("pusher:connection_established")) {
      handleConnectionMessage(rawJson);
    } else if (event.equals("pusher:error")) {
      handleError(rawJson);
    }
  }

  private void handleError(final String rawJson) {
    final var message = getJsonValue(rawJson, "\"message\":");
    final var code = getJsonValue(rawJson, "\"code\":");
    sendErrorToAllListeners(message, code, null);
  }

  private void sendErrorToAllListeners(final String message, final String code, final Throwable e) {
    final Set<ConnectionEventListener> dedupe = new HashSet<>();
    eventListeners.values().stream().flatMap(Set::stream).filter(dedupe::add)
        .forEach(listener -> listenerExecutor.execute(() -> listener.onError(message, code, e)));
  }

  /* WebSocketListener implementation */

  @Override
  public void onOpen(final WebSocket webSocket) {
    webSocket.request(Long.MAX_VALUE);
  }


  @Override
  public CompletionStage<?> onPing(final WebSocket webSocket, final ByteBuffer message) {
    lastActivity = System.currentTimeMillis();
    return webSocket.sendPong(message);
  }

  @Override
  public CompletionStage<?> onPong(final WebSocket webSocket, final ByteBuffer message) {
    lastActivity = System.currentTimeMillis();
    return null;
  }

  @Override
  public CompletionStage<?> onText(final WebSocket webSocket,
                                   final CharSequence message,
                                   final WebSocket.MessagePart part) {
    switch (part) {
      case WHOLE:
        final var msgCopy = message.toString();
        listenerExecutor.execute(() -> handleEvent(msgCopy));
        return null;
      case FIRST:
        if (partialMsgBuilder == null) {
          partialMsgBuilder = new StringBuilder(message);
          return null;
        }
        partialMsgBuilder.setLength(0);
        partialMsgBuilder.append(message);
        return null;
      case PART:
        partialMsgBuilder.append(message);
        return null;
      case LAST:
        final var fullMessage = partialMsgBuilder.append(message).toString();
        listenerExecutor.execute(() -> handleEvent(fullMessage));
        return null;
      default:
        throw new IllegalStateException(part + " WebSocket messages are not currently supported. Please open an issue with replication details.");
    }
  }

  @Override
  public CompletionStage<?> onClose(final WebSocket webSocket,
                                    final int statusCode,
                                    final String reason) {
    switch (state) {
      case DISCONNECTED:
      case RECONNECTING:
        log.log(System.Logger.Level.WARNING,
            "Received close from underlying socket when already disconnected.  Close code ["
                + statusCode + "], Reason [" + reason + "]");
        break;
      case CONNECTING:
      case CONNECTED:
        tryReconnecting();
        break;
      case DISCONNECTING:
        cancelTimeoutsAndTransitionToDisconnected();
        break;
      default:
        break;
    }
    partialMsgBuilder = null;
    return null;
  }

  @Override
  public void onError(final WebSocket webSocket, final Throwable error) {
    sendErrorToAllListeners("An exception was thrown by the WebSocket.", null, error);
    switch (state) {
      case CONNECTING:
      case CONNECTED:
        tryReconnecting();
        break;
      case DISCONNECTING:
        cancelTimeoutsAndTransitionToDisconnected();
        break;
      default:
        break;
    }
    partialMsgBuilder = null;
  }

  private void tryReconnecting() {
    synchronized (timers) {
      if (connectFuture != null && !connectFuture.isDone()) {
        return;
      }
      reconnectAttempts++;
      updateState(ConnectionState.RECONNECTING);
      final long reconnectInterval = Math.min(maxReconnectionGap, Math.abs(reconnectAttempts * reconnectAttempts));
      connectFuture = timers.schedule(this::tryConnecting, reconnectInterval, TimeUnit.SECONDS);
    }
  }

  private void cancelTimeoutsAndTransitionToDisconnected() {
    if (activityMonitorFuture != null) {
      activityMonitorFuture.cancel(false);
    }
    updateState(ConnectionState.DISCONNECTED);
  }

  private void handleConnectionMessage(final String rawJson) {
    socketId = getEscapedJsonValue(rawJson, "\\\"socket_id\\\":");
    if (state != ConnectionState.CONNECTED) {
      updateState(ConnectionState.CONNECTED);
    }
    final var activityTimeoutString = getEscapedJsonValue(rawJson, "\\\"activity_timeout\\\":");
    if (activityTimeoutString != null) {
      activityTimeout = Math.min(activityTimeout, 1_000 * Long.parseLong(activityTimeoutString));
    }
    lastActivity = System.currentTimeMillis();
    scheduleActivityCheck();
    reconnectAttempts = 0;
  }

  private void scheduleActivityCheck() {
    activityMonitorFuture = timers.scheduleWithFixedDelay(() -> {
      final long now = System.currentTimeMillis();
      if (now - lastActivity < activityTimeout) {
        return;
      }
      if (lastPing > lastActivity && now - lastPing > pongTimeout) {
        log.log(System.Logger.Level.TRACE, "Timed out awaiting pong from server - disconnecting");
        disconnect();
        // Proceed immediately to handle the close
        // The WebSocketClient will attempt a graceful WebSocket shutdown by exchanging the close frames
        // but may not succeed if this disconnect was called due to pong timeout...
        onClose(null, -1, "Pong timeout");
        return;
      }
      synchronized (timers) {
        switch (state) {
          case CONNECTING:
            tryReconnecting();
            return;
          case CONNECTED:
            log.log(System.Logger.Level.TRACE, "Sending ping");
            webSocket = webSocket.join().sendPing(ByteBuffer.wrap(PING_MSG)).whenComplete((ws, ex) -> {
              if (ex == null) {
                lastPing = System.currentTimeMillis();
              }
            });
            return;
          default:
        }

      }
    }, activityTimeout, pongTimeout / 2, TimeUnit.MILLISECONDS);
  }
}
