/*
 * Decompiled with CFR 0.152.
 */
package org.kramerlab.autoencoder.neuralnet;

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import org.kramerlab.autoencoder.math.matrix.Mat;
import org.kramerlab.autoencoder.math.optimization.CG_Rasmussen2_WithTermination;
import org.kramerlab.autoencoder.math.optimization.CG_Rasmussen2_WithTermination$;
import org.kramerlab.autoencoder.math.optimization.DifferentiableErrorFunctionFactory;
import org.kramerlab.autoencoder.math.optimization.DifferentiableFunction;
import org.kramerlab.autoencoder.math.optimization.EarlyStopping;
import org.kramerlab.autoencoder.math.optimization.EarlyStopping$;
import org.kramerlab.autoencoder.math.optimization.LimitNumberOfEvaluations;
import org.kramerlab.autoencoder.math.optimization.SquareErrorFunctionFactory$;
import org.kramerlab.autoencoder.math.optimization.TerminationCriterion;
import org.kramerlab.autoencoder.math.random.package$;
import org.kramerlab.autoencoder.neuralnet.BiasedUnitLayer;
import org.kramerlab.autoencoder.neuralnet.FullBipartiteConnection;
import org.kramerlab.autoencoder.neuralnet.Layer;
import org.kramerlab.autoencoder.neuralnet.MatrixParameterizedLayer;
import org.kramerlab.autoencoder.neuralnet.NeuralNetLike;
import org.kramerlab.autoencoder.neuralnet.NeuralNetLike$;
import org.kramerlab.autoencoder.visualization.Observer;
import org.kramerlab.autoencoder.visualization.TrainingObserver;
import org.kramerlab.autoencoder.visualization.Visualizable;
import org.kramerlab.autoencoder.visualization.VisualizableIntermediateResult;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.IterableLike;
import scala.collection.LinearSeqOptimized;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.runtime.BoxesRunTime;

public abstract class NeuralNetLike$class {
    public static Mat apply(NeuralNetLike $this, Mat input) {
        return (Mat)((LinearSeqOptimized)$this.layers().tail()).foldLeft((Object)input, (Function2)new Serializable($this){
            public static final long serialVersionUID = 0L;

            public final Mat apply(Mat in, Layer layer) {
                return layer.propagate(in);
            }
        });
    }

    public static List activities(NeuralNetLike $this, Mat input) {
        return (List)((TraversableLike)$this.layers().tail()).scanLeft((Object)input, (Function2)new Serializable($this){
            public static final long serialVersionUID = 0L;

            public final Mat apply(Mat in, Layer layer) {
                return layer.propagate(in);
            }
        }, List$.MODULE$.canBuildFrom());
    }

    public static Mat reverse(NeuralNetLike $this, Mat output) {
        return (Mat)((LinearSeqOptimized)$this.layers().reverse().tail()).foldLeft((Object)output, (Function2)new Serializable($this){
            public static final long serialVersionUID = 0L;

            public final Mat apply(Mat out, Layer layer) {
                return layer.reversePropagate(out);
            }
        });
    }

