package adams.flow.transformer;

import adams.core.ObjectCopyHelper;
import adams.core.QuickInfoHelper;
import adams.core.Range;
import adams.core.Utils;
import adams.core.base.BaseRegExp;
import adams.core.base.BaseString;
import adams.core.option.OptionUtils;
import adams.data.weka.WekaAttributeIndex;
import adams.data.weka.WekaAttributeRange;
import adams.flow.container.CNTKMultiFilterResultContainer;
import adams.flow.core.Token;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonObject;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Iterator;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.CNTKSaver;
import weka.filters.AllFilter;
import weka.filters.Filter;
import weka.filters.MultiFilter;
import weka.filters.unsupervised.attribute.AddID;
import weka.filters.unsupervised.attribute.PartitionedMultiFilter2;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:adams/flow/transformer/CNTKMultiFilter.class */
public class CNTKMultiFilter extends AbstractTransformer {
    private static final long serialVersionUID = 9077096252192331835L;
    public static final String PREFIX_TARGETS = "targets";
    public static final String PREFIX_FILTERED = "filtered";
    protected String m_DomainName;
    protected String m_DomainType;
    protected Filter[] m_Filters;
    protected BaseRegExp[] m_RegExps;
    protected BaseString[] m_Prefixes;
    protected WekaAttributeRange m_Targets;
    protected WekaAttributeIndex m_InputIDAttribute;
    protected String m_OutputIDAttribute;

    public String globalInfo() {
        return "Applies the filters to the incoming data (also adds a numeric ID column) and outputs this new dataset alongside Python code for CNTK.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("domain-name", "domainName", "");
        this.m_OptionManager.add("domain-type", "domainType", "");
        this.m_OptionManager.add("filter", "filters", new Filter[0]);
        this.m_OptionManager.add("regexp", "regExps", new BaseRegExp[0]);
        this.m_OptionManager.add("prefix", "prefixes", new BaseString[0]);
        this.m_OptionManager.add(PREFIX_TARGETS, PREFIX_TARGETS, new WekaAttributeRange("last"));
        this.m_OptionManager.add("input-id-att", "inputIDAttribute", new WekaAttributeIndex("first"));
        this.m_OptionManager.add("output-id-att", "outputIDAttribute", "ID");
    }

    public void setDomainName(String str) {
        this.m_DomainName = str;
        reset();
    }

    public String getDomainName() {
        return this.m_DomainName;
    }

    public String domainNameTipText() {
        return "The name for the domain.";
    }

    public void setDomainType(String str) {
        this.m_DomainType = str;
        reset();
    }

    public String getDomainType() {
        return this.m_DomainType;
    }

    public String domainTypeTipText() {
        return "The type for the domain.";
    }

    public void setFilters(Filter[] filterArr) {
        this.m_Filters = filterArr;
        this.m_Prefixes = (BaseString[]) Utils.adjustArray(this.m_Prefixes, this.m_Filters.length, new BaseString());
        this.m_RegExps = (BaseRegExp[]) Utils.adjustArray(this.m_RegExps, this.m_Filters.length, new BaseRegExp());
        reset();
    }

    public Filter[] getFilters() {
        return this.m_Filters;
    }

    public String filtersTipText() {
        return "The filters to apply individually to the data (excluding targets and sample ID).";
    }

    public void setRegExps(BaseRegExp[] baseRegExpArr) {
        this.m_RegExps = baseRegExpArr;
        this.m_Filters = (Filter[]) Utils.adjustArray(this.m_Filters, this.m_RegExps.length, new AllFilter());
        this.m_Prefixes = (BaseString[]) Utils.adjustArray(this.m_Prefixes, this.m_RegExps.length, new BaseString());
        reset();
    }

    public BaseRegExp[] getRegExps() {
        return this.m_RegExps;
    }

    public String regExpsTipText() {
        return "The regular expression to apply to the attribute names to identify numeric attributes to use for a filter.";
    }

    public void setPrefixes(BaseString[] baseStringArr) {
        this.m_Prefixes = baseStringArr;
        this.m_Filters = (Filter[]) Utils.adjustArray(this.m_Filters, this.m_Prefixes.length, new AllFilter());
        this.m_RegExps = (BaseRegExp[]) Utils.adjustArray(this.m_RegExps, this.m_Prefixes.length, new BaseRegExp());
        reset();
    }

    public BaseString[] getPrefixes() {
        return this.m_Prefixes;
    }

    public String prefixesTipText() {
        return "The prefixes for the attributes to use (- gets added automatically).";
    }

    public void setTargets(WekaAttributeRange wekaAttributeRange) {
        this.m_Targets = wekaAttributeRange;
        reset();
    }

    public WekaAttributeRange getTargets() {
        return this.m_Targets;
    }

    public String targetsTipText() {
        return "The attributes in the dataset that are considered targets.";
    }

    public void setInputIDAttribute(WekaAttributeIndex wekaAttributeIndex) {
        this.m_InputIDAttribute = wekaAttributeIndex;
        reset();
    }

