/**
 * Copyright (C) 2013-2020 Vasilis Vryniotis <bbriniotis@datumbox.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.datumbox.framework.core.machinelearning.topicmodeling;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.*;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.common.storage.interfaces.StorageEngine.MapType;
import com.datumbox.framework.common.storage.interfaces.StorageEngine.StorageHint;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.common.utilities.PHPMethods;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractTopicModeler;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;

import java.util.Arrays;
import java.util.List;
import java.util.Map;


/**
 * Implementation of the Latent Dirichlet Allocation algorithm.
 * 
 * References:
 * http://videolectures.net/mlss09uk_blei_tm/
 * http://www.ncbi.nlm.nih.gov/pmc/articles/PMC387300/
 * https://github.com/angeloskath/php-nlp-tools/blob/master/src/NlpTools/Models/Lda.php
 * https://gist.github.com/mblondel/542786
 * http://stats.stackexchange.com/questions/9315/topic-prediction-using-latent-dirichlet-allocation
 * http://home.uchicago.edu/~lkorsos/GibbsNGramLDA.pdf
 * http://machinelearning.wustl.edu/mlpapers/paper_files/BleiNJ03.pdf
 * http://www.cl.cam.ac.uk/teaching/1213/L101/clark_lectures/lect7.pdf
 * http://www.ics.uci.edu/~newman/pubs/fastlda.pdf
 * http://www.tnkcs.inf.elte.hu/vedes/Biro_Istvan_Tezisek_en.pdf (Limit Gibbs Sampler & unseen inference)
 * http://airweb.cse.lehigh.edu/2008/submissions/biro_2008_latent_dirichlet_allocation_spam.pdf (unseen inference)
 * http://www.cs.cmu.edu/~akyrola/10702project/kyrola10702FINAL.pdf 
 * http://stats.stackexchange.com/questions/18167/how-to-calculate-perplexity-of-a-holdout-with-latent-dirichlet-allocation
 * http://www.slideserve.com/adamdaniel/an-introduction-to-latent-dirichlet-allocation-lda
 *
 * @author Vasilis Vryniotis <bbriniotis@datumbox.com>
 */
public class LatentDirichletAllocation extends AbstractTopicModeler<LatentDirichletAllocation.ModelParameters, LatentDirichletAllocation.TrainingParameters> {
    
    /** {@inheritDoc} */
    public static class ModelParameters extends AbstractTopicModeler.AbstractModelParameters {
        private static final long serialVersionUID = 1L;

        //number of features in data points used for training
        private Integer d = 0;

        private int totalIterations;
        
        @BigMap(keyClass=List.class, valueClass=Integer.class, mapType=MapType.HASHMAP, storageHint=StorageHint.IN_CACHE, concurrent=false)
        private Map<List<Object>, Integer> topicAssignmentOfDocumentWord; //the Z in the graphical model

        @BigMap(keyClass=List.class, valueClass=Integer.class, mapType=MapType.HASHMAP, storageHint=StorageHint.IN_MEMORY, concurrent=false)
        private Map<List<Integer>, Integer> documentTopicCounts; //the nj(d) in the papers

        @BigMap(keyClass=List.class, valueClass=Integer.class, mapType=MapType.HASHMAP, storageHint=StorageHint.IN_CACHE, concurrent=false)
        private Map<List<Object>, Integer> topicWordCounts; //the nj(w) in the papers
        
        @BigMap(keyClass=Integer.class, valueClass=Integer.class, mapType=MapType.HASHMAP, storageHint=StorageHint.IN_MEMORY, concurrent=false)
        private Map<Integer, Integer> documentWordCounts; //the n.(d) in the papers
        
        @BigMap(keyClass=Integer.class, valueClass=Integer.class, mapType=MapType.HASHMAP, storageHint=StorageHint.IN_MEMORY, concurrent=false)
        private Map<Integer, Integer> topicCounts; //the nj(.) in the papers
        
        /** 
         * @param storageEngine
         * @see AbstractTrainer.AbstractModelParameters#AbstractModelParameters(StorageEngine)
         */
        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        /**
         * Getter for the dimension of the dataset used in training.
         *
         * @return
         */
        public Integer getD() {
            return d;
        }

