/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import utility.FuncUtils;

public class BTM {
    public double alpha;
    public double beta;
    public int numTopics;
    public int numIterations;
    public int topWords;
    public double alphaSum;
    public double betaSum;
    public int numDocuments;
    public int numWordsInCorpus;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int vocabularySize;
    int[][] wordId_of_corpus = null;
    public ArrayList<HashMap<Long, Integer>> biterm_of_corpus = new ArrayList();
    int[] doc_biterm_num;
    ArrayList<Long> biterms = new ArrayList();
    int[] topic_of_biterms;
    int[][] topic_word_num;
    int[] num_of_topic_of_biterm;
    private HashMap<Long, Double> bitermSum = new HashMap();
    public double[] multiPros;
    public String folderPath;
    public String corpusPath;
    public String expName = "BTMmodel";
    public String orgExpName = "BTMmodel";
    public String tAssignsFilePath = "";
    public int savestep = 0;
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public BTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, "BTMmodel");
    }

    public BTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, "", 0);
    }

    public BTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, pathToTAfile, 0);
    }

    public BTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception {
        this(pathToCorpus, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, "", inSaveStep);
    }

    public BTM(String pathToCorpus, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile, int inSaveStep) throws IOException {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.numTopics = inNumTopics;
        this.numIterations = inNumIterations;
        this.topWords = inTopWords;
        this.savestep = inSaveStep;
        this.orgExpName = this.expName = inExpName;
        this.corpusPath = pathToCorpus;
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        ArrayList<int[]> tmpCorpus = new ArrayList<int[]>();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                int[] document = new int[words.length];
                int ind = 0;
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document[ind++] = this.word2IdVocabulary.get(word);
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document[ind++] = indexWord;
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.length;
                tmpCorpus.add(document);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.vocabularySize = this.word2IdVocabulary.size();
        this.multiPros = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            this.multiPros[i] = 1.0 / (double)this.numTopics;
        }
        this.alphaSum = (double)this.numTopics * this.alpha;
        this.betaSum = (double)this.vocabularySize * this.beta;
        this.doc_biterm_num = new int[tmpCorpus.size()];
        this.wordId_of_corpus = new int[tmpCorpus.size()][];
        for (int docIndex = 0; docIndex < this.wordId_of_corpus.length; ++docIndex) {
            this.wordId_of_corpus[docIndex] = (int[])tmpCorpus.get(docIndex);
        }
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.vocabularySize);
        System.out.println("Number of topics: " + this.numTopics);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + this.numIterations);
        System.out.println("Number of top topical words: " + this.topWords);
        this.tAssignsFilePath = pathToTAfile;
        if (this.tAssignsFilePath.length() > 0) {
            this.initialize();
        } else {
            this.initialize();
        }
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initializing topic assignments using BTM");
        long startTime = System.currentTimeMillis();
        int docIndex = 0;
        for (int[] doc : this.wordId_of_corpus) {
            HashMap<Long, Integer> oneCop = new HashMap<Long, Integer>();
            for (int word1 : doc) {
                for (int word2 : doc) {
                    if (word1 >= word2) continue;
                    Long itmeNum = (long)word1 * 1000000L + (long)word2;
                    if (!oneCop.containsKey(itmeNum)) {
                        oneCop.put(itmeNum, 0);
                    }
                    oneCop.put(itmeNum, (Integer)oneCop.get(itmeNum) + 1);
                    this.biterms.add(itmeNum);
                    int n = docIndex;
                    this.doc_biterm_num[n] = this.doc_biterm_num[n] + 1;
                }
            }
            ++docIndex;
            this.biterm_of_corpus.add(oneCop);
        }
        this.topic_of_biterms = new int[this.biterms.size()];
        this.topic_word_num = new int[this.vocabularySize][this.numTopics];
        this.num_of_topic_of_biterm = new int[this.numTopics];
        for (int bitermIndex = 0; bitermIndex < this.biterms.size(); ++bitermIndex) {
            int topicId = FuncUtils.nextDiscrete(this.multiPros);
            int[] nArray = this.topic_word_num[(int)(this.biterms.get(bitermIndex) % 1000000L)];
            int n = topicId;
            nArray[n] = nArray[n] + 1;
            int[] nArray2 = this.topic_word_num[(int)(this.biterms.get(bitermIndex) / 1000000L)];
            int n2 = topicId;
            nArray2[n2] = nArray2[n2] + 1;
            int n3 = topicId;
            this.num_of_topic_of_biterm[n3] = this.num_of_topic_of_biterm[n3] + 1;
            this.topic_of_biterms[bitermIndex] = topicId;
        }
        this.initTime = System.currentTimeMillis() - startTime;
    }

    public void inference() throws IOException {
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        long startTime = System.currentTimeMillis();
        for (int iter = 0; iter < this.numIterations; ++iter) {
            if (iter % 50 == 0) {
                System.out.print(" " + iter);
            }
            for (int bitermIndex = 0; bitermIndex < this.topic_of_biterms.length; ++bitermIndex) {
                int oldTopicId = this.topic_of_biterms[bitermIndex];
                int word1 = (int)(this.biterms.get(bitermIndex) / 1000000L);
                int word2 = (int)(this.biterms.get(bitermIndex) % 1000000L);
                int[] nArray = this.topic_word_num[word1];
                int n = oldTopicId;
                nArray[n] = nArray[n] - 1;
                int[] nArray2 = this.topic_word_num[word2];
                int n2 = oldTopicId;
                nArray2[n2] = nArray2[n2] - 1;
                int n3 = oldTopicId;
                this.num_of_topic_of_biterm[n3] = this.num_of_topic_of_biterm[n3] - 1;
                int newTopicId = -1;
                for (int k = 0; k < this.numTopics; ++k) {
                    this.multiPros[k] = ((double)this.num_of_topic_of_biterm[k] + this.alpha) * ((double)this.topic_word_num[word1][k] + this.beta) * ((double)this.topic_word_num[word2][k] + this.beta) / Math.pow((double)(this.num_of_topic_of_biterm[k] * 2) + (double)this.vocabularySize * this.beta, 2.0);
                }
                newTopicId = FuncUtils.nextDiscrete(this.multiPros);
                int[] nArray3 = this.topic_word_num[word1];
                int n4 = newTopicId;
                nArray3[n4] = nArray3[n4] + 1;
                int[] nArray4 = this.topic_word_num[word2];
                int n5 = newTopicId;
                nArray4[n5] = nArray4[n5] + 1;
                int n6 = newTopicId;
                this.num_of_topic_of_biterm[n6] = this.num_of_topic_of_biterm[n6] + 1;
                this.topic_of_biterms[bitermIndex] = newTopicId;
            }
        }
        this.iterTime = System.currentTimeMillis() - startTime;
        this.expName = this.orgExpName;
        System.out.println();
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed for BTM!");
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tBTM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.numTopics);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-niters\t" + this.numIterations);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        if (this.tAssignsFilePath.length() > 0) {
            writer.write("\n-initFile\t" + this.tAssignsFilePath);
        }
        if (this.savestep > 0) {
            writer.write("\n-sstep\t" + this.savestep);
        }
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)this.numIterations);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.vocabularySize; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    private void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        for (int topic_id = 0; topic_id < this.numTopics; ++topic_id) {
            HashMap<Integer, Double> oneLine = new HashMap<Integer, Double>();
            for (int word_id = 0; word_id < this.vocabularySize; ++word_id) {
                oneLine.put(word_id, (double)this.topic_word_num[word_id][topic_id] / (double)this.num_of_topic_of_biterm[topic_id] / 2.0);
            }
            ArrayList maplist = new ArrayList(oneLine.entrySet());
            Collections.sort(maplist, (o1, o2) -> ((Double)o2.getValue()).compareTo((Double)o1.getValue()));
            int count = 0;
            for (Map.Entry entry : maplist) {
                writer.write(this.id2WordVocabulary.get(entry.getKey()) + " ");
                if (++count < this.topWords) continue;
                break;
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        boolean topic_index = false;
        for (int topic_id = 0; topic_id < this.numTopics; ++topic_id) {
            for (int word_id = 0; word_id < this.vocabularySize; ++word_id) {
                writer.write(((double)this.topic_word_num[word_id][topic_id] + this.beta) / ((double)(this.num_of_topic_of_biterm[topic_id] * 2) + (double)this.vocabularySize * this.beta) + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    private double getSum(Long biterm) {
        if (!this.bitermSum.containsKey(biterm)) {
            double sum = 0.0;
            int word1 = (int)(biterm / 1000000L);
            int word2 = (int)(biterm % 1000000L);
            for (int topic_id = 0; topic_id < this.numTopics; ++topic_id) {
                sum += ((double)this.num_of_topic_of_biterm[topic_id] + this.alpha) * ((double)this.topic_word_num[word1][topic_id] + this.beta) * ((double)this.topic_word_num[word2][topic_id] + this.beta) / Math.pow((double)(this.num_of_topic_of_biterm[topic_id] * 2) + (double)this.vocabularySize * this.beta, 2.0);
            }
            this.bitermSum.put(biterm, sum);
        }
        return this.bitermSum.get(biterm);
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        int docIndex = 0;
        for (HashMap<Long, Integer> line : this.biterm_of_corpus) {
            double[] oneTheta = new double[this.numTopics];
            for (int topic_id = 0; topic_id < this.numTopics; ++topic_id) {
                double oneSum = 0.0;
                for (Long biterm : line.keySet()) {
                    int word1 = (int)(biterm / 1000000L);
                    int word2 = (int)(biterm % 1000000L);
                    oneSum += (double)line.get(biterm).intValue() / (double)this.doc_biterm_num[docIndex] * (((double)this.num_of_topic_of_biterm[topic_id] + this.alpha) * ((double)this.topic_word_num[word1][topic_id] + this.beta) * ((double)this.topic_word_num[word2][topic_id] + this.beta) / Math.pow((double)(this.num_of_topic_of_biterm[topic_id] * 2) + (double)this.vocabularySize * this.beta, 2.0) / this.getSum(biterm));
                }
                writer.write(oneSum + " ");
            }
            writer.write("\n");
            ++docIndex;
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicWordPros();
        this.writeParameters();
    }

    public static void main(String[] args) throws Exception {
        BTM btm = new BTM("dataset/Pascal_Flickr.txt", 20, 0.1, 0.1, 500, 10, "Pascal_FlickrBTM");
        btm.inference();
    }
}

