/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.parser.dvparser.CacheParseHypotheses;
import edu.stanford.nlp.parser.dvparser.DVModel;
import edu.stanford.nlp.parser.dvparser.DVModelReranker;
import edu.stanford.nlp.parser.dvparser.DVParserCostAndGradient;
import edu.stanford.nlp.parser.lexparser.ArgUtils;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.ParserQuery;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.Timing;
import java.io.FileFilter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Random;

public class DVParser {
    DVModel dvModel;
    LexicalizedParser parser;
    Options op;
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");
    static final int MINIMIZER = 3;
    private static final long serialVersionUID = 1L;

    public Options getOp() {
        return this.op;
    }

    DVModel getDVModel() {
        return this.dvModel;
    }

    public static List<Tree> getTopParsesForOneTree(LexicalizedParser parser, int dvKBest, Tree tree, TreeTransformer transformer) {
        ParserQuery pq = parser.parserQuery();
        List<Word> sentence = tree.yieldWords();
        if (sentence.size() <= 1) {
            return null;
        }
        if (!pq.parse(sentence = sentence.subList(0, sentence.size() - 1))) {
            System.err.println("Failed to use the given parser to reparse sentence \"" + sentence + "\"");
            return null;
        }
        ArrayList<Tree> parses = new ArrayList<Tree>();
        List<ScoredObject<Tree>> bestKParses = pq.getKBestPCFGParses(dvKBest);
        for (ScoredObject<Tree> so : bestKParses) {
            Tree result = so.object();
            if (transformer != null) {
                result = transformer.transformTree(result);
            }
            parses.add(result);
        }
        return parses;
    }

    static IdentityHashMap<Tree, List<Tree>> getTopParses(LexicalizedParser parser, Options op, Collection<Tree> trees, TreeTransformer transformer, boolean outputUpdates) {
        IdentityHashMap<Tree, List<Tree>> topParses = new IdentityHashMap<Tree, List<Tree>>();
        for (Tree tree : trees) {
            List<Tree> parses = DVParser.getTopParsesForOneTree(parser, op.trainOptions.dvKBest, tree, transformer);
            topParses.put(tree, parses);
            if (!outputUpdates || topParses.size() % 10 != 0) continue;
            System.err.println("Processed " + topParses.size() + " trees");
        }
        if (outputUpdates) {
            System.err.println("Finished processing " + topParses.size() + " trees");
        }
        return topParses;
    }

    IdentityHashMap<Tree, List<Tree>> getTopParses(List<Tree> trees, TreeTransformer transformer) {
        return DVParser.getTopParses(this.parser, this.op, trees, transformer, false);
    }

    public void train(List<Tree> sentences, IdentityHashMap<Tree, byte[]> compressedParses, Treebank testTreebank, String modelPath, String resultsRecordPath) throws IOException {
        Timing timing = new Timing();
        long maxTrainTimeMillis = this.op.trainOptions.maxTrainTimeSeconds * 1000;
        long nextDebugCycle = this.op.trainOptions.debugOutputSeconds * 1000;
        int debugCycle = 0;
        double bestLabelF1 = 0.0;
        if (this.op.trainOptions.useContextWords) {
            for (Tree tree : sentences) {
                Trees.convertToCoreLabels(tree);
                tree.setSpans();
            }
        }
        double[] sumGradSquare = new double[this.dvModel.totalParamSize()];
        Arrays.fill(sumGradSquare, 1.0);
        int numBatches = sentences.size() / this.op.trainOptions.dvBatchSize + 1;
        System.err.println("Training on " + sentences.size() + " trees in " + numBatches + " batches");
        System.err.println("Times through each training batch: " + this.op.trainOptions.dvIterations);
        System.err.println("QN iterations per batch: " + this.op.trainOptions.qnIterationsPerBatch);
        for (int iter = 0; iter < this.op.trainOptions.dvIterations; ++iter) {
            ArrayList<Tree> shuffledSentences = new ArrayList<Tree>(sentences);
            Collections.shuffle(shuffledSentences, this.dvModel.rand);
            for (int batch = 0; batch < numBatches; ++batch) {
                System.err.println("======================================");
                System.err.println("Iteration " + iter + " batch " + batch);
                int startTree = batch * this.op.trainOptions.dvBatchSize;
                int endTree = (batch + 1) * this.op.trainOptions.dvBatchSize;
                if (endTree + this.op.trainOptions.dvBatchSize > shuffledSentences.size()) {
                    endTree = shuffledSentences.size();
                }
                this.executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);
                long totalElapsed = timing.report();
                System.err.println("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");
                if (maxTrainTimeMillis > 0L && totalElapsed > maxTrainTimeMillis) break;
                if (nextDebugCycle <= 0L || totalElapsed <= nextDebugCycle) continue;
                double tagF1 = 0.0;
                double labelF1 = 0.0;
                if (testTreebank != null) {
                    EvaluateTreebank evaluator = new EvaluateTreebank(this.attachModelToLexicalizedParser());
                    evaluator.testOnTreebank(testTreebank);
                    labelF1 = evaluator.getLBScore();
                    tagF1 = evaluator.getTagScore();
                    if (labelF1 > bestLabelF1) {
                        bestLabelF1 = labelF1;
                    }
                    System.err.println("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
                }
                String tempName = null;
                if (modelPath != null) {
                    tempName = modelPath;
                    if (modelPath.endsWith(".ser.gz")) {
                        tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
                    }
                    this.saveModel(tempName);
                }
                String statusLine = "CHECKPOINT: iteration " + iter + " batch " + batch + " labelF1 " + NF.format(labelF1) + " tagF1 " + NF.format(tagF1) + " bestLabelF1 " + NF.format(bestLabelF1) + " model " + tempName + this.op.trainOptions + " word vectors: " + this.op.lexOptions.wordVectorFile + " numHid: " + this.op.lexOptions.numHid;
                System.err.println(statusLine);
                if (resultsRecordPath != null) {
                    FileWriter fout = new FileWriter(resultsRecordPath, true);
                    fout.write(statusLine);
                    fout.write("\n");
                    fout.close();
                }
                ++debugCycle;
                nextDebugCycle = timing.report() + (long)(this.op.trainOptions.debugOutputSeconds * 1000);
            }
            long totalElapsed = timing.report();
            if (maxTrainTimeMillis <= 0L || totalElapsed <= maxTrainTimeMillis) continue;
            System.err.println("Max training time exceeded, exiting");
            break;
        }
    }

