/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.transferlearning;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class TransferLearningHelper {
    private boolean isGraph = true;
    private boolean applyFrozen = false;
    private ComputationGraph origGraph;
    private MultiLayerNetwork origMLN;
    private int frozenTill;
    private String[] frozenOutputAt;
    private ComputationGraph unFrozenSubsetGraph;
    private MultiLayerNetwork unFrozenSubsetMLN;
    Set<String> frozenInputVertices = new HashSet<String>();
    List<String> graphInputs;
    int frozenInputLayer = 0;

    public TransferLearningHelper(ComputationGraph orig, String ... frozenOutputAt) {
        this.origGraph = orig;
        this.frozenOutputAt = frozenOutputAt;
        this.applyFrozen = true;
        this.initHelperGraph();
    }

    public TransferLearningHelper(ComputationGraph orig) {
        this.origGraph = orig;
        this.initHelperGraph();
    }

    public TransferLearningHelper(MultiLayerNetwork orig, int frozenTill) {
        this.isGraph = false;
        this.frozenTill = frozenTill;
        this.applyFrozen = true;
        this.origMLN = orig;
        this.initHelperMLN();
    }

    public TransferLearningHelper(MultiLayerNetwork orig) {
        this.isGraph = false;
        this.origMLN = orig;
        this.initHelperMLN();
    }

    public void errorIfGraphIfMLN() {
        if (this.isGraph) {
            throw new IllegalArgumentException("This instance was initialized with a computation graph. Cannot apply methods related to MLN");
        }
        throw new IllegalArgumentException("This instance was initialized with a MultiLayerNetwork. Cannot apply methods related to computation graphs");
    }

    public ComputationGraph unfrozenGraph() {
        if (!this.isGraph) {
            this.errorIfGraphIfMLN();
        }
        return this.unFrozenSubsetGraph;
    }

    public MultiLayerNetwork unfrozenMLN() {
        if (this.isGraph) {
            this.errorIfGraphIfMLN();
        }
        return this.unFrozenSubsetMLN;
    }

    public INDArray[] outputFromFeaturized(INDArray[] input) {
        if (!this.isGraph) {
            this.errorIfGraphIfMLN();
        }
        return this.unFrozenSubsetGraph.output(input);
    }

    public INDArray outputFromFeaturized(INDArray input) {
        if (this.isGraph) {
            if (this.unFrozenSubsetGraph.getNumOutputArrays() > 1) {
                throw new IllegalArgumentException("Graph has more than one output. Expecting an input array with outputFromFeaturized method call");
            }
            return this.unFrozenSubsetGraph.output(input)[0];
        }
        return this.unFrozenSubsetMLN.output(input);
    }

    private void initHelperGraph() {
        int j;
        Object gv;
        int i;
        int[] backPropOrder = (int[])this.origGraph.topologicalSortOrder().clone();
        ArrayUtils.reverse((int[])backPropOrder);
        HashSet<String> allFrozen = new HashSet<String>();
        if (this.applyFrozen) {
            Collections.addAll(allFrozen, this.frozenOutputAt);
        }
        for (i = 0; i < backPropOrder.length; ++i) {
            int inputVertexIdx;
            VertexIndices[] inputs;
            gv = this.origGraph.getVertices()[backPropOrder[i]];
            if (this.applyFrozen && allFrozen.contains(gv.getVertexName())) {
                if (gv.hasLayer()) {
                    Layer l = gv.getLayer();
                    gv.setLayerAsFrozen();
                    Layer[] layers = this.origGraph.getLayers();
                    for (j = 0; j < layers.length; ++j) {
                        if (layers[j] != l) continue;
                        layers[j] = gv.getLayer();
                        break;
                    }
                }
                if ((inputs = gv.getInputVertices()) == null || inputs.length <= 0) continue;
                for (int j2 = 0; j2 < inputs.length; ++j2) {
                    inputVertexIdx = inputs[j2].getVertexIndex();
                    String alsoFreeze = this.origGraph.getVertices()[inputVertexIdx].getVertexName();
                    allFrozen.add(alsoFreeze);
                }
                continue;
            }
            if (!gv.hasLayer() || !(gv.getLayer() instanceof FrozenLayer)) continue;
            allFrozen.add(gv.getVertexName());
            inputs = gv.getInputVertices();
            if (inputs == null || inputs.length <= 0) continue;
            for (int j3 = 0; j3 < inputs.length; ++j3) {
                inputVertexIdx = inputs[j3].getVertexIndex();
                String alsoFrozen = this.origGraph.getVertices()[inputVertexIdx].getVertexName();
                allFrozen.add(alsoFrozen);
            }
        }
        for (i = 0; i < backPropOrder.length; ++i) {
            gv = this.origGraph.getVertices()[backPropOrder[i]];
            String gvName = gv.getVertexName();
            if (allFrozen.contains(gvName) || gv.isInputVertex()) continue;
            VertexIndices[] inputs = gv.getInputVertices();
            for (j = 0; j < inputs.length; ++j) {
                int inputVertexIdx = inputs[j].getVertexIndex();
                String inputVertex = this.origGraph.getVertices()[inputVertexIdx].getVertexName();
                if (!allFrozen.contains(inputVertex)) continue;
                this.frozenInputVertices.add(inputVertex);
            }
        }
        TransferLearning.GraphBuilder builder = new TransferLearning.GraphBuilder(this.origGraph);
        for (String toRemove : allFrozen) {
            if (this.frozenInputVertices.contains(toRemove)) {
                builder.removeVertexKeepConnections(toRemove);
                continue;
            }
            builder.removeVertexAndConnections(toRemove);
        }
        HashSet<String> frozenInputVerticesSorted = new HashSet<String>();
        frozenInputVerticesSorted.addAll(this.origGraph.getConfiguration().getNetworkInputs());
        frozenInputVerticesSorted.removeAll(allFrozen);
        for (String existingInput : frozenInputVerticesSorted) {
            builder.removeVertexKeepConnections(existingInput);
        }
        frozenInputVerticesSorted.addAll(this.frozenInputVertices);
        this.graphInputs = new ArrayList<String>(frozenInputVerticesSorted);
        Collections.sort(this.graphInputs);
        for (String asInput : frozenInputVerticesSorted) {
            builder.addInputs(asInput);
        }
        this.unFrozenSubsetGraph = builder.build();
        this.copyOrigParamsToSubsetGraph();
        if (this.frozenInputVertices.isEmpty()) {
            throw new IllegalArgumentException("No frozen layers found");
        }
    }

    private void initHelperMLN() {
        int i;
        if (this.applyFrozen) {
            Layer[] layers = this.origMLN.getLayers();
            for (i = this.frozenTill; i >= 0; --i) {
                layers[i] = new FrozenLayer<Layer>(layers[i]);
            }
            this.origMLN.setLayers(layers);
        }
        for (int i2 = 0; i2 < this.origMLN.getnLayers(); ++i2) {
            if (!(this.origMLN.getLayer(i2) instanceof FrozenLayer)) continue;
            this.frozenInputLayer = i2;
        }
        ArrayList<NeuralNetConfiguration> allConfs = new ArrayList<NeuralNetConfiguration>();
        for (i = this.frozenInputLayer + 1; i < this.origMLN.getnLayers(); ++i) {
            allConfs.add(this.origMLN.getLayer(i).conf());
        }
        MultiLayerConfiguration c = this.origMLN.getLayerWiseConfigurations();
        this.unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder().backprop(c.isBackprop()).inputPreProcessors(c.getInputPreProcessors()).pretrain(c.isPretrain()).backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength()).tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs).build());
        this.unFrozenSubsetMLN.init();
        for (int i3 = this.frozenInputLayer + 1; i3 < this.origMLN.getnLayers(); ++i3) {
            this.unFrozenSubsetMLN.getLayer(i3 - this.frozenInputLayer - 1).setParams(this.origMLN.getLayer(i3).params());
        }
    }

    public MultiDataSet featurize(MultiDataSet input) {
        if (!this.isGraph) {
            throw new IllegalArgumentException("Cannot use multidatasets with MultiLayerNetworks.");
        }
        INDArray[] labels = input.getLabels();
        INDArray[] features = input.getFeatures();
        if (input.getFeaturesMaskArrays() != null) {
            throw new IllegalArgumentException("Currently cannot support featurizing datasets with feature masks");
        }
        INDArray[] featureMasks = null;
        INDArray[] labelMasks = input.getLabelsMaskArrays();
        INDArray[] featuresNow = new INDArray[this.graphInputs.size()];
        Map<String, INDArray> activationsNow = this.origGraph.feedForward(features, false);
        for (int i = 0; i < this.graphInputs.size(); ++i) {
            String anInput = this.graphInputs.get(i);
            if (this.origGraph.getVertex(anInput).isInputVertex()) {
                int inputIndex = this.origGraph.getConfiguration().getNetworkInputs().indexOf(anInput);
                featuresNow[i] = this.origGraph.getInput(inputIndex);
                continue;
            }
            featuresNow[i] = activationsNow.get(anInput);
        }
        return new MultiDataSet(featuresNow, labels, featureMasks, labelMasks);
    }

    public DataSet featurize(DataSet input) {
        if (this.isGraph) {
            if (this.origGraph.getNumInputArrays() > 1 || this.origGraph.getNumOutputArrays() > 1) {
                throw new IllegalArgumentException("Input or output size to a computation graph is greater than one. Requires use of a MultiDataSet.");
            }
            if (input.getFeaturesMaskArray() != null) {
                throw new IllegalArgumentException("Currently cannot support featurizing datasets with feature masks");
            }
            MultiDataSet inbW = new MultiDataSet(new INDArray[]{input.getFeatures()}, new INDArray[]{input.getLabels()}, null, new INDArray[]{input.getLabelsMaskArray()});
            MultiDataSet ret = this.featurize(inbW);
            return new DataSet(ret.getFeatures()[0], input.getLabels(), ret.getLabelsMaskArrays()[0], input.getLabelsMaskArray());
        }
        if (input.getFeaturesMaskArray() != null) {
            throw new UnsupportedOperationException("Feature masks not supported with featurizing currently");
        }
        return new DataSet(this.origMLN.feedForwardToLayer(this.frozenInputLayer + 1, input.getFeatures(), false).get(this.frozenInputLayer + 1), input.getLabels(), null, input.getLabelsMaskArray());
    }

    public void fitFeaturized(MultiDataSetIterator iter) {
        this.unFrozenSubsetGraph.fit(iter);
        this.copyParamsFromSubsetGraphToOrig();
    }

    public void fitFeaturized(MultiDataSet input) {
        this.unFrozenSubsetGraph.fit((org.nd4j.linalg.dataset.api.MultiDataSet)input);
        this.copyParamsFromSubsetGraphToOrig();
    }

    public void fitFeaturized(DataSet input) {
        if (this.isGraph) {
            this.unFrozenSubsetGraph.fit((org.nd4j.linalg.dataset.api.DataSet)input);
            this.copyParamsFromSubsetGraphToOrig();
        } else {
            this.unFrozenSubsetMLN.fit((org.nd4j.linalg.dataset.api.DataSet)input);
            this.copyParamsFromSubsetMLNToOrig();
        }
    }

    public void fitFeaturized(DataSetIterator iter) {
        if (this.isGraph) {
            this.unFrozenSubsetGraph.fit(iter);
            this.copyParamsFromSubsetGraphToOrig();
        } else {
            this.unFrozenSubsetMLN.fit(iter);
            this.copyParamsFromSubsetMLNToOrig();
        }
    }

    private void copyParamsFromSubsetGraphToOrig() {
        for (GraphVertex aVertex : this.unFrozenSubsetGraph.getVertices()) {
            if (!aVertex.hasLayer()) continue;
            this.origGraph.getVertex(aVertex.getVertexName()).getLayer().setParams(aVertex.getLayer().params());
        }
    }

    private void copyOrigParamsToSubsetGraph() {
        for (GraphVertex aVertex : this.unFrozenSubsetGraph.getVertices()) {
            if (!aVertex.hasLayer()) continue;
            aVertex.getLayer().setParams(this.origGraph.getLayer(aVertex.getVertexName()).params());
        }
    }

    private void copyParamsFromSubsetMLNToOrig() {
        for (int i = this.frozenInputLayer + 1; i < this.origMLN.getnLayers(); ++i) {
            this.origMLN.getLayer(i).setParams(this.unFrozenSubsetMLN.getLayer(i - this.frozenInputLayer - 1).params());
        }
    }
}