    public WekaAttributeIndex getInputIDAttribute() {
        return this.m_InputIDAttribute;
    }

    public String inputIDAttributeTipText() {
        return "The attribute index in the input dataset that contains the unique ID for which to generate the ID mapping.";
    }

    public void setOutputIDAttribute(String str) {
        this.m_OutputIDAttribute = str;
        reset();
    }

    public String getOutputIDAttribute() {
        return this.m_OutputIDAttribute;
    }

    public String outputIDAttributeTipText() {
        return "The attribute name in the output dataset that contains the numeric unique ID for which the ID mapping was generated.";
    }

    public String getQuickInfo() {
        return ((((QuickInfoHelper.toString(this, "domainName", this.m_DomainName.isEmpty() ? "-none-" : this.m_DomainName, "name: ") + QuickInfoHelper.toString(this, "domainType", this.m_DomainType.isEmpty() ? "-none-" : this.m_DomainType, ", type: ")) + QuickInfoHelper.toString(this, "filters", this.m_Filters.length + " filters", "")) + QuickInfoHelper.toString(this, PREFIX_TARGETS, this.m_Targets, ", targets: ")) + QuickInfoHelper.toString(this, "inputIDAttribute", this.m_InputIDAttribute, ", input ID: ")) + QuickInfoHelper.toString(this, "outputIDAttribute", this.m_OutputIDAttribute, ", output ID: ");
    }

    public Class[] accepts() {
        return new Class[]{Instances.class};
    }

    public Class[] generates() {
        return new Class[]{CNTKMultiFilterResultContainer.class};
    }

    protected String extractTarget(String str) {
        return str.replaceAll("targets-[0-9]+-", "");
    }

