/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.nndep;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.nndep.ArcStandard;
import edu.stanford.nlp.parser.nndep.Classifier;
import edu.stanford.nlp.parser.nndep.Config;
import edu.stanford.nlp.parser.nndep.Configuration;
import edu.stanford.nlp.parser.nndep.Dataset;
import edu.stanford.nlp.parser.nndep.DependencyTree;
import edu.stanford.nlp.parser.nndep.ParsingSystem;
import edu.stanford.nlp.parser.nndep.Util;
import edu.stanford.nlp.process.DocumentPreprocessor;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.tagger.maxent.MaxentTagger;
import edu.stanford.nlp.trees.EnglishGrammaticalStructure;
import edu.stanford.nlp.trees.GrammaticalRelation;
import edu.stanford.nlp.trees.GrammaticalStructure;
import edu.stanford.nlp.trees.TreeGraphNode;
import edu.stanford.nlp.trees.TypedDependency;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.stream.Collectors;

public class DependencyParser {
    public static final String DEFAULT_MODEL = "edu/stanford/nlp/models/parser/nndep/PTB_Stanford_params.txt.gz";
    private List<String> knownWords;
    private List<String> knownPos;
    private List<String> knownLabels;
    private Map<String, Integer> wordIDs;
    private Map<String, Integer> posIDs;
    private Map<String, Integer> labelIDs;
    private List<Integer> preComputed;
    private Classifier classifier;
    private ParsingSystem system;
    private Map<String, Integer> embedID;
    private double[][] embeddings;
    private final Config config;
    private final GrammaticalRelation.Language language;
    private static final int POS_OFFSET = 18;
    private static final int DEP_OFFSET = 36;
    private static final int STACK_OFFSET = 6;
    private static final int STACK_NUMBER = 6;
    private static final Map<String, Integer> numArgs = new HashMap<String, Integer>();

    DependencyParser() {
        this(new Properties());
    }

    public DependencyParser(Properties properties) {
        this.config = new Config(properties);
        switch (this.config.language) {
            case English: {
                this.language = GrammaticalRelation.Language.English;
                break;
            }
            case Chinese: {
                this.language = GrammaticalRelation.Language.Chinese;
                break;
            }
            default: {
                this.language = GrammaticalRelation.Language.Any;
            }
        }
    }

    public int getWordID(String s) {
        return this.wordIDs.containsKey(s) ? this.wordIDs.get(s).intValue() : this.wordIDs.get("-UNKNOWN-").intValue();
    }

    public int getPosID(String s) {
        return this.posIDs.containsKey(s) ? this.posIDs.get(s).intValue() : this.posIDs.get("-UNKNOWN-").intValue();
    }

    public int getLabelID(String s) {
        return this.labelIDs.get(s);
    }

    public List<Integer> getFeatures(Configuration c) {
        int index;
        int j;
        ArrayList<Integer> fWord = new ArrayList<Integer>(18);
        ArrayList<Integer> fPos = new ArrayList<Integer>(18);
        ArrayList<Integer> fLabel = new ArrayList<Integer>(12);
        for (j = 2; j >= 0; --j) {
            index = c.getStack(j);
            fWord.add(this.getWordID(c.getWord(index)));
            fPos.add(this.getPosID(c.getPOS(index)));
        }
        for (j = 0; j <= 2; ++j) {
            index = c.getBuffer(j);
            fWord.add(this.getWordID(c.getWord(index)));
            fPos.add(this.getPosID(c.getPOS(index)));
        }
        for (j = 0; j <= 1; ++j) {
            int k = c.getStack(j);
            int index2 = c.getLeftChild(k);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getRightChild(k);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getLeftChild(k, 2);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getRightChild(k, 2);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getLeftChild(c.getLeftChild(k));
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getRightChild(c.getRightChild(k));
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
        }
        ArrayList<Integer> feature = new ArrayList<Integer>(48);
        feature.addAll(fWord);
        feature.addAll(fPos);
        feature.addAll(fLabel);
        return feature;
    }

