/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.sentiment.SentimentCostAndGradient;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.sentiment.SentimentUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

public class Evaluate {
    final SentimentCostAndGradient cag;
    final SentimentModel model;
    int labelsCorrect;
    int labelsIncorrect;
    int[][] labelConfusion;
    int rootLabelsCorrect;
    int rootLabelsIncorrect;
    int[][] rootLabelConfusion;
    IntCounter<Integer> lengthLabelsCorrect;
    IntCounter<Integer> lengthLabelsIncorrect;
    private static final NumberFormat NF = new DecimalFormat("0.000000");
    private static final int[] NEG_CLASSES = new int[]{0, 1};
    private static final int[] POS_CLASSES = new int[]{3, 4};

    public Evaluate(SentimentModel model) {
        this.model = model;
        this.cag = new SentimentCostAndGradient(model, null);
        this.reset();
    }

    public void reset() {
        this.labelsCorrect = 0;
        this.labelsIncorrect = 0;
        this.labelConfusion = new int[this.model.op.numClasses][this.model.op.numClasses];
        this.rootLabelsCorrect = 0;
        this.rootLabelsIncorrect = 0;
        this.rootLabelConfusion = new int[this.model.op.numClasses][this.model.op.numClasses];
        this.lengthLabelsCorrect = new IntCounter();
        this.lengthLabelsIncorrect = new IntCounter();
    }

    public void eval(List<Tree> trees) {
        for (Tree tree : trees) {
            this.eval(tree);
        }
    }

    public void eval(Tree tree) {
        this.cag.forwardPropagateTree(tree);
        this.countTree(tree);
        this.countRoot(tree);
        this.countLengthAccuracy(tree);
    }

    private int countLengthAccuracy(Tree tree) {
        int length;
        if (tree.isLeaf()) {
            return 0;
        }
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (tree.isPreTerminal()) {
            length = 1;
        } else {
            length = 0;
            for (Tree child : tree.children()) {
                length += this.countLengthAccuracy(child);
            }
        }
        if (gold.equals(predicted)) {
            this.lengthLabelsCorrect.incrementCount(length);
        } else {
            this.lengthLabelsIncorrect.incrementCount(length);
        }
        return length;
    }

    private void countTree(Tree tree) {
        Integer predicted;
        if (tree.isLeaf()) {
            return;
        }
        for (Tree child : tree.children()) {
            this.countTree(child);
        }
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        if (gold.equals(predicted = Integer.valueOf(RNNCoreAnnotations.getPredictedClass(tree)))) {
            ++this.labelsCorrect;
        } else {
            ++this.labelsIncorrect;
        }
        int[] nArray = this.labelConfusion[gold];
        int n = predicted;
        nArray[n] = nArray[n] + 1;
    }

    private void countRoot(Tree tree) {
        Integer predicted;
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        if (gold.equals(predicted = Integer.valueOf(RNNCoreAnnotations.getPredictedClass(tree)))) {
            ++this.rootLabelsCorrect;
        } else {
            ++this.rootLabelsIncorrect;
        }
        int[] nArray = this.rootLabelConfusion[gold];
        int n = predicted;
        nArray[n] = nArray[n] + 1;
    }

    public double exactNodeAccuracy() {
        return (double)this.labelsCorrect / (double)(this.labelsCorrect + this.labelsIncorrect);
    }

    public double exactRootAccuracy() {
        return (double)this.rootLabelsCorrect / (double)(this.rootLabelsCorrect + this.rootLabelsIncorrect);
    }

    public Counter<Integer> lengthAccuracies() {
        Set<Integer> keys = Generics.newHashSet();
        keys.addAll(this.lengthLabelsCorrect.keySet());
        keys.addAll(this.lengthLabelsIncorrect.keySet());
        ClassicCounter<Integer> results = new ClassicCounter<Integer>();
        for (Integer key : keys) {
            results.setCount(key, this.lengthLabelsCorrect.getCount(key) / (this.lengthLabelsCorrect.getCount(key) + this.lengthLabelsIncorrect.getCount(key)));
        }
        return results;
    }

    public void printLengthAccuracies() {
        Counter<Integer> accuracies = this.lengthAccuracies();
        TreeSet<Integer> keys = Generics.newTreeSet();
        keys.addAll(accuracies.keySet());
        System.err.println("Label accuracy at various lengths:");
        for (Integer key : keys) {
            System.err.println(StringUtils.padLeft(Integer.toString(key), 4) + ": " + NF.format(accuracies.getCount(key)));
        }
    }

    public double[] approxNegPosAccuracy() {
        return Evaluate.approxAccuracy(this.labelConfusion, NEG_CLASSES, POS_CLASSES);
    }

