/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class GPUDMM {
    public double alpha;
    public double beta;
    public int numTopics;
    public int numIterations;
    public int topWords;
    public double alphaSum;
    public double betaSum;
    public ArrayList<int[]> Corpus = new ArrayList();
    private Random rg;
    public double threshold;
    public double weight;
    public int filterSize;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int vocabularySize;
    public Map<Integer, Double> wordIDFMap;
    public Map<Integer, Map<Integer, Double>> docUsefulWords;
    public ArrayList<ArrayList<Integer>> topWordIDList;
    public int numDocuments;
    public int numWordsInCorpus;
    public double[][] phi;
    private double[] pz;
    private double[][] pdz;
    private double[][] topicProbabilityGivenWord;
    public ArrayList<ArrayList<Boolean>> wordGPUFlag;
    public int[] assignmentList;
    public ArrayList<ArrayList<Map<Integer, Double>>> wordGPUInfo;
    public int[] docTopicCount;
    public int[][] topicWordCount;
    public int[] sumTopicWordCount;
    private Map<Integer, Map<Integer, Double>> schemaMap;
    public double[] multiPros;
    public String folderPath;
    public String corpusPath;
    public String pathToVector;
    public String expName = "GPUDMMmodel";
    public String orgExpName = "GPUDMMmodel";
    public String tAssignsFilePath = "";
    public int savestep = 0;
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public GPUDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, "GPUDMMmodel");
    }

    public GPUDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, "", 0);
    }

    public GPUDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, pathToTAfile, 0);
    }

    public GPUDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, int inSaveStep) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inNumIterations, inTopWords, inExpName, "", inSaveStep);
    }

    public GPUDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, int inNumIterations, int inTopWords, String inExpName, String pathToTAfile, int inSaveStep) throws Exception {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.numTopics = inNumTopics;
        this.numIterations = inNumIterations;
        this.topWords = inTopWords;
        this.savestep = inSaveStep;
        this.orgExpName = this.expName = inExpName;
        this.weight = inWeight;
        this.filterSize = inFilterSize;
        this.threshold = threshold_GPU;
        this.corpusPath = pathToCorpus;
        this.pathToVector = pathToVector;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.wordGPUFlag = new ArrayList();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        this.rg = new Random();
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                int[] document = new int[words.length];
                int ind = 0;
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document[ind++] = this.word2IdVocabulary.get(word);
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document[ind++] = indexWord;
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.length;
                this.Corpus.add(document);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.vocabularySize = this.word2IdVocabulary.size();
        this.docTopicCount = new int[this.numTopics];
        this.topicWordCount = new int[this.numTopics][this.vocabularySize];
        this.sumTopicWordCount = new int[this.numTopics];
        this.phi = new double[this.numTopics][this.vocabularySize];
        this.pz = new double[this.numTopics];
        this.topicProbabilityGivenWord = new double[this.vocabularySize][this.numTopics];
        this.pdz = new double[this.numDocuments][this.numTopics];
        this.multiPros = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            this.multiPros[i] = 1.0 / (double)this.numTopics;
        }
        this.alphaSum = (double)this.numTopics * this.alpha;
        this.betaSum = (double)this.vocabularySize * this.beta;
        this.assignmentList = new int[this.numDocuments];
        this.wordGPUInfo = new ArrayList();
        this.rg = new Random();
        long startTime = System.currentTimeMillis();
        this.schemaMap = this.computSchema(pathToVector);
        this.initialize();
        this.initTime = System.currentTimeMillis() - startTime;
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.vocabularySize);
        System.out.println("Number of topics: " + this.numTopics);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("weight: " + this.weight);
        System.out.println("filterSize: " + this.filterSize);
        System.out.println("Number of sampling iterations: " + this.numIterations);
        System.out.println("Number of top topical words: " + this.topWords);
    }

    public double computeSis(HashMap<Integer, float[]> wordMap, int i, int j) {
        if (i == j) {
            return 1.0;
        }
        if (!wordMap.containsKey(i) || !wordMap.containsKey(j)) {
            return 0.0;
        }
        float sis = FuncUtils.ComputeCosineSimilarity(wordMap.get(i), wordMap.get(j));
        return sis;
    }

    public Map<Integer, Map<Integer, Double>> computSchema(String pathToVector) {
        HashMap<Integer, Map<Integer, Double>> schemaMap = new HashMap<Integer, Map<Integer, Double>>();
        HashMap<Integer, float[]> wordMap = new HashMap<Integer, float[]>();
        try {
            BufferedReader br1 = new BufferedReader(new FileReader(pathToVector));
            String line = "";
            float vector = 0.0f;
            while ((line = br1.readLine()) != null) {
                String[] word = line.split(" ");
                String word1 = word[0];
                int id = -1;
                if (!this.word2IdVocabulary.containsKey(word1)) continue;
                id = this.word2IdVocabulary.get(word1);
                float[] vec = new float[word.length - 1];
                for (int i = 1; i < word.length; ++i) {
                    vec[i - 1] = vector = Float.parseFloat(word[i]);
                }
                wordMap.put(id, vec);
            }
            br1.close();
            double count = 0.0;
            for (int i = 0; i < this.vocabularySize; ++i) {
                HashMap<Integer, Double> tmpMap = new HashMap<Integer, Double>();
                for (int j = 0; j < this.vocabularySize; ++j) {
                    double v = this.computeSis(wordMap, i, j);
                    if (Double.compare(v, this.threshold) <= 0) continue;
                    tmpMap.put(j, v);
                }
                if (tmpMap.size() > this.filterSize) {
                    tmpMap.clear();
                }
                tmpMap.remove(i);
                if (tmpMap.size() == 0) continue;
                count += (double)tmpMap.size();
                schemaMap.put(i, tmpMap);
            }
            wordMap.clear();
            System.out.println("finish read schema, the avrage number of value is " + count / (double)schemaMap.size());
            return schemaMap;
        }
        catch (Exception e) {
            System.out.println("Error while reading other file:" + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }

    public void initialize() throws IOException {
        System.out.println("V2: Randomly initializing topic assignments ...");
        for (int d = 0; d < this.numDocuments; ++d) {
            int topic;
            int[] termIDArray = this.Corpus.get(d);
            ArrayList<Boolean> docWordGPUFlag = new ArrayList<Boolean>();
            ArrayList docWordGPUInfo = new ArrayList();
            ArrayList d_assignment_list = new ArrayList();
            this.assignmentList[d] = topic = FuncUtils.nextDiscrete(this.multiPros);
            int n = topic;
            this.docTopicCount[n] = this.docTopicCount[n] + 1;
            for (int t = 0; t < termIDArray.length; ++t) {
                int termID = termIDArray[t];
                int[] nArray = this.topicWordCount[topic];
                int n2 = termID;
                nArray[n2] = nArray[n2] + 1;
                int n3 = topic;
                this.sumTopicWordCount[n3] = this.sumTopicWordCount[n3] + 1;
                int n4 = topic;
                this.docTopicCount[n4] = this.docTopicCount[n4] + 1;
                docWordGPUFlag.add(false);
                docWordGPUInfo.add(new HashMap());
            }
            this.wordGPUFlag.add(docWordGPUFlag);
            this.wordGPUInfo.add(docWordGPUInfo);
        }
        System.out.println("finish init_GPU!");
    }

    public void compute_phi() {
        for (int i = 0; i < this.numTopics; ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < this.vocabularySize; ++j) {
                sum += (double)this.topicWordCount[i][j];
            }
            for (j = 0; j < this.vocabularySize; ++j) {
                this.phi[i][j] = ((double)this.topicWordCount[i][j] + this.beta) / (sum + this.betaSum);
            }
        }
    }

    public void compute_pz() {
        int i;
        double sum = 0.0;
        for (i = 0; i < this.numTopics; ++i) {
            sum += (double)this.sumTopicWordCount[i];
        }
        for (i = 0; i < this.numTopics; ++i) {
            this.pz[i] = 1.0 * ((double)this.sumTopicWordCount[i] + this.alpha) / (sum + this.alphaSum);
        }
    }

    public void updateTopicProbabilityGivenWord() {
        this.compute_pz();
        this.compute_phi();
        for (int i = 0; i < this.vocabularySize; ++i) {
            int j;
            double row_sum = 0.0;
            for (j = 0; j < this.numTopics; ++j) {
                this.topicProbabilityGivenWord[i][j] = this.pz[j] * this.phi[j][i];
                row_sum += this.topicProbabilityGivenWord[i][j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                this.topicProbabilityGivenWord[i][j] = this.topicProbabilityGivenWord[i][j] / row_sum;
            }
        }
    }

    public double findTopicMaxProbabilityGivenWord(int wordID) {
        double max = -1.0;
        for (int i = 0; i < this.numTopics; ++i) {
            double tmp = this.topicProbabilityGivenWord[wordID][i];
            if (Double.compare(tmp, max) <= 0) continue;
            max = tmp;
        }
        return max;
    }

    public double getTopicProbabilityGivenWord(int topic, int wordID) {
        return this.topicProbabilityGivenWord[wordID][topic];
    }

    public void updateWordGPUFlag(int docID, int newTopic) {
        int[] termIDArray = this.Corpus.get(docID);
        ArrayList<Boolean> docWordGPUFlag = new ArrayList<Boolean>();
        for (int t = 0; t < termIDArray.length; ++t) {
            double a;
            int termID = termIDArray[t];
            double maxProbability = this.findTopicMaxProbabilityGivenWord(termID);
            double ratio = this.getTopicProbabilityGivenWord(newTopic, termID) / maxProbability;
            docWordGPUFlag.add(Double.compare(ratio, a = this.rg.nextDouble()) > 0);
        }
        this.wordGPUFlag.set(docID, docWordGPUFlag);
    }

    public void ratioCount(Integer topic, Integer docID, int[] termIDArray, int flag) {
        int wordID;
        int t;
        int n = topic;
        this.docTopicCount[n] = this.docTopicCount[n] + flag;
        for (t = 0; t < termIDArray.length; ++t) {
            wordID = termIDArray[t];
            int[] nArray = this.topicWordCount[topic];
            int n2 = wordID;
            nArray[n2] = nArray[n2] + flag;
            int n3 = topic;
            this.sumTopicWordCount[n3] = this.sumTopicWordCount[n3] + flag;
        }
        if (flag > 0) {
            this.updateWordGPUFlag(docID, topic);
            for (t = 0; t < termIDArray.length; ++t) {
                wordID = termIDArray[t];
                boolean gpuFlag = this.wordGPUFlag.get(docID).get(t);
                HashMap<Integer, Double> gpuInfo = new HashMap<Integer, Double>();
                if (gpuFlag && this.schemaMap.containsKey(wordID)) {
                    Map<Integer, Double> valueMap = this.schemaMap.get(wordID);
                    for (Map.Entry<Integer, Double> entry : valueMap.entrySet()) {
                        Integer k = entry.getKey();
                        double v = this.weight;
                        int[] nArray = this.topicWordCount[topic];
                        int n4 = k;
                        nArray[n4] = (int)((double)nArray[n4] + v);
                        int n5 = topic;
                        this.sumTopicWordCount[n5] = (int)((double)this.sumTopicWordCount[n5] + v);
                        gpuInfo.put(k, v);
                    }
                }
                this.wordGPUInfo.get(docID).set(t, gpuInfo);
            }
        } else {
            for (t = 0; t < termIDArray.length; ++t) {
                Map<Integer, Double> gpuInfo = this.wordGPUInfo.get(docID).get(t);
                int wordID2 = termIDArray[t];
                if (gpuInfo.size() <= 0) continue;
                for (int similarWordID : gpuInfo.keySet()) {
                    double v = this.weight;
                    int[] nArray = this.topicWordCount[topic];
                    int n6 = similarWordID;
                    nArray[n6] = (int)((double)nArray[n6] - v);
                    int n7 = topic;
                    this.sumTopicWordCount[n7] = (int)((double)this.sumTopicWordCount[n7] - v);
                }
            }
        }
    }

    public void inference() throws IOException {
        this.writeDictionary();
        System.out.println("Running Gibbs sampling inference: ");
        long startTime = System.currentTimeMillis();
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            if (iter % 50 == 0) {
                System.out.print(" " + iter);
            }
            this.updateTopicProbabilityGivenWord();
            for (int s = 0; s < this.Corpus.size(); ++s) {
                int newTopic;
                int[] termIDArray = this.Corpus.get(s);
                int preTopic = this.assignmentList[s];
                this.ratioCount(preTopic, s, termIDArray, -1);
                for (int topic = 0; topic < this.numTopics; ++topic) {
                    double pz = 1.0 * ((double)this.docTopicCount[topic] + this.alpha) / ((double)(this.numDocuments - 1) + this.alphaSum);
                    double value = 1.0;
                    double logSum = 0.0;
                    for (int t = 0; t < termIDArray.length; ++t) {
                        int termID = termIDArray[t];
                        value *= ((double)this.topicWordCount[topic][termID] + this.beta) / ((double)this.sumTopicWordCount[topic] + this.betaSum + (double)t);
                    }
                    this.multiPros[topic] = value = pz * value;
                }
                this.assignmentList[s] = newTopic = FuncUtils.nextDiscrete(this.multiPros);
                this.ratioCount(newTopic, s, termIDArray, 1);
            }
        }
        this.expName = this.orgExpName;
        this.iterTime = System.currentTimeMillis() - startTime;
        System.out.println();
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed!");
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tGPUDMM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.numTopics);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-threshold\t" + this.threshold);
        writer.write("\n-weight\t" + this.weight);
        writer.write("\n-filterSize\t" + this.filterSize);
        writer.write("\n-niters\t" + this.numIterations);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        if (this.tAssignsFilePath.length() > 0) {
            writer.write("\n-initFile\t" + this.tAssignsFilePath);
        }
        if (this.savestep > 0) {
            writer.write("\n-sstep\t" + this.savestep);
        }
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)this.numIterations);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.vocabularySize; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.numTopics; ++tIndex) {
            Map wordCount = new TreeMap<Integer, Integer>();
            for (int wIndex = 0; wIndex < this.vocabularySize; ++wIndex) {
                wordCount.put(wIndex, this.topicWordCount[tIndex][wIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = ((double)this.topicWordCount[tIndex][index] + this.beta) / ((double)this.sumTopicWordCount[tIndex] + this.betaSum);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(this.id2WordVocabulary.get(index) + " ");
                    ++count;
                    continue;
                }
                writer.write("\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeDocTopicPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        for (int i = 0; i < this.numDocuments; ++i) {
            int tIndex;
            int docSize = this.Corpus.get(i).length;
            double sum = 0.0;
            for (tIndex = 0; tIndex < this.numTopics; ++tIndex) {
                this.multiPros[tIndex] = (double)this.docTopicCount[tIndex] + this.alpha;
                for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                    int word = this.Corpus.get(i)[wIndex];
                    int n = tIndex;
                    this.multiPros[n] = this.multiPros[n] * (((double)this.topicWordCount[tIndex][word] + this.beta) / ((double)this.sumTopicWordCount[tIndex] + this.betaSum));
                }
                sum += this.multiPros[tIndex];
            }
            for (tIndex = 0; tIndex < this.numTopics; ++tIndex) {
                writer.write(this.multiPros[tIndex] / sum + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicAssignments() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topicAssignments"));
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int docSize = this.Corpus.get(dIndex).length;
            int topic = this.assignmentList[dIndex];
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                writer.write(topic + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.vocabularySize; ++j) {
                double pro = ((double)this.topicWordCount[i][j] + this.beta) / ((double)this.sumTopicWordCount[i] + this.betaSum);
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicAssignments();
        this.writeTopicWordPros();
        this.writeParameters();
    }

    public static void main(String[] args) throws Exception {
        GPUDMM gpudmm = new GPUDMM("dataset/Pascal_Flickr.txt", "dataset/glove.6B.200d.txt", 0.7, 0.1, 10, 20, 0.1, 0.1, 500, 10, "Pascal_FlickrGPUDMM");
        gpudmm.inference();
    }
}