    protected TIntSet generateAttributeBlacklist(Instances instances) {
        this.m_Targets.setData(instances);
        TIntHashSet tIntHashSet = new TIntHashSet(this.m_Targets.getIntIndices());
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i = 0; i < instances.numAttributes(); i++) {
            if (!instances.attribute(i).isNumeric()) {
                tIntArrayList.add(i);
            }
        }
        tIntHashSet.addAll(tIntArrayList);
        return tIntHashSet;
    }

    protected TIntList[] generateAttributeIndices(Instances instances, TIntSet tIntSet) throws Exception {
        TIntList[] tIntListArr = new TIntList[this.m_Filters.length];
        for (int i = 0; i < tIntListArr.length; i++) {
            tIntListArr[i] = new TIntArrayList();
            for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
                if (!tIntSet.contains(i2) && this.m_RegExps[i].isMatch(instances.attribute(i2).name())) {
                    tIntListArr[i].add(i2);
                }
            }
        }
        return tIntListArr;
    }

    protected Filter generateFilter(Instances instances, TIntList[] tIntListArr) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.m_Filters.length; i++) {
            arrayList.add(ObjectCopyHelper.copyObject(this.m_Filters[i]));
        }
        arrayList.add(new AllFilter());
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < this.m_Filters.length; i2++) {
            Range range = new Range();
            range.setMax(instances.numAttributes());
            range.setIndices(tIntListArr[i2].toArray());
            arrayList2.add(new weka.core.Range(range.toExplicitRange()));
        }
        arrayList2.add(new weka.core.Range(this.m_Targets.toExplicitRange()));
        ArrayList arrayList3 = new ArrayList();
        for (int i3 = 0; i3 < this.m_Prefixes.length; i3++) {
            arrayList3.add(new BaseString(this.m_Prefixes[i3].getValue()));
        }
        arrayList3.add(new BaseString(PREFIX_TARGETS));
        Filter partitionedMultiFilter2 = new PartitionedMultiFilter2();
        partitionedMultiFilter2.setFilters((Filter[]) arrayList.toArray(new Filter[0]));
        partitionedMultiFilter2.setRanges((weka.core.Range[]) arrayList2.toArray(new weka.core.Range[0]));
        partitionedMultiFilter2.setPrefixes((BaseString[]) arrayList3.toArray(new BaseString[0]));
        partitionedMultiFilter2.setRemoveUnused(true);
        Filter addID = new AddID();
        addID.setAttributeName(this.m_OutputIDAttribute);
        MultiFilter multiFilter = new MultiFilter();
        multiFilter.setFilters(new Filter[]{partitionedMultiFilter2, addID});
        if (isLoggingEnabled()) {
            getLogger().info("MultiFilter: " + OptionUtils.getCommandLine(multiFilter));
        }
        return multiFilter;
    }

    protected Instances filter(Filter filter, Instances instances) throws Exception {
        filter.setInputFormat(instances);
        return Filter.useFilter(instances, filter);
    }

    protected String toString(JsonObject jsonObject) {
        return new GsonBuilder().setPrettyPrinting().create().toJson(jsonObject);
    }

    protected String generateIDs(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        instances2.setClassIndex(-1);
        this.m_InputIDAttribute.setData(instances2);
        int intIndex = this.m_InputIDAttribute.getIntIndex();
        if (intIndex == -1) {
            throw new IllegalStateException("Failed to locate unique ID attribute in input data: " + this.m_InputIDAttribute);
        }
        Filter remove = new Remove();
        remove.setAttributeIndicesArray(new int[]{intIndex});
        remove.setInvertSelection(true);
        MultiFilter multiFilter = new MultiFilter();
        multiFilter.setFilters(new Filter[]{remove, new AddID()});
        multiFilter.setInputFormat(instances2);
        Instances useFilter = Filter.useFilter(instances2, multiFilter);
        JsonObject jsonObject = new JsonObject();
        Iterator it = useFilter.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            jsonObject.addProperty("" + ((int) instance.value(0)), instance.stringValue(1));
        }
        return toString(jsonObject);
    }

    protected CNTKSaver generateSaver(Instances instances) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        this.m_InputIDAttribute.setData(instances);
        arrayList.add(new Range("" + (instances.attribute(this.m_OutputIDAttribute).index() + 1)));
        arrayList2.add(new BaseString(this.m_OutputIDAttribute));
        for (BaseString baseString : this.m_Prefixes) {
            arrayList2.add(new BaseString(baseString.getValue()));
        }
        for (int i = 0; i < this.m_Filters.length; i++) {
            TIntArrayList tIntArrayList = new TIntArrayList();
            String value = this.m_Prefixes[i].getValue();
            if (value.trim().isEmpty()) {
                value = PREFIX_FILTERED;
            }
            String str = value + "-" + i + "-";
            for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
                if (instances.attribute(i2).name().startsWith(str)) {
                    tIntArrayList.add(i2);
                }
            }
            Range range = new Range();
            range.setMax(instances.numAttributes());
            range.setIndices(tIntArrayList.toArray());
            arrayList.add(range);
        }
        for (int i3 = 0; i3 < instances.numAttributes(); i3++) {
            if (instances.attribute(i3).name().startsWith(PREFIX_TARGETS)) {
                Range range2 = new Range();
                range2.setMax(instances.numAttributes());
                range2.setIndices(new int[]{i3});
                arrayList.add(range2);
                arrayList2.add(new BaseString(extractTarget(instances.attribute(i3).name())));
            }
        }
        CNTKSaver cNTKSaver = new CNTKSaver();
        cNTKSaver.setInputs((Range[]) arrayList.toArray(new Range[0]));
        cNTKSaver.setInputNames((BaseString[]) arrayList2.toArray(new BaseString[0]));
        return cNTKSaver;
    }

    protected String generateDefinition(CNTKSaver cNTKSaver, Filter filter, Instances instances, TIntList[] tIntListArr) {
        JsonObject jsonObject = new JsonObject();
        JsonObject jsonObject2 = new JsonObject();
        jsonObject2.addProperty("Name", this.m_DomainName);
        jsonObject2.addProperty("Type", this.m_DomainType);
        jsonObject.add("Domain", jsonObject2);
        jsonObject.addProperty("UniqueID", this.m_OutputIDAttribute);
        JsonObject jsonObject3 = new JsonObject();
        for (int i = 0; i < cNTKSaver.getInputNames().length; i++) {
            Range clone = cNTKSaver.getInputs()[i].getClone();
            clone.setMax(instances.numAttributes());
            jsonObject3.addProperty(cNTKSaver.getInputNames()[i].getValue(), Integer.valueOf(clone.getIntIndices().length));
        }
        jsonObject.add("Inputs", jsonObject3);
        jsonObject.addProperty("Filter", OptionUtils.getCommandLine(filter));
        JsonObject jsonObject4 = new JsonObject();
        for (int i2 = 0; i2 < this.m_Filters.length; i2++) {
            Range range = new Range();
            range.setMax(instances.numAttributes());
            range.setIndices(tIntListArr[i2].toArray());
            jsonObject4.addProperty(this.m_Prefixes[i2].getValue(), range.toExplicitRange());
        }
        jsonObject.add("AttributeRanges", jsonObject4);
        return toString(jsonObject);
    }

    protected String doExecute() {
        String str = null;
        Instances instances = (Instances) this.m_InputToken.getPayload(Instances.class);
        try {
            TIntList[] generateAttributeIndices = generateAttributeIndices(instances, generateAttributeBlacklist(instances));
            Filter generateFilter = generateFilter(instances, generateAttributeIndices);
            Instances filter = filter(generateFilter, instances);
            String generateIDs = generateIDs(instances);
            CNTKSaver generateSaver = generateSaver(filter);
            this.m_OutputToken = new Token(new CNTKMultiFilterResultContainer(filter, generateIDs, OptionUtils.getCommandLine(generateSaver), generateDefinition(generateSaver, generateFilter, filter, generateAttributeIndices)));
        } catch (Exception e) {
            str = handleException("Failed to filter data!", e);
        }
        return str;
    }
}
