/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.parser.common.NoSuchParseException;
import edu.stanford.nlp.parser.dvparser.DVModel;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.metrics.TreeSpanScoring;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.IntPair;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class DVParserCostAndGradient
extends AbstractCachingDiffFunction {
    List<Tree> trainingBatch;
    IdentityHashMap<Tree, List<Tree>> topParses;
    DVModel dvModel;
    Options op;
    static final double TRAIN_LAMBDA = 1.0;

    public DVParserCostAndGradient(List<Tree> trainingBatch, IdentityHashMap<Tree, List<Tree>> topParses, DVModel dvModel, Options op) {
        this.trainingBatch = trainingBatch;
        this.topParses = topParses;
        this.dvModel = dvModel;
        this.op = op;
    }

    private List<String> getContextWords(Tree tree) {
        ArrayList<String> words = null;
        if (this.op.trainOptions.useContextWords) {
            words = Generics.newArrayList();
            ArrayList<Label> leaves = tree.yield();
            for (Label word : leaves) {
                words.add(word.value());
            }
        }
        return words;
    }

    private SimpleMatrix concatenateContextWords(SimpleMatrix childVec, IntPair span, List<String> words) {
        SimpleMatrix left = span.getSource() < 0 ? this.dvModel.getStartWordVector() : this.dvModel.getWordVector(words.get(span.getSource()));
        SimpleMatrix right = span.getTarget() >= words.size() ? this.dvModel.getEndWordVector() : this.dvModel.getWordVector(words.get(span.getTarget()));
        return NeuralUtils.concatenate(childVec, left, right);
    }

    public static void outputSpans(Tree tree) {
        System.err.print(tree.getSpan() + " ");
        for (Tree child : tree.children()) {
            DVParserCostAndGradient.outputSpans(child);
        }
    }

    public double score(Tree tree, IdentityHashMap<Tree, SimpleMatrix> nodeVectors) {
        List<String> words = this.getContextWords(tree);
        IdentityHashMap<Tree, Double> scores = new IdentityHashMap<Tree, Double>();
        try {
            this.forwardPropagateTree(tree, words, nodeVectors, scores);
        }
        catch (AssertionError e) {
            System.err.println("Failed to correctly process tree " + tree);
            throw e;
        }
        double score = 0.0;
        for (Tree node : scores.keySet()) {
            score += scores.get(node).doubleValue();
        }
        return score;
    }

    private void forwardPropagateTree(Tree tree, List<String> words, IdentityHashMap<Tree, SimpleMatrix> nodeVectors, IdentityHashMap<Tree, Double> scores) {
        SimpleMatrix W;
        if (tree.isLeaf()) {
            return;
        }
        if (tree.isPreTerminal()) {
            Tree wordNode = tree.children()[0];
            String word = wordNode.label().value();
            SimpleMatrix wordVector = this.dvModel.getWordVector(word);
            wordVector = NeuralUtils.elementwiseApplyTanh(wordVector);
            nodeVectors.put(tree, wordVector);
            return;
        }
        for (Tree child : tree.children()) {
            this.forwardPropagateTree(child, words, nodeVectors, scores);
        }
        SimpleMatrix childVec = tree.children().length == 2 ? NeuralUtils.concatenateWithBias(nodeVectors.get(tree.children()[0]), nodeVectors.get(tree.children()[1])) : NeuralUtils.concatenateWithBias(nodeVectors.get(tree.children()[0]));
        if (this.op.trainOptions.useContextWords) {
            childVec = this.concatenateContextWords(childVec, tree.getSpan(), words);
        }
        if ((W = this.dvModel.getWForNode(tree)) == null) {
            String error = "Could not find W for tree " + tree;
            if (this.op.testOptions.verbose) {
                System.err.println(error);
            }
            throw new NoSuchParseException(error);
        }
        SimpleMatrix currentVector = (SimpleMatrix)W.mult((SimpleBase)childVec);
        currentVector = NeuralUtils.elementwiseApplyTanh(currentVector);
        nodeVectors.put(tree, currentVector);
        SimpleMatrix scoreW = this.dvModel.getScoreWForNode(tree);
        if (scoreW == null) {
            String error = "Could not find scoreW for tree " + tree;
            if (this.op.testOptions.verbose) {
                System.err.println(error);
            }
            throw new NoSuchParseException(error);
        }
        double score = scoreW.dot((SimpleBase)currentVector);
        scores.put(tree, score);
    }

    @Override
    public int domainDimension() {
        return this.dvModel.totalParamSize();
    }

    public List<DeepTree> getAllHighestScoringTreesTest(List<Tree> trees) {
        ArrayList<DeepTree> allBestTrees = new ArrayList<DeepTree>();
        for (Tree tree : trees) {
            allBestTrees.add(this.getHighestScoringTree(tree, 0.0));
        }
        return allBestTrees;
    }

    public DeepTree getHighestScoringTree(Tree tree, double lambda) {
        List<Tree> hypotheses = this.topParses.get(tree);
        if (hypotheses == null || hypotheses.size() == 0) {
            throw new AssertionError((Object)("Failed to get any hypothesis trees for " + tree));
        }
        double bestScore = Double.NEGATIVE_INFINITY;
        Tree bestTree = null;
        IdentityHashMap<Tree, SimpleMatrix> bestVectors = null;
        for (Tree hypothesis : hypotheses) {
            IdentityHashMap<Tree, SimpleMatrix> nodeVectors = new IdentityHashMap<Tree, SimpleMatrix>();
            double scoreHyp = this.score(hypothesis, nodeVectors);
            double deltaMargin = 0.0;
            if (lambda != 0.0) {
                deltaMargin = this.op.trainOptions.deltaMargin * lambda * this.getMargin(tree, hypothesis);
            }
            scoreHyp += deltaMargin;
            if (bestTree != null && !(scoreHyp > bestScore)) continue;
            bestTree = hypothesis;
            bestScore = scoreHyp;
            bestVectors = nodeVectors;
        }
        DeepTree returnTree = new DeepTree(bestTree, bestVectors, bestScore);
        return returnTree;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void calculate(double[] theta) {
        double[] localDerivativeB;
        double[] localDerivativeGood;
        int numCols;
        int numRows;
        this.dvModel.vectorToParams(theta);
        double localValue = 0.0;
        double[] localDerivative = new double[theta.length];
        TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsB = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesB = TwoDimensionalMap.treeMap();
        TreeMap<String, SimpleMatrix> unaryW_dfsG = new TreeMap<String, SimpleMatrix>();
        TreeMap<String, SimpleMatrix> unaryW_dfsB = new TreeMap<String, SimpleMatrix>();
        TreeMap<String, SimpleMatrix> unaryScoreDerivativesG = new TreeMap<String, SimpleMatrix>();
        TreeMap<String, SimpleMatrix> unaryScoreDerivativesB = new TreeMap<String, SimpleMatrix>();
        TreeMap<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<String, SimpleMatrix>();
        TreeMap<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<String, SimpleMatrix>();
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.dvModel.binaryTransform) {
            numRows = entry.getValue().numRows();
            numCols = entry.getValue().numCols();
            binaryW_dfsG.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
            binaryW_dfsB.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols));
            binaryScoreDerivativesG.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
            binaryScoreDerivativesB.put(entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows));
        }
        for (Map.Entry entry : this.dvModel.unaryTransform.entrySet()) {
            numRows = ((SimpleMatrix)entry.getValue()).numRows();
            numCols = ((SimpleMatrix)entry.getValue()).numCols();
            unaryW_dfsG.put((String)entry.getKey(), new SimpleMatrix(numRows, numCols));
            unaryW_dfsB.put((String)entry.getKey(), new SimpleMatrix(numRows, numCols));
            unaryScoreDerivativesG.put((String)entry.getKey(), new SimpleMatrix(1, numRows));
            unaryScoreDerivativesB.put((String)entry.getKey(), new SimpleMatrix(1, numRows));
        }
        if (this.op.trainOptions.trainWordVectors) {
            for (Map.Entry entry : this.dvModel.wordVectors.entrySet()) {
                numRows = ((SimpleMatrix)entry.getValue()).numRows();
                numCols = ((SimpleMatrix)entry.getValue()).numCols();
                wordVectorDerivativesG.put((String)entry.getKey(), new SimpleMatrix(numRows, numCols));
                wordVectorDerivativesB.put((String)entry.getKey(), new SimpleMatrix(numRows, numCols));
            }
        }
        Timing scoreTiming = new Timing();
        scoreTiming.doing("Scoring trees");
        boolean bl = false;
        MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper = new MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>>(this.op.trainOptions.trainingThreads, new ScoringProcessor());
        for (Tree tree : this.trainingBatch) {
            wrapper.put(tree);
        }
        wrapper.join();
        scoreTiming.done();
        while (wrapper.peek()) {
            void var16_21;
            Pair<DeepTree, DeepTree> result = wrapper.poll();
            DeepTree goldTree = (DeepTree)result.first;
            DeepTree bestTree = (DeepTree)result.second;
            StringBuilder treeDebugLine = new StringBuilder();
            Formatter formatter = new Formatter(treeDebugLine);
            boolean isDone = Math.abs(bestTree.getScore() - goldTree.getScore()) <= 1.0E-5 || goldTree.getScore() > bestTree.getScore();
            String done = isDone ? "done" : "";
            formatter.format("Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s", (int)var16_21, bestTree.getScore(), goldTree.getScore(), done);
            System.err.println(treeDebugLine.toString());
            if (!isDone) {
                double valueDelta = bestTree.getScore() - goldTree.getScore();
                localValue += valueDelta;
                List<String> words = this.getContextWords(goldTree.getTree());
                this.backpropDerivative(goldTree.getTree(), words, goldTree.getVectors(), binaryW_dfsG, unaryW_dfsG, binaryScoreDerivativesG, unaryScoreDerivativesG, wordVectorDerivativesG);
                this.backpropDerivative(bestTree.getTree(), words, bestTree.getVectors(), binaryW_dfsB, unaryW_dfsB, binaryScoreDerivativesB, unaryScoreDerivativesB, wordVectorDerivativesB);
            }
            ++var16_21;
        }
        if (this.op.trainOptions.trainWordVectors) {
            localDerivativeGood = NeuralUtils.paramsToVector(theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator(), wordVectorDerivativesG.values().iterator());
            localDerivativeB = NeuralUtils.paramsToVector(theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator(), wordVectorDerivativesB.values().iterator());
        } else {
            localDerivativeGood = NeuralUtils.paramsToVector(theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator());
            localDerivativeB = NeuralUtils.paramsToVector(theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator());
        }
        for (int i = 0; i < localDerivativeGood.length; ++i) {
            localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i];
        }
        this.value = localValue;
        this.derivative = localDerivative;
        this.value = 1.0 / (double)this.trainingBatch.size() * this.value;
        ArrayMath.multiplyInPlace(this.derivative, 1.0 / (double)this.trainingBatch.size());
        double[] currentParams = this.dvModel.paramsToVector();
        double regCost = 0.0;
        for (int i = 0; i < currentParams.length; ++i) {
            regCost += currentParams[i] * currentParams[i];
        }
        regCost = this.op.trainOptions.regCost * 0.5 * regCost;
        this.value += regCost;
        ArrayMath.multiplyInPlace(currentParams, this.op.trainOptions.regCost);
        ArrayMath.pairwiseAddInPlace(this.derivative, currentParams);
    }

    public double getMargin(Tree goldTree, Tree bestHypothesis) {
        return TreeSpanScoring.countSpanErrors(this.op.langpack(), goldTree, bestHypothesis);
    }

    public void backpropDerivative(Tree tree, List<String> words, IdentityHashMap<Tree, SimpleMatrix> nodeVectors, TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfs, Map<String, SimpleMatrix> unaryW_dfs, TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivatives, Map<String, SimpleMatrix> unaryScoreDerivatives, Map<String, SimpleMatrix> wordVectorDerivatives) {
        SimpleMatrix delta = new SimpleMatrix(this.op.lexOptions.numHid, 1);
        this.backpropDerivative(tree, words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, delta);
    }

    public void backpropDerivative(Tree tree, List<String> words, IdentityHashMap<Tree, SimpleMatrix> nodeVectors, TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfs, Map<String, SimpleMatrix> unaryW_dfs, TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivatives, Map<String, SimpleMatrix> unaryScoreDerivatives, Map<String, SimpleMatrix> wordVectorDerivatives, SimpleMatrix deltaUp) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.isPreTerminal()) {
            if (this.op.trainOptions.trainWordVectors) {
                String word = tree.children()[0].label().value();
                word = this.dvModel.getVocabWord(word);
                SimpleMatrix derivative = deltaUp;
                wordVectorDerivatives.put(word, (SimpleMatrix)wordVectorDerivatives.get(word).plus((SimpleBase)derivative));
            }
            return;
        }
        SimpleMatrix currentVector = nodeVectors.get(tree);
        SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector);
        SimpleMatrix scoreW = this.dvModel.getScoreWForNode(tree);
        currentVectorDerivative = (SimpleMatrix)currentVectorDerivative.elementMult(scoreW.transpose());
        SimpleMatrix deltaCurrent = (SimpleMatrix)deltaUp.plus((SimpleBase)currentVectorDerivative);
        SimpleMatrix W = this.dvModel.getWForNode(tree);
        SimpleMatrix WTdelta = (SimpleMatrix)((SimpleMatrix)W.transpose()).mult((SimpleBase)deltaCurrent);
        if (tree.children().length == 2) {
            String leftLabel = this.dvModel.basicCategory(tree.children()[0].label().value());
            String rightLabel = this.dvModel.basicCategory(tree.children()[1].label().value());
            binaryScoreDerivatives.put(leftLabel, rightLabel, (SimpleMatrix)binaryScoreDerivatives.get(leftLabel, rightLabel).plus(currentVector.transpose()));
            SimpleMatrix leftVector = nodeVectors.get(tree.children()[0]);
            SimpleMatrix rightVector = nodeVectors.get(tree.children()[1]);
            SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector);
            if (this.op.trainOptions.useContextWords) {
                childrenVector = this.concatenateContextWords(childrenVector, tree.getSpan(), words);
            }
            SimpleMatrix W_df = (SimpleMatrix)deltaCurrent.mult(childrenVector.transpose());
            binaryW_dfs.put(leftLabel, rightLabel, (SimpleMatrix)binaryW_dfs.get(leftLabel, rightLabel).plus((SimpleBase)W_df));
            SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector);
            SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector);
            SimpleMatrix leftWTDelta = (SimpleMatrix)WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1);
            SimpleMatrix rightWTDelta = (SimpleMatrix)WTdelta.extractMatrix(deltaCurrent.numRows(), deltaCurrent.numRows() * 2, 0, 1);
            this.backpropDerivative(tree.children()[0], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, (SimpleMatrix)leftDerivative.elementMult((SimpleBase)leftWTDelta));
            this.backpropDerivative(tree.children()[1], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, (SimpleMatrix)rightDerivative.elementMult((SimpleBase)rightWTDelta));
        } else if (tree.children().length == 1) {
            String childLabel = this.dvModel.basicCategory(tree.children()[0].label().value());
            unaryScoreDerivatives.put(childLabel, (SimpleMatrix)unaryScoreDerivatives.get(childLabel).plus(currentVector.transpose()));
            SimpleMatrix childVector = nodeVectors.get(tree.children()[0]);
            SimpleMatrix childVectorWithBias = NeuralUtils.concatenateWithBias(childVector);
            if (this.op.trainOptions.useContextWords) {
                childVectorWithBias = this.concatenateContextWords(childVectorWithBias, tree.getSpan(), words);
            }
            SimpleMatrix W_df = (SimpleMatrix)deltaCurrent.mult(childVectorWithBias.transpose());
            unaryW_dfs.put(childLabel, (SimpleMatrix)unaryW_dfs.get(childLabel).plus((SimpleBase)W_df));
            SimpleMatrix childDerivative = NeuralUtils.elementwiseApplyTanhDerivative(childVector);
            SimpleMatrix childWTDelta = (SimpleMatrix)WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1);
            this.backpropDerivative(tree.children()[0], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, (SimpleMatrix)childDerivative.elementMult((SimpleBase)childWTDelta));
        }
    }

    class ScoringProcessor
    implements ThreadsafeProcessor<Tree, Pair<DeepTree, DeepTree>> {
        ScoringProcessor() {
        }

        @Override
        public Pair<DeepTree, DeepTree> process(Tree tree) {
            IdentityHashMap<Tree, SimpleMatrix> goldVectors = new IdentityHashMap<Tree, SimpleMatrix>();
            double scoreGold = DVParserCostAndGradient.this.score(tree, goldVectors);
            DeepTree bestTree = DVParserCostAndGradient.this.getHighestScoringTree(tree, 1.0);
            DeepTree goldTree = new DeepTree(tree, goldVectors, scoreGold);
            return Pair.makePair(goldTree, bestTree);
        }

        @Override
        public ThreadsafeProcessor<Tree, Pair<DeepTree, DeepTree>> newInstance() {
            return this;
        }
    }
}

