/*
 * Copyright 2012 Takao Nakaguchi
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.trie4j.doublearray;

import java.io.Externalizable;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

import org.trie4j.AbstractTermIdTrie;
import org.trie4j.Node;
import org.trie4j.TermIdNode;
import org.trie4j.TermIdTrie;
import org.trie4j.Trie;
import org.trie4j.bv.BytesRank1OnlySuccinctBitVector;
import org.trie4j.bv.SuccinctBitVector;
import org.trie4j.util.BitSet;
import org.trie4j.util.FastBitSet;
import org.trie4j.util.Pair;

public class UnifiedIntDoubleArray
extends AbstractTermIdTrie
implements Externalizable, TermIdTrie{
	public static interface TermNodeListener{
		void listen(Node node, int nodeIndex);
	}

	public UnifiedIntDoubleArray() {
	}

	public UnifiedIntDoubleArray(Trie trie){
		this(trie, trie.size() * 2);
	}

	public UnifiedIntDoubleArray(Trie trie, int arraySize){
		this(trie, arraySize, new TermNodeListener(){
			@Override
			public void listen(Node node, int nodeIndex) {
			}
		});
	}

	public UnifiedIntDoubleArray(Trie trie, int arraySize, TermNodeListener listener){
		if(arraySize <= 1) arraySize = 2;
		size = trie.size();
		baseAndCheckInt = new int[]{};
		extend(arraySize - 1);
		FastBitSet bs = new FastBitSet(arraySize);
		nodeSize = 1; // for root node because it has no letter;
		build(trie.getRoot(), 0, bs, listener);
		term = new BytesRank1OnlySuccinctBitVector(bs.getBytes(), bs.size());
		baseAndCheckInt = Arrays.copyOf(baseAndCheckInt, (last + chars.size()) * 2);
	}

	@Override
	public int nodeSize() {
		return nodeSize;
	}

	@Override
	public int size() {
		return size;
	}

	@Override
	public TermIdNode getRoot() {
		return newDoubleArrayNode(0);
	}

	public int[] getBaseAndCheck(){
		return baseAndCheckInt;
	}

	public BitSet getTerm() {
		return term;
	}

	protected class DoubleArrayNode implements TermIdNode{
		public DoubleArrayNode(int nodeId){
			this.nodeId = nodeId;
		}

		public DoubleArrayNode(int nodeId, char firstChar){
			this.nodeId = nodeId;
			this.firstChar = firstChar;
		}

		@Override
		public boolean isTerminate() {
			return term.get(nodeId);
		}

		@Override
		public char[] getLetters() {
			StringBuilder ret = new StringBuilder();
			if(firstChar != 0) ret.append(firstChar);
			return ret.toString().toCharArray();
		}

		@Override
		public DoubleArrayNode[] getChildren() {
			CharSequence children = listupChildChars(nodeId);
			if(children.length() == 0) return emptyNodes;
			return listupChildNodes(getBase(nodeId), children);
		}

		@Override
		public DoubleArrayNode getChild(char c) {
			int code = charToCode[c];
			if(code == -1) return null;
			int nid = getBase(nodeId) + code;
			if(nid >= 0 && nid < getBaseAndCheckLength() && getCheck(nid) == nodeId) return new DoubleArrayNode(nid, c);
			return null;
		}

		public int getNodeId() {
			return nodeId;
		}

		@Override
		public int getTermId(){
			if(!term.get(nodeId)){
				return -1;
			}
			return term.rank1(nodeId) - 1;
		}
	
		private CharSequence listupChildChars(int nodeId){
			StringBuilder b = new StringBuilder();
			int bs = getBase(nodeId);
			for(char c : chars){
				int nid = bs + charToCode[c];
				if(nid >= 0 && nid < getBaseAndCheckLength() && getCheck(nid) == nodeId){
					b.append(c);
				}
			}
			return b;
		}

		private DoubleArrayNode[] listupChildNodes(int base, CharSequence chars){
			int n = chars.length();
			DoubleArrayNode[] ret = new DoubleArrayNode[n];
			for(int i = 0; i < n; i++){
				char c = chars.charAt(i);
				char code = charToCode[c];
				ret[i] = newDoubleArrayNode(base + code, c);
			}
			return ret;
		}

		private char firstChar = 0;
		private int nodeId;
	}

	@Override
	public boolean contains(String text){
		int nodeIndex = 0; // root
		int n = text.length();
		for(int i = 0; i < n; i++){
			char cid = charToCode[text.charAt(i)];
			if(cid == 0) return false;
			int next = getBase(nodeIndex) + cid;
			if(next < 0 || getCheck(next) != nodeIndex) return false;
			nodeIndex = next;
		}
		return term.get(nodeIndex);
	}

	public int getNodeId(String text) {
		int nodeIndex = 0; // root
		int n = text.length();
		for(int i = 0; i < n; i++){
			char cid = charToCode[text.charAt(i)];
			if(cid == 0) return -1;
			int next = getBase(nodeIndex) + cid;
			if(next < 0 || getCheck(next) != nodeIndex) return -1;
			nodeIndex = next;
		}
		return nodeIndex;
	}

	@Override
	public int getTermId(String text) {
		int nid = getNodeId(text);
		if(nid == -1) return -1;
		return term.get(nid) ? term.rank1(nid) - 1 : -1;
	}

	@Override
	public Iterable<String> commonPrefixSearch(String query) {
		List<String> ret = new ArrayList<String>();
		char[] chars = query.toCharArray();
		int charsLen = chars.length;
		int checkLen = getBaseAndCheckLength();
		int nodeIndex = 0;
		for(int i = 0; i < charsLen; i++){
			int cid = findCharId(chars[i]);
			if(cid == -1) return ret;
			int b = getBase(nodeIndex);
			if(b == BASE_EMPTY) return ret;
			int next = b + cid;
			if(next >= checkLen || getCheck(next) != nodeIndex) return ret;
			nodeIndex = next;
			if(term.get(nodeIndex)) ret.add(new String(chars, 0, i + 1));
		}
		return ret;
	}

	@Override
	public Iterable<Pair<String, Integer>> commonPrefixSearchWithTermId(
			String query) {
		List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>();
		char[] chars = query.toCharArray();
		int charsLen = chars.length;
		int checkLen = getBaseAndCheckLength();
		int nodeIndex = 0;
		for(int i = 0; i < charsLen; i++){
			int cid = findCharId(chars[i]);
			if(cid == -1) return ret;
			int b = getBase(nodeIndex);
			if(b == BASE_EMPTY) return ret;
			int next = b + cid;
			if(next >= checkLen || getCheck(next) != nodeIndex) return ret;
			nodeIndex = next;
			if(term.get(nodeIndex)){
				ret.add(Pair.create(
					new String(chars, 0, i + 1),
					term.rank1(nodeIndex) - 1
					));
			}
		}
		return ret;
	}

	@Override
	public int findWord(CharSequence chars, int start, int end, StringBuilder word) {
		for(int i = start; i < end; i++){
			int nodeIndex = 0;
			try{
				for(int j = i; j < end; j++){
					int cid = findCharId(chars.charAt(j));
					if(cid == -1) break;
					int b = getBase(nodeIndex);
					if(b == BASE_EMPTY) break;
					int next = b + cid;
					if(nodeIndex != getCheck(next)) break;
					nodeIndex = next;
					if(term.get(nodeIndex)){
						if(word != null) word.append(chars, i, j + 1);
						return i;
					}
				}
			} catch(ArrayIndexOutOfBoundsException e){
				break;
			}
		}
		return -1;
	}

	@Override
	public Iterable<String> predictiveSearch(String prefix) {
		List<String> ret = new ArrayList<String>();
		char[] chars = prefix.toCharArray();
		int charsLen = chars.length;
		int checkLen = getBaseAndCheckLength();
		int nodeIndex = 0;
		for(int i = 0; i < charsLen; i++){
			int cid = findCharId(chars[i]);
			if(cid == -1) return ret;
			int next = getBase(nodeIndex) + cid;
			if(next < 0 || next >= checkLen || getCheck(next) != nodeIndex) return ret;
			nodeIndex = next;
		}
		if(term.get(nodeIndex)){
			ret.add(prefix);
		}
		Deque<Pair<Integer, String>> q = new LinkedList<Pair<Integer, String>>();
		q.add(Pair.create(nodeIndex, prefix));
		while(!q.isEmpty()){
			Pair<Integer, String> p = q.pop();
			int ni = p.getFirst();
			int b = getBase(ni);
			if(b == BASE_EMPTY) continue;
			String c = p.getSecond();
			for(char v : this.chars){
				int next = b + charToCode[v];
				if(next < 0 || next >= checkLen) continue;
				if(getCheck(next) == ni){
					String n = new StringBuilder(c).append(v).toString();
					if(term.get(next)){
						ret.add(n);
					}
					q.push(Pair.create(next, n));
				}
			}
		}
		return ret;
	}

	@Override
	public Iterable<Pair<String, Integer>> predictiveSearchWithTermId(
			String prefix) {
		List<Pair<String, Integer>> ret = new ArrayList<Pair<String, Integer>>();
		char[] chars = prefix.toCharArray();
		int charsLen = chars.length;
		if(charsLen == 0) return ret;
		if(this.nodeSize == 0) return ret;
		int checkLen = getBaseAndCheckLength();
		int nodeIndex = 0;
		for(int i = 0; i < charsLen; i++){
			int cid = findCharId(chars[i]);
			if(cid == -1) return ret;
			int next = getBase(nodeIndex) + cid;
			if(next < 0 || next >= checkLen || getCheck(next) != nodeIndex) return ret;
			nodeIndex = next;
		}
		if(term.get(nodeIndex)){
			ret.add(Pair.create(prefix, term.rank1(nodeIndex) - 1));
		}
		Deque<Pair<Integer, String>> q = new LinkedList<Pair<Integer, String>>();
		q.add(Pair.create(nodeIndex, prefix));
		while(!q.isEmpty()){
			Pair<Integer, String> p = q.pop();
			int ni = p.getFirst();
			int b = getBase(ni);
			if(b == BASE_EMPTY) continue;
			String c = p.getSecond();
			for(char v : this.chars){
				int next = b + charToCode[v];
				if(next < 0 || next >= checkLen) continue;
				if(getCheck(next) == ni){
					String n = new StringBuilder(c).append(v).toString();
					if(term.get(next)){
						ret.add(Pair.create(
								n,
								term.rank1(next) - 1
								));
					}
					q.push(Pair.create(next, n));
				}
			}
		}
		return ret;
	}

	@Override
	public void writeExternal(ObjectOutput out) throws IOException {
		out.writeInt(size);
		out.writeInt(nodeSize);
		out.writeInt(baseAndCheckInt.length);
		for(int v : baseAndCheckInt){
			out.writeInt(v);
		}
		out.writeObject(term);
		out.writeInt(firstEmptyCheck);
		out.writeInt(chars.size());
		for(char c : chars){
			out.writeChar(c);
			out.writeChar(charToCode[c]);
		}
	}

	public void save(OutputStream os) throws IOException{
		ObjectOutputStream out = new ObjectOutputStream(os);
		try{
			writeExternal(out);
		} finally{
			out.flush();
		}
	}

	@Override
	public void readExternal(ObjectInput in) throws IOException,
			ClassNotFoundException {
		size = in.readInt();
		nodeSize = in.readInt();
		int len = in.readInt();
		baseAndCheckInt = new int[len];
		for(int i = 0; i < len; i++){
			baseAndCheckInt[i] = in.readInt();
		}
		try{
			term = (SuccinctBitVector)in.readObject();
		} catch(ClassNotFoundException e){
			throw new IOException(e);
		}
		firstEmptyCheck = in.readInt();
		int n = in.readInt();
		for(int i = 0; i < n; i++){
			char c = in.readChar();
			char v = in.readChar();
			chars.add(c);
			charToCode[c] = v;
		}
	}

	public void load(InputStream is) throws IOException{
		try{
			readExternal(new ObjectInputStream(is));
		} catch(ClassNotFoundException e){
			throw new IOException(e);
		}
	}

	@Override
	public void trimToSize(){
		int sz = last + 1 + 0xFFFF;
		baseAndCheckInt = Arrays.copyOf(baseAndCheckInt, sz * 2);
	}

	@Override
	public void dump(Writer w){
		PrintWriter writer = new PrintWriter(w);
		try{
			int n = Math.min(16, last);
			writer.println("array size: " + getBaseAndCheckLength());
			writer.print("      |");
			for(int i = 0; i < n; i++){
				writer.print(String.format("%3d|", i));
			}
			writer.println();
			writer.print("|base |");
			for(int i = 0; i < n; i++){
				if(getBase(i) == BASE_EMPTY){
					writer.print("N/A|");
				} else{
					writer.print(String.format("%3d|", getBase(i)));
				}
			}
			writer.println();
			writer.print("|check|");
			for(int i = 0; i < n; i++){
				if(getCheck(i) < 0){
					writer.print("N/A|");
				} else{
					writer.print(String.format("%3d|", getCheck(i)));
				}
			}
			writer.println();
			writer.print("|term |");
			for(int i = 0; i < n; i++){
				writer.print(String.format("%3d|", term.get(i) ? 1 : 0));
			}
			writer.println();
			writer.print("chars: ");
			int c = 0;
			for(char e : chars){
				writer.print(String.format("%c:%d,", e, (int)charToCode[e]));
				c++;
				if(c > 16) break;
			}
			writer.println();
			writer.println("chars count: " + chars.size());
			writer.println();
		} finally{
			writer.flush();
		}
	}

	private void build(Node node, int nodeIndex,
			FastBitSet bs, TermNodeListener listener){
		// letters
		char[] letters = node.getLetters();
		int lettersLen = letters.length;
		if(lettersLen > 0) nodeSize++; // for first letter
		for(int i = 1; i < lettersLen; i++){
			bs.unsetIfLE(nodeIndex);
			int cid = getCharId(letters[i]);
			int empty = findFirstEmptyCheck();
			setCheck(empty, nodeIndex);
			setBaseFast(nodeIndex, empty - cid);
			nodeSize++;
			nodeIndex = empty;
		}
		if(node.isTerminate()){
			bs.set(nodeIndex);
			listener.listen(node, nodeIndex);
		} else{
			bs.unsetIfLE(nodeIndex);
		}

		// children
		Node[] children = node.getChildren();
		int childrenLen = children.length;
		if(childrenLen == 0) return;
		int[] heads = new int[childrenLen];
		int maxHead = 0;
		int minHead = Integer.MAX_VALUE;
		for(int i = 0; i < childrenLen; i++){
			heads[i] = getCharId(children[i].getLetters()[0]);
			maxHead = Math.max(maxHead, heads[i]);
			minHead = Math.min(minHead, heads[i]);
		}

		int offset = findInsertOffset(heads, minHead, maxHead);
		setBaseFast(nodeIndex, offset);
		for(int cid : heads){
			setCheck(offset + cid, nodeIndex);
		}
/*
		for(int i = 0; i < children.length; i++){
			build(children[i], offset + heads[i]);
		}
/*/
		// sort children by children's children count.
		Map<Integer, List<Pair<Node, Integer>>> nodes = new TreeMap<Integer, List<Pair<Node, Integer>>>(new Comparator<Integer>() {
			@Override
			public int compare(Integer arg0, Integer arg1) {
				return arg1 - arg0;
			}
		});
		for(int i = 0; i < children.length; i++){
			Node[] c = children[i].getChildren();
			int n = 0;
			if(c != null){
				n = c.length;
			}
			List<Pair<Node, Integer>> p = nodes.get(n);
			if(p == null){
				p = new ArrayList<Pair<Node, Integer>>();
				nodes.put(n, p);
			}
			p.add(Pair.create(children[i], heads[i]));
		}
		for(Map.Entry<Integer, List<Pair<Node, Integer>>> e : nodes.entrySet()){
			for(Pair<Node, Integer> e2 : e.getValue()){
				build(e2.getFirst(), e2.getSecond() + offset, bs, listener);
			}
		}