    public void executeOneTrainingBatch(List<Tree> trainingBatch, IdentityHashMap<Tree, byte[]> compressedParses, double[] sumGradSquare) {
        Timing convertTiming = new Timing();
        convertTiming.doing("Converting trees");
        IdentityHashMap<Tree, List<Tree>> topParses = CacheParseHypotheses.convertToTrees(trainingBatch, compressedParses, this.op.trainOptions.trainingThreads);
        convertTiming.done();
        DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, this.dvModel, this.op);
        double[] theta = this.dvModel.paramsToVector();
        switch (3) {
            case 1: {
                QNMinimizer qn = new QNMinimizer(this.op.trainOptions.qnEstimates, true);
                qn.useMinPackSearch();
                qn.useDiagonalScaling();
                qn.terminateOnAverageImprovement(true);
                qn.terminateOnNumericalZero(true);
                qn.terminateOnRelativeNorm(true);
                theta = qn.minimize(gcFunc, this.op.trainOptions.qnTolerance, theta, this.op.trainOptions.qnIterationsPerBatch);
            }
            case 2: {
                double lastCost = 0.0;
                double currCost = 0.0;
                boolean firstTime = true;
                for (int i = 0; i < this.op.trainOptions.qnIterationsPerBatch; ++i) {
                    double[] grad = gcFunc.derivativeAt(theta);
                    currCost = gcFunc.valueAt(theta);
                    System.err.println("batch cost: " + currCost);
                    lastCost = currCost;
                    ArrayMath.addMultInPlace(theta, grad, -1.0 * this.op.trainOptions.learningRate);
                }
            }
            case 3: {
                double eps = 0.001;
                double currCost = 0.0;
                for (int i = 0; i < this.op.trainOptions.qnIterationsPerBatch; ++i) {
                    double[] gradf = gcFunc.derivativeAt(theta);
                    currCost = gcFunc.valueAt(theta);
                    System.err.println("batch cost: " + currCost);
                    for (int feature = 0; feature < gradf.length; ++feature) {
                        sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature] * gradf[feature];
                        theta[feature] = theta[feature] - this.op.trainOptions.learningRate * gradf[feature] / (Math.sqrt(sumGradSquare[feature]) + eps);
                    }
                }
                break;
            }
        }
        this.dvModel.vectorToParams(theta);
    }

    public DVParser(DVModel model, LexicalizedParser parser) {
        this.parser = parser;
        this.op = parser.getOp();
        this.dvModel = model;
    }

    public DVParser(LexicalizedParser parser) {
        this.parser = parser;
        this.op = parser.getOp();
        if (this.op.trainOptions.dvSeed == 0L) {
            this.op.trainOptions.dvSeed = new Random().nextLong();
            System.err.println("Random seed not set, using randomly chosen seed of " + this.op.trainOptions.dvSeed);
        } else {
            System.err.println("Random seed set to " + this.op.trainOptions.dvSeed);
        }
        System.err.println("Word vector file: " + this.op.lexOptions.wordVectorFile);
        System.err.println("Size of word vectors: " + this.op.lexOptions.numHid);
        System.err.println("Number of hypothesis trees to train against: " + this.op.trainOptions.dvKBest);
        System.err.println("Number of trees in one batch: " + this.op.trainOptions.dvBatchSize);
        System.err.println("Number of iterations of trees: " + this.op.trainOptions.dvIterations);
        System.err.println("Number of qn iterations per batch: " + this.op.trainOptions.qnIterationsPerBatch);
        System.err.println("Learning rate: " + this.op.trainOptions.learningRate);
        System.err.println("Delta margin: " + this.op.trainOptions.deltaMargin);
        System.err.println("regCost: " + this.op.trainOptions.regCost);
        System.err.println("Using unknown word vector for numbers: " + this.op.trainOptions.unknownNumberVector);
        System.err.println("Using unknown dashed word vector heuristics: " + this.op.trainOptions.unknownDashedWordVectors);
        System.err.println("Using unknown word vector for capitalized words: " + this.op.trainOptions.unknownCapsVector);
        System.err.println("Using unknown number vector for Chinese words: " + this.op.trainOptions.unknownChineseNumberVector);
        System.err.println("Using unknown year vector for Chinese words: " + this.op.trainOptions.unknownChineseYearVector);
        System.err.println("Using unknown percent vector for Chinese words: " + this.op.trainOptions.unknownChinesePercentVector);
        System.err.println("Initial matrices scaled by: " + this.op.trainOptions.scalingForInit);
        System.err.println("Training will use " + this.op.trainOptions.trainingThreads + " thread(s)");
        System.err.println("Context words are " + (this.op.trainOptions.useContextWords ? "on" : "off"));
        System.err.println("Model will " + (this.op.trainOptions.dvSimplifiedModel ? "" : "not ") + "be simplified");
        this.dvModel = new DVModel(this.op, parser.stateIndex, parser.ug, parser.bg);
        if (this.dvModel.unaryTransform.size() != this.dvModel.unaryScore.size()) {
            throw new AssertionError((Object)"Unary transform and score size not the same");
        }
        if (this.dvModel.binaryTransform.size() != this.dvModel.binaryScore.size()) {
            throw new AssertionError((Object)"Binary transform and score size not the same");
        }
    }

    public boolean runGradientCheck(List<Tree> sentences, IdentityHashMap<Tree, byte[]> compressedParses) {
        System.err.println("Gradient check: converting " + sentences.size() + " compressed trees");
        IdentityHashMap<Tree, List<Tree>> topParses = CacheParseHypotheses.convertToTrees(sentences, compressedParses, this.op.trainOptions.trainingThreads);
        System.err.println("Done converting trees");
        DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(sentences, topParses, this.dvModel, this.op);
        return gcFunc.gradientCheck(1000, 50, this.dvModel.paramsToVector());
    }

    public static TreeTransformer buildTrainTransformer(Options op) {
        CompositeTreeTransformer transformer = LexicalizedParser.buildTrainTransformer(op);
        return transformer;
    }

    public LexicalizedParser attachModelToLexicalizedParser() {
        LexicalizedParser newParser = LexicalizedParser.copyLexicalizedParser(this.parser);
        DVModelReranker reranker = new DVModelReranker(this.dvModel);
        newParser.reranker = reranker;
        return newParser;
    }

    public void saveModel(String filename) {
        System.err.println("Saving serialized model to " + filename);
        LexicalizedParser newParser = this.attachModelToLexicalizedParser();
        newParser.saveParserToSerialized(filename);
        System.err.println("... done");
    }

    public static DVParser loadModel(String filename, String[] args) {
        DVParser dvparser;
        System.err.println("Loading serialized model from " + filename);
        try {
            dvparser = (DVParser)IOUtils.readObjectFromFile(filename);
            dvparser.op.setOptions(args);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeIOException(e);
        }
        System.err.println("... done");
        return dvparser;
    }

    public static DVModel getModelFromLexicalizedParser(LexicalizedParser parser) {
        if (!(parser.reranker instanceof DVModelReranker)) {
            throw new IllegalArgumentException("This parser does not contain a DVModel reranker");
        }
        DVModelReranker reranker = (DVModelReranker)parser.reranker;
        return reranker.getModel();
    }

    public static void help() {
        System.err.println("Options supplied by this file:");
        System.err.println("  -model <name>: When training, the name of the model to save.  Otherwise, the name of the model to load.");
        System.err.println("  -parser <name>: When training, the LexicalizedParser to use as the base model.");
        System.err.println("  -cachedTrees <name>: The name of the file containing a treebank with cached parses.  See CacheParseHypotheses.java");
        System.err.println("  -treebank <name> [filter]: A treebank to use instead of cachedTrees.  Trees will be reparsed.  Slow.");
        System.err.println("  -testTreebank <name> [filter]: A treebank for testing the model.");
        System.err.println("  -train: Run training over the treebank, testing on the testTreebank.");
        System.err.println("  -continueTraining <name>: The name of a file to continue training.");
        System.err.println("  -nofilter: Rules for the parser will not be filtered based on the training treebank.");
        System.err.println("  -runGradientCheck: Run a gradient check.");
        System.err.println("  -resultsRecord: A file for recording info on intermediate results");
        System.err.println();
        System.err.println("Options overlapping the parser:");
        System.err.println("  -trainingThreads <int>: How many threads to use when training.");
        System.err.println("  -dvKBest <int>: How many hypotheses to use from the underlying parser.");
        System.err.println("  -dvIterations <int>: When training, how many times to go through the train set.");
        System.err.println("  -regCost <double>: How large of a cost to put on regularization.");
        System.err.println("  -dvBatchSize <int>: How many trees to use in each batch of the training.");
        System.err.println("  -qnIterationsPerBatch <int>: How many steps to take per batch.");
        System.err.println("  -qnEstimates <int>: Parameter for qn optimization.");
        System.err.println("  -qnTolerance <double>: Tolerance for early exit when optimizing a batch.");
        System.err.println("  -debugOutputSeconds <int>: How frequently to score a model when training and write out intermediate models.");
        System.err.println("  -maxTrainTimeSeconds <int>: How long to train before terminating.");
        System.err.println("  -dvSeed <long>: A starting point for the random number generator.  Setting this should lead to repeatable results, even taking into account randomness.  Otherwise, a new random seed will be picked.");
        System.err.println("  -wordVectorFile <name>: A filename to load word vectors from.");
        System.err.println("  -numHid: The size of the matrices.  In most circumstances, should be set to the size of the word vectors.");
        System.err.println("  -learningRate: The rate of optimization when training");
        System.err.println("  -deltaMargin: How much we punish trees for being incorrect when training");
        System.err.println("  -(no)unknownNumberVector: Whether or not to use a word vector for unknown numbers");
        System.err.println("  -(no)unknownDashedWordVectors: Whether or not to split unknown dashed words");
        System.err.println("  -(no)unknownCapsVector: Whether or not to use a word vector for unknown words with capitals");
        System.err.println("  -dvSimplifiedModel: Use a greatly dumbed down DVModel");
        System.err.println("  -scalingForInit: How much to scale matrices when creating a new DVModel");
        System.err.println("  -lpWeight: A weight to give the original LexicalizedParser when testing (0.2 seems to work well)");
        System.err.println("  -unkWord: The vector representing unknown word in the word vectors file");
        System.err.println("  -transformMatrixType: A couple different methods for initializing transform matrices");
    }

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        DVModel model;
        if (args.length == 0) {
            DVParser.help();
            System.exit(2);
        }
        System.err.println("Running DVParser with arguments:");
        for (int i = 0; i < args.length; ++i) {
            System.err.print("  " + args[i]);
        }
        System.err.println();
        String parserPath = null;
        String trainTreebankPath = null;
        FileFilter trainTreebankFilter = null;
        String cachedTrainTreesPath = null;
        boolean runGradientCheck = false;
        boolean runTraining = false;
        String testTreebankPath = null;
        FileFilter testTreebankFilter = null;
        String initialModelPath = null;
        String modelPath = null;
        boolean filter = true;
        String resultsRecordPath = null;
        ArrayList<String> unusedArgs = new ArrayList<String>();
        ArrayList<String> argsWithDefaults = new ArrayList<String>(Arrays.asList("-wordVectorFile", "/scr/nlp/deeplearning/datasets/turian/embeddings-scaled.EMBEDDING_SIZE=25.txt", "-dvKBest", Integer.toString(100), "-dvBatchSize", Integer.toString(500), "-dvIterations", Integer.toString(20), "-qnIterationsPerBatch", Integer.toString(1), "-regCost", Double.toString(1.0E-4), "-learningRate", Double.toString(0.1), "-deltaMargin", Double.toString(0.1), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "UNK", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(0.5)));
        argsWithDefaults.addAll(Arrays.asList(args));
        args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]);
        int argIndex = 0;
        while (argIndex < args.length) {
            Pair<String, FileFilter> treebankDescription;
            if (args[argIndex].equalsIgnoreCase("-parser")) {
                parserPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
                treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                testTreebankPath = treebankDescription.first();
                testTreebankFilter = treebankDescription.second();
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-treebank")) {
                treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                trainTreebankPath = treebankDescription.first();
                trainTreebankFilter = treebankDescription.second();
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-cachedTrees")) {
                cachedTrainTreesPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) {
                runGradientCheck = true;
                ++argIndex;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-train")) {
                runTraining = true;
                ++argIndex;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-model")) {
                modelPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-nofilter")) {
                filter = false;
                ++argIndex;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
                runTraining = true;
                filter = false;
                initialModelPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-resultsRecord")) {
                resultsRecordPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            unusedArgs.add(args[argIndex++]);
        }
        if (parserPath == null && modelPath == null) {
            throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
        }
        if (!runTraining && modelPath == null && !runGradientCheck) {
            throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
        }
        String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
        DVParser dvparser = null;
        LexicalizedParser lexparser = null;
        if (initialModelPath != null) {
            lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
            model = DVParser.getModelFromLexicalizedParser(lexparser);
            dvparser = new DVParser(model, lexparser);
        } else if (runTraining || runGradientCheck) {
            lexparser = LexicalizedParser.loadModel(parserPath, newArgs);
            dvparser = new DVParser(lexparser);
        } else if (modelPath != null) {
            lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
            model = DVParser.getModelFromLexicalizedParser(lexparser);
            dvparser = new DVParser(model, lexparser);
        }
        ArrayList<Tree> trainSentences = new ArrayList<Tree>();
        IdentityHashMap<Tree, byte[]> trainCompressedParses = Generics.newIdentityHashMap();
        if (cachedTrainTreesPath != null) {
            for (String path : cachedTrainTreesPath.split(",")) {
                List cache = (List)IOUtils.readObjectFromFile(path);
                for (Pair pair : cache) {
                    trainSentences.add((Tree)pair.first());
                    trainCompressedParses.put((Tree)pair.first(), (byte[])pair.second());
                }
                System.err.println("Read in " + cache.size() + " trees from " + path);
            }
        }
        if (trainTreebankPath != null) {
            TreeTransformer transformer = DVParser.buildTrainTransformer(dvparser.getOp());
            Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();
            if (trainTreebankPath == null) {
                trainTreebankPath = System.getProperty("os.name").startsWith("Windows") ? "D:\\projects\\deepSyn\\data\\trees\\wsj\\00\\wsj_0001.mrg" : "/afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/00/wsj_0001.mrg";
            }
            treebank.loadPath(trainTreebankPath, trainTreebankFilter);
            treebank = treebank.transform(transformer);
            System.err.println("Read in " + treebank.size() + " trees from " + trainTreebankPath);
            for (Tree tree : treebank) {
                trainSentences.add(tree);
            }
            IdentityHashMap<Tree, List<Tree>> trainParses = dvparser.getTopParses(trainSentences, transformer);
            CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
            trainCompressedParses.putAll(cacher.convertToBytes(trainParses));
            System.err.println("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
        }
        if ((runTraining || runGradientCheck) && filter) {
            System.err.println("Filtering rules for the given training set");
            dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses);
            System.err.println("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors");
        }
        MemoryTreebank testTreebank = null;
        if (testTreebankPath != null) {
            System.err.println("Reading in trees from " + testTreebankPath);
            if (testTreebankFilter != null) {
                System.err.println("Filtering on " + testTreebankFilter);
            }
            testTreebank = dvparser.getOp().tlpParams.memoryTreebank();
            testTreebank.loadPath(testTreebankPath, testTreebankFilter);
            System.err.println("Read in " + ((Treebank)testTreebank).size() + " trees for testing");
        }
        if (runGradientCheck) {
            System.err.println("Running gradient check on " + trainSentences.size() + " trees");
            dvparser.runGradientCheck(trainSentences, trainCompressedParses);
        }
        if (runTraining) {
            System.err.println("Training the RNN parser");
            dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
            if (modelPath != null) {
                dvparser.saveModel(modelPath);
            }
        }
        if (testTreebankPath != null) {
            EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser());
            evaluator.testOnTreebank(testTreebank);
        }
        System.err.println("Successfully ran DVParser");
    }
}

