/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.AbstractLinearClassifierFactory;
import edu.stanford.nlp.classify.AdaptedGaussianPriorObjectiveFunction;
import edu.stanford.nlp.classify.BiasedLogConditionalObjectiveFunction;
import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.ClassifierCreator;
import edu.stanford.nlp.classify.CrossValidator;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.GeneralizedExpectationObjectiveFunction;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LogConditionalObjectiveFunction;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.classify.ProbabilisticClassifierCreator;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.classify.SemiSupervisedLogConditionalObjectiveFunction;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.CGMinimizer;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.GoldenSectionLineSearch;
import edu.stanford.nlp.optimization.HasEvaluators;
import edu.stanford.nlp.optimization.HybridMinimizer;
import edu.stanford.nlp.optimization.InefficientSGDMinimizer;
import edu.stanford.nlp.optimization.LineSearcher;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.optimization.SGDMinimizer;
import edu.stanford.nlp.optimization.SGDToQNMinimizer;
import edu.stanford.nlp.optimization.SMDMinimizer;
import edu.stanford.nlp.optimization.SQNMinimizer;
import edu.stanford.nlp.optimization.StochasticCalculateMethods;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.MultiClassAccuracyStats;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.Factory;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.util.List;
import java.util.function.DoubleUnaryOperator;
import java.util.function.ToDoubleFunction;