    public static NeuralNetLike optimize(NeuralNetLike $this, Mat input, Mat output, DifferentiableErrorFunctionFactory errorFunctionFactory, double relativeValidationSetSize, int maxEvals, List trainingObservers) {
        if (relativeValidationSetSize == 0.0) {
            throw new Error("Training without a validation set is currently not supported, but can be added easily.");
        }
        int[] permutation = package$.MODULE$.permutation(input.height());
        Mat shuffledInput = input.clone();
        Mat shuffledOutput = output.clone();
        shuffledInput.permutateRows(permutation);
        shuffledOutput.permutateRows(permutation);
        int validationHeight = (int)((double)input.height() * relativeValidationSetSize);
        int n = 0;
        Mat validationInput = shuffledInput.apply(org.kramerlab.autoencoder.math.matrix.package$.MODULE$.integerToConstantIndexSelector(validationHeight).$colon$colon$colon(n), org.kramerlab.autoencoder.math.matrix.package$.MODULE$.$colon$colon$colon());
        int n2 = 0;
        Mat validationOutput = shuffledOutput.apply(org.kramerlab.autoencoder.math.matrix.package$.MODULE$.integerToConstantIndexSelector(validationHeight).$colon$colon$colon(n2), org.kramerlab.autoencoder.math.matrix.package$.MODULE$.$colon$colon$colon());
        int n3 = validationHeight;
        Mat trainingInput = shuffledInput.apply(org.kramerlab.autoencoder.math.matrix.package$.MODULE$.end().$colon$colon$colon(n3), org.kramerlab.autoencoder.math.matrix.package$.MODULE$.$colon$colon$colon());
        int n4 = validationHeight;
        Mat trainingOutput = shuffledOutput.apply(org.kramerlab.autoencoder.math.matrix.package$.MODULE$.end().$colon$colon$colon(n4), org.kramerlab.autoencoder.math.matrix.package$.MODULE$.$colon$colon$colon());
        DifferentiableFunction<Mat> errorFunction = errorFunctionFactory.apply(trainingOutput);
        DifferentiableFunction minimizableFunction = NeuralNetLike$.MODULE$.differentiableComposition(trainingInput, errorFunction);
        DifferentiableFunction errorOnValidationSet = NeuralNetLike$.MODULE$.differentiableComposition(validationInput, errorFunctionFactory.apply(validationOutput));
        EarlyStopping earlyStopping = new EarlyStopping(new Serializable($this, errorOnValidationSet){
            public static final long serialVersionUID = 0L;
            private final DifferentiableFunction errorOnValidationSet$1;

            public final double apply(NeuralNetLike.ParameterVector<Repr> x) {
                return -this.errorOnValidationSet$1.apply(x);
            }
            {
                this.errorOnValidationSet$1 = errorOnValidationSet$1;
            }
        }, 64, 16, 512, 128, EarlyStopping$.MODULE$.$lessinit$greater$default$6(), EarlyStopping$.MODULE$.$lessinit$greater$default$7(), EarlyStopping$.MODULE$.$lessinit$greater$default$8(), Ordering.Double$.MODULE$);
        LimitNumberOfEvaluations evalsLimit = new LimitNumberOfEvaluations(maxEvals);
        TerminationCriterion terminationCriterion = evalsLimit.$bar(earlyStopping);
        List adjustedObservers = (List)trainingObservers.map((Function1)new Serializable($this, input){
            public static final long serialVersionUID = 0L;
            public final Mat input$1;

            public final Object apply(TrainingObserver obs) {
                return new Observer<NeuralNetLike.ParameterVector<Repr>>(this, obs){
                    private final /* synthetic */ NeuralNetLike$.anonfun.3 $outer;
                    private final TrainingObserver obs$1;

                    public void notify(NeuralNetLike.ParameterVector<Repr> r, boolean important) {
                        int n = 0;
                        int n2 = 0;
                        Mat firstHundredLines = this.$outer.input$1.apply(org.kramerlab.autoencoder.math.matrix.package$.MODULE$.integerToConstantIndexSelector(scala.math.package$.MODULE$.min(this.$outer.input$1.height(), 100)).$colon$colon$colon(n), org.kramerlab.autoencoder.math.matrix.package$.MODULE$.end().$colon$colon$colon(n2));
                        firstHundredLines.shuffleRows();
                        int n3 = 0;
                        int n4 = 0;
                        Mat sample2 = firstHundredLines.apply(org.kramerlab.autoencoder.math.matrix.package$.MODULE$.integerToConstantIndexSelector(scala.math.package$.MODULE$.min(firstHundredLines.height(), 10)).$colon$colon$colon(n3), org.kramerlab.autoencoder.math.matrix.package$.MODULE$.end().$colon$colon$colon(n4));
                        r.net().dataSample_$eq((Option<Mat>)new Some((Object)sample2));
                        this.obs$1.notify(new VisualizableIntermediateResult((Visualizable)r.net()), important);
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                        this.obs$1 = obs$1;
                    }
                };
            }
            {
                this.input$1 = input$1;
            }
        }, List$.MODULE$.canBuildFrom());
        CG_Rasmussen2_WithTermination minimizer = new CG_Rasmussen2_WithTermination(CG_Rasmussen2_WithTermination$.MODULE$.$lessinit$greater$default$1(), CG_Rasmussen2_WithTermination$.MODULE$.$lessinit$greater$default$2(), CG_Rasmussen2_WithTermination$.MODULE$.$lessinit$greater$default$3(), CG_Rasmussen2_WithTermination$.MODULE$.$lessinit$greater$default$4(), CG_Rasmussen2_WithTermination$.MODULE$.$lessinit$greater$default$5(), CG_Rasmussen2_WithTermination$.MODULE$.$lessinit$greater$default$6());
        return minimizer.minimize(minimizableFunction, new NeuralNetLike.ParameterVector<NeuralNetLike>($this), terminationCriterion, earlyStopping, adjustedObservers, Ordering.Double$.MODULE$).net();
    }

    public static DifferentiableErrorFunctionFactory optimize$default$3(NeuralNetLike $this) {
        return SquareErrorFunctionFactory$.MODULE$;
    }