    public double approxNegPosCombinedAccuracy() {
        return Evaluate.approxCombinedAccuracy(this.labelConfusion, NEG_CLASSES, POS_CLASSES);
    }

    public double[] approxRootNegPosAccuracy() {
        return Evaluate.approxAccuracy(this.rootLabelConfusion, NEG_CLASSES, POS_CLASSES);
    }

    public double approxRootNegPosCombinedAccuracy() {
        return Evaluate.approxCombinedAccuracy(this.rootLabelConfusion, NEG_CLASSES, POS_CLASSES);
    }

    private static void printConfusionMatrix(String name, int[][] confusion) {
        System.err.println(name + " confusion matrix: rows are gold label, columns predicted label");
        for (int i = 0; i < confusion.length; ++i) {
            for (int j = 0; j < confusion[i].length; ++j) {
                System.err.print(StringUtils.padLeft(confusion[i][j], 10));
            }
            System.err.println();
        }
    }

    private static double[] approxAccuracy(int[][] confusion, int[] ... classes) {
        int[] correct = new int[classes.length];
        int[] incorrect = new int[classes.length];
        double[] results = new double[classes.length];
        for (int i = 0; i < classes.length; ++i) {
            for (int j = 0; j < classes[i].length; ++j) {
                for (int k = 0; k < classes[i].length; ++k) {
                    int n = i;
                    correct[n] = correct[n] + confusion[classes[i][j]][classes[i][k]];
                }
            }
            for (int other = 0; other < classes.length; ++other) {
                if (other == i) continue;
                for (int j = 0; j < classes[i].length; ++j) {
                    for (int k = 0; k < classes[other].length; ++k) {
                        int n = i;
                        incorrect[n] = incorrect[n] + confusion[classes[i][j]][classes[other][k]];
                    }
                }
            }
            results[i] = (double)correct[i] / (double)(correct[i] + incorrect[i]);
        }
        return results;
    }

    private static double approxCombinedAccuracy(int[][] confusion, int[] ... classes) {
        int correct = 0;
        int incorrect = 0;
        for (int i = 0; i < classes.length; ++i) {
            for (int j = 0; j < classes[i].length; ++j) {
                for (int k = 0; k < classes[i].length; ++k) {
                    correct += confusion[classes[i][j]][classes[i][k]];
                }
            }
            for (int other = 0; other < classes.length; ++other) {
                if (other == i) continue;
                for (int j = 0; j < classes[i].length; ++j) {
                    for (int k = 0; k < classes[other].length; ++k) {
                        incorrect += confusion[classes[i][j]][classes[other][k]];
                    }
                }
            }
        }
        return (double)correct / (double)(correct + incorrect);
    }

    public void printSummary() {
        System.err.println("EVALUATION SUMMARY");
        System.err.println("Tested " + (this.labelsCorrect + this.labelsIncorrect) + " labels");
        System.err.println("  " + this.labelsCorrect + " correct");
        System.err.println("  " + this.labelsIncorrect + " incorrect");
        System.err.println("  " + NF.format(this.exactNodeAccuracy()) + " accuracy");
        System.err.println("Tested " + (this.rootLabelsCorrect + this.rootLabelsIncorrect) + " roots");
        System.err.println("  " + this.rootLabelsCorrect + " correct");
        System.err.println("  " + this.rootLabelsIncorrect + " incorrect");
        System.err.println("  " + NF.format(this.exactRootAccuracy()) + " accuracy");
        Evaluate.printConfusionMatrix("Label", this.labelConfusion);
        Evaluate.printConfusionMatrix("Root label", this.rootLabelConfusion);
        double[] approxLabelAccuracy = this.approxNegPosAccuracy();
        System.err.println("Approximate negative label accuracy: " + NF.format(approxLabelAccuracy[0]));
        System.err.println("Approximate positive label accuracy: " + NF.format(approxLabelAccuracy[1]));
        System.err.println("Combined approximate label accuracy: " + NF.format(this.approxNegPosCombinedAccuracy()));
        double[] approxRootLabelAccuracy = this.approxRootNegPosAccuracy();
        System.err.println("Approximate negative root label accuracy: " + NF.format(approxRootLabelAccuracy[0]));
        System.err.println("Approximate positive root label accuracy: " + NF.format(approxRootLabelAccuracy[1]));
        System.err.println("Combined approximate root label accuracy: " + NF.format(this.approxRootNegPosCombinedAccuracy()));
    }

    public static void main(String[] args) {
        String modelPath = args[0];
        String treePath = args[1];
        List<Tree> trees = SentimentUtils.readTreesWithGoldLabels(treePath);
        SentimentModel model = SentimentModel.loadSerialized(modelPath);
        Evaluate eval = new Evaluate(model);
        eval.eval(trees);
        eval.printSummary();
    }
}