        /**
         * Setter for the dimension of the dataset used in training.
         *
         * @param d
         */
        protected void setD(Integer d) {
            this.d = d;
        }
        
        /**
         * Getter for the total number of iterations performed during training.
         * 
         * @return 
         */
        public int getTotalIterations() {
            return totalIterations;
        }
        
        /**
         * Setter for the total number of iterations performed during training.
         * 
         * @param totalIterations 
         */
        protected void setTotalIterations(int totalIterations) {
            this.totalIterations = totalIterations;
        }
        
        /**
         * Getter for the Topic Assignments of the words of the document. 
         * It returns the topic assignments of a particular word in a particular document.
         * It is a key-value of {@literal [<Integer>, <Object>]} => Integer 
         * The key is a combination of Record.id and Word Position number (also 
         * an Integer but stored as object because the Record stores columns as Objects).
         * The value is the Id of the topic to which the word is assigned. 
         * 
         * @return 
         */
        public Map<List<Object>, Integer> getTopicAssignmentOfDocumentWord() {
            return topicAssignmentOfDocumentWord;
        }
        
        /**
         * Setter for the Topic Assignments of the words of the document. 
         * 
         * @param topicAssignmentOfDocumentWord 
         */
        protected void setTopicAssignmentOfDocumentWord(Map<List<Object>, Integer> topicAssignmentOfDocumentWord) {
            this.topicAssignmentOfDocumentWord = topicAssignmentOfDocumentWord;
        }
        
        /**
         * Getter for the Document's Topic counts.
         * It contains counts the number of occurrences of topics in a particular document
         * (in other words the number of times that a word from the particular
         * document has been assigned to a particular topic).
         * It is a key value of {@literal [<Integer>, <Integer>]} => Integer
         * The key is a combination of Record.id and Topic id.
         * The value is the number of counts of the pair.
         * 
         * @return 
         */
        public Map<List<Integer>, Integer> getDocumentTopicCounts() {
            return documentTopicCounts;
        }
        
        /**
         * setter for the Document's Topic counts.
         * 
         * @param documentTopicCounts 
         */
        protected void setDocumentTopicCounts(Map<List<Integer>, Integer> documentTopicCounts) {
            this.documentTopicCounts = documentTopicCounts;
        }
        
        /**
         * Getter for the topic-word counts.
         * It counts the number of times a particular word is assigned to a particular
         * topic.
         * It is a key value of {@literal [<Integer>, <Object>]} => Integer
         * The key is a combination of Topic id and Record Value which should normally
         * be a String (the word) but is stored in the associative array of the
         * record as an Object.
         * The value is the number of counts of the pair.
         * 
         * @return 
         */
        public Map<List<Object>, Integer> getTopicWordCounts() {
            return topicWordCounts;
        }
        
        /**
         * Setter for the topic-word counts.
         * 
         * @param topicWordCounts 
         */
        protected void setTopicWordCounts(Map<List<Object>, Integer> topicWordCounts) {
            this.topicWordCounts = topicWordCounts;
        }

        /**
         * Getter for the number of words in each document.
         * Even though this information is available to us from Record.size(), we 
         * store this info to be able to compute the probability of θ without 
         * having access on the original Record.
         * It is a key value of Integer => Integer
         * The key is the Record.Id.
         * The Value is the number of counts.
         * 
         * @return 
         */
        public Map<Integer, Integer> getDocumentWordCounts() {
            return documentWordCounts;
        }
        
        /**
         * Setter for the number of words in each document.
         * 
         * @param documentWordCounts 
         */
        protected void setDocumentWordCounts(Map<Integer, Integer> documentWordCounts) {
            this.documentWordCounts = documentWordCounts;
        }

        /**
         * Getter for the number of words assigned to the particular topic.
         * It is a key value of Integer => Integer
         * The key the Topic Id. The Value is the number of counts.
         * 
         * @return 
         */
        public Map<Integer, Integer> getTopicCounts() {
            return topicCounts;
        }
        
