/*
 * Decompiled with CFR 0.152.
 */
package org.kramerlab.autoencoder.experiments;

import org.kramerlab.autoencoder.experiments.ClassificationResult;
import org.kramerlab.autoencoder.math.matrix.Mat;
import org.kramerlab.autoencoder.math.matrix.Mat$;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class ErrorMeasures$ {
    public static final ErrorMeasures$ MODULE$;

    static {
        new ErrorMeasures$();
    }

    public int reconstructionError(Mat binaryData, Mat binaryReconstruction) {
        Mat binaryDifference = binaryReconstruction.$minus(binaryData);
        return (int)binaryDifference.normSq();
    }

    public double averageBalancedAccuracy(Mat binaryData, Mat binaryReconstruction) {
        int w = binaryData.width();
        int h = binaryData.height();
        Mat row = Mat$.MODULE$.ones(1, h);
        Mat falsePositives = binaryReconstruction.$minus(binaryData).map((Function1<Object, Object>)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final double apply(double x) {
                return this.apply$mcDD$sp(x);
            }

            public double apply$mcDD$sp(double x) {
                return package$.MODULE$.max(x, 0.0);
            }
        });
        Mat falseNegatives = binaryData.$minus(binaryReconstruction).map((Function1<Object, Object>)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final double apply(double x) {
                return this.apply$mcDD$sp(x);
            }

            public double apply$mcDD$sp(double x) {
                return package$.MODULE$.max(x, 0.0);
            }
        });
        Mat falsePositivesForColumns = row.$times(falsePositives);
        Mat falseNegativesForColumns = row.$times(falseNegatives);
        Mat positivesForColumns = row.$times(binaryReconstruction);
        Mat negativesForColumns = Mat$.MODULE$.ones(1, w).$times(h).$minus(positivesForColumns);
        Mat truePositivesForColumns = positivesForColumns.$minus(falsePositivesForColumns);
        Mat trueNegativesForColumns = negativesForColumns.$minus(falseNegativesForColumns);
        ClassificationResult[] classificationResults = (ClassificationResult[])Array$.MODULE$.tabulate(w, (Function1)new Serializable(falsePositivesForColumns, falseNegativesForColumns, truePositivesForColumns, trueNegativesForColumns){
            public static final long serialVersionUID = 0L;
            private final Mat falsePositivesForColumns$1;
            private final Mat falseNegativesForColumns$1;
            private final Mat truePositivesForColumns$1;
            private final Mat trueNegativesForColumns$1;

            public final ClassificationResult apply(int i) {
                return new ClassificationResult((int)this.truePositivesForColumns$1.apply(0, i), (int)this.trueNegativesForColumns$1.apply(0, i), (int)this.falsePositivesForColumns$1.apply(0, i), (int)this.falseNegativesForColumns$1.apply(0, i));
            }
            {
                this.falsePositivesForColumns$1 = falsePositivesForColumns$1;
                this.falseNegativesForColumns$1 = falseNegativesForColumns$1;
                this.truePositivesForColumns$1 = truePositivesForColumns$1;
                this.trueNegativesForColumns$1 = trueNegativesForColumns$1;
            }
        }, ClassTag$.MODULE$.apply(ClassificationResult.class));
        return BoxesRunTime.unboxToDouble((Object)Predef$.MODULE$.doubleArrayOps((double[])Predef$.MODULE$.refArrayOps((Object[])classificationResults).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final double apply(ClassificationResult x$1) {
                return x$1.balancedAccuracy();
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()))).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$)) / (double)w;
    }

    public void main(String[] args) {
        ClassificationResult coin = new ClassificationResult(45, 5, 45, 5);
        ClassificationResult classificationResult = new ClassificationResult(90, 0, 10, 0);
        ClassificationResult good = new ClassificationResult(85, 9, 5, 1);
        ClassificationResult veryGood = new ClassificationResult(89, 9, 1, 1);
        ClassificationResult perfect = new ClassificationResult(90, 10, 0, 0);
        Predef$.MODULE$.println((Object)BoxesRunTime.boxToDouble((double)coin.balancedAccuracy()));
        Predef$.MODULE$.println((Object)BoxesRunTime.boxToDouble((double)classificationResult.balancedAccuracy()));
        Predef$.MODULE$.println((Object)BoxesRunTime.boxToDouble((double)good.balancedAccuracy()));
        Predef$.MODULE$.println((Object)BoxesRunTime.boxToDouble((double)veryGood.balancedAccuracy()));
        Predef$.MODULE$.println((Object)BoxesRunTime.boxToDouble((double)perfect.balancedAccuracy()));
        Mat reality = Mat$.MODULE$.apply(6, 3).apply((Seq<Object>)Predef$.MODULE$.wrapDoubleArray(new double[]{1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0}));
        Mat prediction = Mat$.MODULE$.apply(6, 3).apply((Seq<Object>)Predef$.MODULE$.wrapDoubleArray(new double[]{0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}));
        double aba = this.averageBalancedAccuracy(reality, prediction);
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"manually calculated: ").append((Object)BoxesRunTime.boxToDouble((double)0.7166666666666667)).toString());
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"result: ").append((Object)BoxesRunTime.boxToDouble((double)aba)).toString());
    }

    private ErrorMeasures$() {
        MODULE$ = this;
    }
}

