/*
 * Decompiled with CFR 0.152.
 */
package models;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import utility.FuncUtils;

public class GPU_PDMM {
    public int numTopics;
    public double alpha;
    public double beta;
    public double lambda;
    public int numIterations;
    public ArrayList<int[]> Corpus = new ArrayList();
    private Random rg;
    public double threshold;
    public double weight;
    public int topWords;
    public int filterSize;
    public HashMap<String, Integer> word2IdVocabulary;
    public HashMap<Integer, String> id2WordVocabulary;
    public int vocabularySize;
    public Map<Integer, Set<Integer>> ZdMap;
    public int[] TdArray;
    public int numDocuments;
    public int numWordsInCorpus;
    public int maxTd;
    public int searchTopK;
    public Map<Integer, int[]> docToWordIDListMap;
    public double[][] phi;
    private double[] pz;
    private double[][] pdz;
    private double[][] topicProbabilityGivenWord;
    private double[][] pwz;
    public ArrayList<ArrayList<Boolean>> wordGPUFlag;
    public Map<Integer, int[]> assignmentListMap;
    public ArrayList<ArrayList<Map<Integer, Double>>> wordGPUInfo;
    private double[] nz;
    private double[][] nzw;
    private int[] Ck;
    private int CkSum;
    private Map<Integer, Map<Integer, Double>> schemaMap;
    public String folderPath;
    public String corpusPath;
    public String pathToVector;
    public String expName = "GPU_PDMMmodel";
    public String orgExpName = "GPU_PDMMmodel";
    public String tAssignsFilePath = "";
    public int savestep = 0;
    public double initTime = 0.0;
    public double iterTime = 0.0;

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inlambda, inNumIterations, inTopWords, 3);
    }

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords, int inMaxTd) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inlambda, inNumIterations, inTopWords, inMaxTd, 10);
    }

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords, int inMaxTd, int inSearchTopK) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inlambda, inNumIterations, inTopWords, inMaxTd, inSearchTopK, "GPU_PDMMmodel");
    }

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords, int inMaxTd, int inSearchTopK, String inExpName) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inlambda, inNumIterations, inTopWords, inMaxTd, inSearchTopK, inExpName, "", 0);
    }

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords, int inMaxTd, int inSearchTopK, String inExpName, String pathToTAfile) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inlambda, inNumIterations, inTopWords, inMaxTd, inSearchTopK, inExpName, pathToTAfile, 0);
    }

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords, int inMaxTd, int inSearchTopK, String inExpName, int inSaveStep) throws Exception {
        this(pathToCorpus, pathToVector, inWeight, threshold_GPU, inFilterSize, inNumTopics, inAlpha, inBeta, inlambda, inNumIterations, inTopWords, inMaxTd, inSearchTopK, inExpName, "", inSaveStep);
    }

    public GPU_PDMM(String pathToCorpus, String pathToVector, double inWeight, double threshold_GPU, int inFilterSize, int inNumTopics, double inAlpha, double inBeta, double inlambda, int inNumIterations, int inTopWords, int inMaxTd, int inSearchTopK, String inExpName, String pathToTAfile, int inSaveStep) throws Exception {
        this.alpha = inAlpha;
        this.beta = inBeta;
        this.lambda = inlambda;
        this.numTopics = inNumTopics;
        this.numIterations = inNumIterations;
        this.topWords = inTopWords;
        this.weight = inWeight;
        this.filterSize = inFilterSize;
        this.maxTd = inMaxTd;
        this.searchTopK = inSearchTopK;
        this.savestep = inSaveStep;
        this.orgExpName = this.expName = inExpName;
        this.corpusPath = pathToCorpus;
        this.pathToVector = pathToVector;
        this.folderPath = "results/";
        File dir = new File(this.folderPath);
        if (!dir.exists()) {
            dir.mkdir();
        }
        System.out.println("Reading topic modeling corpus: " + pathToCorpus);
        this.word2IdVocabulary = new HashMap();
        this.id2WordVocabulary = new HashMap();
        this.wordGPUFlag = new ArrayList();
        this.ZdMap = new HashMap<Integer, Set<Integer>>();
        this.assignmentListMap = new HashMap<Integer, int[]>();
        this.docToWordIDListMap = new HashMap<Integer, int[]>();
        this.numDocuments = 0;
        this.numWordsInCorpus = 0;
        this.rg = new Random();
        BufferedReader br = null;
        try {
            String doc;
            int indexWord = -1;
            br = new BufferedReader(new FileReader(pathToCorpus));
            while ((doc = br.readLine()) != null) {
                if (doc.trim().length() == 0) continue;
                String[] words = doc.trim().split("\\s+");
                int[] document = new int[words.length];
                int ind = 0;
                for (String word : words) {
                    if (this.word2IdVocabulary.containsKey(word)) {
                        document[ind++] = this.word2IdVocabulary.get(word);
                        continue;
                    }
                    this.word2IdVocabulary.put(word, ++indexWord);
                    this.id2WordVocabulary.put(indexWord, word);
                    document[ind++] = indexWord;
                }
                ++this.numDocuments;
                this.numWordsInCorpus += document.length;
                this.Corpus.add(document);
            }
            br.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        this.vocabularySize = this.word2IdVocabulary.size();
        this.phi = new double[this.numTopics][this.vocabularySize];
        this.pz = new double[this.numTopics];
        this.pwz = new double[this.vocabularySize][this.numTopics];
        this.TdArray = new int[this.Corpus.size()];
        this.topicProbabilityGivenWord = new double[this.vocabularySize][this.numTopics];
        this.pdz = new double[this.numDocuments][this.numTopics];
        this.wordGPUInfo = new ArrayList();
        this.rg = new Random();
        this.nz = new double[this.numTopics];
        this.nzw = new double[this.numTopics][this.vocabularySize];
        this.Ck = new int[this.numTopics];
        this.CkSum = 0;
        long startTime = System.currentTimeMillis();
        this.schemaMap = this.computSchema(pathToVector, threshold_GPU);
        this.initialize();
        this.initTime = System.currentTimeMillis() - startTime;
        System.out.println("Corpus size: " + this.numDocuments + " docs, " + this.numWordsInCorpus + " words");
        System.out.println("Vocabuary size: " + this.vocabularySize);
        System.out.println("Number of topics: " + this.numTopics);
        System.out.println("alpha: " + this.alpha);
        System.out.println("beta: " + this.beta);
        System.out.println("Number of sampling iterations: " + this.numIterations);
        System.out.println("Number of top topical words: " + this.topWords);
    }

    public void initialize() throws IOException {
        System.out.println("Randomly initializing topic assignments ...");
        for (int d = 0; d < this.numDocuments; ++d) {
            int[] termIDArray = this.Corpus.get(d);
            this.assignmentListMap.put(d, new int[termIDArray.length]);
            ArrayList<Boolean> docWordGPUFlag = new ArrayList<Boolean>();
            ArrayList docWordGPUInfo = new ArrayList();
            for (int t = 0; t < termIDArray.length; ++t) {
                docWordGPUFlag.add(false);
                docWordGPUInfo.add(new HashMap());
            }
            this.wordGPUFlag.add(docWordGPUFlag);
            this.wordGPUInfo.add(docWordGPUInfo);
            this.docToWordIDListMap.put(d, termIDArray);
        }
        this.init_GSDMM();
        System.out.println("finish init_GPU-PDMM!");
    }

    public void normalCountWord(Integer topic, int word, Integer flag) {
        double[] dArray = this.nzw[topic];
        int n = word;
        dArray[n] = dArray[n] + (double)flag.intValue();
        int n2 = topic;
        this.nz[n2] = this.nz[n2] + (double)flag.intValue();
    }

    public void normalCountZd(Set<Integer> Zd, Integer flag) {
        Iterator<Integer> iterator = Zd.iterator();
        while (iterator.hasNext()) {
            int topic;
            int n = topic = iterator.next().intValue();
            this.Ck[n] = this.Ck[n] + flag;
            this.CkSum += flag.intValue();
        }
    }

    public void init_GSDMM() {
        int i;
        double[] ptd = new double[this.maxTd];
        double temp_factorial = 1.0;
        for (i = 0; i < this.maxTd; ++i) {
            ptd[i] = Math.pow(this.lambda, i + 1) * Math.exp(-this.lambda) / (temp_factorial *= (double)(i + 1));
        }
        for (i = 1; i < this.maxTd; ++i) {
            int n = i;
            ptd[n] = ptd[n] + ptd[i - 1];
        }
        for (int d = 0; d < this.numDocuments; ++d) {
            double u = this.rg.nextDouble() * ptd[ptd.length - 1];
            int td = -1;
            int length_ptd = ptd.length;
            for (int i2 = 0; i2 < length_ptd; ++i2) {
                if (Double.compare(ptd[i2], u) < 0) continue;
                td = i2 + 1;
                break;
            }
            assert (td >= 1);
            this.TdArray[d] = td;
            HashSet<Integer> Zd = new HashSet<Integer>();
            while (Zd.size() != td) {
                int u_z = this.rg.nextInt(this.numTopics);
                Zd.add(u_z);
            }
            this.ZdMap.put(d, Zd);
            this.normalCountZd(Zd, 1);
            Object[] ZdList = new Object[td];
            ZdList = Zd.toArray();
            int[] termIDArray = this.docToWordIDListMap.get(d);
            int num_word = termIDArray.length;
            for (int w = 0; w < num_word; ++w) {
                int topic_index = this.rg.nextInt(td);
                int topic = (Integer)ZdList[topic_index];
                int word = termIDArray[w];
                this.assignmentListMap.get((Object)Integer.valueOf((int)d))[w] = topic;
                this.normalCountWord(topic, word, 1);
            }
        }
        System.out.println("finish init_GPU_PDMM!");
    }

    public double computeSis(HashMap<Integer, float[]> wordMap, int i, int j) {
        if (i == j) {
            return 1.0;
        }
        if (!wordMap.containsKey(i) || !wordMap.containsKey(j)) {
            return 0.0;
        }
        float sis = FuncUtils.ComputeCosineSimilarity(wordMap.get(i), wordMap.get(j));
        return sis;
    }

    public Map<Integer, Map<Integer, Double>> computSchema(String pathToVector, double threshold) {
        HashMap<Integer, Map<Integer, Double>> schemaMap = new HashMap<Integer, Map<Integer, Double>>();
        HashMap<Integer, float[]> wordMap = new HashMap<Integer, float[]>();
        try {
            BufferedReader br1 = new BufferedReader(new FileReader(pathToVector));
            String line = "";
            float vector = 0.0f;
            while ((line = br1.readLine()) != null) {
                String[] word = line.split(" ");
                String word1 = word[0];
                int id = -1;
                if (!this.word2IdVocabulary.containsKey(word1)) continue;
                id = this.word2IdVocabulary.get(word1);
                float[] vec = new float[word.length - 1];
                for (int i = 1; i < word.length; ++i) {
                    vec[i - 1] = vector = Float.parseFloat(word[i]);
                }
                wordMap.put(id, vec);
            }
            br1.close();
            double count = 0.0;
            for (int i = 0; i < this.vocabularySize; ++i) {
                if (!wordMap.containsKey(i)) continue;
                HashMap<Integer, Double> tmpMap = new HashMap<Integer, Double>();
                for (int j = 0; j < this.vocabularySize; ++j) {
                    double v;
                    if (!wordMap.containsKey(j) || Double.compare(v = this.computeSis(wordMap, i, j), threshold) <= 0) continue;
                    tmpMap.put(j, v);
                }
                if (tmpMap.size() > this.filterSize) {
                    tmpMap.clear();
                }
                tmpMap.remove(i);
                if (tmpMap.size() == 0) continue;
                count += (double)tmpMap.size();
                schemaMap.put(i, tmpMap);
            }
            wordMap.clear();
            System.out.println("finish read schema, the avrage number of value is " + count / (double)schemaMap.size());
            return schemaMap;
        }
        catch (Exception e) {
            System.out.println("Error while reading other file:" + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }

    private static int factorial(int n) {
        int value = 1;
        while (n > 0) {
            value *= n;
            --n;
        }
        return value;
    }

    private int[][] ZdSearchSize() {
        int count = 0;
        int[] boundary = new int[this.maxTd];
        for (int i = 0; i < this.maxTd; ++i) {
            int temp = 1;
            int factorial = GPU_PDMM.factorial(i + 1);
            for (int j = 0; j < i + 1; ++j) {
                temp *= this.searchTopK - j;
            }
            boundary[i] = count += temp / factorial;
        }
        int[][] array = new int[count][];
        block2: for (int i = 0; i < count; ++i) {
            for (int j = 0; j < boundary.length; ++j) {
                if (i >= boundary[j]) continue;
                array[i] = new int[j + 1];
                continue block2;
            }
        }
        return array;
    }

    public void compute_phi() {
        for (int i = 0; i < this.numTopics; ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < this.vocabularySize; ++j) {
                sum += this.nzw[i][j];
            }
            for (j = 0; j < this.vocabularySize; ++j) {
                this.phi[i][j] = (this.nzw[i][j] + this.beta) / (sum + (double)this.vocabularySize * this.beta);
            }
        }
    }

    public void compute_pz() {
        int i;
        double sum = 0.0;
        for (i = 0; i < this.numTopics; ++i) {
            sum += this.nz[i];
        }
        for (i = 0; i < this.numTopics; ++i) {
            this.pz[i] = 1.0 * (this.nz[i] + this.alpha) / (sum + (double)this.numTopics * this.alpha);
        }
    }

    public void compute_pzd() {
        int i;
        for (i = 0; i < this.vocabularySize; ++i) {
            int j;
            double row_sum = 0.0;
            for (j = 0; j < this.numTopics; ++j) {
                this.pwz[i][j] = this.pz[j] * this.phi[j][i];
                row_sum += this.pwz[i][j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                this.pwz[i][j] = this.pwz[i][j] / row_sum;
            }
        }
        for (i = 0; i < this.numDocuments; ++i) {
            int j;
            int[] doc_word_id = this.docToWordIDListMap.get(i);
            double row_sum = 0.0;
            for (j = 0; j < this.numTopics; ++j) {
                this.pdz[i][j] = 0.0;
                for (int wordID : doc_word_id) {
                    double[] dArray = this.pdz[i];
                    int n = j;
                    dArray[n] = dArray[n] + this.pwz[wordID][j];
                }
                row_sum += this.pdz[i][j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                this.pdz[i][j] = this.pdz[i][j] / row_sum;
            }
        }
    }

    public int[][] getTopKTopics(int[][] docTopKTopics) {
        HashSet<Integer> topKTopics = new HashSet<Integer>();
        int minIndex = -1;
        double minValue = 2.0;
        for (int d = 0; d < this.numDocuments; ++d) {
            minValue = 2.0;
            minIndex = -1;
            topKTopics.clear();
            for (int k = 0; k < this.numTopics; ++k) {
                if (topKTopics.size() < this.searchTopK) {
                    topKTopics.add(k);
                    if (Double.compare(minValue, this.pdz[d][k]) <= 0) continue;
                    minValue = this.pdz[d][k];
                    minIndex = k;
                    continue;
                }
                if (Double.compare(minValue, this.pdz[d][k]) >= 0) continue;
                topKTopics.remove(minIndex);
                topKTopics.add(k);
                minIndex = this.minPDZTopicIndex(d, topKTopics);
                minValue = this.pdz[d][minIndex];
            }
            int index = 0;
            Iterator iterator = topKTopics.iterator();
            while (iterator.hasNext()) {
                int topic = (Integer)iterator.next();
                docTopKTopics[d][index++] = topic;
            }
        }
        return docTopKTopics;
    }

    private int minPDZTopicIndex(int doc, Set<Integer> topics) {
        double min = 2.0;
        int minIndex = -1;
        for (int topic : topics) {
            if (Double.compare(min, this.pdz[doc][topic]) <= 0) continue;
            min = this.pdz[doc][topic];
            minIndex = topic;
        }
        return minIndex;
    }

    public double findTopicMaxProbabilityGivenWord(int wordID) {
        double max = -1.0;
        double tmp = -1.0;
        for (int i = 0; i < this.numTopics; ++i) {
            tmp = this.topicProbabilityGivenWord[wordID][i];
            if (Double.compare(tmp, max) <= 0) continue;
            max = tmp;
        }
        return max;
    }

    public double getTopicProbabilityGivenWord(int topic, int wordID) {
        return this.topicProbabilityGivenWord[wordID][topic];
    }

    public void updateWordGPUFlag(int docID, int word, int index, int newTopic) {
        double a;
        double maxProbability = this.findTopicMaxProbabilityGivenWord(word);
        double ratio = this.getTopicProbabilityGivenWord(newTopic, word) / maxProbability;
        if (Double.compare(ratio, a = this.rg.nextDouble()) > 0) {
            this.wordGPUFlag.get(docID).set(index, true);
        } else {
            this.wordGPUFlag.get(docID).set(index, false);
        }
    }

    public void ratioCount(Integer topic, Integer docID, int word, int index, int flag) {
        double[] dArray = this.nzw[topic];
        int n = word;
        dArray[n] = dArray[n] + (double)flag;
        int n2 = topic;
        this.nz[n2] = this.nz[n2] + (double)flag;
        if (flag > 0) {
            this.updateWordGPUFlag(docID, word, index, topic);
            boolean gpuFlag = this.wordGPUFlag.get(docID).get(index);
            HashMap<Integer, Double> gpuInfo = new HashMap<Integer, Double>();
            if (gpuFlag && this.schemaMap.containsKey(word)) {
                Map<Integer, Double> valueMap = this.schemaMap.get(word);
                for (Map.Entry<Integer, Double> entry : valueMap.entrySet()) {
                    int k = entry.getKey();
                    double v = this.weight;
                    double[] dArray2 = this.nzw[topic];
                    int n3 = k;
                    dArray2[n3] = dArray2[n3] + v;
                    int n4 = topic;
                    this.nz[n4] = this.nz[n4] + v;
                    gpuInfo.put(k, v);
                }
            }
            this.wordGPUInfo.get(docID).set(index, gpuInfo);
        } else {
            Map<Integer, Double> gpuInfo = this.wordGPUInfo.get(docID).get(index);
            if (gpuInfo.size() > 0) {
                for (int similarWordID : gpuInfo.keySet()) {
                    double v = this.weight;
                    double[] dArray3 = this.nzw[topic];
                    int n5 = similarWordID;
                    dArray3[n5] = dArray3[n5] - v;
                    int n6 = topic;
                    this.nz[n6] = this.nz[n6] - v;
                }
            }
        }
    }

    private static int enumerateOneTopicSetting(int[][] topicSettingArray, int[] topKTopics, int index) {
        for (int i = 0; i < topKTopics.length; ++i) {
            topicSettingArray[index++][0] = topKTopics[i];
        }
        return index;
    }

    private static int enumerateTwoTopicSetting(int[][] topicSettingArray, int[] topKTopics, int index) {
        for (int i = 0; i < topKTopics.length; ++i) {
            for (int j = i + 1; j < topKTopics.length; ++j) {
                topicSettingArray[index][0] = topKTopics[i];
                topicSettingArray[index++][1] = topKTopics[j];
            }
        }
        return index;
    }

    private static int enumerateThreeTopicSetting(int[][] topicSettingArray, int[] topKTopics, int index) {
        for (int i = 0; i < topKTopics.length; ++i) {
            for (int j = i + 1; j < topKTopics.length; ++j) {
                for (int n = j + 1; n < topKTopics.length; ++n) {
                    topicSettingArray[index][0] = topKTopics[i];
                    topicSettingArray[index][1] = topKTopics[j];
                    topicSettingArray[index++][2] = topKTopics[n];
                }
            }
        }
        return index;
    }

    private static int enumerateFourTopicSetting(int[][] topicSettingArray, int[] topKTopics, int index) {
        for (int i = 0; i < topKTopics.length; ++i) {
            for (int j = i + 1; j < topKTopics.length; ++j) {
                for (int n = j + 1; n < topKTopics.length; ++n) {
                    for (int m = n + 1; m < topKTopics.length; ++m) {
                        topicSettingArray[index][0] = topKTopics[i];
                        topicSettingArray[index][1] = topKTopics[j];
                        topicSettingArray[index][2] = topKTopics[n];
                        topicSettingArray[index++][3] = topKTopics[m];
                    }
                }
            }
        }
        return index;
    }

    private static int[][] enumerateTopicSetting(int[][] topicSettingArray, int[] topKTopics, int maxTd) {
        int index = 0;
        if (maxTd > 0) {
            index = GPU_PDMM.enumerateOneTopicSetting(topicSettingArray, topKTopics, index);
        }
        if (maxTd > 1) {
            index = GPU_PDMM.enumerateTwoTopicSetting(topicSettingArray, topKTopics, index);
        }
        if (maxTd > 2) {
            index = GPU_PDMM.enumerateThreeTopicSetting(topicSettingArray, topKTopics, index);
        }
        if (maxTd > 3) {
            index = GPU_PDMM.enumerateFourTopicSetting(topicSettingArray, topKTopics, index);
        }
        return topicSettingArray;
    }

    private static long getCurrTime() {
        return System.currentTimeMillis();
    }

    public void inference() throws IOException {
        this.writeDictionary();
        long _s = GPU_PDMM.getCurrTime();
        int[][] topicSettingArray = this.ZdSearchSize();
        int[][] docTopKTopics = new int[this.numDocuments][this.searchTopK];
        double[] Ptd_Zd = new double[topicSettingArray.length];
        int[] termIDArray = null;
        int[][] mediateSamples = null;
        HashMap<Integer, int[][]> mediateSampleMap = new HashMap<Integer, int[][]>();
        for (int i = 0; i < this.numDocuments; ++i) {
            termIDArray = this.docToWordIDListMap.get(i);
            mediateSamples = new int[topicSettingArray.length][termIDArray.length];
            mediateSampleMap.put(i, mediateSamples);
        }
        System.out.println("Running Gibbs sampling inference: ");
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            if (iter % 50 == 0) {
                System.out.print(" " + iter);
            }
            this.compute_phi();
            this.compute_pz();
            this.compute_pzd();
            docTopKTopics = this.getTopKTopics(docTopKTopics);
            for (int s = 0; s < this.numDocuments; ++s) {
                int w;
                int length_topicSetting;
                int[] topicSetting;
                int round;
                termIDArray = this.docToWordIDListMap.get(s);
                int num_word = termIDArray.length;
                Set<Integer> preZd = this.ZdMap.get(s);
                this.normalCountZd(preZd, -1);
                mediateSamples = (int[][])mediateSampleMap.get(s);
                for (int w2 = 0; w2 < num_word; ++w2) {
                    this.ratioCount(this.assignmentListMap.get(s)[w2], s, termIDArray[w2], w2, -1);
                }
                topicSettingArray = GPU_PDMM.enumerateTopicSetting(topicSettingArray, docTopKTopics[s], this.maxTd);
                int length_topicSettingArray = topicSettingArray.length;
                for (round = 0; round < length_topicSettingArray; ++round) {
                    int wordID;
                    topicSetting = topicSettingArray[round];
                    length_topicSetting = topicSetting.length;
                    for (w = 0; w < num_word; ++w) {
                        wordID = termIDArray[w];
                        double[] pzDist = new double[length_topicSetting];
                        for (int index = 0; index < length_topicSetting; ++index) {
                            double pz;
                            int topic = topicSetting[index];
                            pzDist[index] = pz = 1.0 * (this.nzw[topic][wordID] + this.beta) / (this.nz[topic] + this.beta * (double)this.vocabularySize);
                        }
                        for (int i = 1; i < length_topicSetting; ++i) {
                            int n = i;
                            pzDist[n] = pzDist[n] + pzDist[i - 1];
                        }
                        double u = this.rg.nextDouble() * pzDist[length_topicSetting - 1];
                        int newTopic = -1;
                        for (int i = 0; i < length_topicSetting; ++i) {
                            if (Double.compare(pzDist[i], u) < 0) continue;
                            newTopic = topicSetting[i];
                            break;
                        }
                        mediateSamples[round][w] = newTopic;
                        this.ratioCount(newTopic, s, wordID, w, 1);
                    }
                    for (w = 0; w < num_word; ++w) {
                        wordID = termIDArray[w];
                        this.ratioCount(mediateSamples[round][w], s, wordID, w, -1);
                    }
                }
                for (round = 0; round < length_topicSettingArray; ++round) {
                    topicSetting = topicSettingArray[round];
                    length_topicSetting = topicSetting.length;
                    double p1 = Math.pow(this.lambda, topicSetting.length) * Math.exp(-this.lambda);
                    double p21 = 1.0;
                    double p22 = 1.0;
                    for (int k : topicSetting) {
                        p21 *= this.alpha + (double)this.Ck[k];
                    }
                    for (int i = 0; i < length_topicSetting; ++i) {
                        p22 *= (double)this.CkSum + (double)this.numTopics * this.alpha - (double)i;
                    }
                    double p2 = p21 / p22;
                    double p31 = 1.0;
                    double p32 = 1.0;
                    Map<Integer, Map<Integer, Integer>> Ndkt = this.getNdkt_Zd(s, topicSetting, mediateSamples[round]);
                    Map<Integer, Integer> Ndk = this.getNdk_Zd(s, topicSetting, mediateSamples[round]);
                    for (int k : topicSetting) {
                        Set<Integer> dk = this.getdk_Zd(s, mediateSamples[round], k);
                        for (int t : dk) {
                            for (int i = 0; i < Ndkt.get(k).get(t); ++i) {
                                p31 *= this.beta + this.nzw[k][t] + (double)Ndkt.get(k).get(t).intValue() - (double)i;
                            }
                        }
                        for (int j = 0; j < Ndk.get(k); ++j) {
                            p32 *= this.nz[k] + this.beta * (double)this.vocabularySize + (double)Ndk.get(k).intValue() - (double)j;
                        }
                        dk.clear();
                    }
                    Ndkt.clear();
                    Ndk.clear();
                    double p3 = p31 / p32;
                    Ptd_Zd[round] = p1 * p2 * p3;
                }
                for (int i = 1; i < length_topicSettingArray; ++i) {
                    int n = i;
                    Ptd_Zd[n] = Ptd_Zd[n] + Ptd_Zd[i - 1];
                }
                double u_ptdzd = this.rg.nextDouble() * Ptd_Zd[length_topicSettingArray - 1];
                int new_index = -1;
                for (int i = 0; i < length_topicSettingArray; ++i) {
                    if (Double.compare(Ptd_Zd[i], u_ptdzd) < 0) continue;
                    new_index = i;
                    break;
                }
                this.TdArray[s] = topicSettingArray[new_index].length;
                preZd.clear();
                for (int k : topicSettingArray[new_index]) {
                    preZd.add(k);
                }
                this.normalCountZd(preZd, 1);
                System.arraycopy(mediateSamples[new_index], 0, this.assignmentListMap.get(s), 0, mediateSamples[new_index].length);
                for (w = 0; w < termIDArray.length; ++w) {
                    this.ratioCount(mediateSamples[new_index][w], s, termIDArray[w], w, 1);
                }
            }
        }
        this.expName = this.orgExpName;
        long _e = GPU_PDMM.getCurrTime();
        this.iterTime = _e - _s;
        System.out.println();
        System.out.println("Writing output from the last sample ...");
        this.write();
        System.out.println("Sampling completed!");
    }

    public void writeTopTopicalWords() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topWords"));
        block0: for (int tIndex = 0; tIndex < this.numTopics; ++tIndex) {
            Map wordCount = new TreeMap<Integer, Double>();
            for (int wIndex = 0; wIndex < this.vocabularySize; ++wIndex) {
                wordCount.put(wIndex, this.nzw[tIndex][wIndex]);
            }
            wordCount = FuncUtils.sortByValueDescending(wordCount);
            Set mostLikelyWords = wordCount.keySet();
            int count = 0;
            for (Integer index : mostLikelyWords) {
                if (count < this.topWords) {
                    double pro = (this.nzw[tIndex][index] + this.beta) / (this.nz[tIndex] + this.beta * (double)this.vocabularySize);
                    pro = (double)Math.round(pro * 1000000.0) / 1000000.0;
                    writer.write(this.id2WordVocabulary.get(index) + " ");
                    ++count;
                    continue;
                }
                writer.write("\n");
                continue block0;
            }
        }
        writer.close();
    }

    public void writeDocTopicPros() throws IOException {
        int i;
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".theta"));
        for (i = 0; i < this.vocabularySize; ++i) {
            int j;
            double row_sum = 0.0;
            for (j = 0; j < this.numTopics; ++j) {
                this.pwz[i][j] = this.pz[j] * this.phi[j][i];
                row_sum += this.pwz[i][j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                this.pwz[i][j] = this.pwz[i][j] / row_sum;
            }
        }
        for (i = 0; i < this.numDocuments; ++i) {
            int j;
            int[] doc_word_id = this.docToWordIDListMap.get(i);
            double row_sum = 0.0;
            for (j = 0; j < this.numTopics; ++j) {
                this.pdz[i][j] = 0.0;
                for (int wordID : doc_word_id) {
                    double[] dArray = this.pdz[i];
                    int n = j;
                    dArray[n] = dArray[n] + this.pwz[wordID][j];
                }
                row_sum += this.pdz[i][j];
            }
            for (j = 0; j < this.numTopics; ++j) {
                this.pdz[i][j] = this.pdz[i][j] / row_sum;
                writer.write(this.pdz[i][j] + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicAssignments() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".topicAssignments"));
        for (int dIndex = 0; dIndex < this.numDocuments; ++dIndex) {
            int docSize = this.Corpus.get(dIndex).length;
            for (int wIndex = 0; wIndex < docSize; ++wIndex) {
                writer.write(this.assignmentListMap.get(dIndex)[wIndex] + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void writeTopicWordPros() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".phi"));
        for (int i = 0; i < this.numTopics; ++i) {
            for (int j = 0; j < this.vocabularySize; ++j) {
                double pro = (this.nzw[i][j] + this.beta) / (this.nz[i] + (double)this.vocabularySize * this.beta);
                writer.write(pro + " ");
            }
            writer.write("\n");
        }
        writer.close();
    }

    public void write() throws IOException {
        this.writeTopTopicalWords();
        this.writeDocTopicPros();
        this.writeTopicAssignments();
        this.writeTopicWordPros();
        this.writeParameters();
    }

    public Set<Integer> getdk_Zd(int docID, int[] assignment, int topic) {
        HashSet<Integer> dk = new HashSet<Integer>();
        int[] termIDArray = this.docToWordIDListMap.get(docID);
        int length = assignment.length;
        for (int i = 0; i < length; ++i) {
            int z = assignment[i];
            if (z != topic) continue;
            dk.add(termIDArray[i]);
        }
        return dk;
    }

    public Map<Integer, Map<Integer, Integer>> getNdkt_Zd(int docID, int[] ZdList, int[] assignment) {
        HashMap<Integer, Map<Integer, Integer>> Ndkt = new HashMap<Integer, Map<Integer, Integer>>();
        for (int k : ZdList) {
            Ndkt.put(k, new HashMap());
        }
        int[] termIDArray = this.docToWordIDListMap.get(docID);
        int length = termIDArray.length;
        for (int i = 0; i < length; ++i) {
            int word = termIDArray[i];
            int topic = assignment[i];
            if (((Map)Ndkt.get(topic)).containsKey(word)) {
                ((Map)Ndkt.get(topic)).put(word, (Integer)((Map)Ndkt.get(topic)).get(word) + 1);
                continue;
            }
            ((Map)Ndkt.get(topic)).put(word, 1);
        }
        return Ndkt;
    }

    public Map<Integer, Integer> getNdk_Zd(int docID, int[] ZdList, int[] assignment) {
        HashMap<Integer, Integer> Ndk = new HashMap<Integer, Integer>();
        for (int k : ZdList) {
            Ndk.put(k, 0);
        }
        int[] termIDArray = this.docToWordIDListMap.get(docID);
        int length = termIDArray.length;
        for (int i = 0; i < length; ++i) {
            int word = termIDArray[i];
            int topic = assignment[i];
            Ndk.put(topic, (Integer)Ndk.get(topic) + 1);
        }
        return Ndk;
    }

    public void writeParameters() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".paras"));
        writer.write("-model\tGPU_PDMM");
        writer.write("\n-corpus\t" + this.corpusPath);
        writer.write("\n-ntopics\t" + this.numTopics);
        writer.write("\n-alpha\t" + this.alpha);
        writer.write("\n-beta\t" + this.beta);
        writer.write("\n-lambda\t" + this.lambda);
        writer.write("\n-threshold\t" + this.threshold);
        writer.write("\n-weight\t" + this.weight);
        writer.write("\n-filterSize\t" + this.filterSize);
        writer.write("\n-niters\t" + this.numIterations);
        writer.write("\n-twords\t" + this.topWords);
        writer.write("\n-\t" + this.topWords);
        writer.write("\n-name\t" + this.expName);
        if (this.tAssignsFilePath.length() > 0) {
            writer.write("\n-initFile\t" + this.tAssignsFilePath);
        }
        if (this.savestep > 0) {
            writer.write("\n-sstep\t" + this.savestep);
        }
        writer.write("\n-initiation time\t" + this.initTime);
        writer.write("\n-one iteration time\t" + this.iterTime / (double)this.numIterations);
        writer.write("\n-total time\t" + (this.initTime + this.iterTime));
        writer.close();
    }

    public void writeDictionary() throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(this.folderPath + this.expName + ".vocabulary"));
        for (int id = 0; id < this.vocabularySize; ++id) {
            writer.write(this.id2WordVocabulary.get(id) + " " + id + "\n");
        }
        writer.close();
    }

    public static void main(String[] args) throws Exception {
        GPU_PDMM gpupdmm = new GPU_PDMM("dataset/Tweet.txt", "dataset/glove.6B.200d.txt", 0.7, 0.1, 20, 100, 0.1, 0.1, 1.5, 500, 10, 3, 10, "TweetGPU_PDMM");
        gpupdmm.inference();
    }
}