        /**
         * Setter for the number of words assigned to the particular topic.
         * 
         * @param topicCounts 
         */
        protected void setTopicCounts(Map<Integer, Integer> topicCounts) {
            this.topicCounts = topicCounts;
        }
        
    }  
    
    /** {@inheritDoc} */
    public static class TrainingParameters extends AbstractTopicModeler.AbstractTrainingParameters {  
        private static final long serialVersionUID = 1L;
        
        private int k = 2; //number of topics
        private int maxIterations = 50; //both for training and testing
        
        //a good value for alpha and beta is to set them equal to 1.0/k
        private double alpha = 1.0; //the hyperparameter of dirichlet prior for document topic distribution
        private double beta = 1.0; //the hyperparameter of dirichlet prior for word topic distribution
        
        /**
         * Getter for the total number of topics k.
         * 
         * @return 
         */
        public int getK() {
            return k;
        }
        
        /**
         * Setter for the total number of topics k.
         * 
         * @param k 
         */
        public void setK(int k) {
            this.k = k;
        }
        
        /**
         * Getter for the total number of max iterations permitted in the training.
         * 
         * @return 
         */
        public int getMaxIterations() {
            return maxIterations;
        }
        
        /**
         * Setter for the total number of max iterations permitted in the training.
         * 
         * @param maxIterations 
         */
        public void setMaxIterations(int maxIterations) {
            this.maxIterations = maxIterations;
        }
        
        /**
         * Getter for the hyperparameter of dirichlet prior for document topic distribution.
         * 
         * @return 
         */
        public double getAlpha() {
            return alpha;
        }
        
        /**
         * Setter for the hyperparameter of dirichlet prior for document topic distribution.
         * 
         * @param alpha 
         */
        public void setAlpha(double alpha) {
            this.alpha = alpha;
        }
        
        /**
         * Getter for the hyperparameter of dirichlet prior for word topic distribution.
         * 
         * @return 
         */
        public double getBeta() {
            return beta;
        }
        
        /**
         * Setter for the hyperparameter of dirichlet prior for word topic distribution.
         * 
         * @param beta 
         */
        public void setBeta(double beta) {
            this.beta = beta;
        }
        
    }

