/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.trees.BobChrisTreeNormalizer;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.LabeledScoredTreeFactory;
import edu.stanford.nlp.trees.PennTreeReader;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeVisitor;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Pair;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class SisterAnnotationStats
implements TreeVisitor {
    public static final boolean DO_TAGS = true;
    private Map nodeRules = new HashMap();
    private Map leftRules = new HashMap();
    private Map rightRules = new HashMap();
    public static final double[] CUTOFFS = new double[]{250.0, 500.0, 1000.0, 1500.0};
    public static final double SUPPCUTOFF = 100.0;

    @Override
    public void visitTree(Tree t) {
        this.recurse(t, null);
    }

    public void recurse(Tree t, Tree p) {
        if (!t.isLeaf()) {
            if (t.isPreTerminal()) {
                // empty if block
            }
        } else {
            return;
        }
        if (p != null && !t.label().value().equals("ROOT")) {
            this.sisterCounters(t, p);
        }
        Tree[] kids = t.children();
        for (int i = 0; i < kids.length; ++i) {
            this.recurse(kids[i], t);
        }
    }

    public static List<String> leftSisterLabels(Tree t, Tree p) {
        ArrayList<String> l = new ArrayList<String>();
        if (p == null) {
            return l;
        }
        Tree[] kids = p.children();
        for (int i = 0; i < kids.length && !kids[i].equals(t); ++i) {
            l.add(0, kids[i].label().value());
        }
        return l;
    }

    public static List<String> rightSisterLabels(Tree t, Tree p) {
        ArrayList<String> l = new ArrayList<String>();
        if (p == null) {
            return l;
        }
        Tree[] kids = p.children();
        for (int i = kids.length - 1; i >= 0 && !kids[i].equals(t); --i) {
            l.add(kids[i].label().value());
        }
        return l;
    }

    public static List<String> kidLabels(Tree t) {
        Tree[] kids = t.children();
        ArrayList<String> l = new ArrayList<String>(kids.length);
        for (int i = 0; i < kids.length; ++i) {
            l.add(kids[i].label().value());
        }
        return l;
    }

    protected void sisterCounters(Tree t, Tree p) {
        List<String> rewrite = SisterAnnotationStats.kidLabels(t);
        List<String> left = SisterAnnotationStats.leftSisterLabels(t, p);
        List<String> right = SisterAnnotationStats.rightSisterLabels(t, p);
        String label = t.label().value();
        if (!this.nodeRules.containsKey(label)) {
            this.nodeRules.put(label, new ClassicCounter());
        }
        if (!this.rightRules.containsKey(label)) {
            this.rightRules.put(label, new HashMap());
        }
        if (!this.leftRules.containsKey(label)) {
            this.leftRules.put(label, new HashMap());
        }
        ((ClassicCounter)this.nodeRules.get(label)).incrementCount(rewrite);
        this.sideCounters(label, rewrite, left, this.leftRules);
        this.sideCounters(label, rewrite, right, this.rightRules);
    }

    protected void sideCounters(String label, List rewrite, List sideSisters, Map sideRules) {
        for (String sis : sideSisters) {
            if (!((Map)sideRules.get(label)).containsKey(sis)) {
                ((Map)sideRules.get(label)).put(sis, new ClassicCounter());
            }
            ((ClassicCounter)((HashMap)sideRules.get(label)).get(sis)).incrementCount(rewrite);
        }
    }

    public void printStats() {
        int i;
        double psd;
        Pair p;
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(2);
        StringBuffer[] javaSB = new StringBuffer[CUTOFFS.length];
        for (int i2 = 0; i2 < CUTOFFS.length; ++i2) {
            javaSB[i2] = new StringBuffer("  private static String[] sisterSplit" + (i2 + 1) + " = new String[] {");
        }
        ArrayList<Pair<String, Double>> topScores = new ArrayList<Pair<String, Double>>();
        Iterator it = this.nodeRules.keySet().iterator();
        while (it.hasNext()) {
            String annotatedLabel;
            double kl;
            double support2;
            ClassicCounter cntr2;
            ArrayList<Pair<String, Double>> answers = new ArrayList<Pair<String, Double>>();
            String label = (String)it.next();
            ClassicCounter cntr = (ClassicCounter)this.nodeRules.get(label);
            double support = cntr.totalCount();
            System.out.println("Node " + label + " support is " + support);
            for (String sis : ((HashMap)this.leftRules.get(label)).keySet()) {
                cntr2 = (ClassicCounter)((HashMap)this.leftRules.get(label)).get(sis);
                support2 = cntr2.totalCount();
                kl = Counters.klDivergence(cntr2, cntr);
                annotatedLabel = label + "=l=" + sis;
                System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
                answers.add(new Pair<String, Double>(annotatedLabel, new Double(kl * support2)));
                topScores.add(new Pair<String, Double>(annotatedLabel, new Double(kl * support2)));
            }
            for (String sis : ((HashMap)this.rightRules.get(label)).keySet()) {
                cntr2 = (ClassicCounter)((HashMap)this.rightRules.get(label)).get(sis);
                support2 = cntr2.totalCount();
                kl = Counters.klDivergence(cntr2, cntr);
                annotatedLabel = label + "=r=" + sis;
                System.out.println("KL(" + annotatedLabel + "||" + label + ") = " + nf.format(kl) + "\t" + "support(" + sis + ") = " + support2);
                answers.add(new Pair<String, Double>(annotatedLabel, new Double(kl * support2)));
                topScores.add(new Pair<String, Double>(annotatedLabel, new Double(kl * support2)));
            }
            System.out.println("----");
            System.out.println("Sorted descending support * KL");
            Collections.sort(answers, (o1, o2) -> {
                Pair p1 = (Pair)o1;
                Pair p2 = (Pair)o2;
                Double p12 = (Double)p1.second();
                Double p22 = (Double)p2.second();
                return p22.compareTo(p12);
            });
            int size = answers.size();
            for (int i3 = 0; i3 < size; ++i3) {
                Pair p2 = (Pair)answers.get(i3);
                double psd2 = (Double)p2.second();
                System.out.println(p2.first() + ": " + nf.format(psd2));
                if (!(psd2 >= CUTOFFS[0])) continue;
                String annotatedLabel2 = (String)p2.first();
                for (int j = 0; j < CUTOFFS.length; ++j) {
                    if (!(psd2 >= CUTOFFS[j])) continue;
                }
            }
            System.out.println();
        }
        Collections.sort(topScores, (o1, o2) -> {
            Pair p1 = (Pair)o1;
            Pair p2 = (Pair)o2;
            Double p12 = (Double)p1.second();
            Double p22 = (Double)p2.second();
            return p22.compareTo(p12);
        });
        String outString = "All enriched categories, sorted by score\n";
        int size = topScores.size();
        for (int i4 = 0; i4 < size; ++i4) {
            p = (Pair)topScores.get(i4);
            psd = (Double)p.second();
            System.out.println(p.first() + ": " + nf.format(psd));
        }
        System.out.println();
        System.out.println("  // Automatically generated by SisterAnnotationStats -- preferably don't edit");
        int k = CUTOFFS.length - 1;
        for (int j = 0; j < topScores.size(); ++j) {
            p = (Pair)topScores.get(j);
            psd = (Double)p.second();
            if (psd < CUTOFFS[k]) {
                if (k == 0) break;
                --k;
                --j;
                continue;
            }
            javaSB[k].append("\"").append(p.first());
            javaSB[k].append("\",");
        }
        for (i = 0; i < CUTOFFS.length; ++i) {
            int len = javaSB[i].length();
            javaSB[i].replace(len - 2, len, "};");
            System.out.println(javaSB[i]);
        }
        System.out.print("  public static String[] sisterSplit = ");
        for (i = CUTOFFS.length; i > 0; --i) {
            if (i == 1) {
                System.out.print("sisterSplit1");
                continue;
            }
            System.out.print("selectiveSisterSplit" + i + " ? sisterSplit" + i + " : (");
        }
        for (i = CUTOFFS.length; i >= 0; --i) {
            System.out.print(")");
        }
        System.out.println(";");
    }

    public static void main(String[] args) {
        ClassicCounter<String> c = new ClassicCounter<String>();
        c.setCount("A", 0.0);
        c.setCount("B", 1.0);
        double d = Counters.klDivergence(c, c);
        System.out.println("KL Divergence: " + d);
        String encoding = "UTF-8";
        if (args.length > 1) {
            encoding = args[1];
        }
        if (args.length < 1) {
            System.out.println("Usage: ParentAnnotationStats treebankPath");
        } else {
            SisterAnnotationStats pas = new SisterAnnotationStats();
            DiskTreebank treebank = new DiskTreebank(in -> new PennTreeReader(in, new LabeledScoredTreeFactory(new StringLabelFactory()), new BobChrisTreeNormalizer()), encoding);
            treebank.loadPath(args[0]);
            ((Treebank)treebank).apply(pas);
            pas.printStats();
        }
    }
}

