/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.sentiment.Evaluate;
import edu.stanford.nlp.sentiment.RNNOptions;
import edu.stanford.nlp.sentiment.SentimentCostAndGradient;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.sentiment.SentimentUtils;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class SentimentTraining {
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");

    public static void executeOneTrainingBatch(SentimentModel model, List<Tree> trainingBatch, double[] sumGradSquare) {
        SentimentCostAndGradient gcFunc = new SentimentCostAndGradient(model, trainingBatch);
        double[] theta = model.paramsToVector();
        double eps = 0.001;
        double currCost = 0.0;
        double[] gradf = gcFunc.derivativeAt(theta);
        currCost = gcFunc.valueAt(theta);
        System.err.println("batch cost: " + currCost);
        for (int feature = 0; feature < gradf.length; ++feature) {
            sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature] * gradf[feature];
            theta[feature] = theta[feature] - model.op.trainOptions.learningRate * gradf[feature] / (Math.sqrt(sumGradSquare[feature]) + eps);
        }
        model.vectorToParams(theta);
    }

    public static void train(SentimentModel model, String modelPath, List<Tree> trainingTrees, List<Tree> devTrees) {
        Timing timing = new Timing();
        long maxTrainTimeMillis = model.op.trainOptions.maxTrainTimeSeconds * 1000;
        long nextDebugCycle = model.op.trainOptions.debugOutputSeconds * 1000;
        int debugCycle = 0;
        double bestAccuracy = 0.0;
        double[] sumGradSquare = new double[model.totalParamSize()];
        Arrays.fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight);
        int numBatches = trainingTrees.size() / model.op.trainOptions.batchSize + 1;
        System.err.println("Training on " + trainingTrees.size() + " trees in " + numBatches + " batches");
        System.err.println("Times through each training batch: " + model.op.trainOptions.epochs);
        for (int epoch = 0; epoch < model.op.trainOptions.epochs; ++epoch) {
            System.err.println("======================================");
            System.err.println("Starting epoch " + epoch);
            if (epoch > 0 && model.op.trainOptions.adagradResetFrequency > 0 && epoch % model.op.trainOptions.adagradResetFrequency == 0) {
                System.err.println("Resetting adagrad weights to " + model.op.trainOptions.initialAdagradWeight);
                Arrays.fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight);
            }
            ArrayList<Tree> shuffledSentences = Generics.newArrayList(trainingTrees);
            Collections.shuffle(shuffledSentences, model.rand);
            for (int batch = 0; batch < numBatches; ++batch) {
                System.err.println("======================================");
                System.err.println("Epoch " + epoch + " batch " + batch);
                int startTree = batch * model.op.trainOptions.batchSize;
                int endTree = (batch + 1) * model.op.trainOptions.batchSize;
                if (endTree + model.op.trainOptions.batchSize > shuffledSentences.size()) {
                    endTree = shuffledSentences.size();
                }
                SentimentTraining.executeOneTrainingBatch(model, shuffledSentences.subList(startTree, endTree), sumGradSquare);
                long totalElapsed = timing.report();
                System.err.println("Finished epoch " + epoch + " batch " + batch + "; total training time " + totalElapsed + " ms");
                if (maxTrainTimeMillis > 0L && totalElapsed > maxTrainTimeMillis) break;
                if (nextDebugCycle <= 0L || totalElapsed <= nextDebugCycle) continue;
                Evaluate eval = new Evaluate(model);
                eval.eval(devTrees);
                eval.printSummary();
                double score = eval.exactNodeAccuracy() * 100.0;
                if (modelPath != null) {
                    String tempPath = modelPath;
                    tempPath = modelPath.endsWith(".ser.gz") ? modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score) + ".ser.gz" : (modelPath.endsWith(".gz") ? modelPath.substring(0, modelPath.length() - 3) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score) + ".gz" : modelPath.substring(0, modelPath.length() - 3) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score));
                    model.saveSerialized(tempPath);
                }
                ++debugCycle;
                nextDebugCycle = timing.report() + (long)(model.op.trainOptions.debugOutputSeconds * 1000);
            }
            long totalElapsed = timing.report();
            if (maxTrainTimeMillis <= 0L || totalElapsed <= maxTrainTimeMillis) continue;
            System.err.println("Max training time exceeded, exiting");
            break;
        }
    }

    public static boolean runGradientCheck(SentimentModel model, List<Tree> trees) {
        SentimentCostAndGradient gcFunc = new SentimentCostAndGradient(model, trees);
        return gcFunc.gradientCheck(model.totalParamSize(), 50, model.paramsToVector());
    }

    public static void main(String[] args) {
        RNNOptions op = new RNNOptions();
        String trainPath = "sentimentTreesDebug.txt";
        String devPath = null;
        boolean runGradientCheck = false;
        boolean runTraining = false;
        String modelPath = null;
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-train")) {
                runTraining = true;
                ++argIndex;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-gradientcheck")) {
                runGradientCheck = true;
                ++argIndex;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-trainpath")) {
                trainPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-devpath")) {
                devPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-model")) {
                modelPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            int newArgIndex = op.setOption(args, argIndex);
            if (newArgIndex == argIndex) {
                throw new IllegalArgumentException("Unknown argument " + args[argIndex]);
            }
            argIndex = newArgIndex;
        }
        List<Tree> trainingTrees = SentimentUtils.readTreesWithGoldLabels(trainPath);
        List<Tree> devTrees = SentimentUtils.readTreesWithGoldLabels(devPath);
        System.err.println("Sentiment model options:\n" + op);
        SentimentModel model = new SentimentModel(op, trainingTrees);
        if (runGradientCheck) {
            SentimentTraining.runGradientCheck(model, trainingTrees);
        }
        if (runTraining) {
            SentimentTraining.train(model, modelPath, trainingTrees, devTrees);
            model.saveSerialized(modelPath);
        }
    }
}

