/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class WNTM {
    public double alpha;
    public double beta;
    public int numTopics;
    public int numIterations;
    public int topWords;
    public double alphaSum;
    public double betaSum;
    public List<List<Integer>> corpus;
    Map<Integer, Map<Integer, Double>> wordGraph = new HashMap<Integer, Map<Integer, Double>>();
    Map<Integer, Integer> wordDegree = new HashMap<Integer, Integer>();
    public List<List<Integer>> pseudo_corpus;
    public int[][] z;
    public int numDocuments;
    public int numPseudoDocuments;
    public int numWordsInCorpus;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int vocabularySize;
    public int[][] pseudocTopicCount;
    public int[][] topicWordCount;
    public int[] sumTopicWordCount;
    public double[] multiPros;
    public double[][] phi;
    public int windowSize;
    public String folderPath;
    public String corpusPath;
    public String expName = "WNTMmodel";
    public String orgExpName = "WNTMmodel";
    public String tAssignsFilePath = "";
    public int savestep = 0;
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public WNTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, int inWindowSize) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inWindowSize, "WNTMmodel");
    }

    public WNTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, int inWindowSize, String inExpName) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inWindowSize, inExpName, "", 0);
    }

    public WNTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, int inWindowSize, String inExpName, String pathToTAfile) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inWindowSize, inExpName, pathToTAfile, 0);
    }

    public WNTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, int inWindowSize, String inExpName, int inSaveStep) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inWindowSize, inExpName, "", inSaveStep);
    }

    public WNTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, int inWindowSize, String inExpName, String pathToTAfile, int inSaveStep) throws Exception {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.numTopics = inNumTopics;
        this.windowSize = inWindowSize;
        this.numIterations = inNumIterations;
        this.topWords = inTopWords;
        this.savestep = inSaveStep;
        this.orgExpName = this.expName = inExpName;
        this.corpusPath = pathToCorpus;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.corpus = new ArrayList<List<Integer>>();
        this.pseudo_corpus = new ArrayList<List<Integer>>();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                ArrayList<Integer> document = new ArrayList<Integer>();
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document.add(this.word2IdVocabulary.get(word));
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document.add(indexWord);
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.size();
                this.corpus.add(document);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.vocabularySize = this.word2IdVocabulary.size();
        long startTime = System.currentTimeMillis();
        this.constructWordGraph();
        this.constructPseudoCorpus();
        this.numPseudoDocuments = this.pseudo_corpus.size();
        this.pseudocTopicCount = new int[this.numPseudoDocuments][this.numTopics];
        this.topicWordCount = new int[this.numTopics][this.vocabularySize];
        this.sumTopicWordCount = new int[this.numTopics];
        this.phi = new double[this.numTopics][this.vocabularySize];
        this.multiPros = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            this.multiPros[i] = 1.0 / (double)this.numTopics;
        }
        this.alphaSum = (double)this.numTopics * this.alpha;
        this.betaSum = (double)this.vocabularySize * this.beta;
        this.tAssignsFilePath = pathToTAfile;
        if (this.tAssignsFilePath.length() > 0) {
            this.initialize();
        } else {
            this.initialize();
        }
        this.initTime = System.currentTimeMillis() - startTime;
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.vocabularySize);
        System.out.println("Number of topics: " + this.numTopics);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + this.numIterations);
        System.out.println("Number of top topical words: " + this.topWords);
    }

    public boolean containsEdge(int p1, int p2) {
        if (!this.wordGraph.containsKey(p1)) {
            return false;
        }
        if (!this.wordGraph.containsKey(p2)) {
            return false;
        }
        if (!this.wordGraph.get(p1).containsKey(p2)) {
            return false;
        }
        return this.wordGraph.get(p2).containsKey(p1);
    }

    public void addEdge(int p1, int p2) {
        if (this.containsEdge(p1, p2)) {
            this.wordGraph.get(p1).put(p2, this.wordGraph.get(p1).get(p2) + 1.0);
            this.wordGraph.get(p2).put(p1, this.wordGraph.get(p2).get(p1) + 1.0);
            this.wordDegree.put(p1, this.wordDegree.get(p1) + 1);
            this.wordDegree.put(p2, this.wordDegree.get(p2) + 1);
            return;
        }
        if (!this.wordGraph.containsKey(p1)) {
            this.wordGraph.put(p1, new HashMap());
            this.wordDegree.put(p1, 0);
        }
        if (!this.wordGraph.containsKey(p2)) {
            this.wordGraph.put(p2, new HashMap());
            this.wordDegree.put(p2, 0);
        }
        this.wordGraph.get(p1).put(p2, 1.0);
        this.wordGraph.get(p2).put(p1, 1.0);
        this.wordDegree.put(p1, this.wordDegree.get(p1) + 1);
        this.wordDegree.put(p2, this.wordDegree.get(p2) + 1);
    }

    public void show() {
        Set<Map.Entry<Integer, Map<Integer, Double>>> set = this.wordGraph.entrySet();
        for (Map.Entry<Integer, Map<Integer, Double>> e : set) {
            Set<Map.Entry<Integer, Double>> temp = e.getValue().entrySet();
            if (temp.size() <= 0) continue;
            System.out.print(this.id2WordVocabulary.get(e.getKey()) + " -> ");
            for (Map.Entry<Integer, Double> e1 : temp) {
                System.out.print(this.id2WordVocabulary.get(e1.getKey()) + "(" + e1.getValue() + ") ");
            }
            System.out.println();
        }
    }

    public void constructWordGraph() {
        for (int i = 0; i < this.numDocuments; ++i) {
            int docSize = this.corpus.get(i).size();
            if (docSize <= this.windowSize) {
                for (int k = 0; k < docSize - 1; ++k) {
                    int wordId = this.corpus.get(i).get(k);
                    for (int m = k + 1; m < docSize; ++m) {
                        int nextId = this.corpus.get(i).get(m);
                        this.addEdge(wordId, nextId);
                    }
                }
                continue;
            }
            for (int j = 0; j < docSize - this.windowSize + 1; ++j) {
                for (int k = j; k < j + this.windowSize - 1; ++k) {
                    int wordId = this.corpus.get(i).get(k);
                    for (int m = k + 1; m < j + this.windowSize; ++m) {
                        int nextId = this.corpus.get(i).get(m);
                        this.addEdge(wordId, nextId);
                    }
                }
            }
        }
    }

    public void constructPseudoCorpus() {
        for (Map.Entry<Integer, Integer> entry : this.wordDegree.entrySet()) {
            System.out.println("key= " + entry.getKey() + " and value= " + entry.getValue());
        }
        Set<Map.Entry<Integer, Map<Integer, Double>>> set = this.wordGraph.entrySet();
        for (Map.Entry<Integer, Map<Integer, Double>> e : set) {
            Set<Map.Entry<Integer, Double>> temp = e.getValue().entrySet();
            ArrayList<Integer> onePseudo = new ArrayList<Integer>();
            if (temp.size() <= 0) continue;
            int degree1 = this.wordDegree.get(e.getKey());
            double activity = (double)degree1 / (double)temp.size();
            for (Map.Entry<Integer, Double> e1 : temp) {
                int degree2 = this.wordDegree.get(e1.getKey());
                double activity2 = (double)degree2 / (double)this.wordGraph.get(e1.getKey()).size();
                if (activity > activity2) {
                    activity = activity2;
                }
                int reweight = (int)Math.ceil(this.wordGraph.get(e.getKey()).get(e1.getKey()) / activity);
                for (int i = 0; i < reweight; ++i) {
                    onePseudo.add(e1.getKey());
                }
                this.wordGraph.get(e.getKey()).put(e1.getKey(), Double.valueOf(reweight));
            }
            this.pseudo_corpus.add(onePseudo);
        }
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initializing topic assignments ...");
        this.z = new int[this.numPseudoDocuments][];
        for (int i = 0; i < this.numPseudoDocuments; ++i) {
            int docSize = this.pseudo_corpus.get(i).size();
            this.z[i] = new int[docSize];
            for (int j = 0; j < docSize; ++j) {
                int topic = FuncUtils.nextDiscrete(this.multiPros);
                int[] nArray = this.pseudocTopicCount[i];
                int n = topic;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = this.topicWordCount[topic];
                int n2 = this.pseudo_corpus.get(i).get(j);
                nArray2[n2] = nArray2[n2] + 1;
                int n3 = topic;
                this.sumTopicWordCount[n3] = this.sumTopicWordCount[n3] + 1;
                this.z[i][j] = topic;
            }
        }
    }

    public void inference() throws IOException {
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        long startTime = System.currentTimeMillis();
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            if (iter % 50 == 0) {
                System.out.print(" " + iter);
            }
            this.sampleInSingleIteration();
        }
        this.expName = this.orgExpName;
        this.iterTime = System.currentTimeMillis() - startTime;
        System.out.println();
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed!");
    }

    private int sampleFullConditional(int m, int n) {
        int topic = this.z[m][n];
        int wordId = this.pseudo_corpus.get(m).get(n);
        int[] nArray = this.topicWordCount[topic];
        int n2 = wordId;
        nArray[n2] = nArray[n2] - 1;
        int[] nArray2 = this.pseudocTopicCount[m];
        int n3 = topic;
        nArray2[n3] = nArray2[n3] - 1;
        int n4 = topic;
        this.sumTopicWordCount[n4] = this.sumTopicWordCount[n4] - 1;
        for (int k = 0; k < this.numTopics; ++k) {
            this.multiPros[k] = ((double)this.topicWordCount[k][wordId] + this.beta) / ((double)this.sumTopicWordCount[k] + this.betaSum) * ((double)this.pseudocTopicCount[m][k] + this.alpha);
        }
        topic = FuncUtils.nextDiscrete(this.multiPros);
        int[] nArray3 = this.topicWordCount[topic];
        int n5 = wordId;
        nArray3[n5] = nArray3[n5] + 1;
        int[] nArray4 = this.pseudocTopicCount[m];
        int n6 = topic;
        nArray4[n6] = nArray4[n6] + 1;
        int n7 = topic;
        this.sumTopicWordCount[n7] = this.sumTopicWordCount[n7] + 1;
        return topic;
    }

    public void sampleInSingleIteration() {
        for (int dIndex = 0; dIndex < this.numPseudoDocuments; ++dIndex) {
            int docSize = this.pseudo_corpus.get(dIndex).size();
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                int topic;
                this.z[dIndex][wIndex] = topic = this.sampleFullConditional(dIndex, wIndex);
            }
        }
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tWNTM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.numTopics);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-niters\t" + this.numIterations);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        if (this.tAssignsFilePath.length() > 0) {
            writer.write("\n-initFile\t" + this.tAssignsFilePath);
        }
        if (this.savestep > 0) {
            writer.write("\n-sstep\t" + this.savestep);
        }
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)this.numIterations);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.vocabularySize; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.numTopics; ++tIndex) {
            Map wordCount = new TreeMap<Integer, Integer>();
            for (int wIndex = 0; wIndex < this.vocabularySize; ++wIndex) {
                wordCount.put(wIndex, this.topicWordCount[tIndex][wIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = ((double)this.topicWordCount[tIndex][index] + this.beta) / ((double)this.sumTopicWordCount[tIndex] + this.betaSum);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(this.id2WordVocabulary.get(index) + " ");
                    ++count;
                    continue;
                }
                writer.write("\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.vocabularySize; ++j) {
                double pro;
                this.phi[i][j] = pro = ((double)this.topicWordCount[i][j] + this.beta) / ((double)this.sumTopicWordCount[i] + this.betaSum);
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        for (int i = 0; i < this.numDocuments; ++i) {
            int len = this.corpus.get(i).size();
            for (int j = 0; j < this.numTopics; ++j) {
                double pro = 0.0;
                for (int wIndex = 0; wIndex < len; ++wIndex) {
                    int wordId = this.corpus.get(i).get(wIndex);
                    pro += this.phi[j][wordId] / (double)len;
                }
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeTopicWordPros();
        this.writeDocTopicPros();
        this.writeParameters();
    }

    public static void main(String[] args) throws Exception {
        WNTM wntm = new WNTM("dataset/Pascal_Flickr.txt", 100, 0.1, 0.01, 1000, 10, 10, "Pascal_FlickrWNTM");
        wntm.inference();
    }
}

