/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.ClassifierFactory;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogConditionalEqConstraintFunction;
import edu.stanford.nlp.classify.LogConditionalObjectiveFunction;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.NaiveBayesClassifier;
import edu.stanford.nlp.classify.NominalDataReader;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class NaiveBayesClassifierFactory<L, F>
implements ClassifierFactory<L, F, NaiveBayesClassifier<L, F>> {
    private static final long serialVersionUID = -8164165428834534041L;
    public static final int JL = 0;
    public static final int CL = 1;
    public static final int UCL = 2;
    int kind = 0;
    double alphaClass;
    double alphaFeature;
    double sigma;
    int prior = LogPrior.LogPriorType.NULL.ordinal();
    Index<L> labelIndex;
    Index<F> featureIndex;

    public NaiveBayesClassifierFactory() {
    }

    public NaiveBayesClassifierFactory(double alphaC, double alphaF, double sigma, int prior, int kind) {
        this.alphaClass = alphaC;
        this.alphaFeature = alphaF;
        this.sigma = sigma;
        this.prior = prior;
        this.kind = kind;
    }

    private NaiveBayesClassifier<L, F> trainClassifier(int[][] data, int[] labels, int numFeatures, int numClasses, Index<L> labelIndex, Index<F> featureIndex) {
        HashSet<L> labelSet = new HashSet<L>();
        NBWeights nbWeights = this.trainWeights(data, labels, numFeatures, numClasses);
        ClassicCounter<L> priors = new ClassicCounter<L>();
        double[] pr = nbWeights.priors;
        for (int i = 0; i < pr.length; ++i) {
            priors.incrementCount(labelIndex.get(i), pr[i]);
            labelSet.add(labelIndex.get(i));
        }
        ClassicCounter weightsCounter = new ClassicCounter();
        double[][][] wts = nbWeights.weights;
        for (int c = 0; c < numClasses; ++c) {
            L label = labelIndex.get(c);
            for (int f = 0; f < numFeatures; ++f) {
                F feature = featureIndex.get(f);
                Pair<L, F> p = new Pair<L, F>(label, feature);
                for (int val = 0; val < wts[c][f].length; ++val) {
                    Pair<Pair<L, F>, Integer> key = new Pair<Pair<L, F>, Integer>(p, val);
                    weightsCounter.incrementCount(key, wts[c][f][val]);
                }
            }
        }
        return new NaiveBayesClassifier(weightsCounter, priors, labelSet);
    }

    @Override
    @Deprecated
    public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
        RVFDatum<L, F> d0 = examples.get(0);
        int numFeatures = d0.asFeatures().size();
        int[][] data = new int[examples.size()][numFeatures];
        int[] labels = new int[examples.size()];
        this.labelIndex = new HashIndex<L>();
        this.featureIndex = new HashIndex<F>();
        for (int d = 0; d < examples.size(); ++d) {
            RVFDatum<L, F> datum = examples.get(d);
            Counter<F> c = datum.asFeaturesCounter();
            for (F feature : c.keySet()) {
                int value;
                if (!this.featureIndex.add(feature)) continue;
                int fNo = this.featureIndex.indexOf(feature);
                data[d][fNo] = value = (int)c.getCount(feature);
            }
            this.labelIndex.add(datum.label());
            labels[d] = this.labelIndex.indexOf(datum.label());
        }
        int numClasses = this.labelIndex.size();
        return this.trainClassifier(data, labels, numFeatures, numClasses, this.labelIndex, this.featureIndex);
    }

    public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples, Set<F> featureSet) {
        int numFeatures = featureSet.size();
        int[][] data = new int[examples.size()][numFeatures];
        int[] labels = new int[examples.size()];
        this.labelIndex = new HashIndex<L>();
        this.featureIndex = new HashIndex<F>();
        for (F feat : featureSet) {
            this.featureIndex.add(feat);
        }
        for (int d = 0; d < examples.size(); ++d) {
            RVFDatum<L, F> datum = examples.get(d);
            Counter<F> c = datum.asFeaturesCounter();
            for (F feature : c.keySet()) {
                int value;
                int fNo = this.featureIndex.indexOf(feature);
                data[d][fNo] = value = (int)c.getCount(feature);
            }
            this.labelIndex.add(datum.label());
            labels[d] = this.labelIndex.indexOf(datum.label());
        }
        int numClasses = this.labelIndex.size();
        return this.trainClassifier(data, labels, numFeatures, numClasses, this.labelIndex, this.featureIndex);
    }

    private NBWeights trainWeights(int[][] data, int[] labels, int numFeatures, int numClasses) {
        if (this.kind == 0) {
            return this.trainWeightsJL(data, labels, numFeatures, numClasses);
        }
        if (this.kind == 2) {
            return this.trainWeightsUCL(data, labels, numFeatures, numClasses);
        }
        if (this.kind == 1) {
            return this.trainWeightsCL(data, labels, numFeatures, numClasses);
        }
        return null;
    }

    private NBWeights trainWeightsJL(int[][] data, int[] labels, int numFeatures, int numClasses) {
        int fno;
        int cl;
        int[] numValues = NaiveBayesClassifierFactory.numberValues(data, numFeatures);
        double[] priors = new double[numClasses];
        double[][][] weights = new double[numClasses][numFeatures][];
        for (cl = 0; cl < numClasses; ++cl) {
            for (fno = 0; fno < numFeatures; ++fno) {
                weights[cl][fno] = new double[numValues[fno]];
            }
        }
        for (int i = 0; i < data.length; ++i) {
            int n = labels[i];
            priors[n] = priors[n] + 1.0;
            for (fno = 0; fno < numFeatures; ++fno) {
                double[] dArray = weights[labels[i]][fno];
                int n2 = data[i][fno];
                dArray[n2] = dArray[n2] + 1.0;
            }
        }
        for (cl = 0; cl < numClasses; ++cl) {
            for (fno = 0; fno < numFeatures; ++fno) {
                for (int val = 0; val < numValues[fno]; ++val) {
                    weights[cl][fno][val] = Math.log((weights[cl][fno][val] + this.alphaFeature) / (priors[cl] + this.alphaFeature * (double)numValues[fno]));
                }
            }
            priors[cl] = Math.log((priors[cl] + this.alphaClass) / ((double)data.length + this.alphaClass * (double)numClasses));
        }
        return new NBWeights(priors, weights);
    }

    private NBWeights trainWeightsUCL(int[][] data, int[] labels, int numFeatures, int numClasses) {
        int[] numValues = NaiveBayesClassifierFactory.numberValues(data, numFeatures);
        int[] sumValues = new int[numFeatures];
        for (int j = 1; j < numFeatures; ++j) {
            sumValues[j] = sumValues[j - 1] + numValues[j - 1];
        }
        int[][] newdata = new int[data.length][numFeatures + 1];
        for (int i = 0; i < data.length; ++i) {
            newdata[i][0] = 0;
            for (int j = 0; j < numFeatures; ++j) {
                newdata[i][j + 1] = sumValues[j] + data[i][j] + 1;
            }
        }
        int totalFeatures = sumValues[numFeatures - 1] + numValues[numFeatures - 1] + 1;
        System.err.println("total feats " + totalFeatures);
        LogConditionalObjectiveFunction objective = new LogConditionalObjectiveFunction(totalFeatures, numClasses, newdata, labels, this.prior, this.sigma, 0.0);
        QNMinimizer min = new QNMinimizer();
        double[] argmin = min.minimize(objective, 1.0E-4, objective.initial());
        double[][] wts = objective.to2D(argmin);
        System.out.println("weights have dimension " + wts.length);
        return new NBWeights(wts, numValues);
    }

    private NBWeights trainWeightsCL(int[][] data, int[] labels, int numFeatures, int numClasses) {
        LogConditionalEqConstraintFunction objective = new LogConditionalEqConstraintFunction(numFeatures, numClasses, data, labels, this.prior, this.sigma, 0.0);
        QNMinimizer min = new QNMinimizer();
        double[] argmin = min.minimize(objective, 1.0E-4, objective.initial());
        double[][][] wts = objective.to3D(argmin);
        double[] priors = objective.priors(argmin);
        return new NBWeights(priors, wts);
    }

    static int[] numberValues(int[][] data, int numFeatures) {
        int[] numValues = new int[numFeatures];
        for (int i = 0; i < data.length; ++i) {
            for (int j = 0; j < data[i].length; ++j) {
                if (numValues[j] >= data[i][j] + 1) continue;
                numValues[j] = data[i][j] + 1;
            }
        }
        return numValues;
    }

    public static void main(String[] args) {
        float accTest;
        float accTrain;
        Classifier classifier;
        int j;
        String trainFile = args[0];
        String testFile = args[1];
        NominalDataReader nR = new NominalDataReader();
        HashMap<Integer, Index<String>> indices = new HashMap<Integer, Index<String>>();
        ArrayList<RVFDatum<String, Integer>> train = NominalDataReader.readData(trainFile, indices);
        ArrayList<RVFDatum<String, Integer>> test = NominalDataReader.readData(testFile, indices);
        System.out.println("Constrained conditional likelihood no prior :");
        for (j = 0; j < 100; ++j) {
            classifier = new NaiveBayesClassifierFactory(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), 1).trainClassifier(train);
            ((NaiveBayesClassifier)classifier).print();
            accTrain = ((NaiveBayesClassifier)classifier).accuracy(train.iterator());
            System.err.println("training accuracy " + accTrain);
            accTest = ((NaiveBayesClassifier)classifier).accuracy(test.iterator());
            System.err.println("test accuracy " + accTest);
        }
        System.out.println("Unconstrained conditional likelihood no prior :");
        for (j = 0; j < 100; ++j) {
            classifier = new NaiveBayesClassifierFactory(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), 2).trainClassifier(train);
            ((NaiveBayesClassifier)classifier).print();
            accTrain = ((NaiveBayesClassifier)classifier).accuracy(train.iterator());
            System.err.println("training accuracy " + accTrain);
            accTest = ((NaiveBayesClassifier)classifier).accuracy(test.iterator());
            System.err.println("test accuracy " + accTest);
        }
    }

    @Override
    public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
        if (dataset instanceof RVFDataset) {
            throw new RuntimeException("Not sure if RVFDataset runs correctly in this method. Please update this code if it does.");
        }
        return this.trainClassifier(dataset.getDataArray(), dataset.labels, dataset.numFeatures(), dataset.numClasses(), dataset.labelIndex, dataset.featureIndex);
    }

    static class NBWeights {
        double[] priors;
        double[][][] weights;

        NBWeights(double[] priors, double[][][] weights) {
            this.priors = priors;
            this.weights = weights;
        }

        NBWeights(double[][] wts, int[] numValues) {
            int numClasses = wts[0].length;
            this.priors = new double[numClasses];
            System.arraycopy(wts[0], 0, this.priors, 0, numClasses);
            int[] sumValues = new int[numValues.length];
            for (int j = 1; j < numValues.length; ++j) {
                sumValues[j] = sumValues[j - 1] + numValues[j - 1];
            }
            this.weights = new double[this.priors.length][sumValues.length][];
            for (int fno = 0; fno < numValues.length; ++fno) {
                for (int c = 0; c < numClasses; ++c) {
                    this.weights[c][fno] = new double[numValues[fno]];
                }
                for (int val = 0; val < numValues[fno]; ++val) {
                    int code = sumValues[fno] + val + 1;
                    for (int cls = 0; cls < numClasses; ++cls) {
                        this.weights[cls][fno][val] = wts[code][cls];
                    }
                }
            }
        }
    }
}

