/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class DMM {
    int V;
    int K;
    double alpha = 0.1;
    double beta = 0.1;
    int[] z;
    int[] m_z;
    int[] n_z;
    int[][] n_w_z;
    int[] N_d;
    public int topWords;
    public List<List<Integer>> corpus;
    public int numDocuments;
    public int numWordsInCorpus;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public List<List<Integer>> occurenceToIndexCount;
    public double[] multiPros;
    public String folderPath;
    public String corpusPath;
    private static int ITERATIONS = 500;
    public String expName = "DMMmodel";
    public String orgExpName = "DMMmodel";
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public DMM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, "DMMmodel");
    }

    public DMM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName) throws Exception {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.K = inNumTopics;
        ITERATIONS = inNumIterations;
        this.topWords = inTopWords;
        this.orgExpName = this.expName = inExpName;
        this.corpusPath = pathToCorpus;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.corpus = new ArrayList<List<Integer>>();
        this.occurenceToIndexCount = new ArrayList<List<Integer>>();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                ArrayList<Integer> document = new ArrayList<Integer>();
                ArrayList<Integer> wordOccurenceToIndexInDoc = new ArrayList<Integer>();
                HashMap<Integer, Integer> wordOccurenceToIndexInDocCount = new HashMap<Integer, Integer>();
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document.add(this.word2IdVocabulary.get(word));
                    } else {
                        this.word2IdVocabulary.put(word, ++indexWord);
                        this.id2WordVocabulary.put(indexWord, word);
                        document.add(indexWord);
                    }
                    int times = 0;
                    if (wordOccurenceToIndexInDocCount.containsKey(indexWord)) {
                        times = (Integer)wordOccurenceToIndexInDocCount.get(indexWord);
                    }
                    wordOccurenceToIndexInDocCount.put(indexWord, ++times);
                    wordOccurenceToIndexInDoc.add(times);
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.size();
                this.corpus.add(document);
                this.occurenceToIndexCount.add(wordOccurenceToIndexInDoc);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.V = this.word2IdVocabulary.size();
        this.m_z = new int[this.K];
        this.n_z = new int[this.K];
        this.n_w_z = new int[this.K][this.V];
        this.N_d = new int[this.numDocuments];
        this.multiPros = new double[this.K];
        for (int i = 0; i < this.K; ++i) {
            this.multiPros[i] = 1.0 / (double)this.K;
        }
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.V);
        System.out.println("Number of topics: " + this.K);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + ITERATIONS);
        System.out.println("Number of top topical words: " + this.topWords);
        this.initialize();
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initialzing topic assignments ...");
        long startTime = System.currentTimeMillis();
        this.z = new int[this.numDocuments];
        for (int i = 0; i < this.numDocuments; ++i) {
            int topic;
            this.z[i] = topic = FuncUtils.nextDiscrete(this.multiPros);
            int n = topic;
            this.m_z[n] = this.m_z[n] + 1;
            int docLen = 0;
            for (int n2 = 0; n2 < this.corpus.get(i).size(); ++n2) {
                int[] nArray = this.n_w_z[topic];
                int n3 = this.corpus.get(i).get(n2);
                nArray[n3] = nArray[n3] + 1;
                ++docLen;
            }
            int n4 = topic;
            this.n_z[n4] = this.n_z[n4] + docLen;
            this.N_d[i] = docLen;
        }
        this.initTime = System.currentTimeMillis() - startTime;
    }

    public void inference() throws IOException {
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        long startTime = System.currentTimeMillis();
        for (int iter = 1; iter <= ITERATIONS; ++iter) {
            if (iter % 50 == 0) {
                System.out.print(" " + iter);
            }
            for (int m = 0; m < this.numDocuments; ++m) {
                int topic;
                this.z[m] = topic = this.sampleFullConditional(m);
            }
        }
        this.iterTime = System.currentTimeMillis() - startTime;
        this.expName = this.orgExpName;
        System.out.println();
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed!");
        System.out.println("the initiation tims: " + this.initTime);
        System.out.println("the iteration tims: " + this.iterTime);
    }

    private int sampleFullConditional(int m) {
        int n;
        int topic;
        int n2 = topic = this.z[m];
        this.m_z[n2] = this.m_z[n2] - 1;
        int n3 = topic;
        this.n_z[n3] = this.n_z[n3] - this.N_d[m];
        for (n = 0; n < this.corpus.get(m).size(); ++n) {
            int[] nArray = this.n_w_z[topic];
            int n4 = this.corpus.get(m).get(n);
            nArray[n4] = nArray[n4] - 1;
        }
        for (int tIndex = 0; tIndex < this.K; ++tIndex) {
            this.multiPros[tIndex] = (double)this.m_z[tIndex] + this.alpha;
            for (int wIndex = 0; wIndex < this.corpus.get(m).size(); ++wIndex) {
                int word = this.corpus.get(m).get(wIndex);
                int n5 = tIndex;
                this.multiPros[n5] = this.multiPros[n5] * (((double)this.n_w_z[tIndex][word] + this.beta + (double)this.occurenceToIndexCount.get(m).get(wIndex).intValue() - 1.0) / ((double)this.n_z[tIndex] + this.beta * (double)this.V + (double)wIndex));
            }
        }
        int n6 = topic = FuncUtils.nextDiscrete(this.multiPros);
        this.m_z[n6] = this.m_z[n6] + 1;
        int n7 = topic;
        this.n_z[n7] = this.n_z[n7] + this.N_d[m];
        for (n = 0; n < this.corpus.get(m).size(); ++n) {
            int[] nArray = this.n_w_z[topic];
            int n8 = this.corpus.get(m).get(n);
            nArray[n8] = nArray[n8] + 1;
        }
        return topic;
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tDMM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.K);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-niters\t" + ITERATIONS);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)ITERATIONS);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.V; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public void writeTopicAssignments() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topicAssignments"));
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int docSize = this.corpus.get(dIndex).size();
            int topic = this.z[dIndex];
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                writer.write(topic + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.K; ++tIndex) {
            Map wordCount = new TreeMap<Integer, Integer>();
            for (int wIndex = 0; wIndex < this.V; ++wIndex) {
                wordCount.put(wIndex, this.n_w_z[tIndex][wIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = ((double)this.n_w_z[tIndex][index] + this.beta) / ((double)this.n_z[tIndex] + this.beta * (double)this.V);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(this.id2WordVocabulary.get(index) + " ");
                    ++count;
                    continue;
                }
                writer.write("\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        for (int i = 0; i < this.K; ++i) {
            for (int j = 0; j < this.V; ++j) {
                double pro = ((double)this.n_w_z[i][j] + this.beta) / ((double)this.n_z[i] + this.beta * (double)this.V);
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        for (int i = 0; i < this.numDocuments; ++i) {
            int tIndex;
            int docSize = this.corpus.get(i).size();
            double sum = 0.0;
            for (tIndex = 0; tIndex < this.K; ++tIndex) {
                this.multiPros[tIndex] = (double)this.m_z[tIndex] + this.alpha;
                for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                    int word = this.corpus.get(i).get(wIndex);
                    int n = tIndex;
                    this.multiPros[n] = this.multiPros[n] * (((double)this.n_w_z[tIndex][word] + this.beta) / ((double)this.n_z[tIndex] + this.beta * (double)this.V));
                }
                sum += this.multiPros[tIndex];
            }
            for (tIndex = 0; tIndex < this.K; ++tIndex) {
                writer.write(this.multiPros[tIndex] / sum + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeParameters();
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicAssignments();
        this.writeTopicWordPros();
    }

    public static void main(String[] args) throws Exception {
        DMM dmm = new DMM("dataset/GoogleNews.txt", 200, 0.1, 0.1, 50, 10, "GoogleNewsDMM");
        dmm.inference();
    }
}

