/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class LDA {
    public double alpha;
    public double beta;
    public int K;
    public int topWords;
    public List<List<Integer>> corpus;
    public int numDocuments;
    public int numWordsInCorpus;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int V;
    int[][] z;
    int[][] nw;
    int[][] nd;
    int[] nwsum;
    int[] ndsum;
    double[][] thetasum;
    double[][] phisum;
    int numstats;
    private static int THIN_INTERVAL = 20;
    private static int BURN_IN = 100;
    private static int ITERATIONS = 1000;
    private static int dispcol = 0;
    private static int SAMPLE_LAG;
    public double[] multiPros;
    public String folderPath;
    public String corpusPath;
    public String expName = "LDAmodel";
    public String orgExpName = "LDAmodel";
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public LDA(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, "LDAmodel");
    }

    public LDA(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName) throws Exception {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.K = inNumTopics;
        ITERATIONS = inNumIterations;
        this.topWords = inTopWords;
        this.orgExpName = this.expName = inExpName;
        this.corpusPath = pathToCorpus;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.corpus = new ArrayList<List<Integer>>();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) {
                    System.out.println(this.numDocuments);
                    continue;
                }
                String[] words = doc.trim().split("\\s+");
                ArrayList<Integer> document = new ArrayList<Integer>();
                if (words.length == 0) {
                    System.out.println("here!");
                }
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document.add(this.word2IdVocabulary.get(word));
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document.add(indexWord);
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.size();
                this.corpus.add(document);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.V = this.word2IdVocabulary.size();
        this.nw = new int[this.V][this.K];
        this.nd = new int[this.numDocuments][this.K];
        this.nwsum = new int[this.K];
        this.ndsum = new int[this.numDocuments];
        this.multiPros = new double[this.K];
        for (int i = 0; i < this.K; ++i) {
            this.multiPros[i] = 1.0 / (double)this.K;
        }
        this.z = new int[this.numDocuments][];
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.V);
        System.out.println("Number of topics: " + this.K);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + ITERATIONS);
        System.out.println("Number of top topical words: " + this.topWords);
        this.initialize();
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initializing topic assignments ...");
        long startTime = System.currentTimeMillis();
        for (int i = 0; i < this.numDocuments; ++i) {
            ArrayList topics = new ArrayList();
            int docSize = this.corpus.get(i).size();
            this.z[i] = new int[docSize];
            for (int j = 0; j < docSize; ++j) {
                int topic = FuncUtils.nextDiscrete(this.multiPros);
                int[] nArray = this.nd[i];
                int n = topic;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = this.nw[this.corpus.get(i).get(j)];
                int n2 = topic;
                nArray2[n2] = nArray2[n2] + 1;
                int n3 = i;
                this.ndsum[n3] = this.ndsum[n3] + 1;
                int n4 = topic;
                this.nwsum[n4] = this.nwsum[n4] + 1;
                this.z[i][j] = topic;
            }
        }
        this.initTime = System.currentTimeMillis() - startTime;
    }

    public void inference() throws IOException {
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        if (SAMPLE_LAG > 0) {
            this.thetasum = new double[this.numDocuments][this.K];
            this.phisum = new double[this.K][this.V];
            this.numstats = 0;
        }
        System.out.println("Sampling " + ITERATIONS + " iterations with burn-in of " + BURN_IN + " (B/S=" + THIN_INTERVAL + ").");
        long startTime = System.currentTimeMillis();
        for (int i = 0; i < ITERATIONS; ++i) {
            for (int m = 0; m < this.z.length; ++m) {
                for (int n = 0; n < this.z[m].length; ++n) {
                    int topic;
                    this.z[m][n] = topic = this.sampleFullConditional(m, n);
                }
            }
            if (i < BURN_IN && i % THIN_INTERVAL == 0) {
                System.out.print("B");
                ++dispcol;
            }
            if (i > BURN_IN && i % THIN_INTERVAL == 0) {
                System.out.print("S");
                ++dispcol;
            }
            if (i > BURN_IN && SAMPLE_LAG > 0 && i % SAMPLE_LAG == 0) {
                this.updateParams();
                System.out.print("|");
                if (i % THIN_INTERVAL != 0) {
                    ++dispcol;
                }
            }
            if (dispcol < 100) continue;
            System.out.println();
            dispcol = 0;
        }
        this.iterTime = System.currentTimeMillis() - startTime;
        this.expName = this.orgExpName;
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed!");
    }

    private void updateParams() {
        for (int m = 0; m < this.numDocuments; ++m) {
            for (int k = 0; k < this.K; ++k) {
                double[] dArray = this.thetasum[m];
                int n = k;
                dArray[n] = dArray[n] + ((double)this.nd[m][k] + this.alpha) / ((double)this.ndsum[m] + (double)this.K * this.alpha);
            }
        }
        for (int k = 0; k < this.K; ++k) {
            for (int w = 0; w < this.V; ++w) {
                double[] dArray = this.phisum[k];
                int n = w;
                dArray[n] = dArray[n] + ((double)this.nw[w][k] + this.beta) / ((double)this.nwsum[k] + (double)this.V * this.beta);
            }
        }
        ++this.numstats;
    }

    private int sampleFullConditional(int m, int n) {
        int k;
        int topic = this.z[m][n];
        int[] nArray = this.nw[this.corpus.get(m).get(n)];
        int n2 = topic;
        nArray[n2] = nArray[n2] - 1;
        int[] nArray2 = this.nd[m];
        int n3 = topic;
        nArray2[n3] = nArray2[n3] - 1;
        int n4 = topic;
        this.nwsum[n4] = this.nwsum[n4] - 1;
        double[] p = new double[this.K];
        for (k = 0; k < this.K; ++k) {
            p[k] = ((double)this.nw[this.corpus.get(m).get(n)][k] + this.beta) / ((double)this.nwsum[k] + (double)this.V * this.beta) * ((double)this.nd[m][k] + this.alpha);
        }
        for (k = 1; k < p.length; ++k) {
            int n5 = k;
            p[n5] = p[n5] + p[k - 1];
        }
        double u = Math.random() * p[this.K - 1];
        for (topic = 0; topic < p.length && !(u < p[topic]); ++topic) {
        }
        int[] nArray3 = this.nw[this.corpus.get(m).get(n)];
        int n6 = topic;
        nArray3[n6] = nArray3[n6] + 1;
        int[] nArray4 = this.nd[m];
        int n7 = topic;
        nArray4[n7] = nArray4[n7] + 1;
        int n8 = topic;
        this.nwsum[n8] = this.nwsum[n8] + 1;
        return topic;
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tLDA");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.K);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-niters\t" + ITERATIONS);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)ITERATIONS);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.V; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public void writeTopicAssignments() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topicAssignments"));
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int docSize = this.corpus.get(dIndex).size();
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                writer.write(this.z[dIndex][wIndex] + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.K; ++tIndex) {
            Map wordCount = new TreeMap<Integer, Integer>();
            for (int wIndex = 0; wIndex < this.V; ++wIndex) {
                wordCount.put(wIndex, this.nw[wIndex][tIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = ((double)this.nw[index][tIndex] + this.beta) / ((double)this.nwsum[tIndex] + this.beta * (double)this.V);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(this.id2WordVocabulary.get(index) + " ");
                    ++count;
                    continue;
                }
                writer.write("\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        if (SAMPLE_LAG > 0) {
            for (int k = 0; k < this.K; ++k) {
                for (int w = 0; w < this.V; ++w) {
                    double pro = this.phisum[k][w] / (double)this.numstats;
                    writer.write(pro + " ");
                }
                writer.write("\n");
            }
        } else {
            for (int k = 0; k < this.K; ++k) {
                for (int w = 0; w < this.V; ++w) {
                    double pro = ((double)this.nw[w][k] + this.beta) / ((double)this.nwsum[k] + (double)this.V * this.beta);
                    writer.write(pro + " ");
                }
                writer.write("\n");
            }
        }
        writer.close();
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        if (SAMPLE_LAG > 0) {
            for (int m = 0; m < this.numDocuments; ++m) {
                for (int k = 0; k < this.K; ++k) {
                    double pro = this.thetasum[m][k] / (double)this.numstats;
                    writer.write(pro + " ");
                }
                writer.write("\n");
            }
        } else {
            for (int m = 0; m < this.numDocuments; ++m) {
                for (int k = 0; k < this.K; ++k) {
                    double pro = ((double)this.nd[m][k] + this.alpha) / ((double)this.ndsum[m] + (double)this.K * this.alpha);
                    writer.write(pro + " ");
                }
                writer.write("\n");
            }
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicAssignments();
        this.writeTopicWordPros();
        this.writeParameters();
    }

    public static void main(String[] args) throws Exception {
        LDA lda = new LDA("dataset/Tweet.txt", 100, 0.1, 0.1, 1000, 10, "TweetLDA");
        lda.inference();
    }
}

