/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.neural;

import edu.stanford.nlp.neural.NeuralUtils;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Random;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class SimpleTensor
implements Serializable {
    private SimpleMatrix[] slices;
    final int numRows;
    final int numCols;
    final int numSlices;
    private static final long serialVersionUID = 1L;

    public SimpleTensor(int numRows, int numCols, int numSlices) {
        this.slices = new SimpleMatrix[numSlices];
        for (int i = 0; i < numSlices; ++i) {
            this.slices[i] = new SimpleMatrix(numRows, numCols);
        }
        this.numRows = numRows;
        this.numCols = numCols;
        this.numSlices = numSlices;
    }

    public SimpleTensor(SimpleMatrix[] slices) {
        this.numRows = slices[0].numRows();
        this.numCols = slices[0].numCols();
        this.numSlices = slices.length;
        this.slices = new SimpleMatrix[slices.length];
        for (int i = 0; i < this.numSlices; ++i) {
            if (slices[i].numRows() != this.numRows || slices[i].numCols() != this.numCols) {
                throw new IllegalArgumentException("Slice " + i + " has matrix dimensions " + slices[i].numRows() + "," + slices[i].numCols() + ", expected " + this.numRows + "," + this.numCols);
            }
            this.slices[i] = new SimpleMatrix(slices[i]);
        }
    }

    public static SimpleTensor random(int numRows, int numCols, int numSlices, double minValue, double maxValue, Random rand) {
        SimpleTensor tensor = new SimpleTensor(numRows, numCols, numSlices);
        for (int i = 0; i < numSlices; ++i) {
            tensor.slices[i] = SimpleMatrix.random((int)numRows, (int)numCols, (double)minValue, (double)maxValue, (Random)rand);
        }
        return tensor;
    }

    public int numRows() {
        return this.numRows;
    }

    public int numCols() {
        return this.numCols;
    }

    public int numSlices() {
        return this.numSlices;
    }

    public int getNumElements() {
        return this.numRows * this.numCols * this.numSlices;
    }

    public void set(double value) {
        for (int slice = 0; slice < this.numSlices; ++slice) {
            this.slices[slice].set(value);
        }
    }

    public SimpleTensor scale(double scaling) {
        SimpleTensor result = new SimpleTensor(this.numRows, this.numCols, this.numSlices);
        for (int slice = 0; slice < this.numSlices; ++slice) {
            result.slices[slice] = (SimpleMatrix)this.slices[slice].scale(scaling);
        }
        return result;
    }

    public SimpleTensor plus(SimpleTensor other) {
        if (other.numRows != this.numRows || other.numCols != this.numCols || other.numSlices != this.numSlices) {
            throw new IllegalArgumentException("Sizes of tensors do not match.  Our size: " + this.numRows + "," + this.numCols + "," + this.numSlices + "; other size " + other.numRows + "," + other.numCols + "," + other.numSlices);
        }
        SimpleTensor result = new SimpleTensor(this.numRows, this.numCols, this.numSlices);
        for (int i = 0; i < this.numSlices; ++i) {
            result.slices[i] = (SimpleMatrix)this.slices[i].plus((SimpleBase)other.slices[i]);
        }
        return result;
    }

    public SimpleTensor elementMult(SimpleTensor other) {
        if (other.numRows != this.numRows || other.numCols != this.numCols || other.numSlices != this.numSlices) {
            throw new IllegalArgumentException("Sizes of tensors do not match.  Our size: " + this.numRows + "," + this.numCols + "," + this.numSlices + "; other size " + other.numRows + "," + other.numCols + "," + other.numSlices);
        }
        SimpleTensor result = new SimpleTensor(this.numRows, this.numCols, this.numSlices);
        for (int i = 0; i < this.numSlices; ++i) {
            result.slices[i] = (SimpleMatrix)this.slices[i].elementMult((SimpleBase)other.slices[i]);
        }
        return result;
    }

    public double elementSum() {
        double sum = 0.0;
        for (SimpleMatrix slice : this.slices) {
            sum += slice.elementSum();
        }
        return sum;
    }

    public void setSlice(int slice, SimpleMatrix matrix) {
        if (slice < 0 || slice >= this.numSlices) {
            throw new IllegalArgumentException("Unexpected slice number " + slice + " for tensor with " + this.numSlices + " slices");
        }
        if (matrix.numCols() != this.numCols) {
            throw new IllegalArgumentException("Incompatible matrix size.  Has " + matrix.numCols() + " columns, tensor has " + this.numCols);
        }
        if (matrix.numRows() != this.numRows) {
            throw new IllegalArgumentException("Incompatible matrix size.  Has " + matrix.numRows() + " columns, tensor has " + this.numRows);
        }
        this.slices[slice] = matrix;
    }

    public SimpleMatrix getSlice(int slice) {
        if (slice < 0 || slice >= this.numSlices) {
            throw new IllegalArgumentException("Unexpected slice number " + slice + " for tensor with " + this.numSlices + " slices");
        }
        return this.slices[slice];
    }

    public SimpleMatrix bilinearProducts(SimpleMatrix in) {
        if (in.numCols() != 1) {
            throw new AssertionError((Object)"Expected a column vector");
        }
        if (in.numRows() != this.numCols) {
            throw new AssertionError((Object)"Number of rows in the input does not match number of columns in tensor");
        }
        if (this.numRows != this.numCols) {
            throw new AssertionError((Object)"Can only perform this operation on a SimpleTensor with square slices");
        }
        SimpleMatrix inT = (SimpleMatrix)in.transpose();
        SimpleMatrix out2 = new SimpleMatrix(this.numSlices, 1);
        for (int slice = 0; slice < this.numSlices; ++slice) {
            double result = ((SimpleMatrix)((SimpleMatrix)inT.mult((SimpleBase)this.slices[slice])).mult((SimpleBase)in)).get(0);
            out2.set(slice, result);
        }
        return out2;
    }

    public boolean isZero() {
        for (int i = 0; i < this.numSlices; ++i) {
            if (NeuralUtils.isZero(this.slices[i])) continue;
            return false;
        }
        return true;
    }

    public Iterator<SimpleMatrix> iteratorSimpleMatrix() {
        return Arrays.asList(this.slices).iterator();
    }

    public static Iterator<SimpleMatrix> iteratorSimpleMatrix(Iterator<SimpleTensor> tensors) {
        return new SimpleMatrixIteratorWrapper(tensors);
    }

    private static class SimpleMatrixIteratorWrapper
    implements Iterator<SimpleMatrix> {
        Iterator<SimpleTensor> tensors;
        Iterator<SimpleMatrix> currentIterator;

        public SimpleMatrixIteratorWrapper(Iterator<SimpleTensor> tensors) {
            this.tensors = tensors;
            this.advanceIterator();
        }

        @Override
        public boolean hasNext() {
            if (this.currentIterator == null) {
                return false;
            }
            if (this.currentIterator.hasNext()) {
                return true;
            }
            this.advanceIterator();
            return this.currentIterator != null;
        }

        @Override
        public SimpleMatrix next() {
            if (this.currentIterator != null && this.currentIterator.hasNext()) {
                return this.currentIterator.next();
            }
            this.advanceIterator();
            if (this.currentIterator != null) {
                return this.currentIterator.next();
            }
            throw new NoSuchElementException();
        }

        private void advanceIterator() {
            if (this.currentIterator != null && this.currentIterator.hasNext()) {
                return;
            }
            while (this.tensors.hasNext()) {
                this.currentIterator = this.tensors.next().iteratorSimpleMatrix();
                if (!this.currentIterator.hasNext()) continue;
                return;
            }
            this.currentIterator = null;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }
}

