/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.datavec.api.berkeley.Pair;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.sparkfunction.SequenceToRows;
import org.datavec.spark.transform.sparkfunction.ToRecord;
import org.datavec.spark.transform.sparkfunction.ToRow;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceCreateCombiner;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceMergeCombiner;
import org.datavec.spark.transform.sparkfunction.sequence.DataFrameToSequenceMergeValue;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class DataFrames {
    public static final String SEQUENCE_UUID_COLUMN = "__SEQ_UUID";
    public static final String SEQUENCE_INDEX_COLUMN = "__SEQ_IDX";

    private DataFrames() {
    }

    public static Column std(DataFrame dataFrame, String columnName) {
        return functions.sqrt((Column)DataFrames.var(dataFrame, columnName));
    }

    public static Column var(DataFrame dataFrame, String columnName) {
        return dataFrame.groupBy(columnName, new String[0]).agg(functions.variance((String)columnName), new Column[0]).col(columnName);
    }

    public static Column min(DataFrame dataFrame, String columnName) {
        return dataFrame.groupBy(columnName, new String[0]).agg(functions.min((String)columnName), new Column[0]).col(columnName);
    }

    public static Column max(DataFrame dataFrame, String columnName) {
        return dataFrame.groupBy(columnName, new String[0]).agg(functions.max((String)columnName), new Column[0]).col(columnName);
    }

    public static Column mean(DataFrame dataFrame, String columnName) {
        return dataFrame.groupBy(columnName, new String[0]).agg(functions.avg((String)columnName), new Column[0]).col(columnName);
    }

    public static StructType fromSchema(Schema schema) {
        StructField[] structFields = new StructField[schema.numColumns()];
        block6: for (int i = 0; i < structFields.length; ++i) {
            switch ((ColumnType)schema.getColumnTypes().get(i)) {
                case Double: {
                    structFields[i] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                    continue block6;
                }
                case Integer: {
                    structFields[i] = new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                    continue block6;
                }
                case Long: {
                    structFields[i] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                    continue block6;
                }
                case Float: {
                    structFields[i] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                    continue block6;
                }
                default: {
                    throw new IllegalStateException("This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
                }
            }
        }
        return new StructType(structFields);
    }

    public static StructType fromSchemaSequence(Schema schema) {
        StructField[] structFields = new StructField[schema.numColumns() + 2];
        structFields[0] = new StructField(SEQUENCE_UUID_COLUMN, DataTypes.StringType, false, Metadata.empty());
        structFields[1] = new StructField(SEQUENCE_INDEX_COLUMN, DataTypes.IntegerType, false, Metadata.empty());
        block6: for (int i = 0; i < schema.numColumns(); ++i) {
            switch ((ColumnType)schema.getColumnTypes().get(i)) {
                case Double: {
                    structFields[i + 2] = new StructField(schema.getName(i), DataTypes.DoubleType, false, Metadata.empty());
                    continue block6;
                }
                case Integer: {
                    structFields[i + 2] = new StructField(schema.getName(i), DataTypes.IntegerType, false, Metadata.empty());
                    continue block6;
                }
                case Long: {
                    structFields[i + 2] = new StructField(schema.getName(i), DataTypes.LongType, false, Metadata.empty());
                    continue block6;
                }
                case Float: {
                    structFields[i + 2] = new StructField(schema.getName(i), DataTypes.FloatType, false, Metadata.empty());
                    continue block6;
                }
                default: {
                    throw new IllegalStateException("This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
                }
            }
        }
        return new StructType(structFields);
    }

    public static Schema fromStructType(StructType structType) {
        Schema.Builder builder = new Schema.Builder();
        StructField[] fields = structType.fields();
        String[] fieldNames = structType.fieldNames();
        block15: for (int i = 0; i < fields.length; ++i) {
            String name;
            switch (name = fields[i].dataType().typeName().toLowerCase()) {
                case "double": {
                    builder.addColumnDouble(fieldNames[i]);
                    continue block15;
                }
                case "float": {
                    builder.addColumnFloat(fieldNames[i]);
                    continue block15;
                }
                case "long": {
                    builder.addColumnLong(fieldNames[i]);
                    continue block15;
                }
                case "int": 
                case "integer": {
                    builder.addColumnInteger(fieldNames[i]);
                    continue block15;
                }
                case "string": {
                    builder.addColumnString(fieldNames[i]);
                    continue block15;
                }
                default: {
                    throw new RuntimeException("Unknown type: " + name);
                }
            }
        }
        return builder.build();
    }

    public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(DataFrame dataFrame) {
        Schema schema = DataFrames.fromStructType(dataFrame.schema());
        return new Pair((Object)schema, (Object)dataFrame.javaRDD().map((Function)new ToRecord(schema)));
    }

    public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(DataFrame dataFrame) {
        JavaPairRDD grouped = dataFrame.javaRDD().groupBy((Function)new Function<Row, String>(){

            public String call(Row row) throws Exception {
                return row.getString(0);
            }
        });
        Schema schema = DataFrames.fromStructType(dataFrame.schema());
        DataFrameToSequenceCreateCombiner createCombiner = new DataFrameToSequenceCreateCombiner(schema);
        DataFrameToSequenceMergeValue mergeValue = new DataFrameToSequenceMergeValue(schema);
        DataFrameToSequenceMergeCombiner mergeCombiners = new DataFrameToSequenceMergeCombiner();
        JavaRDD sequences = grouped.combineByKey((Function)createCombiner, (Function2)mergeValue, (Function2)mergeCombiners).values();
        JavaRDD out = sequences.map((Function)new Function<List<List<Writable>>, List<List<Writable>>>(){

            public List<List<Writable>> call(List<List<Writable>> v1) throws Exception {
                ArrayList<List<Writable>> out = new ArrayList<List<Writable>>(v1.size());
                for (List<Writable> l : v1) {
                    ArrayList<Writable> subset = new ArrayList<Writable>();
                    for (int i = 2; i < l.size(); ++i) {
                        subset.add(l.get(i));
                    }
                    out.add(subset);
                }
                return out;
            }
        });
        return new Pair((Object)schema, (Object)out);
    }

    public static DataFrame toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
        JavaSparkContext sc = new JavaSparkContext(data.context());
        SQLContext sqlContext = new SQLContext(sc);
        JavaRDD rows = data.map((Function)new ToRow(schema));
        DataFrame dataFrame = sqlContext.createDataFrame(rows, DataFrames.fromSchema(schema));
        return dataFrame;
    }

    public static DataFrame toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
        JavaSparkContext sc = new JavaSparkContext(data.context());
        SQLContext sqlContext = new SQLContext(sc);
        JavaRDD rows = data.flatMap((FlatMapFunction)new SequenceToRows(schema));
        DataFrame dataFrame = sqlContext.createDataFrame(rows, DataFrames.fromSchemaSequence(schema));
        return dataFrame;
    }

    public static List<Writable> rowToWritables(Schema schema, Row row) {
        ArrayList<Writable> ret = new ArrayList<Writable>();
        block7: for (int i = 0; i < row.size(); ++i) {
            switch (schema.getType(i)) {
                case Double: {
                    ret.add((Writable)new DoubleWritable(row.getDouble(i)));
                    continue block7;
                }
                case Float: {
                    ret.add((Writable)new FloatWritable(row.getFloat(i)));
                    continue block7;
                }
                case Integer: {
                    ret.add((Writable)new IntWritable(row.getInt(i)));
                    continue block7;
                }
                case Long: {
                    ret.add((Writable)new LongWritable(row.getLong(i)));
                    continue block7;
                }
                case String: {
                    ret.add((Writable)new Text(row.getString(i)));
                    continue block7;
                }
                default: {
                    throw new IllegalStateException("Illegal type");
                }
            }
        }
        return ret;
    }

    public static List<String> toList(String[] input) {
        ArrayList<String> ret = new ArrayList<String>();
        for (int i = 0; i < input.length; ++i) {
            ret.add(input[i]);
        }
        return ret;
    }

    public static String[] toArray(List<String> list) {
        String[] ret = new String[list.size()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = list.get(i);
        }
        return ret;
    }

    public static INDArray toMatrix(List<Row> rows) {
        INDArray ret = Nd4j.create((int)rows.size(), (int)rows.get(0).size());
        for (int i = 0; i < ret.rows(); ++i) {
            for (int j = 0; j < ret.columns(); ++j) {
                ret.putScalar(i, j, rows.get(i).getDouble(j));
            }
        }
        return ret;
    }

    public static List<Column> toColumn(List<String> columns) {
        ArrayList<Column> ret = new ArrayList<Column>();
        for (String s : columns) {
            ret.add(functions.col((String)s));
        }
        return ret;
    }

    public static Column[] toColumns(String ... columns) {
        Column[] ret = new Column[columns.length];
        for (int i = 0; i < columns.length; ++i) {
            ret[i] = functions.col((String)columns[i]);
        }
        return ret;
    }
}

