/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFCliqueTree;
import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.HasCliquePotentialFunction;
import edu.stanford.nlp.ie.crf.LinearCliquePotentialFunction;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.util.Index;
import java.util.Arrays;
import java.util.List;

public class CRFLogConditionalObjectiveFunction
extends AbstractStochasticCachingDiffUpdateFunction
implements HasCliquePotentialFunction {
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    private final int prior;
    private final double sigma;
    private final double epsilon = 0.1;
    private final List<Index<CRFLabel>> labelIndices;
    private final Index<String> classIndex;
    private final double[][] Ehat;
    private final int window;
    private final int numClasses;
    private final int[] map;
    private final int[][][][] data;
    private final double[][][][] featureVal;
    private final int[][] labels;
    private final int domainDimension;
    private double[][] eHat4Update;
    private double[][] e4Update;
    private int[][] weightIndices;
    private final String backgroundSymbol;
    public static boolean VERBOSE = false;

    public static int getPriorType(String priorTypeStr) {
        if (priorTypeStr == null) {
            return 1;
        }
        if ("QUADRATIC".equalsIgnoreCase(priorTypeStr)) {
            return 1;
        }
        if ("HUBER".equalsIgnoreCase(priorTypeStr)) {
            return 2;
        }
        if ("QUARTIC".equalsIgnoreCase(priorTypeStr)) {
            return 3;
        }
        if ("NONE".equalsIgnoreCase(priorTypeStr)) {
            return 0;
        }
        throw new IllegalArgumentException("Unknown prior type: " + priorTypeStr);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String backgroundSymbol) {
        this(data, labels, window, classIndex, labelIndices, map, "QUADRATIC", backgroundSymbol);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol) {
        this(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, 1.0, null);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String backgroundSymbol, double sigma, double[][][][] featureVal) {
        this(data, labels, window, classIndex, labelIndices, map, "QUADRATIC", backgroundSymbol, sigma, featureVal);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal) {
        this.window = window;
        this.classIndex = classIndex;
        this.numClasses = classIndex.size();
        this.labelIndices = labelIndices;
        this.map = map;
        this.data = data;
        this.featureVal = featureVal;
        this.labels = labels;
        this.prior = CRFLogConditionalObjectiveFunction.getPriorType(priorType);
        this.backgroundSymbol = backgroundSymbol;
        this.sigma = sigma;
        this.Ehat = this.empty2D();
        this.empiricalCounts(this.Ehat);
        int myDomainDimension = 0;
        for (int dim : map) {
            myDomainDimension += labelIndices.get(dim).size();
        }
        this.domainDimension = myDomainDimension;
    }

    @Override
    public int domainDimension() {
        return this.domainDimension;
    }

    public static double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
        double[][] newWeights = new double[map.length][];
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            newWeights[i] = new double[labelIndices.get(map[i]).size()];
            System.arraycopy(weights, index, newWeights[i], 0, labelIndices.get(map[i]).size());
            index += labelIndices.get(map[i]).size();
        }
        return newWeights;
    }

    public double[][] to2D(double[] weights) {
        return CRFLogConditionalObjectiveFunction.to2D(weights, this.labelIndices, this.map);
    }

    public double[][] to2D(double[] weights, double wscale) {
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = weights[i] * wscale;
        }
        return CRFLogConditionalObjectiveFunction.to2D(weights, this.labelIndices, this.map);
    }

    public static double[] to1D(double[][] weights, int domainDimension) {
        double[] newWeights = new double[domainDimension];
        int index = 0;
        for (int i = 0; i < weights.length; ++i) {
            System.arraycopy(weights[i], 0, newWeights, index, weights[i].length);
            index += weights[i].length;
        }
        return newWeights;
    }

    public double[] to1D(double[][] weights) {
        return CRFLogConditionalObjectiveFunction.to1D(weights, this.domainDimension());
    }

    public int[][] getWeightIndices() {
        if (this.weightIndices == null) {
            this.weightIndices = new int[this.map.length][];
            int index = 0;
            for (int i = 0; i < this.map.length; ++i) {
                this.weightIndices[i] = new int[this.labelIndices.get(this.map[i]).size()];
                for (int j = 0; j < this.labelIndices.get(this.map[i]).size(); ++j) {
                    this.weightIndices[i][j] = index++;
                }
            }
        }
        return this.weightIndices;
    }

    private double[][] empty2D() {
        double[][] d = new double[this.map.length][];
        for (int i = 0; i < this.map.length; ++i) {
            d[i] = new double[this.labelIndices.get(this.map[i]).size()];
        }
        return d;
    }

    private void empiricalCounts(double[][] eHat) {
        for (int m = 0; m < this.data.length; ++m) {
            this.empiricalCountsForADoc(eHat, m);
        }
    }

    private void empiricalCountsForADoc(double[][] eHat, int docIndex) {
        int[][][] docData = this.data[docIndex];
        int[] docLabels = this.labels[docIndex];
        int[] windowLabels = new int[this.window];
        Arrays.fill(windowLabels, this.classIndex.indexOf(this.backgroundSymbol));
        double[][][] featureValArr = null;
        if (this.featureVal != null) {
            featureValArr = this.featureVal[docIndex];
        }
        if (docLabels.length > docData.length) {
            System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
            int[] newDocLabels = new int[docData.length];
            System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
            docLabels = newDocLabels;
        }
        for (int i = 0; i < docData.length; ++i) {
            System.arraycopy(windowLabels, 1, windowLabels, 0, this.window - 1);
            windowLabels[this.window - 1] = docLabels[i];
            for (int j = 0; j < docData[i].length; ++j) {
                int[] cliqueLabel = new int[j + 1];
                System.arraycopy(windowLabels, this.window - 1 - j, cliqueLabel, 0, j + 1);
                CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                int labelIndex = this.labelIndices.get(j).indexOf(crfLabel);
                for (int n = 0; n < docData[i][j].length; ++n) {
                    double fVal = 1.0;
                    if (featureValArr != null && j == 0) {
                        fVal = featureValArr[i][j][n];
                    }
                    double[] dArray = eHat[docData[i][j][n]];
                    int n2 = labelIndex;
                    dArray[n2] = dArray[n2] + fVal;
                }
            }
        }
    }

    public double valueForADoc(double[][] weights, int docIndex) {
        return this.expectedCountsAndValueForADoc(weights, null, docIndex, true);
    }

    private double expectedCountsAndValueForADoc(double[][] weights, double[][] E, int docIndex) {
        return this.expectedCountsAndValueForADoc(weights, E, docIndex, false);
    }

    @Override
    public CliquePotentialFunction getCliquePotentialFunction(double[] x) {
        double[][] weights = this.to2D(x);
        return new LinearCliquePotentialFunction(weights);
    }

    private double expectedCountsAndValueForADoc(double[][] weights, double[][] E, int docIndex, boolean skipExpectedCountCal) {
        int i;
        double prob = 0.0;
        int[][][] docData = this.data[docIndex];
        int[] docLabels = this.labels[docIndex];
        double[][][] featureVal3DArr = null;
        if (this.featureVal != null) {
            featureVal3DArr = this.featureVal[docIndex];
        }
        LinearCliquePotentialFunction cliquePotentialFunc = new LinearCliquePotentialFunction(weights);
        CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol, cliquePotentialFunc, featureVal3DArr);
        int[] given = new int[this.window - 1];
        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
        if (docLabels.length > docData.length) {
            System.arraycopy(docLabels, 0, given, 0, given.length);
            int[] newDocLabels = new int[docData.length];
            System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
            docLabels = newDocLabels;
        }
        for (i = 0; i < docData.length; ++i) {
            int label = docLabels[i];
            double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
            if (VERBOSE) {
                System.err.println("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
            }
            prob += p;
            System.arraycopy(given, 1, given, 0, given.length - 1);
            given[given.length - 1] = label;
        }
        if (!skipExpectedCountCal) {
            for (i = 0; i < docData.length; ++i) {
                for (int j = 0; j < docData[i].length; ++j) {
                    Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                    for (int k = 0; k < labelIndex.size(); ++k) {
                        int[] label = labelIndex.get(k).getLabel();
                        double p = cliqueTree.prob(i, label);
                        for (int n = 0; n < docData[i][j].length; ++n) {
                            double fVal = 1.0;
                            if (j == 0 && featureVal3DArr != null) {
                                fVal = featureVal3DArr[i][j][n];
                            }
                            double[] dArray = E[docData[i][j][n]];
                            int n2 = k;
                            dArray[n2] = dArray[n2] + p * fVal;
                        }
                    }
                }
            }
        }
        return prob;
    }

    @Override
    public void calculate(double[] x) {
        double prob = 0.0;
        double[][] weights = this.to2D(x);
        double[][] E = this.empty2D();
        for (int m = 0; m < this.data.length; ++m) {
            prob += this.expectedCountsAndValueForADoc(weights, E, m);
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate() - this may well indicate numeric underflow due to overly long documents.");
        }
        this.value = -prob;
        if (VERBOSE) {
            System.err.println("value is " + this.value);
        }
        int index = 0;
        for (int i = 0; i < E.length; ++i) {
            for (int j = 0; j < E[i].length; ++j) {
                this.derivative[index++] = E[i][j] - this.Ehat[i][j];
                if (!VERBOSE) continue;
                System.err.println("deriv(" + i + "," + j + ") = " + E[i][j] + " - " + this.Ehat[i][j] + " = " + this.derivative[index - 1]);
            }
        }
        this.applyPrior(x, 1.0);
    }

    private void applyPrior(double[] x, double batchScale) {
        block5: {
            block6: {
                block4: {
                    if (this.prior != 1) break block4;
                    double sigmaSq = this.sigma * this.sigma;
                    double lambda = 0.5 / sigmaSq;
                    int i = 0;
                    while (i < x.length) {
                        double w = x[i];
                        this.value += batchScale * w * w * lambda;
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + batchScale * w / sigmaSq;
                    }
                    break block5;
                }
                if (this.prior != 2) break block6;
                double sigmaSq = this.sigma * this.sigma;
                for (int i = 0; i < x.length; ++i) {
                    double w = x[i];
                    double wabs = Math.abs(w);
                    if (wabs < 0.1) {
                        this.value += batchScale * w * w / 2.0 / 0.1 / sigmaSq;
                        int n = i;
                        this.derivative[n] = this.derivative[n] + batchScale * w / 0.1 / sigmaSq;
                        continue;
                    }
                    this.value += batchScale * (wabs - 0.05) / sigmaSq;
                    int n = i;
                    this.derivative[n] = this.derivative[n] + batchScale * (w < 0.0 ? -1.0 : 1.0) / sigmaSq;
                }
                break block5;
            }
            if (this.prior != 3) break block5;
            double sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            double lambda = 0.5 / sigmaQu;
            int i = 0;
            while (i < x.length) {
                double w = x[i];
                this.value += batchScale * w * w * w * w * lambda;
                int n = i++;
                this.derivative[n] = this.derivative[n] + batchScale * w / sigmaQu;
            }
        }
    }

    @Override
    public void calculateStochastic(double[] x, double[] v, int[] batch) {
        this.calculateStochasticGradientOnly(x, batch);
    }

    @Override
    public int dataDimension() {
        return this.data.length;
    }

    public void calculateStochasticGradientOnly(double[] x, int[] batch) {
        double prob = 0.0;
        double[][] weights = this.to2D(x);
        double batchScale = (double)batch.length / (double)this.dataDimension();
        double[][] E = this.empty2D();
        for (int ind : batch) {
            prob += this.expectedCountsAndValueForADoc(weights, E, ind);
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        int index = 0;
        for (int i = 0; i < E.length; ++i) {
            for (int j = 0; j < E[i].length; ++j) {
                this.derivative[index++] = E[i][j] - batchScale * this.Ehat[i][j];
                if (!VERBOSE) continue;
                System.err.println("deriv(" + i + "," + j + ") = " + E[i][j] + " - " + this.Ehat[i][j] + " = " + this.derivative[index - 1]);
            }
        }
        this.applyPrior(x, batchScale);
    }

    private void clearUpdateEs() {
        int i;
        for (i = 0; i < this.eHat4Update.length; ++i) {
            this.eHat4Update[i] = new double[this.eHat4Update[i].length];
        }
        for (i = 0; i < this.e4Update.length; ++i) {
            this.e4Update[i] = new double[this.e4Update[i].length];
        }
    }

    @Override
    public double calculateStochasticUpdate(double[] x, double xscale, int[] batch, double gscale) {
        double prob = 0.0;
        double[][] weights = this.to2D(x, xscale);
        if (this.eHat4Update == null) {
            this.eHat4Update = this.empty2D();
            this.e4Update = new double[this.eHat4Update.length][];
            for (int i = 0; i < this.e4Update.length; ++i) {
                this.e4Update[i] = new double[this.eHat4Update[i].length];
            }
        } else {
            this.clearUpdateEs();
        }
        for (int ind : batch) {
            this.empiricalCountsForADoc(this.eHat4Update, ind);
            prob += this.expectedCountsAndValueForADoc(weights, this.e4Update, ind);
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        int index = 0;
        for (int i = 0; i < this.e4Update.length; ++i) {
            for (int j = 0; j < this.e4Update[i].length; ++j) {
                int n = index++;
                x[n] = x[n] + (this.eHat4Update[i][j] - this.e4Update[i][j]) * gscale;
            }
        }
        return this.value;
    }

    @Override
    public double valueAt(double[] x, double xscale, int[] batch) {
        double prob = 0.0;
        double[][] weights = this.to2D(x, xscale);
        for (int ind : batch) {
            prob += this.valueForADoc(weights, ind);
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        return this.value;
    }
}

