/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class DMM_Inf {
    public double alpha;
    public double beta;
    public int numTopics;
    public int numIterations;
    public int topWords;
    public double alphaSum;
    public double betaSum;
    public List<List<Integer>> corpus;
    public List<Integer> topicAssignments;
    public int numDocuments;
    public int numWordsInCorpus;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int vocabularySize;
    public int[] docTopicCount;
    public int[][] topicWordCount;
    public int[] sumTopicWordCount;
    public double[] multiPros;
    public String folderPath;
    public String corpusPath;
    public List<List<Integer>> occurenceToIndexCount;
    public String expName = "DMMinf";
    public String orgExpName = "DMMinf";
    public String tAssignsFilePath = "";
    public int savestep = 0;

    public DMM_Inf(String pathToTrainingParasFile, String pathToUnseenCorpus, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception {
        HashMap<String, String> paras = this.parseTrainingParasFile(pathToTrainingParasFile);
        if (!paras.get("-model").equals("DMM")) {
            throw new Exception("Wrong pre-trained model!!!");
        }
        this.alpha = new Double(paras.get("-alpha"));
        this.beta = new Double(paras.get("-beta"));
        this.numTopics = new Integer(paras.get("-ntopics"));
        this.numIterations = inNumIterations;
        this.topWords = inTopWords;
        this.savestep = inSaveStep;
        this.orgExpName = this.expName = inExpName;
        String trainingCorpus = paras.get("-corpus");
        String trainingCorpusfolder = trainingCorpus.substring(0, Math.max(trainingCorpus.lastIndexOf("/"), trainingCorpus.lastIndexOf("\\")) + 1);
        String topicAssignment4TrainFile = trainingCorpusfolder + paras.get("-name") + ".topicAssignments";
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.initializeWordCount(trainingCorpus, topicAssignment4TrainFile);
        this.corpusPath = pathToUnseenCorpus;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading unseen corpus: " + pathToUnseenCorpus);
        this.corpus = new ArrayList<List<Integer>>();
        this.occurenceToIndexCount = new ArrayList<List<Integer>>();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        BufferedReader br = null;
        try {
            String doc;
            br = new BufferedReader(new FileReader(pathToUnseenCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                ArrayList<Integer> document = new ArrayList<Integer>();
                ArrayList<Integer> wordOccurenceToIndexInDoc = new ArrayList<Integer>();
                HashMap<String, Integer> wordOccurenceToIndexInDocCount = new HashMap<String, Integer>();
                for (String word : words) {
                    if (!this.word2IdVocabulary.containsKey(word)) continue;
                    document.add(this.word2IdVocabulary.get(word));
                    int times = 0;
                    if (wordOccurenceToIndexInDocCount.containsKey(word)) {
                        times = (Integer)wordOccurenceToIndexInDocCount.get(word);
                    }
                    wordOccurenceToIndexInDocCount.put(word, ++times);
                    wordOccurenceToIndexInDoc.add(times);
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.size();
                this.corpus.add(document);
                this.occurenceToIndexCount.add(wordOccurenceToIndexInDoc);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.docTopicCount = new int[this.numTopics];
        this.multiPros = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            this.multiPros[i] = 1.0 / (double)this.numTopics;
        }
        this.alphaSum = (double)this.numTopics * this.alpha;
        this.betaSum = (double)this.vocabularySize * this.beta;
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.vocabularySize);
        System.out.println("Number of topics: " + this.numTopics);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + this.numIterations);
        System.out.println("Number of top topical words: " + this.topWords);
        this.initialize();
    }

    private HashMap<String, String> parseTrainingParasFile(String pathToTrainingParasFile) throws Exception {
        HashMap<String, String> paras = new HashMap<String, String>();
        BufferedReader br = null;
        try {
            String line;
            br = new BufferedReader(new FileReader(pathToTrainingParasFile));
            while ((line = br.readLine()) != null) {
                if (line.trim().length() == 0) continue;
                String[] paraOptions = line.trim().split("\\s+");
                paras.put(paraOptions[0], paraOptions[1]);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return paras;
    }

    private void initializeWordCount(String pathToTrainingCorpus, String pathToTopicAssignmentFile) {
        System.out.println("Loading pre-trained model...");
        ArrayList trainCorpus = new ArrayList();
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToTrainingCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                ArrayList<Integer> document = new ArrayList<Integer>();
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document.add(this.word2IdVocabulary.get(word));
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document.add(indexWord);
                }
                trainCorpus.add(document);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.vocabularySize = this.word2IdVocabulary.size();
        this.topicWordCount = new int[this.numTopics][this.vocabularySize];
        this.sumTopicWordCount = new int[this.numTopics];
        try {
            String line;
            br = new BufferedReader(new FileReader(pathToTopicAssignmentFile));
            int docId = 0;
            while ((line = br.readLine()) != null) {
                String[] strTopics = line.trim().split("\\s+");
                for (int j = 0; j < strTopics.length; ++j) {
                    int wordId = (Integer)((List)trainCorpus.get(docId)).get(j);
                    int topic = new Integer(strTopics[j]);
                    int[] nArray = this.topicWordCount[topic];
                    int n = wordId;
                    nArray[n] = nArray[n] + 1;
                    int n2 = topic;
                    this.sumTopicWordCount[n2] = this.sumTopicWordCount[n2] + 1;
                }
                ++docId;
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initialzing topic assignments ...");
        this.topicAssignments = new ArrayList<Integer>();
        for (int i = 0; i < this.numDocuments; ++i) {
            int topic;
            int n = topic = FuncUtils.nextDiscrete(this.multiPros);
            this.docTopicCount[n] = this.docTopicCount[n] + 1;
            int docSize = this.corpus.get(i).size();
            for (int j = 0; j < docSize; ++j) {
                int[] nArray = this.topicWordCount[topic];
                int n2 = this.corpus.get(i).get(j);
                nArray[n2] = nArray[n2] + 1;
                int n3 = topic;
                this.sumTopicWordCount[n3] = this.sumTopicWordCount[n3] + 1;
            }
            this.topicAssignments.add(topic);
        }
    }

    public void inference() throws IOException {
        this.writeParameters();
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            System.out.println("\tSampling iteration: " + iter);
            this.sampleInSingleIteration();
            if (this.savestep <= 0 || iter % this.savestep != 0 || iter >= this.numIterations) continue;
            System.out.println("\t\tSaving the output from the " + iter + "^{th} sample");
            this.expName = this.orgExpName + "-" + iter;
            this.write();
        }
        this.expName = this.orgExpName;
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed!");
    }

    public void sampleInSingleIteration() {
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int word;
            int wIndex;
            int topic = this.topicAssignments.get(dIndex);
            List<Integer> document = this.corpus.get(dIndex);
            int docSize = document.size();
            int n = topic;
            this.docTopicCount[n] = this.docTopicCount[n] - 1;
            for (wIndex = 0; wIndex < docSize; ++wIndex) {
                word = document.get(wIndex);
                int[] nArray = this.topicWordCount[topic];
                int n2 = word;
                nArray[n2] = nArray[n2] - 1;
                int n3 = topic;
                this.sumTopicWordCount[n3] = this.sumTopicWordCount[n3] - 1;
            }
            for (int tIndex = 0; tIndex < this.numTopics; ++tIndex) {
                this.multiPros[tIndex] = (double)this.docTopicCount[tIndex] + this.alpha;
                for (int wIndex2 = 0; wIndex2 < docSize; ++wIndex2) {
                    int word2 = document.get(wIndex2);
                    int n4 = tIndex;
                    this.multiPros[n4] = this.multiPros[n4] * (((double)this.topicWordCount[tIndex][word2] + this.beta + (double)this.occurenceToIndexCount.get(dIndex).get(wIndex2).intValue() - 1.0) / ((double)this.sumTopicWordCount[tIndex] + this.betaSum + (double)wIndex2));
                }
            }
            int n5 = topic = FuncUtils.nextDiscrete(this.multiPros);
            this.docTopicCount[n5] = this.docTopicCount[n5] + 1;
            for (wIndex = 0; wIndex < docSize; ++wIndex) {
                word = document.get(wIndex);
                int[] nArray = this.topicWordCount[topic];
                int n6 = word;
                nArray[n6] = nArray[n6] + 1;
                int n7 = topic;
                this.sumTopicWordCount[n7] = this.sumTopicWordCount[n7] + 1;
            }
            this.topicAssignments.set(dIndex, topic);
        }
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tDMM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.numTopics);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-niters\t" + this.numIterations);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        if (this.tAssignsFilePath.length() > 0) {
            writer.write("\n-initFile\t" + this.tAssignsFilePath);
        }
        if (this.savestep > 0) {
            writer.write("\n-sstep\t" + this.savestep);
        }
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.vocabularySize; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public void writeIDbasedCorpus() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".IDcorpus"));
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int docSize = this.corpus.get(dIndex).size();
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                writer.write(this.corpus.get(dIndex).get(wIndex) + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicAssignments() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topicAssignments"));
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int docSize = this.corpus.get(dIndex).size();
            int topic = this.topicAssignments.get(dIndex);
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                writer.write(topic + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.numTopics; ++tIndex) {
            writer.write("Topic" + new Integer(tIndex) + ":");
            Map wordCount = new TreeMap<Integer, Integer>();
            for (int wIndex = 0; wIndex < this.vocabularySize; ++wIndex) {
                wordCount.put(wIndex, this.topicWordCount[tIndex][wIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = ((double)this.topicWordCount[tIndex][index] + this.beta) / ((double)this.sumTopicWordCount[tIndex] + this.betaSum);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(" " + this.id2WordVocabulary.get(index) + "(" + pro + ")");
                    ++count;
                    continue;
                }
                writer.write("\n\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.vocabularySize; ++j) {
                double pro = ((double)this.topicWordCount[i][j] + this.beta) / ((double)this.sumTopicWordCount[i] + this.betaSum);
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicWordCount() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".WTcount"));
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.vocabularySize; ++j) {
                writer.write(this.topicWordCount[i][j] + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        for (int i = 0; i < this.numDocuments; ++i) {
            int tIndex;
            int docSize = this.corpus.get(i).size();
            double sum = 0.0;
            for (tIndex = 0; tIndex < this.numTopics; ++tIndex) {
                this.multiPros[tIndex] = (double)this.docTopicCount[tIndex] + this.alpha;
                for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                    int word = this.corpus.get(i).get(wIndex);
                    int n = tIndex;
                    this.multiPros[n] = this.multiPros[n] * (((double)this.topicWordCount[tIndex][word] + this.beta) / ((double)this.sumTopicWordCount[tIndex] + this.betaSum));
                }
                sum += this.multiPros[tIndex];
            }
            for (tIndex = 0; tIndex < this.numTopics; ++tIndex) {
                writer.write(this.multiPros[tIndex] / sum + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicAssignments();
        this.writeTopicWordPros();
    }

    public static void main(String[] args) throws Exception {
        DMM_Inf dmm = new DMM_Inf("test/testDMM.paras", "test/unseenTest.txt", 100, 20, "testDMMinf", 0);
        dmm.inference();
    }
}

