/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.transform.transform.integer;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.IntegerMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(value={"inputSchema", "columnIdx", "stateNames", "statesMap"})
public class IntegerToOneHotTransform
extends BaseTransform {
    private String columnName;
    private int minValue;
    private int maxValue;
    private int columnIdx = -1;

    public IntegerToOneHotTransform(@JsonProperty(value="columnName") String columnName, @JsonProperty(value="minValue") int minValue, @JsonProperty(value="maxValue") int maxValue) {
        this.columnName = columnName;
        this.minValue = minValue;
        this.maxValue = maxValue;
    }

    @Override
    public void setInputSchema(Schema inputSchema) {
        super.setInputSchema(inputSchema);
        this.columnIdx = inputSchema.getIndexOfColumn(this.columnName);
        ColumnMetaData meta = inputSchema.getMetaData(this.columnName);
        if (!(meta instanceof IntegerMetaData)) {
            throw new IllegalStateException("Cannot convert column \"" + this.columnName + "\" from integer to one-hot: column is not integer (is: " + meta.getColumnType() + ")");
        }
    }

    @Override
    public String toString() {
        return "CategoricalToOneHotTransform(columnName=\"" + this.columnName + "\")";
    }

    @Override
    public Schema transform(Schema schema) {
        List<String> origNames = schema.getColumnNames();
        List<ColumnMetaData> origMeta = schema.getColumnMetaData();
        int i = 0;
        Iterator<String> namesIter = origNames.iterator();
        Iterator<ColumnMetaData> typesIter = origMeta.iterator();
        ArrayList<ColumnMetaData> newMeta = new ArrayList<ColumnMetaData>(schema.numColumns());
        while (namesIter.hasNext()) {
            String s = namesIter.next();
            ColumnMetaData t = typesIter.next();
            if (i++ == this.columnIdx) {
                for (int x = this.minValue; x <= this.maxValue; ++x) {
                    String newName = s + "[" + x + "]";
                    newMeta.add(new IntegerMetaData(newName, 0, 1));
                }
                continue;
            }
            newMeta.add(t);
        }
        return schema.newSchema(newMeta);
    }

    @Override
    public List<Writable> map(List<Writable> writables) {
        if (writables.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + this.toString());
        }
        int idx = this.getColumnIdx();
        int n = this.maxValue - this.minValue + 1;
        ArrayList<Writable> out = new ArrayList<Writable>(writables.size() + n);
        int i = 0;
        for (Writable w : writables) {
            if (i++ == idx) {
                int currValue = w.toInt();
                if (currValue < this.minValue || currValue > this.maxValue) {
                    throw new IllegalStateException("Invalid value: integer value (" + currValue + ") is outside of valid range: must be between " + this.minValue + " and " + this.maxValue + " inclusive");
                }
                for (int j = this.minValue; j <= this.maxValue; ++j) {
                    if (j == currValue) {
                        out.add(new IntWritable(1));
                        continue;
                    }
                    out.add(new IntWritable(0));
                }
                continue;
            }
            out.add(w);
        }
        return out;
    }

    @Override
    public Object map(Object input) {
        int currValue = ((Number)input).intValue();
        if (currValue < this.minValue || currValue > this.maxValue) {
            throw new IllegalStateException("Invalid value: integer value (" + currValue + ") is outside of valid range: must be between " + this.minValue + " and " + this.maxValue + " inclusive");
        }
        ArrayList<Integer> oneHot = new ArrayList<Integer>();
        for (int j = this.minValue; j <= this.maxValue; ++j) {
            if (j == currValue) {
                oneHot.add(1);
                continue;
            }
            oneHot.add(0);
        }
        return oneHot;
    }

    @Override
    public Object mapSequence(Object sequence) {
        List values = (List)sequence;
        ArrayList<List> ret = new ArrayList<List>();
        for (Object obj : values) {
            ret.add((List)this.map(obj));
        }
        return ret;
    }

    @Override
    public String outputColumnName() {
        throw new UnsupportedOperationException("Output column name will be more than 1");
    }

    @Override
    public String[] outputColumnNames() {
        List<String> l = this.transform(this.inputSchema).getColumnNames();
        return l.toArray(new String[l.size()]);
    }

    @Override
    public String[] columnNames() {
        return new String[]{this.columnName};
    }

    @Override
    public String columnName() {
        return this.columnName;
    }

    public String getColumnName() {
        return this.columnName;
    }

    public int getMinValue() {
        return this.minValue;
    }

    public int getMaxValue() {
        return this.maxValue;
    }

    public int getColumnIdx() {
        return this.columnIdx;
    }

    public void setColumnName(String columnName) {
        this.columnName = columnName;
    }

    public void setMinValue(int minValue) {
        this.minValue = minValue;
    }

    public void setMaxValue(int maxValue) {
        this.maxValue = maxValue;
    }

    public void setColumnIdx(int columnIdx) {
        this.columnIdx = columnIdx;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof IntegerToOneHotTransform)) {
            return false;
        }
        IntegerToOneHotTransform other = (IntegerToOneHotTransform)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getMinValue() != other.getMinValue()) {
            return false;
        }
        if (this.getMaxValue() != other.getMaxValue()) {
            return false;
        }
        String this$columnName = this.getColumnName();
        String other$columnName = other.getColumnName();
        return !(this$columnName == null ? other$columnName != null : !this$columnName.equals(other$columnName));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof IntegerToOneHotTransform;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getMinValue();
        result = result * 59 + this.getMaxValue();
        String $columnName = this.getColumnName();
        result = result * 59 + ($columnName == null ? 43 : $columnName.hashCode());
        return result;
    }
}

