/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.ClassifierFactory;
import edu.stanford.nlp.classify.CrossValidator;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.classify.SVMLightClassifier;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.GoldenSectionLineSearch;
import edu.stanford.nlp.optimization.LineSearcher;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.MultiClassAccuracyStats;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.SystemUtils;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.regex.Pattern;

public class SVMLightClassifierFactory<L, F>
implements ClassifierFactory<L, F, SVMLightClassifier<L, F>> {
    private static final long serialVersionUID = 1L;
    protected double C = -1.0;
    private boolean useSigmoid = false;
    protected boolean verbose = true;
    private String svmLightLearn = "/u/nlp/packages/svm_light/svm_learn";
    private String svmStructLearn = "/u/nlp/packages/svm_multiclass/svm_multiclass_learn";
    private String svmLightClassify = "/u/nlp/packages/svm_light/svm_classify";
    private String svmStructClassify = "/u/nlp/packages/svm_multiclass/svm_multiclass_classify";
    private boolean useAlphaFile = false;
    protected File alphaFile;
    private boolean deleteTempFilesOnExit = true;
    private int svmLightVerbosity = 0;
    private boolean doEval = false;
    private boolean tuneHeldOut = false;
    private boolean tuneCV = false;
    private Scorer<L> scorer = new MultiClassAccuracyStats();
    private LineSearcher tuneMinimizer = new GoldenSectionLineSearch(true);
    private int folds;
    private double heldOutPercent;
    Pattern whitespacePattern = Pattern.compile("\\s+");

    public SVMLightClassifierFactory(String svmLightLearn, String svmStructLearn) {
        this.svmLightLearn = svmLightLearn;
        this.svmStructLearn = svmStructLearn;
    }

    public SVMLightClassifierFactory() {
    }

    public void setC(double C) {
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setUseSigmoid(boolean useSigmoid) {
        this.useSigmoid = useSigmoid;
    }

    public boolean getUseSigma() {
        return this.useSigmoid;
    }

    public boolean getDeleteTempFilesOnExitFlag() {
        return this.deleteTempFilesOnExit;
    }

    public void setDeleteTempFilesOnExitFlag(boolean deleteTempFilesOnExit) {
        this.deleteTempFilesOnExit = deleteTempFilesOnExit;
    }

    private static Pair<Double, ClassicCounter<Integer>> readModel(File modelFile, boolean multiclass) {
        int modelLineCount = 0;
        try {
            int numLinesToSkip = multiclass ? 13 : 10;
            String stopToken = "#";
            BufferedReader in = new BufferedReader(new FileReader(modelFile));
            for (int i = 0; i < numLinesToSkip; ++i) {
                in.readLine();
                ++modelLineCount;
            }
            ArrayList<Pair<Double, ClassicCounter<Integer>>> supportVectors = new ArrayList<Pair<Double, ClassicCounter<Integer>>>();
            String thresholdLine = in.readLine();
            ++modelLineCount;
            String[] pieces = thresholdLine.split("\\s+");
            double threshold = Double.parseDouble(pieces[0]);
            while (in.ready()) {
                String piece;
                String svLine = in.readLine();
                ++modelLineCount;
                pieces = svLine.split("\\s+");
                double alpha = Double.parseDouble(pieces[0]);
                ClassicCounter<Integer> supportVector = new ClassicCounter<Integer>();
                for (int i = 1; i < pieces.length && !(piece = pieces[i]).equals(stopToken); ++i) {
                    String[] indexNum = piece.split(":");
                    String featureIndex = indexNum[0];
                    if (featureIndex.equals("qid")) continue;
                    double count = Double.parseDouble(indexNum[1]);
                    supportVector.incrementCount(Integer.valueOf(featureIndex), count);
                }
                supportVectors.add(new Pair(alpha, supportVector));
            }
            in.close();
            return new Pair<Double, ClassicCounter<Integer>>(threshold, SVMLightClassifierFactory.getWeights(supportVectors));
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("Error reading SVM model (line " + modelLineCount + " in file " + modelFile.getAbsolutePath() + ")");
        }
    }

    private static ClassicCounter<Integer> getWeights(List<Pair<Double, ClassicCounter<Integer>>> supportVectors) {
        ClassicCounter<Integer> weights = new ClassicCounter<Integer>();
        for (Pair<Double, ClassicCounter<Integer>> sv : supportVectors) {
            ClassicCounter c = new ClassicCounter(sv.second());
            Counters.multiplyInPlace(c, sv.first());
            Counters.addInPlace(weights, c);
        }
        return weights;
    }

    private ClassicCounter<Pair<F, L>> convertWeights(ClassicCounter<Integer> weights, Index<F> featureIndex, Index<L> labelIndex, boolean multiclass) {
        return multiclass ? this.convertSVMStructWeights(weights, featureIndex, labelIndex) : this.convertSVMLightWeights(weights, featureIndex, labelIndex);
    }

    private ClassicCounter<Pair<F, L>> convertSVMLightWeights(ClassicCounter<Integer> weights, Index<F> featureIndex, Index<L> labelIndex) {
        ClassicCounter<Pair<F, L>> newWeights = new ClassicCounter<Pair<F, L>>();
        for (int i : weights.keySet()) {
            F f = featureIndex.get(i - 1);
            double w = weights.getCount(i);
            newWeights.incrementCount(new Pair<F, L>(f, labelIndex.get(0)), w);
            newWeights.incrementCount(new Pair<F, L>(f, labelIndex.get(1)), -w);
        }
        return newWeights;
    }

    private ClassicCounter<Pair<F, L>> convertSVMStructWeights(ClassicCounter<Integer> weights, Index<F> featureIndex, Index<L> labelIndex) {
        int numFeatures = featureIndex.size();
        ClassicCounter<Pair<F, L>> newWeights = new ClassicCounter<Pair<F, L>>();
        for (int i : weights.keySet()) {
            L l = labelIndex.get((i - 1) / numFeatures);
            F f = featureIndex.get((i - 1) % numFeatures);
            double w = weights.getCount(i);
            newWeights.incrementCount(new Pair<F, L>(f, l), w);
        }
        return newWeights;
    }

    private LinearClassifier<L, L> fitSigmoid(SVMLightClassifier<L, F> classifier, GeneralDataset<L, F> dataset) {
        RVFDataset<L, L> plattDataset = new RVFDataset<L, L>();
        for (int i = 0; i < dataset.size(); ++i) {
            RVFDatum<L, F> d = dataset.getRVFDatum(i);
            Counter<L> scores = classifier.scoresOf((Datum<L, F>)d);
            scores.incrementCount(null);
            plattDataset.add(new RVFDatum<L, L>(scores, d.label()));
        }
        LinearClassifierFactory factory = new LinearClassifierFactory();
        factory.setPrior(new LogPrior(LogPrior.LogPriorType.NULL));
        return factory.trainClassifier((GeneralDataset)plattDataset);
    }

    public void crossValidateSetC(GeneralDataset<L, F> dataset, int numFolds, Scorer<L> scorer, LineSearcher minimizer) {
        System.out.println("in Cross Validate");
        this.useAlphaFile = true;
        boolean oldUseSigmoid = this.useSigmoid;
        this.useSigmoid = false;
        CrossValidator<L, F> crossValidator = new CrossValidator<L, F>(dataset, numFolds);
        Function<Triple, Double> score = fold -> {
            GeneralDataset trainSet = (GeneralDataset)fold.first();
            GeneralDataset devSet = (GeneralDataset)fold.second();
            this.alphaFile = (File)((CrossValidator.SavedState)fold.third()).state;
            SVMLightClassifier<L, F> classifier = this.trainClassifierBasic(trainSet);
            ((CrossValidator.SavedState)fold.third()).state = this.alphaFile;
            return scorer.score(classifier, devSet);
        };
        Function<Double, Double> negativeScorer = cToTry -> {
            this.C = cToTry;
            if (this.verbose) {
                System.out.print("C = " + cToTry + " ");
            }
            Double averageScore = crossValidator.computeAverage(score);
            if (this.verbose) {
                System.out.println(" -> average Score: " + averageScore);
            }
            return -averageScore.doubleValue();
        };
        this.C = minimizer.minimize(negativeScorer);
        this.useAlphaFile = false;
        this.useSigmoid = oldUseSigmoid;
    }

    public void heldOutSetC(GeneralDataset<L, F> train, double percentHeldOut, Scorer<L> scorer, LineSearcher minimizer) {
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> data = train.split(percentHeldOut);
        this.heldOutSetC(data.first(), data.second(), scorer, minimizer);
    }

    public void heldOutSetC(GeneralDataset<L, F> trainSet, GeneralDataset<L, F> devSet, Scorer<L> scorer, LineSearcher minimizer) {
        this.useAlphaFile = true;
        boolean oldUseSigmoid = this.useSigmoid;
        this.useSigmoid = false;
        Function<Double, Double> negativeScorer = cToTry -> {
            this.C = cToTry;
            SVMLightClassifier<L, F> classifier = this.trainClassifierBasic(trainSet);
            double score = scorer.score(classifier, devSet);
            return -score;
        };
        this.C = minimizer.minimize(negativeScorer);
        this.useAlphaFile = false;
        this.useSigmoid = oldUseSigmoid;
    }

    @Override
    @Deprecated
    public SVMLightClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
        return null;
    }

    public double getHeldOutPercent() {
        return this.heldOutPercent;
    }

    public void setHeldOutPercent(double heldOutPercent) {
        this.heldOutPercent = heldOutPercent;
    }

    public int getFolds() {
        return this.folds;
    }

    public void setFolds(int folds) {
        this.folds = folds;
    }

    public LineSearcher getTuneMinimizer() {
        return this.tuneMinimizer;
    }

    public void setTuneMinimizer(LineSearcher minimizer) {
        this.tuneMinimizer = minimizer;
    }

    public Scorer getScorer() {
        return this.scorer;
    }

    public void setScorer(Scorer<L> scorer) {
        this.scorer = scorer;
    }

    public boolean getTuneCV() {
        return this.tuneCV;
    }

    public void setTuneCV(boolean tuneCV) {
        this.tuneCV = tuneCV;
    }

    public boolean getTuneHeldOut() {
        return this.tuneHeldOut;
    }

    public void setTuneHeldOut(boolean tuneHeldOut) {
        this.tuneHeldOut = tuneHeldOut;
    }

    public int getSvmLightVerbosity() {
        return this.svmLightVerbosity;
    }

    public void setSvmLightVerbosity(int svmLightVerbosity) {
        this.svmLightVerbosity = svmLightVerbosity;
    }

    @Override
    public SVMLightClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
        if (this.tuneHeldOut) {
            this.heldOutSetC(dataset, this.heldOutPercent, this.scorer, this.tuneMinimizer);
        } else if (this.tuneCV) {
            this.crossValidateSetC(dataset, this.folds, this.scorer, this.tuneMinimizer);
        }
        return this.trainClassifierBasic(dataset);
    }

    public SVMLightClassifier<L, F> trainClassifierBasic(GeneralDataset<L, F> dataset) {
        Index<L> labelIndex = dataset.labelIndex();
        Index featureIndex = dataset.featureIndex;
        boolean multiclass = dataset.numClasses() > 2;
        try {
            File modelFile = File.createTempFile("svm-", ".model");
            if (this.deleteTempFilesOnExit) {
                modelFile.deleteOnExit();
            }
            File dataFile = File.createTempFile("svm-", ".data");
            if (this.deleteTempFilesOnExit) {
                dataFile.deleteOnExit();
            }
            PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
            dataset.printSVMLightFormat(pw);
            pw.close();
            String cmd = (multiclass ? this.svmStructLearn : this.svmLightLearn) + " -v " + this.svmLightVerbosity + " -m 400 ";
            if (this.C > 0.0) {
                cmd = cmd + " -c " + this.C + " ";
            }
            if (this.useAlphaFile) {
                File newAlphaFile = File.createTempFile("svm-", ".alphas");
                if (this.deleteTempFilesOnExit) {
                    newAlphaFile.deleteOnExit();
                }
                cmd = cmd + " -a " + newAlphaFile.getAbsolutePath();
                if (this.alphaFile != null) {
                    cmd = cmd + " -y " + this.alphaFile.getAbsolutePath();
                }
                this.alphaFile = newAlphaFile;
            }
            cmd = cmd + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath();
            if (this.verbose) {
                System.err.println("<< " + cmd + " >>");
            }
            SystemUtils.run(new ProcessBuilder(this.whitespacePattern.split(cmd)), new PrintWriter(System.err), new PrintWriter(System.err));
            if (this.doEval) {
                File predictFile = File.createTempFile("svm-", ".pred");
                if (this.deleteTempFilesOnExit) {
                    predictFile.deleteOnExit();
                }
                String evalCmd = (multiclass ? this.svmStructClassify : this.svmLightClassify) + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath() + " " + predictFile.getAbsolutePath();
                if (this.verbose) {
                    System.err.println("<< " + evalCmd + " >>");
                }
                SystemUtils.run(new ProcessBuilder(this.whitespacePattern.split(evalCmd)), new PrintWriter(System.err), new PrintWriter(System.err));
            }
            Pair<Double, ClassicCounter<Integer>> weightsAndThresh = SVMLightClassifierFactory.readModel(modelFile, multiclass);
            double threshold = weightsAndThresh.first();
            ClassicCounter weights = this.convertWeights(weightsAndThresh.second(), featureIndex, labelIndex, multiclass);
            ClassicCounter<L> thresholds = new ClassicCounter<L>();
            if (!multiclass) {
                thresholds.setCount(labelIndex.get(0), -threshold);
                thresholds.setCount(labelIndex.get(1), threshold);
            }
            SVMLightClassifier classifier = new SVMLightClassifier(weights, thresholds);
            if (this.doEval) {
                File predictFile = File.createTempFile("svm-", ".pred2");
                if (this.deleteTempFilesOnExit) {
                    predictFile.deleteOnExit();
                }
                PrintWriter pw2 = new PrintWriter(predictFile);
                NumberFormat nf = NumberFormat.getNumberInstance();
                nf.setMaximumFractionDigits(5);
                for (Datum datum : dataset) {
                    Counter scores = classifier.scoresOf(datum);
                    pw2.println(Counters.toString(scores, nf));
                }
                pw2.close();
            }
            if (this.useSigmoid) {
                if (this.verbose) {
                    System.out.print("fitting sigmoid...");
                }
                classifier.setPlatt(this.fitSigmoid(classifier, dataset));
                if (this.verbose) {
                    System.out.println("done");
                }
            }
            return classifier;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