    /**
     * @param trainingParameters
     * @param configuration
     * @see AbstractTrainer#AbstractTrainer(AbstractTrainingParameters, Configuration)
     */
    protected LatentDirichletAllocation(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    /**
     * @param storageName
     * @param configuration
     * @see AbstractTrainer#AbstractTrainer(String, Configuration)
     */
    protected LatentDirichletAllocation(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }
    
    /**
     * Returns the distribution of the words in each topic.
     * 
     * @return 
     */
    public AssociativeArray2D getWordProbabilitiesPerTopic() {
        AssociativeArray2D ptw = new AssociativeArray2D();
        
        ModelParameters modelParameters = knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();
        
        //initialize a probability list for every topic
        int k = trainingParameters.getK();
        for(int topicId=0;topicId<k;++topicId) {
            ptw.put(topicId, new AssociativeArray());
        }
        
        int d = modelParameters.getD();
        double beta = trainingParameters.getBeta();
        
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        for(Map.Entry<List<Object>, Integer> entry : topicWordCounts.entrySet()) {
            List<Object> tpk = entry.getKey();
            Integer topicId = (Integer)tpk.get(0);
            Object word = tpk.get(1);
            Integer njw = entry.getValue();
            
            Integer nj = topicCounts.get(topicId);
            
            double probability = (njw+beta)/(nj+beta*d);
            
            ptw.get(topicId).put(word, probability);
        }
        
        for(int topicId=0;topicId<k;++topicId) {
            ptw.put(topicId, MapMethods.sortAssociativeArrayByValueDescending(ptw.get(topicId)));
        }
        
        return ptw;
    }
    
    /** {@inheritDoc} */
    @Override
    protected void _fit(Dataframe trainingData) {
        ModelParameters modelParameters = knowledgeBase.getModelParameters();
        modelParameters.setD(trainingData.xColumnSize());

        int d = modelParameters.getD();
        
        TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

        
        //get model parameters
        int k = trainingParameters.getK(); //number of topics
        Map<List<Object>, Integer> topicAssignmentOfDocumentWord = modelParameters.getTopicAssignmentOfDocumentWord();
        Map<List<Integer>, Integer> documentTopicCounts = modelParameters.getDocumentTopicCounts();
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> documentWordCounts = modelParameters.getDocumentWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        
        //initialize topic assignments of each word randomly and update the counters
        for(Map.Entry<Integer, Record> e : trainingData.entries()) {
            Integer rId = e.getKey();
            Record r = e.getValue();
            Integer documentId = rId;
            
            documentWordCounts.put(documentId, r.getX().size());
            
            for(Map.Entry<Object, Object> entry : r.getX().entrySet()) {
                Object wordPosition = entry.getKey();
                Object word = entry.getValue();
                
                //sample a topic
                Integer topic = PHPMethods.mt_rand(0,k-1);
                
                increase(topicCounts, topic);
                topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                increase(documentTopicCounts, Arrays.asList(documentId, topic));
                increase(topicWordCounts, Arrays.asList(topic, word));
            }
        }
        
        
        double alpha = trainingParameters.getAlpha();
        double beta = trainingParameters.getBeta();
        
        int maxIterations = trainingParameters.getMaxIterations();
        
        int iteration=0;
        while(iteration<maxIterations) {
            
            logger.debug("Iteration {}", iteration);
            
            int changedCounter = 0;
            //collapsed gibbs sampler
            for(Map.Entry<Integer, Record> e : trainingData.entries()) {
                Integer rId = e.getKey();
                Record r = e.getValue();
                Integer documentId = rId;
                
                AssociativeArray topicAssignments = new AssociativeArray();
                for(int j=0;j<k;++j) {
                    topicAssignments.put(j, 0.0);
                }
                
                int totalWords = r.getX().size();
                
                for(Map.Entry<Object, Object> entry : r.getX().entrySet()) {
                    Object wordPosition = entry.getKey();
                    Object word = entry.getValue();
            
                    
                    //remove the word from the dataset
                    Integer topic = topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
                    //decrease(documentWordCounts, documentId); //slow
                    decrease(topicCounts, topic);
                    decrease(documentTopicCounts, Arrays.asList(documentId, topic));
                    decrease(topicWordCounts, Arrays.asList(topic, word));
                    
                    //int numberOfDocumentWords = r.getX().size()-1; //fast - decreased by 1
                    
                    //compute the posteriors of the topics and sample from it
                    AssociativeArray topicProbabilities = new AssociativeArray();
                    for(int j=0;j<k;++j) {
                        Integer njw = topicWordCounts.get(Arrays.asList(j,word));
                        double enumerator;
                        if(njw !=null) {
                            enumerator = njw + beta;
                        }
                        else {
                            enumerator = beta;
                        }
                        
                        Integer njd = documentTopicCounts.get(Arrays.asList(documentId, j));
                        if(njd != null) {
                            enumerator *= (njd + alpha);
                        }
                        else {
                            enumerator *= alpha;
                        }
                        
                        double denominator = topicCounts.get((Integer)j)+beta*d;
                        //denominator *= numberOfDocumentWords+alpha*k; //this is not necessary because it is the same for all categories, so it can be omited
                        
                        topicProbabilities.put(j, enumerator/denominator);
                    }

                    //normalize probabilities
                    //Descriptives.normalize(topicProbabilities);
                    
                    //sample from these probabilieis
                    Integer newTopic = (Integer) SimpleRandomSampling.weightedSampling(topicProbabilities, 1, true).iterator().next();
                    topic = newTopic; //new topic assigment
                    
                    //add back the word in the dataset
                    topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                    //increase(documentWordCounts, documentId); //slow
                    increase(topicCounts, topic);
                    increase(documentTopicCounts, Arrays.asList(documentId, topic));
                    increase(topicWordCounts, Arrays.asList(topic, word));
                    
                    topicAssignments.put(topic, TypeInference.toDouble(topicAssignments.get(topic))+1.0/totalWords);
                }
                
                Object mainTopic=MapMethods.selectMaxKeyValue(topicAssignments).getKey();
                
                if(!mainTopic.equals(r.getYPredicted())) {
                    ++changedCounter;
                }
                trainingData._unsafe_set(rId, new Record(r.getX(), r.getY(), mainTopic, topicAssignments));
            }
            ++iteration;
            
            logger.debug("Reassigned Records {}", changedCounter);
            
            if(changedCounter==0) {
                break;
            }
        }
        
        modelParameters.setTotalIterations(iteration);
        
    }
    
    /**
     * Utility method that increases the map value by 1.
     * 
     * @param <K>
     * @param map
     * @param key 
     */
    private <K> void increase(Map<K, Integer> map, K key) {
        map.put(key, map.getOrDefault(key, 0)+1);
    }

    /**
     * Utility method that decreases the map value by 1.
     * @param <K>
     * @param map
     * @param key 
     */
    private <K> void decrease(Map<K, Integer> map, K key) {
        map.put(key, map.getOrDefault(key, 0)-1);
    }

    /** {@inheritDoc} */
    @Override
    protected void _predict(Dataframe newData) {
        //This method uses similar approach to the training but the most important
        //difference is that we do not wish to modify the original training params.
        //as a result we need to modify the code to use additional temporary
        //counts for the testing data and merge them with the parameters from the
        //training data in order to make a decision
        ModelParameters modelParameters = knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = knowledgeBase.getTrainingParameters();

        //get model parameters
        int d = modelParameters.getD();
        int k = trainingParameters.getK(); //number of topics
        
        
        Map<List<Object>, Integer> topicWordCounts = modelParameters.getTopicWordCounts();
        Map<Integer, Integer> topicCounts = modelParameters.getTopicCounts();
        
        
        StorageEngine storageEngine = knowledgeBase.getStorageEngine();
        
        //we create temporary maps for the prediction sets to avoid modifing the maps that we already learned
        Map<List<Object>, Integer> tmp_topicAssignmentOfDocumentWord = storageEngine.getBigMap("tmp_topicAssignmentOfDocumentWord", (Class<List<Object>>)(Class<?>)List.class, Integer.class, MapType.HASHMAP, StorageHint.IN_CACHE, false, true);
        Map<List<Integer>, Integer> tmp_documentTopicCounts = storageEngine.getBigMap("tmp_documentTopicCounts", (Class<List<Integer>>)(Class<?>)List.class, Integer.class, MapType.HASHMAP, StorageHint.IN_MEMORY, false, true);
        Map<List<Object>, Integer> tmp_topicWordCounts = storageEngine.getBigMap("tmp_topicWordCounts", (Class<List<Object>>)(Class<?>)List.class, Integer.class, MapType.HASHMAP, StorageHint.IN_CACHE, false, true);
        Map<Integer, Integer> tmp_topicCounts = storageEngine.getBigMap("tmp_topicCounts", Integer.class, Integer.class, MapType.HASHMAP, StorageHint.IN_MEMORY, false, true);
        
        //initialize topic assignments of each word randomly and update the counters
        for(Map.Entry<Integer, Record> e : newData.entries()) {
            Integer rId = e.getKey();
            Record r = e.getValue();
            Integer documentId = rId;
            
            for(Map.Entry<Object, Object> entry : r.getX().entrySet()) {
                Object wordPosition = entry.getKey();
                Object word = entry.getValue();

                //sample a topic
                Integer topic = PHPMethods.mt_rand(0,k-1);
                
                increase(tmp_topicCounts, topic);
                tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
                increase(tmp_topicWordCounts, Arrays.asList(topic, word));
            }
        }
        
        
        double alpha = trainingParameters.getAlpha();
        double beta = trainingParameters.getBeta();
        
        int maxIterations = trainingParameters.getMaxIterations();

        for(int iteration=0;iteration<maxIterations;++iteration) {
            
            logger.debug("Iteration {}", iteration);
            
            
            //collapsed gibbs sampler
            int changedCounter = 0;
            double perplexity = 0.0;
            double totalDatasetWords = 0.0;
            for(Map.Entry<Integer, Record> e : newData.entries()) {
                Integer rId = e.getKey();
                Record r = e.getValue();
                Integer documentId = rId;
                
                
                AssociativeArray topicAssignments = new AssociativeArray();
                for(int j=0;j<k;++j) {
                    topicAssignments.put(j, 0.0);
                }
                
                int totalDocumentWords = r.getX().size();
                totalDatasetWords+=totalDocumentWords;
                for(Map.Entry<Object, Object> entry : r.getX().entrySet()) {
                    Object wordPosition = entry.getKey();
                    Object word = entry.getValue();
            
                    
                    //remove the word from the dataset
                    Integer topic = tmp_topicAssignmentOfDocumentWord.get(Arrays.asList(documentId, wordPosition));
                    decrease(tmp_topicCounts, topic);
                    decrease(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
                    decrease(tmp_topicWordCounts, Arrays.asList(topic, word));
                        
                    int numberOfDocumentWords = r.getX().size()-1;
                    
                    //compute the posteriors of the topics and sample from it
                    AssociativeArray topicProbabilities = new AssociativeArray();
                    for(int j=0;j<k;++j) {
                        //get the counts from the current testing data
                        List<Object> topicWordKey = Arrays.asList(j,word);
                        Integer njw = tmp_topicWordCounts.get(topicWordKey);
                        
                        double enumerator;
                        if(njw !=null) {
                            enumerator = njw + beta;
                        }
                        else {
                            enumerator = beta;
                        }
                        
                        //get also the counts from the training data
                        Integer njw_original = topicWordCounts.get(topicWordKey);
                        if(njw_original!=null) {
                            enumerator+=njw_original;
                        }
                        
                        Integer njd = tmp_documentTopicCounts.get(Arrays.asList(documentId, j));
                        if(njd != null) {
                            enumerator *= (njd + alpha);
                        }
                        else {
                            enumerator *= alpha;
                        }
                        
                        //add the counts from testing data
                        double denominator = tmp_topicCounts.get((Integer)j)+beta*d -1;
                        //and the ones from training data
                        denominator+=topicCounts.get((Integer)j);
                        denominator *= numberOfDocumentWords+alpha*k;
                        
                        topicProbabilities.put(j, enumerator/denominator);
                    }
                    
                    perplexity += Math.log(Descriptives.sum(topicProbabilities.toFlatDataCollection()));
                    
                    //normalize probabilities
                    //Descriptives.normalize(topicProbabilities);
                    
                    //sample from these probabilieis
                    Integer newTopic = (Integer)SimpleRandomSampling.weightedSampling(topicProbabilities, 1, true).iterator().next();
                    topic = newTopic; //new topic assignment
                    
                    
                    //add back the word in the dataset
                    tmp_topicAssignmentOfDocumentWord.put(Arrays.asList(documentId, wordPosition), topic);
                    increase(tmp_topicCounts, topic);
                    increase(tmp_documentTopicCounts, Arrays.asList(documentId, topic));
                    increase(tmp_topicWordCounts, Arrays.asList(topic, word));
                    
                    topicAssignments.put(topic, TypeInference.toDouble(topicAssignments.get(topic))+1.0/totalDocumentWords);
                }
                
                Object mainTopic=MapMethods.selectMaxKeyValue(topicAssignments).getKey();
                
                if(!mainTopic.equals(r.getYPredicted())) {
                    ++changedCounter;
                }                
                newData._unsafe_set(rId, new Record(r.getX(), r.getY(), mainTopic, topicAssignments));
            }

            perplexity=Math.exp(-perplexity/totalDatasetWords);
            
            logger.debug("Reassigned Records {} - Perplexity: {}", changedCounter, perplexity);
            
            if(changedCounter==0) {
                break;
            }            
        }
        
        //Drop the temporary Collection
        storageEngine.dropBigMap("tmp_topicAssignmentOfDocumentWord", tmp_topicAssignmentOfDocumentWord);
        storageEngine.dropBigMap("tmp_documentTopicCounts", tmp_documentTopicCounts);
        storageEngine.dropBigMap("tmp_topicWordCounts", tmp_topicWordCounts);
        storageEngine.dropBigMap("tmp_topicCounts", tmp_topicCounts);
    }
}