public class LinearClassifierFactory<L, F>
extends AbstractLinearClassifierFactory<L, F> {
    private static final long serialVersionUID = 7893768984379107397L;
    private double TOL;
    private int mem = 15;
    private boolean verbose = false;
    private LogPrior logPrior;
    private boolean tuneSigmaHeldOut = false;
    private boolean tuneSigmaCV = false;
    private int folds;
    private double min = 0.1;
    private double max = 10.0;
    private boolean retrainFromScratchAfterSigmaTuning = false;
    private Factory<Minimizer<DiffFunction>> minimizerCreator = null;
    private int evalIters = -1;
    private Evaluator[] evaluators;
    private static final Redwood.RedwoodChannels logger = Redwood.channels(LinearClassifierFactory.class);
    protected static final double[] sigmasToTry = new double[]{0.5, 1.0, 2.0, 4.0, 10.0, 20.0, 100.0};
    private LineSearcher heldOutSearcher;

    public LinearClassifierFactory() {
        this((Factory<Minimizer<DiffFunction>>)null);
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> min) {
        this(min, 1.0E-4, false);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min) {
        this(min, 1.0E-4, false);
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum) {
        this(min, tol, useSum, 1.0);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum) {
        this(min, tol, useSum, 1.0);
    }

    public LinearClassifierFactory(double tol, boolean useSum, double sigma) {
        this((Factory<Minimizer<DiffFunction>>)null, tol, useSum, sigma);
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, double sigma) {
        this(min, tol, useSum, LogPrior.LogPriorType.QUADRATIC.ordinal(), sigma);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, double sigma) {
        this(min, tol, useSum, LogPrior.LogPriorType.QUADRATIC.ordinal(), sigma);
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, int prior, double sigma) {
        this(min, tol, useSum, prior, sigma, 0.0);
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, int prior, double sigma) {
        this(min, tol, useSum, prior, sigma, 0.0);
    }

    public LinearClassifierFactory(double tol, boolean useSum, int prior, double sigma, double epsilon) {
        this((Factory<Minimizer<DiffFunction>>)null, tol, useSum, new LogPrior(prior, sigma, epsilon));
    }

    public LinearClassifierFactory(double tol, boolean useSum, int prior, double sigma, double epsilon, int mem) {
        this((Factory<Minimizer<DiffFunction>>)null, tol, useSum, new LogPrior(prior, sigma, epsilon));
        this.mem = mem;
    }

    public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, int prior, double sigma, double epsilon) {
        this(min, tol, useSum, new LogPrior(prior, sigma, epsilon));
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, int prior, double sigma, double epsilon) {
        this(min, tol, useSum, new LogPrior(prior, sigma, epsilon));
    }

    public LinearClassifierFactory(final Minimizer<DiffFunction> min, double tol, boolean useSum, LogPrior logPrior) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = -6439748445540743949L;

            @Override
            public Minimizer<DiffFunction> create() {
                return min;
            }
        };
        this.TOL = tol;
        this.logPrior = logPrior;
    }

    public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> minimizerCreator, double tol, boolean useSum, LogPrior logPrior) {
        this.minimizerCreator = minimizerCreator == null ? new QNFactory() : minimizerCreator;
        this.TOL = tol;
        this.logPrior = logPrior;
    }

    public void setTol(double tol) {
        this.TOL = tol;
    }

    public void setPrior(LogPrior logPrior) {
        this.logPrior = logPrior;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setMinimizerCreator(Factory<Minimizer<DiffFunction>> minimizerCreator) {
        this.minimizerCreator = minimizerCreator;
    }

    public void setEpsilon(double eps) {
        this.logPrior.setEpsilon(eps);
    }

    public void setSigma(double sigma) {
        this.logPrior.setSigma(sigma);
    }

    public double getSigma() {
        return this.logPrior.getSigma();
    }

    public void useQuasiNewton() {
        this.minimizerCreator = new QNFactory();
    }

    public void useQuasiNewton(final boolean useRobust) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = -9108222058357693242L;

            @Override
            public Minimizer<DiffFunction> create() {
                QNMinimizer qnMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem, useRobust);
                if (!LinearClassifierFactory.this.verbose) {
                    qnMinimizer.shutUp();
                }
                return qnMinimizer;
            }
        };
    }

    public void useStochasticQN(final double initialSMDGain, final int stochasticBatchSize) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = -7760753348350678588L;

            @Override
            public Minimizer<DiffFunction> create() {
                SQNMinimizer<DiffFunction> sqnMinimizer = new SQNMinimizer<DiffFunction>(LinearClassifierFactory.this.mem, initialSMDGain, stochasticBatchSize, false);
                if (!LinearClassifierFactory.this.verbose) {
                    sqnMinimizer.shutUp();
                }
                return sqnMinimizer;
            }
        };
    }

    public void useStochasticMetaDescent() {
        this.useStochasticMetaDescent(0.1, 15, StochasticCalculateMethods.ExternalFiniteDifference, 20);
    }

    public void useStochasticMetaDescent(final double initialSMDGain, final int stochasticBatchSize, final StochasticCalculateMethods stochasticMethod, final int passes) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = 6860437108371914482L;

            @Override
            public Minimizer<DiffFunction> create() {
                SMDMinimizer<DiffFunction> smdMinimizer = new SMDMinimizer<DiffFunction>(initialSMDGain, stochasticBatchSize, stochasticMethod, passes);
                if (!LinearClassifierFactory.this.verbose) {
                    smdMinimizer.shutUp();
                }
                return smdMinimizer;
            }
        };
    }

    public void useStochasticGradientDescent() {
        this.useStochasticGradientDescent(0.1, 15);
    }

    public void useStochasticGradientDescent(final double gainSGD, final int stochasticBatchSize) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = 2564615420955196299L;

            @Override
            public Minimizer<DiffFunction> create() {
                InefficientSGDMinimizer<DiffFunction> sgdMinimizer = new InefficientSGDMinimizer<DiffFunction>(gainSGD, stochasticBatchSize);
                if (!LinearClassifierFactory.this.verbose) {
                    sgdMinimizer.shutUp();
                }
                return sgdMinimizer;
            }
        };
    }

    public void useInPlaceStochasticGradientDescent() {
        this.useInPlaceStochasticGradientDescent(-1, -1, 1.0);
    }

    public void useInPlaceStochasticGradientDescent(final int SGDPasses, final int tuneSampleSize, final double sigma) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = -5319225231759162616L;

            @Override
            public Minimizer<DiffFunction> create() {
                SGDMinimizer<DiffFunction> sgdMinimizer = new SGDMinimizer<DiffFunction>(sigma, SGDPasses, tuneSampleSize);
                if (!LinearClassifierFactory.this.verbose) {
                    sgdMinimizer.shutUp();
                }
                return sgdMinimizer;
            }
        };
    }

    public void useHybridMinimizerWithInPlaceSGD(final int SGDPasses, final int tuneSampleSize, final double sigma) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = -3042400543337763144L;

            @Override
            public Minimizer<DiffFunction> create() {
                SGDMinimizer<DiffFunction> firstMinimizer = new SGDMinimizer<DiffFunction>(sigma, SGDPasses, tuneSampleSize);
                QNMinimizer secondMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem);
                if (!LinearClassifierFactory.this.verbose) {
                    firstMinimizer.shutUp();
                    secondMinimizer.shutUp();
                }
                return new HybridMinimizer(firstMinimizer, secondMinimizer, SGDPasses);
            }
        };
    }

    public void useStochasticGradientDescentToQuasiNewton(final double SGDGain, final int batchSize, final int sgdPasses, final int qnPasses, final int hessSamples, final int QNMem, final boolean outputToFile) {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = 5823852936137599566L;

            @Override
            public Minimizer<DiffFunction> create() {
                SGDToQNMinimizer sgdToQNMinimizer = new SGDToQNMinimizer(SGDGain, batchSize, sgdPasses, qnPasses, hessSamples, QNMem, outputToFile);
                if (!LinearClassifierFactory.this.verbose) {
                    sgdToQNMinimizer.shutUp();
                }
                return sgdToQNMinimizer;
            }
        };
    }

    public void useHybridMinimizer() {
        this.useHybridMinimizer(0.1, 15, StochasticCalculateMethods.ExternalFiniteDifference, 0);
    }

    public void useHybridMinimizer(double initialSMDGain, int stochasticBatchSize, StochasticCalculateMethods stochasticMethod, int cutoffIteration) {
        this.minimizerCreator = () -> {
            SMDMinimizer<DiffFunction> firstMinimizer = new SMDMinimizer<DiffFunction>(initialSMDGain, stochasticBatchSize, stochasticMethod, cutoffIteration);
            QNMinimizer secondMinimizer = new QNMinimizer(this.mem);
            if (!this.verbose) {
                firstMinimizer.shutUp();
                secondMinimizer.shutUp();
            }
            return new HybridMinimizer(firstMinimizer, secondMinimizer, cutoffIteration);
        };
    }

    public void setMem(int mem) {
        this.mem = mem;
    }

    public void useConjugateGradientAscent(boolean verbose) {
        this.verbose = verbose;
        this.useConjugateGradientAscent();
    }

    public void useConjugateGradientAscent() {
        this.minimizerCreator = new Factory<Minimizer<DiffFunction>>(){
            private static final long serialVersionUID = -561168861131879990L;

            @Override
            public Minimizer<DiffFunction> create() {
                return new CGMinimizer(!LinearClassifierFactory.this.verbose);
            }
        };
    }

    public void setUseSum(boolean useSum) {
    }

    private Minimizer<DiffFunction> getMinimizer() {
        Minimizer<DiffFunction> minimizer = this.minimizerCreator.create();
        if (minimizer instanceof HasEvaluators) {
            ((HasEvaluators)((Object)minimizer)).setEvaluators(this.evalIters, this.evaluators);
        }
        return minimizer;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public double[][] adaptWeights(double[][] origWeights, GeneralDataset<L, F> adaptDataset) {
        Minimizer<DiffFunction> minimizer = this.getMinimizer();
        logger.info("adaptWeights in LinearClassifierFactory. increase weight dim only");
        double[][] newWeights = new double[adaptDataset.featureIndex.size()][adaptDataset.labelIndex.size()];
        Class<System> clazz = System.class;
        synchronized (System.class) {
            System.arraycopy(origWeights, 0, newWeights, 0, origWeights.length);
            // ** MonitorExit[var5_5] (shouldn't be in output)
            AdaptedGaussianPriorObjectiveFunction<L, F> objective = new AdaptedGaussianPriorObjectiveFunction<L, F>(adaptDataset, this.logPrior, newWeights);
            double[] initial = objective.initial();
            double[] weights = minimizer.minimize(objective, this.TOL, initial);
            return objective.to2D(weights);
        }
    }

    @Override
    public double[][] trainWeights(GeneralDataset<L, F> dataset) {
        return this.trainWeights(dataset, null);
    }

    public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial) {
        return this.trainWeights(dataset, initial, false);
    }

    public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial, boolean bypassTuneSigma) {
        Minimizer<DiffFunction> minimizer = this.getMinimizer();
        if (dataset instanceof RVFDataset) {
            ((RVFDataset)dataset).ensureRealValues();
        }
        double[] interimWeights = null;
        if (!bypassTuneSigma) {
            if (this.tuneSigmaHeldOut) {
                interimWeights = this.heldOutSetSigma(dataset);
            } else if (this.tuneSigmaCV) {
                this.crossValidateSetSigma(dataset, this.folds);
            }
        }
        LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(dataset, this.logPrior);
        if (initial == null && interimWeights != null && !this.retrainFromScratchAfterSigmaTuning) {
            initial = interimWeights;
        }
        if (initial == null) {
            initial = objective.initial();
        }
        double[] weights = minimizer.minimize(objective, this.TOL, initial);
        return objective.to2D(weights);
    }

    public Classifier<L, F> trainClassifierSemiSup(GeneralDataset<L, F> data, GeneralDataset<L, F> biasedData, double[][] confusionMatrix, double[] initial) {
        double[][] weights = this.trainWeightsSemiSup(data, biasedData, confusionMatrix, initial);
        LinearClassifier<L, F> classifier = new LinearClassifier<L, F>(weights, data.featureIndex(), data.labelIndex());
        return classifier;
    }

    public double[][] trainWeightsSemiSup(GeneralDataset<L, F> data, GeneralDataset<L, F> biasedData, double[][] confusionMatrix, double[] initial) {
        Minimizer<DiffFunction> minimizer = this.getMinimizer();
        LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(data, new LogPrior(LogPrior.LogPriorType.NULL));
        BiasedLogConditionalObjectiveFunction biasedObjective = new BiasedLogConditionalObjectiveFunction(biasedData, confusionMatrix, new LogPrior(LogPrior.LogPriorType.NULL));
        SemiSupervisedLogConditionalObjectiveFunction semiSupObjective = new SemiSupervisedLogConditionalObjectiveFunction(objective, biasedObjective, this.logPrior);
        if (initial == null) {
            initial = objective.initial();
        }
        double[] weights = minimizer.minimize(semiSupObjective, this.TOL, initial);
        return objective.to2D(weights);
    }

    public LinearClassifier<L, F> trainSemiSupGE(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList, List<F> GEFeatures, double convexComboCoeff) {
        Minimizer<DiffFunction> minimizer = this.getMinimizer();
        LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(labeledDataset, new LogPrior(LogPrior.LogPriorType.NULL));
        GeneralizedExpectationObjectiveFunction<L, F> geObjective = new GeneralizedExpectationObjectiveFunction<L, F>(labeledDataset, unlabeledDataList, GEFeatures);
        SemiSupervisedLogConditionalObjectiveFunction semiSupObjective = new SemiSupervisedLogConditionalObjectiveFunction(objective, geObjective, null, convexComboCoeff);
        double[] initial = objective.initial();
        double[] weights = minimizer.minimize(semiSupObjective, this.TOL, initial);
        return new LinearClassifier<L, F>(objective.to2D(weights), labeledDataset.featureIndex(), labeledDataset.labelIndex());
    }

    public LinearClassifier<L, F> trainSemiSupGE(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList) {
        List<F> GEFeatures = this.getHighPrecisionFeatures(labeledDataset, 0.9, 10);
        return this.trainSemiSupGE(labeledDataset, unlabeledDataList, GEFeatures, 0.5);
    }

    public LinearClassifier<L, F> trainSemiSupGE(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList, double convexComboCoeff) {
        List<F> GEFeatures = this.getHighPrecisionFeatures(labeledDataset, 0.9, 10);
        return this.trainSemiSupGE(labeledDataset, unlabeledDataList, GEFeatures, convexComboCoeff);
    }

    private List<F> getHighPrecisionFeatures(GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) {
        int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()];
        int[][] data = dataset.data;
        int[] labels = dataset.labels;
        for (int d = 0; d < data.length; ++d) {
            int label = labels[d];
            if (data[d] == null) continue;
            for (int n = 0; n < data[d].length; ++n) {
                int[] nArray = feature2label[data[d][n]];
                int n2 = label;
                nArray[n2] = nArray[n2] + 1;
            }
        }
        ClassicCounter feature2freq = new ClassicCounter();
        for (int f = 0; f < dataset.numFeatures(); ++f) {
            int maxF = ArrayMath.max(feature2label[f]);
            int total = ArrayMath.sum(feature2label[f]);
            double precision = (double)maxF / (double)total;
            Object feature = dataset.featureIndex.get(f);
            if (!(precision >= minPrecision)) continue;
            feature2freq.incrementCount(feature, total);
        }
        if (feature2freq.size() > maxNumFeatures) {
            Counters.retainTop(feature2freq, maxNumFeatures);
        }
        return Counters.toSortedList(feature2freq);
    }

    public LinearClassifier<L, F> trainClassifierV(GeneralDataset<L, F> train, GeneralDataset<L, F> validation, double min, double max, boolean accuracy) {
        this.labelIndex = train.labelIndex();
        this.featureIndex = train.featureIndex();
        this.min = min;
        this.max = max;
        this.heldOutSetSigma(train, validation);
        double[][] weights = this.trainWeights(train);
        return new LinearClassifier<L, F>(weights, train.featureIndex(), train.labelIndex());
    }

    public LinearClassifier<L, F> trainClassifierV(GeneralDataset<L, F> train, double min, double max, boolean accuracy) {
        this.labelIndex = train.labelIndex();
        this.featureIndex = train.featureIndex();
        this.tuneSigmaHeldOut = true;
        this.min = min;
        this.max = max;
        this.heldOutSetSigma(train);
        double[][] weights = this.trainWeights(train);
        return new LinearClassifier<L, F>(weights, train.featureIndex(), train.labelIndex());
    }

    public void setTuneSigmaHeldOut() {
        this.tuneSigmaHeldOut = true;
        this.tuneSigmaCV = false;
    }

    public void setTuneSigmaCV(int folds) {
        this.tuneSigmaCV = true;
        this.tuneSigmaHeldOut = false;
        this.folds = folds;
    }

    public void resetWeight() {
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> dataset) {
        this.crossValidateSetSigma(dataset, 5);
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> dataset, int kfold) {
        logger.info("##you are here.");
        this.crossValidateSetSigma(dataset, kfold, new MultiClassAccuracyStats(2), new GoldenSectionLineSearch(true, 0.01, this.min, this.max));
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> dataset, int kfold, Scorer<L> scorer) {
        this.crossValidateSetSigma(dataset, kfold, scorer, new GoldenSectionLineSearch(true, 0.01, this.min, this.max));
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> dataset, int kfold, LineSearcher minimizer) {
        this.crossValidateSetSigma(dataset, kfold, new MultiClassAccuracyStats(2), minimizer);
    }

    public void crossValidateSetSigma(GeneralDataset<L, F> dataset, int kfold, Scorer<L> scorer, LineSearcher minimizer) {
        logger.info("##in Cross Validate, folds = " + kfold);
        logger.info("##Scorer is " + scorer);
        this.featureIndex = dataset.featureIndex;
        this.labelIndex = dataset.labelIndex;
        CrossValidator<L, F> crossValidator = new CrossValidator<L, F>(dataset, kfold);
        ToDoubleFunction<Triple> scoreFn = fold -> {
            GeneralDataset trainSet = (GeneralDataset)fold.first();
            GeneralDataset devSet = (GeneralDataset)fold.second();
            double[] weights = (double[])((CrossValidator.SavedState)fold.third()).state;
            double[][] weights2D = this.trainWeights(trainSet, weights, true);
            ((CrossValidator.SavedState)fold.third()).state = ArrayUtils.flatten(weights2D);
            LinearClassifier classifier = new LinearClassifier(weights2D, trainSet.featureIndex, trainSet.labelIndex);
            double score = scorer.score(classifier, devSet);
            System.out.print(".");
            return score;
        };
        DoubleUnaryOperator negativeScorer = sigmaToTry -> {
            this.setSigma(sigmaToTry);
            Double averageScore = crossValidator.computeAverage(scoreFn);
            logger.info("##sigma = " + this.getSigma() + " -> average Score: " + averageScore);
            return -averageScore.doubleValue();
        };
        double bestSigma = minimizer.minimize(negativeScorer);
        logger.info("##best sigma: " + bestSigma);
        this.setSigma(bestSigma);
    }

    public void setHeldOutSearcher(LineSearcher heldOutSearcher) {
        this.heldOutSearcher = heldOutSearcher;
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> train) {
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> data = train.split(0.3);
        return this.heldOutSetSigma(data.first(), data.second());
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> train, Scorer<L> scorer) {
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> data = train.split(0.3);
        return this.heldOutSetSigma(data.first(), data.second(), scorer);
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev) {
        return this.heldOutSetSigma(train, dev, new MultiClassAccuracyStats(2), this.heldOutSearcher == null ? new GoldenSectionLineSearch(true, 0.01, this.min, this.max) : this.heldOutSearcher);
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, Scorer<L> scorer) {
        return this.heldOutSetSigma(train, dev, scorer, new GoldenSectionLineSearch(true, 0.01, this.min, this.max));
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, LineSearcher minimizer) {
        return this.heldOutSetSigma(train, dev, new MultiClassAccuracyStats(2), minimizer);
    }

    public double[] heldOutSetSigma(GeneralDataset<L, F> trainSet, GeneralDataset<L, F> devSet, Scorer<L> scorer, LineSearcher minimizer) {
        this.featureIndex = trainSet.featureIndex;
        this.labelIndex = trainSet.labelIndex;
        Timing timer = new Timing();
        NegativeScorer negativeScorer = new NegativeScorer(trainSet, devSet, scorer, timer);
        timer.start();
        double bestSigma = minimizer.minimize(negativeScorer);
        logger.info("##best sigma: " + bestSigma);
        this.setSigma(bestSigma);
        return ArrayUtils.flatten(this.trainWeights(trainSet, negativeScorer.weights, true));
    }

    public void setRetrainFromScratchAfterSigmaTuning(boolean retrainFromScratchAfterSigmaTuning) {
        this.retrainFromScratchAfterSigmaTuning = retrainFromScratchAfterSigmaTuning;
    }

    @Override
    public Classifier<L, F> trainClassifier(Iterable<Datum<L, F>> dataIterable) {
        Minimizer<DiffFunction> minimizer = this.getMinimizer();
        Index featureIndex = Generics.newIndex();
        Index labelIndex = Generics.newIndex();
        for (Datum<L, F> d : dataIterable) {
            labelIndex.add(d.label());
            featureIndex.addAll(d.asFeatures());
        }
        logger.info(String.format("Training linear classifier with %d features and %d labels", featureIndex.size(), labelIndex.size()));
        LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(dataIterable, this.logPrior, featureIndex, labelIndex);
        double[] initial = objective.initial();
        double[] weights = minimizer.minimize(objective, this.TOL, initial);
        LinearClassifier classifier = new LinearClassifier(objective.to2D(weights), featureIndex, labelIndex);
        return classifier;
    }

    public Classifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, float[] dataWeights, LogPrior prior) {
        Minimizer<DiffFunction> minimizer = this.getMinimizer();
        if (dataset instanceof RVFDataset) {
            ((RVFDataset)dataset).ensureRealValues();
        }
        LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(dataset, dataWeights, prior);
        double[] initial = objective.initial();
        double[] weights = minimizer.minimize(objective, this.TOL, initial);
        LinearClassifier<L, F> classifier = new LinearClassifier<L, F>(objective.to2D(weights), dataset.featureIndex(), dataset.labelIndex());
        return classifier;
    }

    @Override
    public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
        return this.trainClassifier(dataset, null);
    }

    public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, double[] initial) {
        if (dataset instanceof RVFDataset) {
            ((RVFDataset)dataset).ensureRealValues();
        }
        if (initial != null) {
            for (double weight : initial) {
                if (!Double.isNaN(weight) && !Double.isInfinite(weight)) continue;
                throw new IllegalArgumentException("Initial weights are invalid!");
            }
        }
        double[][] weights = this.trainWeights(dataset, initial, false);
        LinearClassifier<L, F> classifier = new LinearClassifier<L, F>(weights, dataset.featureIndex(), dataset.labelIndex());
        return classifier;
    }

    public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> dataset, double[][] initialWeights2D) {
        double[] initialWeights = initialWeights2D != null ? ArrayUtils.flatten(initialWeights2D) : null;
        return this.trainClassifier(dataset, initialWeights);
    }

    public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> dataset, LinearClassifier<L, F> initialClassifier) {
        double[][] initialWeights2D = initialClassifier != null ? initialClassifier.weights() : (double[][])null;
        return this.trainClassifierWithInitialWeights(dataset, initialWeights2D);
    }

    public static LinearClassifier<String, String> loadFromFilename(String file) {
        try {
            BufferedReader in = IOUtils.readerFromString(file);
            Index<String> labelIndex = HashIndex.loadFromReader(in);
            Index<String> featureIndex = HashIndex.loadFromReader(in);
            double[][] weights = new double[featureIndex.size()][labelIndex.size()];
            int currLine = 1;
            String line = in.readLine();
            while (line != null && line.length() > 0) {
                double value;
                String[] tuples = line.split("\t");
                if (tuples.length != 3) {
                    throw new Exception("Error: incorrect number of tokens in weight specifier, line=" + currLine + " in file " + file);
                }
                ++currLine;
                int feature = Integer.parseInt(tuples[0]);
                int label = Integer.parseInt(tuples[1]);
                weights[feature][label] = value = Double.parseDouble(tuples[2]);
                line = in.readLine();
            }
            int numThresholds = Integer.parseInt(in.readLine());
            double[] thresholds = new double[numThresholds];
            int curr = 0;
            while ((line = in.readLine()) != null) {
                double tval = Double.parseDouble(line.trim());
                thresholds[curr++] = tval;
            }
            in.close();
            LinearClassifier<String, String> classifier = new LinearClassifier<String, String>(weights, featureIndex, labelIndex);
            return classifier;
        }
        catch (Exception e) {
            throw new RuntimeIOException("Error in LinearClassifierFactory, loading from file=" + file, e);
        }
    }

    public void setEvaluators(int iters, Evaluator[] evaluators) {
        this.evalIters = iters;
        this.evaluators = evaluators;
    }

    public LinearClassifierCreator<L, F> getClassifierCreator(GeneralDataset<L, F> dataset) {
        return new LinearClassifierCreator(dataset.featureIndex, dataset.labelIndex);
    }

    public static class LinearClassifierCreator<L, F>
    implements ClassifierCreator,
    ProbabilisticClassifierCreator {
        LogConditionalObjectiveFunction objective;
        Index<F> featureIndex;
        Index<L> labelIndex;

        public LinearClassifierCreator(LogConditionalObjectiveFunction objective, Index<F> featureIndex, Index<L> labelIndex) {
            this.objective = objective;
            this.featureIndex = featureIndex;
            this.labelIndex = labelIndex;
        }

        public LinearClassifierCreator(Index<F> featureIndex, Index<L> labelIndex) {
            this.featureIndex = featureIndex;
            this.labelIndex = labelIndex;
        }

        public LinearClassifier createLinearClassifier(double[] weights) {
            double[][] weights2D = this.objective != null ? this.objective.to2D(weights) : ArrayUtils.to2D(weights, this.featureIndex.size(), this.labelIndex.size());
            return new LinearClassifier<L, F>(weights2D, this.featureIndex, this.labelIndex);
        }

        public Classifier createClassifier(double[] weights) {
            return this.createLinearClassifier(weights);
        }

        public ProbabilisticClassifier createProbabilisticClassifier(double[] weights) {
            return this.createLinearClassifier(weights);
        }
    }

    class NegativeScorer
    implements DoubleUnaryOperator {
        public double[] weights;
        GeneralDataset<L, F> trainSet;
        GeneralDataset<L, F> devSet;
        Scorer<L> scorer;
        Timing timer;

        public NegativeScorer(GeneralDataset<L, F> trainSet, GeneralDataset<L, F> devSet, Scorer<L> scorer, Timing timer) {
            this.trainSet = trainSet;
            this.devSet = devSet;
            this.scorer = scorer;
            this.timer = timer;
        }

        @Override
        public double applyAsDouble(double sigmaToTry) {
            LinearClassifierFactory.this.setSigma(sigmaToTry);
            double[][] weights2D = LinearClassifierFactory.this.trainWeights(this.trainSet, this.weights, true);
            this.weights = ArrayUtils.flatten(weights2D);
            LinearClassifier classifier = new LinearClassifier(weights2D, this.trainSet.featureIndex, this.trainSet.labelIndex);
            double score = this.scorer.score(classifier, this.devSet);
            logger.info("##sigma = " + LinearClassifierFactory.this.getSigma() + " -> average Score: " + score);
            logger.info("##time elapsed: " + this.timer.stop() + " milliseconds.");
            this.timer.restart();
            return -score;
        }
    }

    private class QNFactory
    implements Factory<Minimizer<DiffFunction>> {
        private static final long serialVersionUID = 9028306475652690036L;

        private QNFactory() {
        }

        @Override
        public Minimizer<DiffFunction> create() {
            QNMinimizer qnMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem);
            if (!LinearClassifierFactory.this.verbose) {
                qnMinimizer.shutUp();
            }
            return qnMinimizer;
        }
    }
}