    private int[] getFeatureArray(Configuration c) {
        int index;
        int j;
        int[] feature = new int[48];
        for (j = 2; j >= 0; --j) {
            index = c.getStack(j);
            feature[2 - j] = this.getWordID(c.getWord(index));
            feature[18 + (2 - j)] = this.getPosID(c.getPOS(index));
        }
        for (j = 0; j <= 2; ++j) {
            index = c.getBuffer(j);
            feature[3 + j] = this.getWordID(c.getWord(index));
            feature[21 + j] = this.getPosID(c.getPOS(index));
        }
        for (j = 0; j <= 1; ++j) {
            int k = c.getStack(j);
            int index2 = c.getLeftChild(k);
            feature[6 + j * 6] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6] = this.getLabelID(c.getLabel(index2));
            index2 = c.getRightChild(k);
            feature[6 + j * 6 + 1] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 1] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 1] = this.getLabelID(c.getLabel(index2));
            index2 = c.getLeftChild(k, 2);
            feature[6 + j * 6 + 2] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 2] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 2] = this.getLabelID(c.getLabel(index2));
            index2 = c.getRightChild(k, 2);
            feature[6 + j * 6 + 3] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 3] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 3] = this.getLabelID(c.getLabel(index2));
            index2 = c.getLeftChild(c.getLeftChild(k));
            feature[6 + j * 6 + 4] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 4] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 4] = this.getLabelID(c.getLabel(index2));
            index2 = c.getRightChild(c.getRightChild(k));
            feature[6 + j * 6 + 5] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 5] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 5] = this.getLabelID(c.getLabel(index2));
        }
        return feature;
    }

    public Dataset genTrainExamples(List<CoreMap> sents, List<DependencyTree> trees) {
        Dataset ret = new Dataset(48, this.system.transitions.size());
        IntCounter<Integer> tokPosCount = new IntCounter<Integer>();
        System.err.println("###################");
        System.err.println("Generate training examples...");
        for (int i = 0; i < sents.size(); ++i) {
            if (i > 0) {
                if (i % 1000 == 0) {
                    System.err.print(i + " ");
                }
                if (i % 10000 == 0 || i == sents.size() - 1) {
                    System.err.println();
                }
            }
            if (!trees.get(i).isProjective()) continue;
            Configuration c = this.system.initialConfiguration(sents.get(i));
            while (!this.system.isTerminal(c)) {
                int j;
                String oracle = this.system.getOracle(c, trees.get(i));
                List<Integer> feature = this.getFeatures(c);
                ArrayList<Integer> label = new ArrayList<Integer>();
                for (j = 0; j < this.system.transitions.size(); ++j) {
                    String str = this.system.transitions.get(j);
                    if (str.equals(oracle)) {
                        label.add(1);
                        continue;
                    }
                    if (this.system.canApply(c, str)) {
                        label.add(0);
                        continue;
                    }
                    label.add(-1);
                }
                ret.addExample(feature, label);
                for (j = 0; j < feature.size(); ++j) {
                    tokPosCount.incrementCount(feature.get(j) * feature.size() + j);
                }
                this.system.apply(c, oracle);
            }
        }
        System.err.println("#Train Examples: " + ret.n);
        this.preComputed = new ArrayList<Integer>(this.config.numPreComputed);
        List sortedTokens = Counters.toSortedList(tokPosCount, false);
        this.preComputed = new ArrayList<Integer>(sortedTokens.subList(0, Math.min(this.config.numPreComputed, sortedTokens.size())));
        return ret;
    }

    private void generateIDs() {
        this.wordIDs = new HashMap<String, Integer>();
        this.posIDs = new HashMap<String, Integer>();
        this.labelIDs = new HashMap<String, Integer>();
        int index = 0;
        for (String word : this.knownWords) {
            this.wordIDs.put(word, index++);
        }
        for (String pos : this.knownPos) {
            this.posIDs.put(pos, index++);
        }
        for (String label : this.knownLabels) {
            this.labelIDs.put(label, index++);
        }
    }

    private void genDictionaries(List<CoreMap> sents, List<DependencyTree> trees) {
        ArrayList<String> word = new ArrayList<String>();
        ArrayList<String> pos = new ArrayList<String>();
        ArrayList<String> label = new ArrayList<String>();
        for (CoreMap sentence : sents) {
            List tokens = (List)sentence.get(CoreAnnotations.TokensAnnotation.class);
            for (CoreLabel token : tokens) {
                word.add(token.word());
                pos.add(token.tag());
            }
        }
        String rootLabel = null;
        for (DependencyTree tree : trees) {
            for (int k = 1; k <= tree.n; ++k) {
                if (tree.getHead(k) == 0) {
                    rootLabel = tree.getLabel(k);
                    continue;
                }
                label.add(tree.getLabel(k));
            }
        }
        this.knownWords = Util.generateDict(word, this.config.wordCutOff);
        this.knownPos = Util.generateDict(pos);
        this.knownLabels = Util.generateDict(label);
        this.knownLabels.add(0, rootLabel);
        this.knownWords.add(0, "-UNKNOWN-");
        this.knownWords.add(1, "-NULL-");
        this.knownWords.add(2, "-ROOT-");
        this.knownPos.add(0, "-UNKNOWN-");
        this.knownPos.add(1, "-NULL-");
        this.knownPos.add(2, "-ROOT-");
        this.knownLabels.add(0, "-NULL-");
        this.generateIDs();
        System.out.println("###################");
        System.out.println("#Word: " + this.knownWords.size());
        System.out.println("#POS:" + this.knownPos.size());
        System.out.println("#Label: " + this.knownLabels.size());
    }

    public void writeModelFile(String modelFile) {
        try {
            int i;
            int j;
            int k;
            double[][] W1 = this.classifier.getW1();
            double[] b1 = this.classifier.getb1();
            double[][] W2 = this.classifier.getW2();
            double[][] E = this.classifier.getE();
            PrintWriter output = IOUtils.getPrintWriter(modelFile);
            ((Writer)output).write("dict=" + this.knownWords.size() + "\n");
            ((Writer)output).write("pos=" + this.knownPos.size() + "\n");
            ((Writer)output).write("label=" + this.knownLabels.size() + "\n");
            ((Writer)output).write("embeddingSize=" + E[0].length + "\n");
            ((Writer)output).write("hiddenSize=" + b1.length + "\n");
            ((Writer)output).write("numTokens=" + W1[0].length / E[0].length + "\n");
            ((Writer)output).write("preComputed=" + this.preComputed.size() + "\n");
            int index = 0;
            for (String word : this.knownWords) {
                ((Writer)output).write(word);
                for (k = 0; k < E[index].length; ++k) {
                    ((Writer)output).write(" " + E[index][k]);
                }
                ((Writer)output).write("\n");
                ++index;
            }
            for (String pos : this.knownPos) {
                ((Writer)output).write(pos);
                for (k = 0; k < E[index].length; ++k) {
                    ((Writer)output).write(" " + E[index][k]);
                }
                ((Writer)output).write("\n");
                ++index;
            }
            for (String label : this.knownLabels) {
                ((Writer)output).write(label);
                for (k = 0; k < E[index].length; ++k) {
                    ((Writer)output).write(" " + E[index][k]);
                }
                ((Writer)output).write("\n");
                ++index;
            }
            for (j = 0; j < W1[0].length; ++j) {
                for (int i2 = 0; i2 < W1.length; ++i2) {
                    ((Writer)output).write("" + W1[i2][j]);
                    if (i2 == W1.length - 1) {
                        ((Writer)output).write("\n");
                        continue;
                    }
                    ((Writer)output).write(" ");
                }
            }
            for (i = 0; i < b1.length; ++i) {
                ((Writer)output).write("" + b1[i]);
                if (i == b1.length - 1) {
                    ((Writer)output).write("\n");
                    continue;
                }
                ((Writer)output).write(" ");
            }
            for (j = 0; j < W2[0].length; ++j) {
                for (int i3 = 0; i3 < W2.length; ++i3) {
                    ((Writer)output).write("" + W2[i3][j]);
                    if (i3 == W2.length - 1) {
                        ((Writer)output).write("\n");
                        continue;
                    }
                    ((Writer)output).write(" ");
                }
            }
            for (i = 0; i < this.preComputed.size(); ++i) {
                ((Writer)output).write("" + this.preComputed.get(i));
                if ((i + 1) % 100 == 0 || i == this.preComputed.size() - 1) {
                    ((Writer)output).write("\n");
                    continue;
                }
                ((Writer)output).write(" ");
            }
            ((Writer)output).close();
        }
        catch (IOException e) {
            System.out.println(e);
        }
    }

    public static DependencyParser loadFromModelFile(String modelFile) {
        return DependencyParser.loadFromModelFile(modelFile, null);
    }

    public static DependencyParser loadFromModelFile(String modelFile, Properties extraProperties) {
        DependencyParser parser = extraProperties == null ? new DependencyParser() : new DependencyParser(extraProperties);
        parser.loadModelFile(modelFile, false);
        return parser;
    }

    public void loadModelFile(String modelFile) {
        this.loadModelFile(modelFile, true);
    }

    private void loadModelFile(String modelFile, boolean verbose) {
        Timing t = new Timing();
        try {
            int i;
            int i2;
            int k;
            String s;
            System.err.println("Loading depparse model file: " + modelFile + " ... ");
            BufferedReader input = IOUtils.readerFromString(modelFile);
            int nPreComputed = 0;
            int nTokens = 0;
            int hSize = 0;
            int eSize = 0;
            int nLabel = 0;
            int nPOS = 0;
            int nDict = 0;
            block11: for (int k2 = 0; k2 < 7; ++k2) {
                s = input.readLine();
                if (verbose) {
                    System.err.println(s);
                }
                int number = Integer.parseInt(s.substring(s.indexOf(61) + 1));
                switch (k2) {
                    case 0: {
                        nDict = number;
                        continue block11;
                    }
                    case 1: {
                        nPOS = number;
                        continue block11;
                    }
                    case 2: {
                        nLabel = number;
                        continue block11;
                    }
                    case 3: {
                        eSize = number;
                        continue block11;
                    }
                    case 4: {
                        hSize = number;
                        continue block11;
                    }
                    case 5: {
                        nTokens = number;
                        continue block11;
                    }
                    case 6: {
                        nPreComputed = number;
                        continue block11;
                    }
                }
            }
            this.knownWords = new ArrayList<String>();
            this.knownPos = new ArrayList<String>();
            this.knownLabels = new ArrayList<String>();
            double[][] E = new double[nDict + nPOS + nLabel][eSize];
            int index = 0;
            for (k = 0; k < nDict; ++k) {
                s = input.readLine();
                String[] splits = s.split(" ");
                this.knownWords.add(splits[0]);
                for (i2 = 0; i2 < eSize; ++i2) {
                    E[index][i2] = Double.parseDouble(splits[i2 + 1]);
                }
                ++index;
            }
            for (k = 0; k < nPOS; ++k) {
                s = input.readLine();
                String[] splits = s.split(" ");
                this.knownPos.add(splits[0]);
                for (i2 = 0; i2 < eSize; ++i2) {
                    E[index][i2] = Double.parseDouble(splits[i2 + 1]);
                }
                ++index;
            }
            for (k = 0; k < nLabel; ++k) {
                s = input.readLine();
                String[] splits = s.split(" ");
                this.knownLabels.add(splits[0]);
                for (i2 = 0; i2 < eSize; ++i2) {
                    E[index][i2] = Double.parseDouble(splits[i2 + 1]);
                }
                ++index;
            }
            this.generateIDs();
            double[][] W1 = new double[hSize][eSize * nTokens];
            for (int j = 0; j < W1[0].length; ++j) {
                s = input.readLine();
                String[] splits = s.split(" ");
                for (i = 0; i < W1.length; ++i) {
                    W1[i][j] = Double.parseDouble(splits[i]);
                }
            }
            double[] b1 = new double[hSize];
            s = input.readLine();
            String[] splits = s.split(" ");
            for (i = 0; i < b1.length; ++i) {
                b1[i] = Double.parseDouble(splits[i]);
            }
            double[][] W2 = new double[nLabel * 2 - 1][hSize];
            for (int j = 0; j < W2[0].length; ++j) {
                s = input.readLine();
                splits = s.split(" ");
                for (int i3 = 0; i3 < W2.length; ++i3) {
                    W2[i3][j] = Double.parseDouble(splits[i3]);
                }
            }
            this.preComputed = new ArrayList<Integer>();
            while (this.preComputed.size() < nPreComputed) {
                s = input.readLine();
                for (String split : splits = s.split(" ")) {
                    this.preComputed.add(Integer.parseInt(split));
                }
            }
            input.close();
            this.classifier = new Classifier(this.config, E, W1, b1, W2, this.preComputed);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
        this.initialize(verbose);
        t.done("Initializing dependency parser");
    }

    private void readEmbedFile(String embedFile) {
        this.embedID = new HashMap<String, Integer>();
        if (embedFile == null) {
            return;
        }
        BufferedReader input = null;
        try {
            String s;
            input = IOUtils.readerFromString(embedFile);
            ArrayList<String> lines = new ArrayList<String>();
            while ((s = input.readLine()) != null) {
                lines.add(s);
            }
            int nWords = lines.size();
            String[] splits = ((String)lines.get(0)).split("\\s+");
            int dim = splits.length - 1;
            this.embeddings = new double[nWords][dim];
            System.err.println("Embedding File " + embedFile + ": #Words = " + nWords + ", dim = " + dim);
            if (dim != this.config.embeddingSize) {
                System.err.println("ERROR: embedding dimension mismatch");
            }
            for (int i = 0; i < lines.size(); ++i) {
                splits = ((String)lines.get(i)).split("\\s+");
                this.embedID.put(splits[0], i);
                for (int j = 0; j < dim; ++j) {
                    this.embeddings[i][j] = Double.parseDouble(splits[j + 1]);
                }
            }
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
        finally {
            IOUtils.closeIgnoringExceptions(input);
        }
    }

    public void train(String trainFile, String devFile, String modelFile, String embedFile) {
        System.err.println("Train File: " + trainFile);
        System.err.println("Dev File: " + devFile);
        System.err.println("Model File: " + modelFile);
        System.err.println("Embedding File: " + embedFile);
        ArrayList<CoreMap> trainSents = new ArrayList<CoreMap>();
        ArrayList<DependencyTree> trainTrees = new ArrayList<DependencyTree>();
        Util.loadConllFile(trainFile, trainSents, trainTrees);
        Util.printTreeStats("Train", trainTrees);
        ArrayList<CoreMap> devSents = new ArrayList<CoreMap>();
        ArrayList<DependencyTree> devTrees = new ArrayList<DependencyTree>();
        if (devFile != null) {
            Util.loadConllFile(devFile, devSents, devTrees);
            Util.printTreeStats("Dev", devTrees);
        }
        this.genDictionaries(trainSents, trainTrees);
        ArrayList<String> lDict = new ArrayList<String>(this.knownLabels);
        lDict.remove(0);
        this.system = new ArcStandard(this.config.tlp, lDict, true);
        this.setupClassifierForTraining(trainSents, trainTrees, embedFile);
        System.err.println("###################");
        this.config.printParameters();
        long startTime = System.currentTimeMillis();
        double bestUAS = 0.0;
        for (int iter = 0; iter < this.config.maxIter; ++iter) {
            System.err.println("##### Iteration " + iter);
            Classifier.Cost cost = this.classifier.computeCostFunction(this.config.batchSize, this.config.regParameter, this.config.dropProb);
            System.err.println("Cost = " + cost.getCost() + ", Correct(%) = " + cost.getPercentCorrect());
            this.classifier.takeAdaGradientStep(cost, this.config.adaAlpha, this.config.adaEps);
            System.err.println("Elapsed Time: " + (double)(System.currentTimeMillis() - startTime) / 1000.0 + " (s)");
            if (devFile != null && iter % this.config.evalPerIter == 0) {
                this.classifier.preCompute();
                List<DependencyTree> predicted = devSents.stream().map(this::predictInner).collect(Collectors.toList());
                double uas = this.system.getUASScore(devSents, predicted, devTrees);
                System.err.println("UAS: " + uas);
                if (this.config.saveIntermediate && uas > bestUAS) {
                    System.err.printf("Exceeds best previous UAS of %f. Saving model file..%n", bestUAS);
                    bestUAS = uas;
                    this.writeModelFile(modelFile);
                }
            }
            if (this.config.clearGradientsPerIter <= 0 || iter % this.config.clearGradientsPerIter != 0) continue;
            System.err.println("Clearing gradient histories..");
            this.classifier.clearGradientHistories();
        }
        this.classifier.finalizeTraining();
        if (devFile != null) {
            List<DependencyTree> predicted = devSents.stream().map(this::predictInner).collect(Collectors.toList());
            double uas = this.system.getUASScore(devSents, predicted, devTrees);
            if (uas > bestUAS) {
                System.err.printf("Final model UAS: %f%n", uas);
                System.err.printf("Exceeds best previous UAS of %f. Saving model file..%n", bestUAS);
                this.writeModelFile(modelFile);
            }
        } else {
            this.writeModelFile(modelFile);
        }
    }

    public void train(String trainFile, String devFile, String modelFile) {
        this.train(trainFile, devFile, modelFile, null);
    }

    public void train(String trainFile, String modelFile) {
        this.train(trainFile, null, modelFile);
    }

    private void setupClassifierForTraining(List<CoreMap> trainSents, List<DependencyTree> trainTrees, String embedFile) {
        int j;
        int i;
        double[][] E = new double[this.knownWords.size() + this.knownPos.size() + this.knownLabels.size()][this.config.embeddingSize];
        double[][] W1 = new double[this.config.hiddenSize][this.config.embeddingSize * 48];
        double[] b1 = new double[this.config.hiddenSize];
        double[][] W2 = new double[this.knownLabels.size() * 2 - 1][this.config.hiddenSize];
        Random random = Util.getRandom();
        for (i = 0; i < W1.length; ++i) {
            for (j = 0; j < W1[i].length; ++j) {
                W1[i][j] = random.nextDouble() * 2.0 * this.config.initRange - this.config.initRange;
            }
        }
        for (i = 0; i < b1.length; ++i) {
            b1[i] = random.nextDouble() * 2.0 * this.config.initRange - this.config.initRange;
        }
        for (i = 0; i < W2.length; ++i) {
            for (j = 0; j < W2[i].length; ++j) {
                W2[i][j] = random.nextDouble() * 2.0 * this.config.initRange - this.config.initRange;
            }
        }
        this.readEmbedFile(embedFile);
        int foundEmbed = 0;
        for (int i2 = 0; i2 < E.length; ++i2) {
            int j2;
            int index = -1;
            if (i2 < this.knownWords.size()) {
                String str = this.knownWords.get(i2);
                if (this.embedID.containsKey(str)) {
                    index = this.embedID.get(str);
                } else if (this.embedID.containsKey(str.toLowerCase())) {
                    index = this.embedID.get(str.toLowerCase());
                }
            }
            if (index >= 0) {
                ++foundEmbed;
                for (j2 = 0; j2 < E[i2].length; ++j2) {
                    E[i2][j2] = this.embeddings[index][j2];
                }
                continue;
            }
            for (j2 = 0; j2 < E[i2].length; ++j2) {
                E[i2][j2] = random.nextDouble() * this.config.initRange * 2.0 - this.config.initRange;
            }
        }
        System.err.println("Found embeddings: " + foundEmbed + " / " + this.knownWords.size());
        Dataset trainSet = this.genTrainExamples(trainSents, trainTrees);
        this.classifier = new Classifier(this.config, trainSet, E, W1, b1, W2, this.preComputed);
    }

    private DependencyTree predictInner(CoreMap sentence) {
        int numTrans = this.system.transitions.size();
        Configuration c = this.system.initialConfiguration(sentence);
        while (!this.system.isTerminal(c)) {
            double[] scores = this.classifier.computeScores(this.getFeatureArray(c));
            double optScore = Double.NEGATIVE_INFINITY;
            String optTrans = null;
            for (int j = 0; j < numTrans; ++j) {
                if (!(scores[j] > optScore) || !this.system.canApply(c, this.system.transitions.get(j))) continue;
                optScore = scores[j];
                optTrans = this.system.transitions.get(j);
            }
            this.system.apply(c, optTrans);
        }
        return c.tree;
    }

    public GrammaticalStructure predict(CoreMap sentence) {
        if (this.system == null) {
            throw new IllegalStateException("Parser has not been  loaded and initialized; first load a model.");
        }
        DependencyTree result = this.predictInner(sentence);
        List tokens = (List)sentence.get(CoreAnnotations.TokensAnnotation.class);
        ArrayList<TypedDependency> dependencies = new ArrayList<TypedDependency>();
        IndexedWord root = new IndexedWord(new Word("ROOT"));
        root.set(CoreAnnotations.IndexAnnotation.class, 0);
        for (int i = 1; i <= result.n; ++i) {
            int head = result.getHead(i);
            String label = result.getLabel(i);
            IndexedWord thisWord = new IndexedWord((CoreLabel)tokens.get(i - 1));
            IndexedWord headWord = head == 0 ? root : new IndexedWord((CoreLabel)tokens.get(head - 1));
            GrammaticalRelation relation = head == 0 ? GrammaticalRelation.ROOT : new GrammaticalRelation(this.language, label, null, GrammaticalRelation.DEPENDENT);
            dependencies.add(new TypedDependency(relation, headWord, thisWord));
        }
        TreeGraphNode rootNode = new TreeGraphNode(root);
        return new EnglishGrammaticalStructure(dependencies, rootNode);
    }

    public GrammaticalStructure predict(List<? extends HasWord> sentence) {
        CoreLabel sentenceLabel = new CoreLabel();
        ArrayList<CoreLabel> tokens = new ArrayList<CoreLabel>();
        int i = 1;
        for (HasWord hasWord : sentence) {
            CoreLabel label;
            if (hasWord instanceof CoreLabel) {
                label = (CoreLabel)hasWord;
                if (label.tag() == null) {
                    throw new IllegalArgumentException("Parser requires words with part-of-speech tag annotations");
                }
            } else {
                label = new CoreLabel();
                label.setValue(hasWord.word());
                label.setWord(hasWord.word());
                if (!(hasWord instanceof HasTag)) {
                    throw new IllegalArgumentException("Parser requires words with part-of-speech tag annotations");
                }
                label.setTag(((HasTag)((Object)hasWord)).tag());
            }
            label.setIndex(i);
            ++i;
            tokens.add(label);
        }
        sentenceLabel.set(CoreAnnotations.TokensAnnotation.class, tokens);
        return this.predict(sentenceLabel);
    }

    public double testCoNLL(String testFile, String outFile) {
        System.err.println("Test File: " + testFile);
        Timing timer = new Timing();
        ArrayList<CoreMap> testSents = new ArrayList<CoreMap>();
        ArrayList<DependencyTree> testTrees = new ArrayList<DependencyTree>();
        Util.loadConllFile(testFile, testSents, testTrees);
        int numWords = 0;
        int numSentences = 0;
        for (CoreMap testSent : testSents) {
            ++numSentences;
            numWords += ((List)testSent.get(CoreAnnotations.TokensAnnotation.class)).size();
        }
        List<DependencyTree> predicted = testSents.stream().map(this::predictInner).collect(Collectors.toList());
        Map<String, Double> result = this.system.evaluate(testSents, predicted, testTrees);
        double lasNoPunc = result.get("LASwoPunc");
        System.err.printf("UAS = %.4f%n", result.get("UASwoPunc"));
        System.err.printf("LAS = %.4f%n", lasNoPunc);
        long millis = timer.stop();
        double wordspersec = (double)numWords / ((double)millis / 1000.0);
        double sentspersec = (double)numSentences / ((double)millis / 1000.0);
        System.err.printf("%s tagged %d words in %d sentences in %.1fs at %.1f w/s, %.1f sent/s.%n", StringUtils.getShortClassName(this), numWords, numSentences, (double)millis / 1000.0, wordspersec, sentspersec);
        if (outFile != null) {
            Util.writeConllFile(outFile, testSents, predicted);
        }
        return lasNoPunc;
    }

    private void parseTextFile(BufferedReader input, PrintWriter output) {
        DocumentPreprocessor preprocessor = new DocumentPreprocessor(input);
        preprocessor.setSentenceFinalPuncWords(this.config.tlp.sentenceFinalPunctuationWords());
        preprocessor.setEscaper(this.config.escaper);
        preprocessor.setSentenceDelimiter(this.config.sentenceDelimiter);
        preprocessor.setTokenizerFactory(this.config.tlp.getTokenizerFactory());
        Timing timer = new Timing();
        MaxentTagger tagger = new MaxentTagger(this.config.tagger);
        ArrayList<List<TaggedWord>> tagged = new ArrayList<List<TaggedWord>>();
        for (Object sentence : preprocessor) {
            tagged.add(tagger.tagSentence((List<? extends HasWord>)sentence));
        }
        System.err.printf("Tagging completed in %.2f sec.%n", (double)timer.stop() / 1000.0);
        timer.start();
        int numSentences = 0;
        for (List list : tagged) {
            GrammaticalStructure parse = this.predict(list);
            Collection<TypedDependency> deps = parse.typedDependencies();
            for (TypedDependency dep : deps) {
                output.println(dep);
            }
            output.println();
            ++numSentences;
        }
        long millis = timer.stop();
        double seconds = (double)millis / 1000.0;
        System.err.printf("Parsed %d sentences in %.2f seconds (%.2f sents/sec).%n", numSentences, seconds, (double)numSentences / seconds);
    }

    private void initialize(boolean verbose) {
        if (this.knownLabels == null) {
            throw new IllegalStateException("Model has not been loaded or trained");
        }
        ArrayList<String> lDict = new ArrayList<String>(this.knownLabels);
        lDict.remove(0);
        this.system = new ArcStandard(this.config.tlp, lDict, verbose);
        if (this.config.numPreComputed > 0) {
            this.classifier.preCompute();
        }
    }

    public static void main(String[] args) {
        Properties props = StringUtils.argsToProperties(args, numArgs);
        DependencyParser parser = new DependencyParser(props);
        if (props.containsKey("trainFile")) {
            parser.train(props.getProperty("trainFile"), props.getProperty("devFile"), props.getProperty("model"), props.getProperty("embedFile"));
        }
        boolean loaded = false;
        if (props.containsKey("testFile")) {
            parser.loadModelFile(props.getProperty("model"));
            loaded = true;
            parser.testCoNLL(props.getProperty("testFile"), props.getProperty("outFile"));
        }
        if (props.containsKey("textFile")) {
            PrintWriter output;
            BufferedReader input;
            if (!loaded) {
                parser.loadModelFile(props.getProperty("model"));
                loaded = true;
            }
            String encoding = parser.config.tlp.getEncoding();
            String inputFilename = props.getProperty("textFile");
            try {
                input = inputFilename.equals("-") ? IOUtils.readerFromStdin(encoding) : IOUtils.readerFromString(inputFilename, encoding);
            }
            catch (IOException e) {
                throw new RuntimeIOException("No input file provided (use -textFile)", e);
            }
            String outputFilename = props.getProperty("outFile");
            try {
                output = outputFilename == null || outputFilename.equals("-") ? IOUtils.encodedOutputStreamPrintWriter(System.out, encoding, true) : IOUtils.getPrintWriter(outputFilename, encoding);
            }
            catch (IOException e) {
                throw new RuntimeIOException("Error opening output file", e);
            }
            parser.parseTextFile(input, output);
        }
    }

    static {
        numArgs.put("textFile", 1);
        numArgs.put("outFile", 1);
    }
}

