/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class SATM {
    public double alpha;
    public double beta;
    public int numTopics;
    public int numIterations;
    public int topWords;
    public double alphaSum;
    public double betaSum;
    ArrayList<int[]> Corpus = new ArrayList();
    public int numShorDoc;
    public int numWordsInCorpus;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int vocabularySize;
    public double threshold;
    public int numLongDoc;
    private Random rg;
    public double[][] psd;
    private double[] pz;
    public ArrayList<ArrayList<int[]>> assignmentList = new ArrayList();
    private int[][] U;
    private int[] longDocCnts;
    private int[][] V;
    private int[] topicCnts;
    private int[][] longDocWordCnts;
    private int tokenSize;
    public String folderPath;
    public String corpusPath;
    private static final double ZERO_SMOOTH = 1.0E-10;
    public String expName = "SATMmodel";
    public String orgExpName = "SATMmodel";
    public String tAssignsFilePath = "";
    public int savestep = 0;
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public SATM(String pathToCorpus, int inNumTopics, int num_longDoc, double threshold, double inAlpha, double inBeta, int inNumIterations, int inTopWords) throws Exception {
        this(pathToCorpus, inNumTopics, num_longDoc, threshold, inAlpha, inBeta, inNumIterations, inTopWords, "SATMmodel");
    }

    public SATM(String pathToCorpus, int inNumTopics, int num_longDoc, double threshold, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName) throws Exception {
        this(pathToCorpus, inNumTopics, num_longDoc, threshold, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, "", 0);
    }

    public SATM(String pathToCorpus, int inNumTopics, int num_longDoc, double threshold, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile) throws Exception {
        this(pathToCorpus, inNumTopics, num_longDoc, threshold, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, pathToTAfile, 0);
    }

    public SATM(String pathToCorpus, int inNumTopics, int num_longDoc, double threshold, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception {
        this(pathToCorpus, inNumTopics, num_longDoc, threshold, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, "", inSaveStep);
    }

    public SATM(String pathToCorpus, int inNumTopics, int num_longDoc, double threshold, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile, int inSaveStep) throws IOException {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.numTopics = inNumTopics;
        this.numIterations = inNumIterations;
        this.topWords = inTopWords;
        this.savestep = inSaveStep;
        this.orgExpName = this.expName = inExpName;
        this.corpusPath = pathToCorpus;
        this.numLongDoc = num_longDoc;
        this.rg = new Random();
        this.threshold = threshold;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.numShorDoc = 0;
        this.numWordsInCorpus = 0;
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                int[] document = new int[words.length];
                int ind = 0;
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document[ind++] = this.word2IdVocabulary.get(word);
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document[ind++] = indexWord;
                }
                ++this.numShorDoc;
                this.numWordsInCorpus += document.length;
                this.Corpus.add(document);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.vocabularySize = this.word2IdVocabulary.size();
        this.psd = new double[this.numShorDoc][this.numLongDoc];
        this.pz = new double[this.numTopics];
        this.alphaSum = (double)this.numTopics * this.alpha;
        this.betaSum = (double)this.vocabularySize * this.beta;
        this.U = new int[this.numTopics][this.numLongDoc];
        this.longDocCnts = new int[this.numLongDoc];
        this.V = new int[this.vocabularySize][this.numTopics];
        this.topicCnts = new int[this.numTopics];
        this.longDocWordCnts = new int[this.numLongDoc][this.vocabularySize];
        System.out.println("Corpus size: " + this.numShorDoc + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.vocabularySize);
        System.out.println("Number of topics: " + this.numTopics);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + this.numIterations);
        System.out.println("Number of top topical words: " + this.topWords);
        this.tAssignsFilePath = pathToTAfile;
        long startTime = System.currentTimeMillis();
        this.initialize();
        this.initTime = System.currentTimeMillis() - startTime;
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initializing topic assignments ...");
        for (int d = 0; d < this.numShorDoc; ++d) {
            int[] termIDArray = this.Corpus.get(d);
            ArrayList<int[]> d_assignment_list = new ArrayList<int[]>();
            for (int t = 0; t < termIDArray.length; ++t) {
                int termID = termIDArray[t];
                int topic = this.rg.nextInt(this.numTopics);
                int longDoc = this.rg.nextInt(this.numLongDoc);
                int[] assignment = new int[]{topic, longDoc};
                int[] nArray = this.U[topic];
                int n = longDoc;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = this.V[termID];
                int n2 = topic;
                nArray2[n2] = nArray2[n2] + 1;
                int[] nArray3 = this.longDocWordCnts[longDoc];
                int n3 = termID;
                nArray3[n3] = nArray3[n3] + 1;
                int n4 = longDoc;
                this.longDocCnts[n4] = this.longDocCnts[n4] + 1;
                int n5 = topic;
                this.topicCnts[n5] = this.topicCnts[n5] + 1;
                ++this.tokenSize;
                d_assignment_list.add(assignment);
            }
            this.assignmentList.add(d_assignment_list);
        }
        System.out.println("finish init_SATM!");
    }

    public void computePds() {
        for (int s = 0; s < this.numShorDoc; ++s) {
            int[] termIDArray = this.Corpus.get(s);
            for (int l = 0; l < this.numLongDoc; ++l) {
                double pd;
                if (this.longDocCnts[l] == 0) {
                    this.psd[s][l] = 1.0E-10;
                    continue;
                }
                double _score = pd = 1.0 * (double)this.longDocCnts[l] / (double)this.tokenSize;
                for (int t = 0; t < termIDArray.length; ++t) {
                    double pdw = 1.0 * (double)this.longDocWordCnts[l][termIDArray[t]] / (double)this.longDocCnts[l];
                    if (Double.compare(pdw, 0.0) == 0) {
                        pdw = 1.0E-10;
                    }
                    _score *= pdw;
                }
                this.psd[s][l] = _score;
            }
            this.psd[s] = FuncUtils.L1NormWithReusable(this.psd[s]);
            if (this.psd[s] != null) continue;
            this.psd[s] = new double[this.numLongDoc];
        }
    }

    public int[] joint_sample(double[][] dist, double sum) {
        double u = this.rg.nextDouble() * sum;
        double temp = 0.0;
        int[] sample = new int[2];
        for (int l = 0; l < dist.length; ++l) {
            for (int z = 0; z < dist[l].length; ++z) {
                if (Double.compare(temp += dist[l][z], u) < 0) continue;
                sample[0] = z;
                sample[1] = l;
                return sample;
            }
        }
        return sample;
    }

    public void inference() throws IOException {
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        long startTime = System.currentTimeMillis();
        for (int iter = 1; iter < this.numIterations; ++iter) {
            if (iter % 50 == 0) {
                System.out.print(" " + iter);
            }
            this.computePds();
            double pdz = 0.0;
            double pzw = 0.0;
            double distSum = 0.0;
            ArrayList<Integer> validLongDocIDList = new ArrayList<Integer>();
            for (int s = 0; s < this.Corpus.size(); ++s) {
                validLongDocIDList.clear();
                int[] termIDArray = this.Corpus.get(s);
                ArrayList<int[]> s_assignment = this.assignmentList.get(s);
                for (int d = 0; d < this.psd[s].length; ++d) {
                    if (Double.isNaN(this.psd[s][d]) || Double.compare(this.psd[s][d], this.threshold) <= 0) continue;
                    validLongDocIDList.add(d);
                }
                if (validLongDocIDList.isEmpty()) continue;
                double[][] pdzMat = new double[validLongDocIDList.size()][this.numTopics];
                for (int t = 0; t < termIDArray.length; ++t) {
                    distSum = 0.0;
                    int termID = termIDArray[t];
                    int[] _assignment = s_assignment.get(t);
                    int preTopic = _assignment[0];
                    int preLongDoc = _assignment[1];
                    int[] nArray = this.U[preTopic];
                    int n = preLongDoc;
                    nArray[n] = nArray[n] - 1;
                    int[] nArray2 = this.V[termID];
                    int n2 = preTopic;
                    nArray2[n2] = nArray2[n2] - 1;
                    int[] nArray3 = this.longDocWordCnts[preLongDoc];
                    int n3 = termID;
                    nArray3[n3] = nArray3[n3] - 1;
                    int n4 = preLongDoc;
                    this.longDocCnts[n4] = this.longDocCnts[n4] - 1;
                    int n5 = preTopic;
                    this.topicCnts[n5] = this.topicCnts[n5] - 1;
                    for (int d = 0; d < validLongDocIDList.size(); ++d) {
                        int longDocID = (Integer)validLongDocIDList.get(d);
                        for (int z = 0; z < this.numTopics; ++z) {
                            pdz = 1.0 * ((double)this.U[z][longDocID] + this.alpha) / ((double)this.longDocCnts[longDocID] + this.alphaSum);
                            pzw = 1.0 * ((double)this.V[termID][z] + this.beta) / ((double)this.topicCnts[z] + this.betaSum);
                            pdzMat[d][z] = this.psd[s][longDocID] * pdz * pzw;
                            distSum += pdzMat[d][z];
                        }
                    }
                    int[] topicAndLongDoc = this.joint_sample(pdzMat, distSum);
                    int newTopic = topicAndLongDoc[0];
                    int newLongDocIndex = topicAndLongDoc[1];
                    int newLongDoc = (Integer)validLongDocIDList.get(newLongDocIndex);
                    int[] nArray4 = this.U[newTopic];
                    int n6 = newLongDoc;
                    nArray4[n6] = nArray4[n6] + 1;
                    int[] nArray5 = this.V[termID];
                    int n7 = newTopic;
                    nArray5[n7] = nArray5[n7] + 1;
                    int[] nArray6 = this.longDocWordCnts[newLongDoc];
                    int n8 = termID;
                    nArray6[n8] = nArray6[n8] + 1;
                    int n9 = newLongDoc;
                    this.longDocCnts[n9] = this.longDocCnts[n9] + 1;
                    int n10 = newTopic;
                    this.topicCnts[n10] = this.topicCnts[n10] + 1;
                    _assignment[0] = newTopic;
                    _assignment[1] = newLongDoc;
                }
            }
        }
        this.iterTime = System.currentTimeMillis() - startTime;
        this.expName = this.orgExpName;
        System.out.println();
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed for SATM!");
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicWordPros();
        this.writeParameters();
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        for (int d = 0; d < this.numShorDoc; ++d) {
            int k;
            double[] multiTopic = new double[this.numTopics];
            for (k = 0; k < this.numTopics; ++k) {
                multiTopic[k] = 1.0;
                for (int wIndex = 0; wIndex < this.Corpus.get(d).length; ++wIndex) {
                    int word = this.Corpus.get(d)[wIndex];
                    int n = k;
                    multiTopic[n] = multiTopic[n] * (((double)this.V[word][k] + this.beta) / ((double)this.topicCnts[k] + this.betaSum));
                }
            }
            multiTopic = FuncUtils.L1NormWithReusable(multiTopic);
            for (k = 0; k < this.numTopics; ++k) {
                writer.write(multiTopic[k] + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.numTopics; ++tIndex) {
            Map wordCount = new TreeMap<Integer, Integer>();
            for (int wIndex = 0; wIndex < this.vocabularySize; ++wIndex) {
                wordCount.put(wIndex, this.V[wIndex][tIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = ((double)this.V[index][tIndex] + this.beta) / ((double)this.topicCnts[tIndex] + this.betaSum);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(this.id2WordVocabulary.get(index) + " ");
                    ++count;
                    continue;
                }
                writer.write("\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.vocabularySize; ++j) {
                double pro = ((double)this.V[j][i] + this.beta) / ((double)this.topicCnts[i] + this.betaSum);
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tSATM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.numTopics);
        writer.write("\n-nlongdoc\t" + this.numLongDoc);
        writer.write("\n-threshold\t" + this.threshold);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-niters\t" + this.numIterations);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        if (this.tAssignsFilePath.length() > 0) {
            writer.write("\n-initFile\t" + this.tAssignsFilePath);
        }
        if (this.savestep > 0) {
            writer.write("\n-sstep\t" + this.savestep);
        }
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)this.numIterations);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.vocabularySize; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public static void main(String[] args) throws Exception {
        SATM satm = new SATM("dataset/GoogleNews.txt", 152, 1000, 0.001, 0.1, 0.1, 1000, 10, "GoogleNewsSATM");
        satm.inference();
    }
}

