/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.transform.transform.column;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.Writable;

@JsonIgnoreProperties(value={"inputSchema", "columnsToKeepIdx", "indicesToKeep"})
public class RemoveAllColumnsExceptForTransform
extends BaseTransform {
    private int[] columnsToKeepIdx;
    private String[] columnsToKeep;
    private Set<Integer> indicesToKeep;

    public RemoveAllColumnsExceptForTransform(String ... columnsToKeep) {
        this.columnsToKeep = columnsToKeep;
    }

    @Override
    public void setInputSchema(Schema schema) {
        super.setInputSchema(schema);
        this.indicesToKeep = new HashSet<Integer>();
        int i = 0;
        this.columnsToKeepIdx = new int[this.columnsToKeep.length];
        for (String s : this.columnsToKeep) {
            int idx = schema.getIndexOfColumn(s);
            if (idx < 0) {
                throw new RuntimeException("Column \"" + s + "\" not found");
            }
            this.columnsToKeepIdx[i++] = idx;
            this.indicesToKeep.add(idx);
        }
    }

    @Override
    public Schema transform(Schema schema) {
        List<String> origNames = schema.getColumnNames();
        List<ColumnMetaData> origMeta = schema.getColumnMetaData();
        HashSet keepSet = new HashSet();
        Collections.addAll(keepSet, this.columnsToKeep);
        ArrayList<ColumnMetaData> newMeta = new ArrayList<ColumnMetaData>(this.columnsToKeep.length);
        Iterator<String> namesIter = origNames.iterator();
        Iterator<ColumnMetaData> metaIter = origMeta.iterator();
        while (namesIter.hasNext()) {
            String n = namesIter.next();
            ColumnMetaData t = metaIter.next();
            if (!keepSet.contains(n)) continue;
            newMeta.add(t);
        }
        return schema.newSchema(newMeta);
    }

    @Override
    public List<Writable> map(List<Writable> writables) {
        if (writables.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + this.toString());
        }
        ArrayList<Writable> outList = new ArrayList<Writable>(this.columnsToKeep.length);
        int i = 0;
        for (Writable w : writables) {
            if (!this.indicesToKeep.contains(i++)) continue;
            outList.add(w);
        }
        return outList;
    }

    @Override
    public String toString() {
        return "RemoveAllColumnsExceptForTransform(" + Arrays.toString(this.columnsToKeep) + ")";
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        RemoveAllColumnsExceptForTransform o2 = (RemoveAllColumnsExceptForTransform)o;
        return Arrays.equals(this.columnsToKeep, o2.columnsToKeep);
    }

    public int hashCode() {
        return Arrays.hashCode(this.columnsToKeep);
    }
}