//*/
	}

	private DoubleArrayNode newDoubleArrayNode(int id){
		return new DoubleArrayNode(id);
	}

	private DoubleArrayNode newDoubleArrayNode(int id, char s){
		return new DoubleArrayNode(id, s);
	}

	private int findCharId(char c){
		char v = charToCode[c];
		if(v != 0) return v;
		return -1;
	}

	private int findInsertOffset(int[] heads, int minHead, int maxHead){
		for(int empty = findFirstEmptyCheck(); ; empty = findNextEmptyCheck(empty)){
			int offset = empty - minHead;
			if((offset + maxHead) >= getBaseAndCheckLength()){
				extend(offset + maxHead);
			}
			// find space
			boolean found = true;
			for(int cid : heads){
				if(getCheck(offset + cid) >= 0){
					found = false;
					break;
				}
			}
			if(found) return offset;
		}
	}

	private int getCharId(char c){
		char v = charToCode[c];
		if(v != 0) return v;
		v = (char)(chars.size() + 1);
		chars.add(c);
		charToCode[c] = v;
		return v;
	}

	private void extend(int i){
		int sz = getBaseAndCheckLength();
		int nsz = Math.max(i + 0xFFFF, (int)(sz * 1.5));
//		System.out.println("extend to " + nsz);
		baseAndCheckInt = Arrays.copyOf(baseAndCheckInt, nsz * 2);
		for(int idx = sz * 2; idx < nsz * 2; idx++){
			baseAndCheckInt[idx++] = BASE_EMPTY;
			baseAndCheckInt[idx] = CHECK_EMPTY;
		}
	}

	private int findFirstEmptyCheck(){
		int i = firstEmptyCheck;
		while(getCheck(i) >= 0 || getBase(i) != BASE_EMPTY){
			i++;
		}
		firstEmptyCheck = i;
		return i;
	}

	private int findNextEmptyCheck(int i){
/*
		for(i++; i < check.length; i++){
			if(check[i] < 0) return i;
		}
		extend(i);
		return i;
/*/
		int d = getCheck(i) * -1;
		if(d <= 0){
			throw new RuntimeException();
		}
		int prev = i;
		i += d;
		if(getBaseAndCheckLength() <= i){
			extend(i);
			return i;
		}
		if(getCheck(i) < 0){
			return i;
		}
		for(i++; i < getBaseAndCheckLength(); i++){
			if(getCheck(i) < 0){
				setCheckFast(prev, prev - i);
				return i;
			}
		}
		extend(i);
		setCheckFast(prev, prev - i);
		return i;
//*/
	}

	private int getBaseAndCheckLength(){
		return baseAndCheckInt.length / 2;
	}

	private int getBase(int index){
		return baseAndCheckInt[index * 2];
	}
	private void setBaseFast(int index, int value){
		baseAndCheckInt[index * 2] = value;
	}

	private int getCheck(int index){
		return baseAndCheckInt[index * 2 + 1];
	}
	private void setCheckFast(int index, int value){
		baseAndCheckInt[index * 2 + 1] = value;
	}
	private void setCheck(int index, int value){
		if(firstEmptyCheck == index){
			firstEmptyCheck = findNextEmptyCheck(firstEmptyCheck);
		}
		baseAndCheckInt[index * 2 + 1] = value;
		last = Math.max(last, index);
	}

	private int size;
	private int nodeSize;
	private int[] baseAndCheckInt;
	private int firstEmptyCheck = 1;
	private int last;
	private SuccinctBitVector term;
	private Set<Character> chars = new TreeSet<Character>();
	private char[] charToCode = new char[Character.MAX_VALUE];
	private static final int BASE_EMPTY = 0x7fffffff;
	private static final int CHECK_EMPTY = 0xffffffff;
	private static final DoubleArrayNode[] emptyNodes = {};
}
