/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Pair;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;

public class UnknownGTTrainer {
    ClassicCounter<Pair<String, String>> wtCount = new ClassicCounter();
    ClassicCounter<String> tagCount = new ClassicCounter();
    ClassicCounter<String> r1 = new ClassicCounter();
    ClassicCounter<String> r0 = new ClassicCounter();
    Set<String> seenWords = new HashSet<String>();
    double tokens = 0.0;
    HashMap<String, Float> unknownGT = new HashMap();

    public void train(Collection<Tree> trees) {
        this.train(trees, 1.0);
    }

    public void train(Collection<Tree> trees, double weight) {
        for (Tree t : trees) {
            this.train(t, weight);
        }
    }

    public void train(Tree tree, double weight) {
        for (TaggedWord word : tree.taggedYield()) {
            this.train(word, weight);
        }
    }

    public void train(TaggedWord tw, double weight) {
        this.tokens += weight;
        String word = tw.word();
        String tag = tw.tag();
        Pair<String, String> wt = new Pair<String, String>(word, tag);
        this.wtCount.incrementCount(wt, weight);
        this.tagCount.incrementCount(tag, weight);
        this.seenWords.add(word);
    }

    public void finishTraining() {
        System.err.println("Total tokens: " + this.tokens);
        System.err.println("Total WordTag types: " + this.wtCount.keySet().size());
        System.err.println("Total tag types: " + this.tagCount.keySet().size());
        System.err.println("Total word types: " + this.seenWords.size());
        for (Pair<String, String> wt : this.wtCount.keySet()) {
            if (this.wtCount.getCount(wt) != 1.0) continue;
            this.r1.incrementCount(wt.second());
        }
        for (String tag : this.tagCount.keySet()) {
            for (String word : this.seenWords) {
                Pair<String, String> wt = new Pair<String, String>(word, tag);
                if (this.wtCount.keySet().contains(wt)) continue;
                this.r0.incrementCount(tag);
            }
        }
        for (String tag : this.tagCount.keySet()) {
            float logprob = (float)Math.log(this.r1.getCount(tag) / (this.tagCount.getCount(tag) * this.r0.getCount(tag)));
            this.unknownGT.put(tag, Float.valueOf(logprob));
        }
    }
}