    public static BufferedImage toImage(NeuralNetLike $this) {
        Option<Mat> option;
        block4: {
            List list;
            List layerImages;
            block3: {
                block2: {
                    layerImages = (List)$this.layers().map((Function1)new Serializable($this){
                        public static final long serialVersionUID = 0L;

                        public final BufferedImage apply(Layer x$9) {
                            return x$9.toImage();
                        }
                    }, List$.MODULE$.canBuildFrom());
                    option = $this.dataSample();
                    if (!(option instanceof Some)) break block2;
                    Some some = (Some)option;
                    Mat d = (Mat)some.x();
                    List<Mat> neuralActivities = $this.activities(d);
                    list = (List)((TraversableLike)neuralActivities.zip($this.layers(), List$.MODULE$.canBuildFrom())).map((Function1)new Serializable($this){
                        public static final long serialVersionUID = 0L;

                        public final BufferedImage apply(Tuple2<Mat, Layer> x0$1) {
                            Tuple2<Mat, Layer> tuple2 = x0$1;
                            if (tuple2 != null) {
                                Mat activity = (Mat)tuple2._1();
                                Layer layer = (Layer)tuple2._2();
                                BufferedImage bufferedImage = layer.visualizeActivity(activity);
                                return bufferedImage;
                            }
                            throw new MatchError(tuple2);
                        }
                    }, List$.MODULE$.canBuildFrom());
                    break block3;
                }
                None$ none$ = None$.MODULE$;
                Option<Mat> option2 = option;
                if (none$ != null ? !none$.equals(option2) : option2 != null) break block4;
                list = (List)layerImages.map((Function1)new Serializable($this){
                    public static final long serialVersionUID = 0L;

                    public final BufferedImage apply(BufferedImage x) {
                        return new BufferedImage(1, 1, 1);
                    }
                }, List$.MODULE$.canBuildFrom());
            }
            List activityImages = list;
            List layerHeights = (List)layerImages.map((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final int apply(BufferedImage x$10) {
                    return x$10.getHeight();
                }
            }, List$.MODULE$.canBuildFrom());
            List activityHeights = (List)activityImages.map((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final int apply(BufferedImage x$11) {
                    return x$11.getHeight();
                }
            }, List$.MODULE$.canBuildFrom());
            List heights = (List)((TraversableLike)layerHeights.zip((GenIterable)activityHeights, List$.MODULE$.canBuildFrom())).withFilter((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Tuple2<Object, Object> check$ifrefutable$1) {
                    Tuple2<Object, Object> tuple2 = check$ifrefutable$1;
                    boolean bl = tuple2 != null;
                    return bl;
                }
            }).map((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final int apply(Tuple2<Object, Object> x$12) {
                    Tuple2<Object, Object> tuple2 = x$12;
                    if (tuple2 != null) {
                        int lh = tuple2._1$mcI$sp();
                        int ah = tuple2._2$mcI$sp();
                        int n = scala.math.package$.MODULE$.max(lh, ah);
                        return n;
                    }
                    throw new MatchError(tuple2);
                }
            }, List$.MODULE$.canBuildFrom());
            int unitHeight = BoxesRunTime.unboxToInt((Object)heights.sum((Numeric)Numeric.IntIsIntegral$.MODULE$));
            int totalPadding = unitHeight / 10;
            int singlePadding = scala.math.package$.MODULE$.max(12, totalPadding / ($this.layers().size() + 1));
            int borderWidth = singlePadding / 3;
            int h = singlePadding * ($this.layers().size() + 1) + unitHeight;
            int maxLayerWidth = BoxesRunTime.unboxToInt((Object)((TraversableOnce)((TraversableLike)layerImages.zip((GenIterable)activityImages, List$.MODULE$.canBuildFrom())).withFilter((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Tuple2<BufferedImage, BufferedImage> check$ifrefutable$2) {
                    Tuple2<BufferedImage, BufferedImage> tuple2 = check$ifrefutable$2;
                    boolean bl = tuple2 != null;
                    return bl;
                }
            }).map((Function1)new Serializable($this, borderWidth){
                public static final long serialVersionUID = 0L;
                private final int borderWidth$1;

                public final int apply(Tuple2<BufferedImage, BufferedImage> x$13) {
                    Tuple2<BufferedImage, BufferedImage> tuple2 = x$13;
                    if (tuple2 != null) {
                        BufferedImage l = (BufferedImage)tuple2._1();
                        BufferedImage a = (BufferedImage)tuple2._2();
                        int n = l.getWidth() + a.getWidth() + this.borderWidth$1;
                        return n;
                    }
                    throw new MatchError(tuple2);
                }
                {
                    this.borderWidth$1 = borderWidth$1;
                }
            }, List$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Int$.MODULE$));
            int w = 3 * singlePadding + maxLayerWidth;
            List offsets = (List)heights.scanLeft((Object)BoxesRunTime.boxToInteger((int)singlePadding), (Function2)new Serializable($this, singlePadding){
                public static final long serialVersionUID = 0L;
                private final int singlePadding$1;

                public final int apply(int x$14, int x$15) {
                    return this.apply$mcIII$sp(x$14, x$15);
                }

                public int apply$mcIII$sp(int x$14, int x$15) {
                    return x$14 + x$15 + this.singlePadding$1;
                }
                {
                    this.singlePadding$1 = singlePadding$1;
                }
            }, List$.MODULE$.canBuildFrom());
            BufferedImage img = new BufferedImage(w, h, 1);
            Graphics2D g = (Graphics2D)img.getGraphics();
            g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
            g.setColor(Color.BLACK);
            g.fillRect(0, 0, w, h);
            g.setColor(Color.DARK_GRAY);
            ((TraversableLike)((IterableLike)layerImages.zip((GenIterable)activityImages, List$.MODULE$.canBuildFrom())).zip((GenIterable)offsets, List$.MODULE$.canBuildFrom())).withFilter((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Tuple2<Tuple2<BufferedImage, BufferedImage>, Object> check$ifrefutable$3) {
                    Tuple2 tuple2;
                    Tuple2<Tuple2<BufferedImage, BufferedImage>, Object> tuple22 = check$ifrefutable$3;
                    boolean bl = tuple22 != null && (tuple2 = (Tuple2)tuple22._1()) != null;
                    return bl;
                }
            }).foreach((Function1)new Serializable($this, singlePadding, borderWidth, maxLayerWidth, g){
                public static final long serialVersionUID = 0L;
                private final int singlePadding$1;
                private final int borderWidth$1;
                private final int maxLayerWidth$1;
                private final Graphics2D g$1;

                public final boolean apply(Tuple2<Tuple2<BufferedImage, BufferedImage>, Object> x$16) {
                    Tuple2<Tuple2<BufferedImage, BufferedImage>, Object> tuple2 = x$16;
                    if (tuple2 != null) {
                        Tuple2 tuple22 = (Tuple2)tuple2._1();
                        int offset = tuple2._2$mcI$sp();
                        if (tuple22 != null) {
                            BufferedImage layerImg = (BufferedImage)tuple22._1();
                            BufferedImage activityImg = (BufferedImage)tuple22._2();
                            int leftImageStartX = this.singlePadding$1 + (this.maxLayerWidth$1 - layerImg.getWidth() - activityImg.getWidth() - this.borderWidth$1) / 2;
                            this.g$1.fillRoundRect(leftImageStartX - this.borderWidth$1, offset - this.borderWidth$1, layerImg.getWidth() + activityImg.getWidth() + 3 * this.borderWidth$1, scala.math.package$.MODULE$.max(layerImg.getHeight(), activityImg.getHeight()) + 2 * this.borderWidth$1, this.borderWidth$1, this.borderWidth$1);
                            this.g$1.drawImage(layerImg, leftImageStartX, offset, layerImg.getWidth(), layerImg.getHeight(), null);
                            boolean bl = this.g$1.drawImage(activityImg, leftImageStartX + layerImg.getWidth() + this.borderWidth$1, offset, activityImg.getWidth(), activityImg.getHeight(), null);
                            return bl;
                        }
                    }
                    throw new MatchError(tuple2);
                }
                {
                    void var5_5;
                    this.singlePadding$1 = singlePadding$1;
                    this.borderWidth$1 = borderWidth$1;
                    this.maxLayerWidth$1 = maxLayerWidth$1;
                    this.g$1 = var5_5;
                }
            });
            return img;
        }
        throw new MatchError(option);
    }

    public static NeuralNetLike prependAffineLinearTransformation(NeuralNetLike $this, Mat factor, Mat offset) {
        Layer zerothIrrelevantLayer = (Layer)$this.layers().apply(0);
        FullBipartiteConnection firstConnectionLayer = (FullBipartiteConnection)$this.layers().apply(1);
        BiasedUnitLayer secondUnitLayer = (BiasedUnitLayer)$this.layers().apply(2);
        Mat transformedOffset = firstConnectionLayer.propagate(offset);
        Mat modifiedBias = secondUnitLayer.parameters().$plus(transformedOffset);
        Mat modifiedWeights = factor.$times(firstConnectionLayer.parameters());
        FullBipartiteConnection newConnectionLayer = firstConnectionLayer.build(modifiedWeights);
        MatrixParameterizedLayer newUnitLayer = secondUnitLayer.build(modifiedBias);
        Layer layer = zerothIrrelevantLayer;
        FullBipartiteConnection fullBipartiteConnection = newConnectionLayer;
        MatrixParameterizedLayer matrixParameterizedLayer = newUnitLayer;
        List newLayers = $this.layers().drop(3).$colon$colon((Object)matrixParameterizedLayer).$colon$colon((Object)fullBipartiteConnection).$colon$colon((Object)layer);
        return $this.build((List<Layer>)newLayers);
    }

    public static void $init$(NeuralNetLike $this) {
        $this.dataSample_$eq((Option<Mat>)None$.MODULE$);
    }
}

