package io.relayr.java.websocket;

import com.google.gson.Gson;

import org.eclipse.paho.client.mqttv3.MqttException;

import java.util.HashMap;
import java.util.Map;

import javax.inject.Inject;
import javax.inject.Singleton;

import io.relayr.java.api.ChannelApi;
import io.relayr.java.model.DataPackage;
import io.relayr.java.model.Device;
import io.relayr.java.model.action.Reading;
import io.relayr.java.model.channel.ChannelDefinition;
import io.relayr.java.model.channel.DataChannel;
import io.relayr.java.model.channel.PublishChannel;
import rx.Observable;
import rx.Observer;
import rx.Subscriber;
import rx.functions.Action1;
import rx.functions.Func1;
import rx.schedulers.Schedulers;
import rx.subjects.PublishSubject;

@Singleton
public class WebSocketClient {

    final ChannelApi mChannelApi;
    final WebSocket<DataChannel> mWebSocket;
    final Map<String, DataChannel> mDeviceChannels = new HashMap<>();
    final Map<String, DataChannel> mPublishChannels = new HashMap<>();
    final Map<String, PublishSubject<Reading>> mSocketConnections = new HashMap<>();

    @Inject
    public WebSocketClient(ChannelApi channelApi, WebSocketFactory factory) {
        mChannelApi = channelApi;
        mWebSocket = factory.createWebSocket();
    }

    public Observable<Reading> subscribe(Device device) {
        return subscribe(device.getId());
    }

    public Observable<Reading> subscribe(String deviceId) {
        if (mSocketConnections.containsKey(deviceId))
            return mSocketConnections.get(deviceId);
        else
            return start(deviceId);
    }

    public void unSubscribe(final String deviceId) {
        if (mSocketConnections.containsKey(deviceId)) {
            mSocketConnections.get(deviceId).onCompleted();
            mSocketConnections.remove(deviceId);
        }

        if (!mDeviceChannels.isEmpty() && mDeviceChannels.containsKey(deviceId))
            if (mWebSocket.unSubscribe(mDeviceChannels.get(deviceId).getCredentials().getTopic()))
                mDeviceChannels.remove(deviceId);
    }

    public Observable<Void> publish(final String deviceId, final Reading payload) {
        Observable<Void> observable = Observable.create(new Observable.OnSubscribe<Void>() {
            @Override public void call(final Subscriber<? super Void> subscriber) {
                if (mPublishChannels.containsKey(deviceId))
                    publish(deviceId, payload, subscriber);
                else
                    mChannelApi.createForDevice(new ChannelDefinition(deviceId, "mqtt"), deviceId)
                            .flatMap(new Func1<PublishChannel, Observable<DataChannel>>() {
                                @Override
                                public Observable<DataChannel> call(PublishChannel channel) {
                                    return mWebSocket.createClient(channel);
                                }
                            })
                            .subscribe(new Observer<DataChannel>() {
                                @Override public void onCompleted() {}

                                @Override public void onError(Throwable e) {
                                    mPublishChannels.remove(deviceId);
                                    subscriber.onError(e);
                                }

                                @Override public void onNext(DataChannel channel) {
                                    if (!mPublishChannels.containsKey(deviceId))
                                        mPublishChannels.put(deviceId, channel);
                                    publish(deviceId, payload, subscriber);
                                }
                            });
            }
        }).subscribeOn(Schedulers.io());

        observable.subscribe(new Observer<Void>() {
            @Override public void onCompleted() {}

            @Override public void onError(Throwable e) {
            }

            @Override public void onNext(Void aVoid) {}
        });
        return observable;
    }

    private void publish(String deviceId, Reading payload, Subscriber<? super Void> subscriber) {
        try {
            mWebSocket.publish(mPublishChannels.get(deviceId).getCredentials().getTopic() + "data",
                    new Gson().toJson(payload));
            subscriber.onNext(null);
        } catch (MqttException e) {
            subscriber.onError(e);
        }
    }

    private synchronized Observable<Reading> start(final String deviceId) {
        final PublishSubject<Reading> subject = PublishSubject.create();
        mSocketConnections.put(deviceId, subject);

        mChannelApi.create(new ChannelDefinition(deviceId, "mqtt"))
                .flatMap(new Func1<DataChannel, Observable<DataChannel>>() {
                    @Override
                    public Observable<DataChannel> call(final DataChannel channel) {
                        return mWebSocket.createClient(channel);
                    }
                })
                .subscribeOn(Schedulers.newThread())
                .subscribe(new Subscriber<DataChannel>() {
                    @Override
                    public void onCompleted() {
                    }

                    @Override
                    public void onError(Throwable e) {
                        e.printStackTrace();
                        mSocketConnections.remove(deviceId);
                    }

                    @Override
                    public void onNext(DataChannel channel) {
                        subscribeToChannel(channel, deviceId, subject);
                    }
                });

        return subject.doOnError(new Action1<Throwable>() {
            @Override
            public void call(Throwable throwable) {
                unSubscribe(deviceId);
            }
        });
    }

    private void subscribeToChannel(final DataChannel channel, final String deviceId,
                                    final PublishSubject<Reading> subject) {
        mWebSocket.subscribe(channel.getCredentials().getTopic(), channel.getChannelId(), new WebSocketCallback() {
            @Override
            public void connectCallback(Object message) {
                if (!mDeviceChannels.containsKey(deviceId))
                    mDeviceChannels.put(deviceId, channel);
            }

            @Override
            public void disconnectCallback(Object message) {
                subject.onError((Throwable) message);
                mDeviceChannels.remove(deviceId);
                mSocketConnections.remove(deviceId);
            }

            @Override
            public void successCallback(Object message) {
                DataPackage dataPackage = new Gson().fromJson(message.toString(), DataPackage.class);
                for (DataPackage.Data dataPoint : dataPackage.readings) {
                    subject.onNext(new Reading(dataPackage.received, dataPoint.recorded,
                            dataPoint.meaning, dataPoint.path, dataPoint.value));
                }
            }

            @Override
            public void errorCallback(Throwable e) {
                subject.onError(e);
                mDeviceChannels.remove(deviceId);
                mSocketConnections.remove(deviceId);
            }
        });
    }
}

