/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.sentiment.RNNOptions;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;
import java.io.IOException;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class SentimentModel
implements Serializable {
    public TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform;
    public TwoDimensionalMap<String, String, SimpleTensor> binaryTensors;
    public TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification;
    public Map<String, SimpleMatrix> unaryClassification;
    public Map<String, SimpleMatrix> wordVectors;
    public final int numClasses;
    public final int numHid;
    public final int numBinaryMatrices;
    public final int binaryTransformSize;
    public final int binaryTensorSize;
    public final int binaryClassificationSize;
    public final int numUnaryMatrices;
    public final int unaryClassificationSize;
    transient SimpleMatrix identity;
    final Random rand;
    static final String UNKNOWN_WORD = "*UNK*";
    final RNNOptions op;
    private static final long serialVersionUID = 1L;

    static SentimentModel modelFromMatrices(SimpleMatrix W, SimpleMatrix Wcat, SimpleTensor Wt, Map<String, SimpleMatrix> wordVectors, RNNOptions op) {
        if (!op.combineClassification || !op.simplifiedModel) {
            throw new IllegalArgumentException("Can only create a model using this method if combineClassification and simplifiedModel are turned on");
        }
        TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform = TwoDimensionalMap.treeMap();
        binaryTransform.put("", "", W);
        TwoDimensionalMap<String, String, SimpleTensor> binaryTensors = TwoDimensionalMap.treeMap();
        binaryTensors.put("", "", Wt);
        TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification = TwoDimensionalMap.treeMap();
        TreeMap<String, SimpleMatrix> unaryClassification = Generics.newTreeMap();
        unaryClassification.put("", Wcat);
        return new SentimentModel(binaryTransform, binaryTensors, binaryClassification, unaryClassification, wordVectors, op);
    }

    private SentimentModel(TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform, TwoDimensionalMap<String, String, SimpleTensor> binaryTensors, TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification, Map<String, SimpleMatrix> unaryClassification, Map<String, SimpleMatrix> wordVectors, RNNOptions op) {
        this.op = op;
        this.binaryTransform = binaryTransform;
        this.binaryTensors = binaryTensors;
        this.binaryClassification = binaryClassification;
        this.unaryClassification = unaryClassification;
        this.wordVectors = wordVectors;
        this.numClasses = op.numClasses;
        if (op.numHid <= 0) {
            int nh = 0;
            for (SimpleMatrix wv : wordVectors.values()) {
                nh = wv.getNumElements();
            }
            this.numHid = nh;
        } else {
            this.numHid = op.numHid;
        }
        this.numBinaryMatrices = binaryTransform.size();
        this.binaryTransformSize = this.numHid * (2 * this.numHid + 1);
        this.binaryTensorSize = op.useTensors ? this.numHid * this.numHid * this.numHid * 4 : 0;
        this.binaryClassificationSize = op.combineClassification ? 0 : this.numClasses * (this.numHid + 1);
        this.numUnaryMatrices = unaryClassification.size();
        this.unaryClassificationSize = this.numClasses * (this.numHid + 1);
        this.rand = new Random(op.randomSeed);
        this.identity = SimpleMatrix.identity((int)this.numHid);
    }

    public SentimentModel(RNNOptions op, List<Tree> trainingTrees) {
        this.op = op;
        this.rand = new Random(op.randomSeed);
        if (op.randomWordVectors) {
            this.initRandomWordVectors(trainingTrees);
        } else {
            this.readWordVectors();
        }
        if (op.numHid > 0) {
            this.numHid = op.numHid;
        } else {
            int size = 0;
            Iterator<SimpleMatrix> i$ = this.wordVectors.values().iterator();
            if (i$.hasNext()) {
                SimpleMatrix vector = i$.next();
                size = vector.getNumElements();
            }
            this.numHid = size;
        }
        TwoDimensionalSet<String, String> binaryProductions = TwoDimensionalSet.hashSet();
        if (!op.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        binaryProductions.add("", "");
        Set<String> unaryProductions = Generics.newHashSet();
        if (!op.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        unaryProductions.add("");
        this.numClasses = op.numClasses;
        this.identity = SimpleMatrix.identity((int)this.numHid);
        this.binaryTransform = TwoDimensionalMap.treeMap();
        this.binaryTensors = TwoDimensionalMap.treeMap();
        this.binaryClassification = TwoDimensionalMap.treeMap();
        for (Pair binary : binaryProductions) {
            String right;
            String left = this.basicCategory((String)binary.first);
            if (this.binaryTransform.contains(left, right = this.basicCategory((String)binary.second))) continue;
            this.binaryTransform.put(left, right, this.randomTransformMatrix());
            if (op.useTensors) {
                this.binaryTensors.put(left, right, this.randomBinaryTensor());
            }
            if (op.combineClassification) continue;
            this.binaryClassification.put(left, right, this.randomClassificationMatrix());
        }
        this.numBinaryMatrices = this.binaryTransform.size();
        this.binaryTransformSize = this.numHid * (2 * this.numHid + 1);
        this.binaryTensorSize = op.useTensors ? this.numHid * this.numHid * this.numHid * 4 : 0;
        this.binaryClassificationSize = op.combineClassification ? 0 : this.numClasses * (this.numHid + 1);
        this.unaryClassification = Generics.newTreeMap();
        for (String unary : unaryProductions) {
            if (this.unaryClassification.containsKey(unary = this.basicCategory(unary))) continue;
            this.unaryClassification.put(unary, this.randomClassificationMatrix());
        }
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numClasses * (this.numHid + 1);
    }

    SimpleTensor randomBinaryTensor() {
        double range = 1.0 / (4.0 * (double)this.numHid);
        SimpleTensor tensor = SimpleTensor.random(this.numHid * 2, this.numHid * 2, this.numHid, -range, range, this.rand);
        return tensor.scale(this.op.trainOptions.scalingForInit);
    }

    SimpleMatrix randomTransformMatrix() {
        SimpleMatrix binary = new SimpleMatrix(this.numHid, this.numHid * 2 + 1);
        binary.insertIntoThis(0, 0, (SimpleBase)this.randomTransformBlock());
        binary.insertIntoThis(0, this.numHid, (SimpleBase)this.randomTransformBlock());
        return (SimpleMatrix)binary.scale(this.op.trainOptions.scalingForInit);
    }

    SimpleMatrix randomTransformBlock() {
        double range = 1.0 / (Math.sqrt(this.numHid) * 2.0);
        return (SimpleMatrix)SimpleMatrix.random((int)this.numHid, (int)this.numHid, (double)(-range), (double)range, (Random)this.rand).plus((SimpleBase)this.identity);
    }

    SimpleMatrix randomClassificationMatrix() {
        SimpleMatrix score = new SimpleMatrix(this.numClasses, this.numHid + 1);
        double range = 1.0 / Math.sqrt(this.numHid);
        score.insertIntoThis(0, 0, (SimpleBase)SimpleMatrix.random((int)this.numClasses, (int)this.numHid, (double)(-range), (double)range, (Random)this.rand));
        return (SimpleMatrix)score.scale(this.op.trainOptions.scalingForInit);
    }

    SimpleMatrix randomWordVector() {
        return SentimentModel.randomWordVector(this.op.numHid, this.rand);
    }

    static SimpleMatrix randomWordVector(int size, Random rand) {
        return NeuralUtils.randomGaussian(size, 1, rand);
    }

    void initRandomWordVectors(List<Tree> trainingTrees) {
        if (this.op.numHid == 0) {
            throw new RuntimeException("Cannot create random word vectors for an unknown numHid");
        }
        Set<String> words = Generics.newHashSet();
        words.add(UNKNOWN_WORD);
        for (Tree tree : trainingTrees) {
            List leaves = tree.getLeaves();
            for (Tree leaf : leaves) {
                String word = leaf.label().value();
                if (this.op.lowercaseWordVectors) {
                    word = word.toLowerCase();
                }
                words.add(word);
            }
        }
        this.wordVectors = Generics.newTreeMap();
        for (String word : words) {
            SimpleMatrix vector = this.randomWordVector();
            this.wordVectors.put(word, vector);
        }
    }

    void readWordVectors() {
        Embedding embedding = new Embedding(this.op.wordVectors, this.op.numHid);
        this.wordVectors = Generics.newTreeMap();
        for (String word : embedding.keySet()) {
            this.wordVectors.put(word, embedding.get(word));
        }
        String unkWord = this.op.unkWord;
        SimpleMatrix unknownWordVector = this.wordVectors.get(unkWord);
        this.wordVectors.put(UNKNOWN_WORD, unknownWordVector);
        if (unknownWordVector == null) {
            throw new RuntimeException("Unknown word vector not specified in the word vector file");
        }
    }

    public int totalParamSize() {
        int totalSize = 0;
        totalSize = this.numBinaryMatrices * (this.binaryTransformSize + this.binaryClassificationSize + this.binaryTensorSize);
        totalSize += this.numUnaryMatrices * this.unaryClassificationSize;
        return totalSize += this.wordVectors.size() * this.numHid;
    }

    public double[] paramsToVector() {
        int totalSize = this.totalParamSize();
        return NeuralUtils.paramsToVector(totalSize, this.binaryTransform.valueIterator(), this.binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(this.binaryTensors.valueIterator()), this.unaryClassification.values().iterator(), this.wordVectors.values().iterator());
    }

    public void vectorToParams(double[] theta) {
        NeuralUtils.vectorToParams(theta, this.binaryTransform.valueIterator(), this.binaryClassification.valueIterator(), SimpleTensor.iteratorSimpleMatrix(this.binaryTensors.valueIterator()), this.unaryClassification.values().iterator(), this.wordVectors.values().iterator());
    }

    public SimpleMatrix getWForNode(Tree node) {
        if (node.children().length == 2) {
            String leftLabel = node.children()[0].value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children()[1].value();
            String rightBasic = this.basicCategory(rightLabel);
            return this.binaryTransform.get(leftBasic, rightBasic);
        }
        if (node.children().length == 1) {
            throw new AssertionError((Object)"No unary transform matrices, only unary classification");
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().length));
    }

    public SimpleTensor getTensorForNode(Tree node) {
        if (!this.op.useTensors) {
            throw new AssertionError((Object)"Not using tensors");
        }
        if (node.children().length == 2) {
            String leftLabel = node.children()[0].value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children()[1].value();
            String rightBasic = this.basicCategory(rightLabel);
            return this.binaryTensors.get(leftBasic, rightBasic);
        }
        if (node.children().length == 1) {
            throw new AssertionError((Object)"No unary transform matrices, only unary classification");
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().length));
    }

    public SimpleMatrix getClassWForNode(Tree node) {
        if (this.op.combineClassification) {
            return this.unaryClassification.get("");
        }
        if (node.children().length == 2) {
            String leftLabel = node.children()[0].value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children()[1].value();
            String rightBasic = this.basicCategory(rightLabel);
            return this.binaryClassification.get(leftBasic, rightBasic);
        }
        if (node.children().length == 1) {
            String unaryLabel = node.children()[0].value();
            String unaryBasic = this.basicCategory(unaryLabel);
            return this.unaryClassification.get(unaryBasic);
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().length));
    }

    public SimpleMatrix getWordVector(String word) {
        return this.wordVectors.get(this.getVocabWord(word));
    }

    public String getVocabWord(String word) {
        if (this.op.lowercaseWordVectors) {
            word = word.toLowerCase();
        }
        if (this.wordVectors.containsKey(word)) {
            return word;
        }
        return UNKNOWN_WORD;
    }

    public String basicCategory(String category) {
        if (this.op.simplifiedModel) {
            return "";
        }
        String basic = this.op.langpack.basicCategory(category);
        if (basic.length() > 0 && basic.charAt(0) == '@') {
            basic = basic.substring(1);
        }
        return basic;
    }

    public SimpleMatrix getUnaryClassification(String category) {
        category = this.basicCategory(category);
        return this.unaryClassification.get(category);
    }

    public SimpleMatrix getBinaryClassification(String left, String right) {
        if (this.op.combineClassification) {
            return this.unaryClassification.get("");
        }
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return this.binaryClassification.get(left, right);
    }

    public SimpleMatrix getBinaryTransform(String left, String right) {
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return this.binaryTransform.get(left, right);
    }

    public SimpleTensor getBinaryTensor(String left, String right) {
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return this.binaryTensors.get(left, right);
    }

    public void saveSerialized(String path) {
        try {
            IOUtils.writeObjectToFile((Object)this, path);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public static SentimentModel loadSerialized(String path) {
        try {
            return (SentimentModel)IOUtils.readObjectFromURLOrClasspathOrFileSystem(path);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeIOException(e);
        }
    }

    public void printParamInformation(int index) {
        int curIndex = 0;
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryTransform) {
            if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
                System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryTransform \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
                return;
            }
            curIndex += entry.getValue().getNumElements();
        }
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryClassification) {
            if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) {
                System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryClassification \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
                return;
            }
            curIndex += entry.getValue().getNumElements();
        }
        for (TwoDimensionalMap.Entry<String, String, Object> entry : this.binaryTensors) {
            if (curIndex <= index && curIndex + ((SimpleTensor)entry.getValue()).getNumElements() > index) {
                System.err.println("Index " + index + " is element " + (index - curIndex) + " of binaryTensor \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\"");
                return;
            }
            curIndex += ((SimpleTensor)entry.getValue()).getNumElements();
        }
        for (Map.Entry entry : this.unaryClassification.entrySet()) {
            if (curIndex <= index && curIndex + ((SimpleMatrix)entry.getValue()).getNumElements() > index) {
                System.err.println("Index " + index + " is element " + (index - curIndex) + " of unaryClassification \"" + (String)entry.getKey() + "\"");
                return;
            }
            curIndex += ((SimpleMatrix)entry.getValue()).getNumElements();
        }
        for (Map.Entry entry : this.wordVectors.entrySet()) {
            if (curIndex <= index && curIndex + ((SimpleMatrix)entry.getValue()).getNumElements() > index) {
                System.err.println("Index " + index + " is element " + (index - curIndex) + " of wordVector \"" + (String)entry.getKey() + "\"");
                return;
            }
            curIndex += ((SimpleMatrix)entry.getValue()).getNumElements();
        }
        System.err.println("Index " + index + " is beyond the length of the parameters; total parameter space was " + this.totalParamSize());
    }
}

