/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.optimization.HasEvaluators;
import edu.stanford.nlp.optimization.HasFeatureGrouping;
import edu.stanford.nlp.optimization.HasRegularizerParamRange;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

public class SGDWithAdaGradAndFOBOS<T extends Function>
implements Minimizer<T>,
HasEvaluators {
    protected double[] x;
    protected double initRate;
    protected double lambda;
    protected double alpha = 1.0;
    protected boolean quiet = false;
    private static final int DEFAULT_NUM_PASSES = 50;
    protected final int numPasses;
    protected int bSize = 1;
    private static final int DEFAULT_TUNING_SAMPLES = Integer.MAX_VALUE;
    private static final int DEFAULT_BATCH_SIZE = 1000;
    private final double eps = 0.001;
    protected Random gen = new Random(1L);
    protected long maxTime = Long.MAX_VALUE;
    private int evaluateIters = 0;
    private Evaluator[] evaluators;
    private Prior prior = Prior.LASSO;
    private boolean useEvalImprovement = false;
    private boolean suppressTestPrompt = false;
    private int terminateOnEvalImprovementNumOfEpoch = 1;
    private double bestEvalSoFar = Double.NEGATIVE_INFINITY;
    private double[] xBest;
    private int noImproveItrCount = 0;
    private static final NumberFormat nf = new DecimalFormat("0.000E0");

    public void terminateOnEvalImprovement(boolean toTerminate) {
        this.useEvalImprovement = toTerminate;
    }

    public void suppressTestPrompt(boolean suppressTestPrompt) {
        this.suppressTestPrompt = suppressTestPrompt;
    }

    public void setTerminateOnEvalImprovementNumOfEpoch(int terminateOnEvalImprovementNumOfEpoch) {
        this.terminateOnEvalImprovementNumOfEpoch = terminateOnEvalImprovementNumOfEpoch;
    }

    public boolean toContinue(double[] x, double currEval) {
        if (currEval >= this.bestEvalSoFar) {
            this.bestEvalSoFar = currEval;
            this.noImproveItrCount = 0;
            if (this.xBest == null) {
                this.xBest = Arrays.copyOf(x, x.length);
            } else {
                System.arraycopy(x, 0, this.xBest, 0, x.length);
            }
            return true;
        }
        ++this.noImproveItrCount;
        return this.noImproveItrCount <= this.terminateOnEvalImprovementNumOfEpoch;
    }

    private static Prior getPrior(String priorType) {
        if (priorType.equals("none")) {
            return Prior.NONE;
        }
        if (priorType.equals("lasso")) {
            return Prior.LASSO;
        }
        if (priorType.equals("ridge")) {
            return Prior.RIDGE;
        }
        if (priorType.equals("gaussian")) {
            return Prior.GAUSSIAN;
        }
        if (priorType.equals("ae-lasso")) {
            return Prior.aeLASSO;
        }
        if (priorType.equals("g-lasso")) {
            return Prior.gLASSO;
        }
        if (priorType.equals("sg-lasso")) {
            return Prior.sgLASSO;
        }
        throw new IllegalArgumentException("prior type " + priorType + " not recognized; supported priors " + "are: lasso, ridge, gaussian, ae-lasso, g-lasso, and sg-lasso");
    }

    public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses) {
        this(initRate, lambda, numPasses, -1);
    }

    public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses, int batchSize) {
        this(initRate, lambda, numPasses, batchSize, "lasso", 1.0);
    }

    public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses, int batchSize, String priorType, double alpha) {
        this.initRate = initRate;
        this.prior = SGDWithAdaGradAndFOBOS.getPrior(priorType);
        this.bSize = batchSize;
        this.lambda = lambda;
        this.alpha = alpha;
        if (numPasses >= 0) {
            this.numPasses = numPasses;
        } else {
            this.numPasses = 50;
            this.sayln("  SGDWithAdaGradAndFOBOS: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    protected String getName() {
        return "SGDWithAdaGradAndFOBOS" + this.bSize + "_lambda" + nf.format(this.lambda) + "_alpha" + nf.format(this.alpha);
    }

    @Override
    public void setEvaluators(int iters, Evaluator[] evaluators) {
        this.evaluateIters = iters;
        this.evaluators = evaluators;
    }

    private static double getNorm(double[] w) {
        double norm = 0.0;
        for (int i = 0; i < w.length; ++i) {
            norm += w[i] * w[i];
        }
        return Math.sqrt(norm);
    }

    private double doEvaluation(double[] x) {
        if (this.evaluators == null) {
            return Double.NEGATIVE_INFINITY;
        }
        double score = Double.NEGATIVE_INFINITY;
        for (Evaluator eval : this.evaluators) {
            double aScore;
            if (!this.suppressTestPrompt) {
                this.sayln("  Evaluating: " + eval.toString());
            }
            if ((aScore = eval.evaluate(x)) == Double.NEGATIVE_INFINITY) continue;
            score = aScore;
        }
        return score;
    }

    private static double pospart(double number) {
        return number > 0.0 ? number : 0.0;
    }

    @Override
    public double[] minimize(Function function, double functionTolerance, double[] initial) {
        return this.minimize(function, functionTolerance, initial, -1);
    }

    @Override
    public double[] minimize(Function f, double functionTolerance, double[] initial, int maxIterations) {
        boolean have_max;
        int totalSamples = 0;
        this.sayln("Using lambda=" + this.lambda);
        if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
            AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction)f;
            func.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Shuffled;
            totalSamples = func.dataDimension();
            if (this.bSize > totalSamples) {
                System.err.println("WARNING: Total number of samples=" + totalSamples + " is smaller than requested batch size=" + this.bSize + "!!!");
                this.bSize = totalSamples;
                this.sayln("Using batch size=" + this.bSize);
            }
            if (this.bSize <= 0) {
                System.err.println("WARNING: Requested batch size=" + this.bSize + " <= 0 !!!");
                this.bSize = totalSamples;
                this.sayln("Using batch size=" + this.bSize);
            }
        }
        this.x = new double[initial.length];
        double[] testUpdateCache = null;
        double[] currentRateCache = null;
        double[] bCache = null;
        double[] sumGradSquare = new double[initial.length];
        int[][] featureGrouping = null;
        if (this.prior != Prior.LASSO && this.prior != Prior.NONE) {
            testUpdateCache = new double[initial.length];
            currentRateCache = new double[initial.length];
        }
        if (this.prior != Prior.LASSO && this.prior != Prior.RIDGE && this.prior != Prior.GAUSSIAN) {
            if (!(f instanceof HasFeatureGrouping)) {
                throw new UnsupportedOperationException("prior is specified to be ae-lasso or g-lasso, but function does not support feature grouping");
            }
            featureGrouping = ((HasFeatureGrouping)((Object)f)).getFeatureGrouping();
        }
        if (this.prior == Prior.sgLASSO) {
            bCache = new double[initial.length];
        }
        System.arraycopy(initial, 0, this.x, 0, this.x.length);
        int numBatches = 1;
        if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
            numBatches = totalSamples / this.bSize;
        }
        boolean bl = have_max = maxIterations > 0 || this.numPasses > 0;
        if (!have_max) {
            throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
        }
        maxIterations = Math.max(maxIterations, this.numPasses) * numBatches;
        this.sayln("       Batch size of: " + this.bSize);
        this.sayln("       Data dimension of: " + totalSamples);
        this.sayln("       Batches per pass through data:  " + numBatches);
        this.sayln("       Number of passes is = " + this.numPasses);
        this.sayln("       Max iterations is = " + maxIterations);
        Timing total = new Timing();
        Timing current = new Timing();
        total.start();
        current.start();
        int iters = 0;
        double gValue = 0.0;
        double wValue = 0.0;
        double sgsValue = 0.0;
        double currentRate = 0.0;
        double testUpdate = 0.0;
        double realUpdate = 0.0;
        for (int pass = 0; pass < this.numPasses; ++pass) {
            boolean doEval = pass > 0 && this.evaluateIters > 0 && pass % this.evaluateIters == 0;
            double evalScore = Double.NEGATIVE_INFINITY;
            if (doEval) {
                evalScore = this.doEvaluation(this.x);
                if (this.useEvalImprovement && !this.toContinue(this.x, evalScore)) break;
            }
            double objVal = 0.0;
            this.say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
            int numOfNonZero = 0;
            int numOfNonZeroGroup = 0;
            String gSizeStr = "";
            for (int batch = 0; batch < numBatches; ++batch) {
                ++iters;
                double[] gradients = null;
                if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
                    AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction)f;
                    if (this.bSize == totalSamples) {
                        objVal = func.valueAt(this.x);
                        gradients = func.getDerivative();
                    } else {
                        func.calculateStochasticGradient(this.x, this.bSize);
                        gradients = func.getDerivative();
                    }
                } else if (f instanceof AbstractCachingDiffFunction) {
                    AbstractCachingDiffFunction func = (AbstractCachingDiffFunction)f;
                    gradients = func.derivativeAt(this.x);
                }
                if (this.prior == Prior.NONE) {
                    Set<Object> paramRange = null;
                    if (f instanceof HasRegularizerParamRange) {
                        paramRange = ((HasRegularizerParamRange)((Object)f)).getRegularizerParamRange(this.x);
                    } else {
                        paramRange = new HashSet();
                        for (int i = 0; i < this.x.length; ++i) {
                            paramRange.add(i);
                        }
                    }
                    Iterator<Object> i$ = paramRange.iterator();
                    while (i$.hasNext()) {
                        int index = (Integer)i$.next();
                        gValue = gradients[index];
                        sgsValue = gValue * gValue;
                        int n = index;
                        sumGradSquare[n] = sumGradSquare[n] + sgsValue;
                        wValue = this.x[index];
                        currentRate = this.initRate / (Math.sqrt(sumGradSquare[index]) + 0.001);
                        this.x[index] = testUpdate = wValue - currentRate * gValue;
                    }
                    continue;
                }
                if (this.prior == Prior.LASSO || this.prior == Prior.RIDGE || this.prior == Prior.GAUSSIAN) {
                    double testUpdateSquaredSum = 0.0;
                    Set<Object> paramRange = null;
                    if (f instanceof HasRegularizerParamRange) {
                        paramRange = ((HasRegularizerParamRange)((Object)f)).getRegularizerParamRange(this.x);
                    } else {
                        paramRange = new HashSet();
                        for (int i = 0; i < this.x.length; ++i) {
                            paramRange.add(i);
                        }
                    }
                    Iterator<Object> i$ = paramRange.iterator();
                    while (i$.hasNext()) {
                        int index = (Integer)i$.next();
                        gValue = gradients[index];
                        sgsValue = gValue * gValue;
                        int n = index;
                        sumGradSquare[n] = sumGradSquare[n] + sgsValue;
                        wValue = this.x[index];
                        currentRate = this.initRate / (Math.sqrt(sumGradSquare[index]) + 0.001);
                        testUpdate = wValue - currentRate * gValue;
                        double currentLambda = currentRate * this.lambda;
                        if (this.prior == Prior.LASSO) {
                            this.x[index] = realUpdate = Math.signum(testUpdate) * SGDWithAdaGradAndFOBOS.pospart(Math.abs(testUpdate) - currentLambda);
                            if (realUpdate == 0.0) continue;
                            ++numOfNonZero;
                            continue;
                        }
                        if (this.prior == Prior.RIDGE) {
                            testUpdateSquaredSum += testUpdate * testUpdate;
                            testUpdateCache[index] = testUpdate;
                            currentRateCache[index] = currentRate;
                            continue;
                        }
                        if (this.prior != Prior.GAUSSIAN) continue;
                        this.x[index] = realUpdate = testUpdate / (1.0 + currentLambda);
                        objVal += currentLambda * wValue * wValue;
                    }
                    if (this.prior != Prior.RIDGE) continue;
                    double testUpdateNorm = Math.sqrt(testUpdateSquaredSum);
                    for (int index = 0; index < testUpdateCache.length; ++index) {
                        this.x[index] = realUpdate = testUpdateCache[index] * SGDWithAdaGradAndFOBOS.pospart(1.0 - currentRateCache[index] * this.lambda / testUpdateNorm);
                        if (realUpdate == 0.0) continue;
                        ++numOfNonZero;
                    }
                    continue;
                }
                for (int gIndex = 0; gIndex < featureGrouping.length; ++gIndex) {
                    int[] gFeatureIndices = featureGrouping[gIndex];
                    double testUpdateSquaredSum = 0.0;
                    double testUpdateAbsSum = 0.0;
                    double M = gFeatureIndices.length;
                    double dm = Math.log(M);
                    for (int index : gFeatureIndices) {
                        gValue = gradients[index];
                        sgsValue = gValue * gValue;
                        int n = index;
                        sumGradSquare[n] = sumGradSquare[n] + sgsValue;
                        wValue = this.x[index];
                        currentRate = this.initRate / (Math.sqrt(sumGradSquare[index]) + 0.001);
                        testUpdate = wValue - currentRate * gValue;
                        testUpdateSquaredSum += testUpdate * testUpdate;
                        testUpdateAbsSum += Math.abs(testUpdate);
                        testUpdateCache[index] = testUpdate;
                        currentRateCache[index] = currentRate;
                    }
                    if (this.prior == Prior.gLASSO) {
                        double testUpdateNorm = Math.sqrt(testUpdateSquaredSum);
                        boolean groupHasNonZero = false;
                        for (int index : gFeatureIndices) {
                            this.x[index] = realUpdate = testUpdateCache[index] * SGDWithAdaGradAndFOBOS.pospart(1.0 - currentRateCache[index] * this.lambda * dm / testUpdateNorm);
                            if (realUpdate == 0.0) continue;
                            ++numOfNonZero;
                            groupHasNonZero = true;
                        }
                        if (!groupHasNonZero) continue;
                        ++numOfNonZeroGroup;
                        continue;
                    }
                    if (this.prior == Prior.aeLASSO) {
                        int nonZeroCount = 0;
                        boolean groupHasNonZero = false;
                        for (int index : gFeatureIndices) {
                            double tau = currentRateCache[index] * this.lambda / (1.0 + currentRateCache[index] * this.lambda * M) * testUpdateAbsSum;
                            this.x[index] = realUpdate = Math.signum(testUpdateCache[index]) * SGDWithAdaGradAndFOBOS.pospart(Math.abs(testUpdateCache[index]) - tau);
                            if (realUpdate == 0.0) continue;
                            ++numOfNonZero;
                            ++nonZeroCount;
                            groupHasNonZero = true;
                        }
                        if (!groupHasNonZero) continue;
                        ++numOfNonZeroGroup;
                        continue;
                    }
                    if (this.prior != Prior.sgLASSO) continue;
                    double bSquaredSum = 0.0;
                    double b = 0.0;
                    for (int index : gFeatureIndices) {
                        bCache[index] = b = Math.signum(testUpdateCache[index]) * SGDWithAdaGradAndFOBOS.pospart(Math.abs(testUpdateCache[index]) - currentRateCache[index] * this.alpha * this.lambda);
                        bSquaredSum += b * b;
                    }
                    double bNorm = Math.sqrt(bSquaredSum);
                    int nonZeroCount = 0;
                    boolean groupHasNonZero = false;
                    for (int index : gFeatureIndices) {
                        this.x[index] = realUpdate = bCache[index] * SGDWithAdaGradAndFOBOS.pospart(1.0 - currentRateCache[index] * (1.0 - this.alpha) * this.lambda * dm / bNorm);
                        if (realUpdate == 0.0) continue;
                        ++numOfNonZero;
                        ++nonZeroCount;
                        groupHasNonZero = true;
                    }
                    if (!groupHasNonZero) continue;
                    ++numOfNonZeroGroup;
                }
            }
            try {
                ArrayMath.assertFinite(this.x, "x");
            }
            catch (ArrayMath.InvalidElementException e) {
                System.err.println(e.toString());
                for (int i = 0; i < this.x.length; ++i) {
                    this.x[i] = Double.NaN;
                }
                break;
            }
            this.sayln(String.valueOf(numBatches) + ", n0-fCount:" + numOfNonZero + (this.prior != Prior.LASSO && this.prior != Prior.RIDGE ? ", n0-gCount:" + numOfNonZeroGroup : "") + (evalScore != Double.NEGATIVE_INFINITY ? ", evalScore:" + evalScore : "") + ", obj_val:" + nf.format(objVal));
            if (iters >= maxIterations) {
                this.sayln("Online Optimization complete.  Stopped after max iterations");
                break;
            }
            if (total.report() < this.maxTime) continue;
            this.sayln("Online Optimization complete.  Stopped after max time");
            break;
        }
        if (this.evaluateIters > 0) {
            double evalScore = this.useEvalImprovement ? this.doEvaluation(this.xBest) : this.doEvaluation(this.x);
            this.sayln("final evalScore is: " + evalScore);
        }
        this.sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");
        return this.useEvalImprovement ? this.xBest : this.x;
    }

    protected void sayln(String s) {
        if (!this.quiet) {
            System.err.println(s);
        }
    }

    protected void say(String s) {
        if (!this.quiet) {
            System.err.print(s);
        }
    }

    public static enum Prior {
        LASSO,
        RIDGE,
        GAUSSIAN,
        aeLASSO,
        gLASSO,
        sgLASSO,
        NONE;

    }
}

