package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.io.PlaceholderFile;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import gnu.trove.list.array.TIntArrayList;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils;
import weka.core.tokenizers.cleaners.RemoveNonWordCharTokens;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.EquiDistance;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:adams/flow/transformer/WekaInstancesMerge.class */
public class WekaInstancesMerge extends AbstractTransformer implements ProvenanceSupporter {
    private static final long serialVersionUID = -2923715594018710295L;
    protected boolean m_UsePrefix;
    protected boolean m_AddIndex;
    protected boolean m_Remove;
    protected String m_Prefix;
    protected String m_PrefixSeparator;
    protected String m_ExcludedAttributes;
    protected boolean m_InvertMatchingSense;
    protected String m_UniqueID;
    protected boolean m_KeepOnlySingleUniqueID;
    protected boolean m_Strict;
    protected int m_AttType;
    protected List<String> m_UniqueIDAtts;

    public String globalInfo() {
        return "Merges multiple datasets, either from file or using Instances/Instance objects.\nIf no 'ID' attribute is named, then all datasets must contain the same number of rows.\nAttributes can be excluded from ending up in the final dataset via a regular expression. They can also be prefixed with name and/or index.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("use-prefix", "usePrefix", false);
        this.m_OptionManager.add("add-index", "addIndex", false);
        this.m_OptionManager.add("remove", "remove", false);
        this.m_OptionManager.add(EquiDistance.PREFIX, EquiDistance.PREFIX, "dataset");
        this.m_OptionManager.add("prefix-separator", "prefixSeparator", "-");
        this.m_OptionManager.add("exclude-atts", "excludedAttributes", "");
        this.m_OptionManager.add(RemoveNonWordCharTokens.INVERT, "invertMatchingSense", false);
        this.m_OptionManager.add("unique-id", "uniqueID", "");
        this.m_OptionManager.add("keep-only-single-unique-id", "keepOnlySingleUniqueID", false);
        this.m_OptionManager.add("strict", "strict", false);
    }

    public void setRemove(boolean z) {
        this.m_Remove = z;
        reset();
    }

    public boolean getRemove() {
        return this.m_Remove;
    }

    public String removeTipText() {
        return "If true, only keep instances where data is available from each source.";
    }

    public void setUsePrefix(boolean z) {
        this.m_UsePrefix = z;
        reset();
    }

    public boolean getUsePrefix() {
        return this.m_UsePrefix;
    }

    public String usePrefixTipText() {
        return "Whether to prefix the attribute names of each dataset with an index and an optional string.";
    }

    public void setAddIndex(boolean z) {
        this.m_AddIndex = z;
        reset();
    }

    public boolean getAddIndex() {
        return this.m_AddIndex;
    }

    public String addIndexTipText() {
        return "Whether to add the index of the dataset to the prefix.";
    }

    public void setPrefix(String str) {
        this.m_Prefix = str;
        reset();
    }

    public String getPrefix() {
        return this.m_Prefix;
    }

    public String prefixTipText() {
        return "The optional prefix string to prefix the index number with (in case prefixes are used); '@' is a placeholder for the relation name.";
    }

    public void setPrefixSeparator(String str) {
        this.m_PrefixSeparator = str;
        reset();
    }

    public String getPrefixSeparator() {
        return this.m_PrefixSeparator;
    }

    public String prefixSeparatorTipText() {
        return "The separator string between the generated prefix and the original attribute name.";
    }

    public void setExcludedAttributes(String str) {
        this.m_ExcludedAttributes = str;
        reset();
    }

    public String getExcludedAttributes() {
        return this.m_ExcludedAttributes;
    }

    public String excludedAttributesTipText() {
        return "The regular expression used on the attribute names, to determine whether an attribute should be excluded or not (matching sense can be inverted); leave empty to include all attributes.";
    }

    public void setInvertMatchingSense(boolean z) {
        this.m_InvertMatchingSense = z;
        reset();
    }

    public boolean getInvertMatchingSense() {
        return this.m_InvertMatchingSense;
    }

    public String invertMatchingSenseTipText() {
        return "Whether to invert the matching sense of excluding attributes, ie, the regular expression is used for including attributes.";
    }

    public void setUniqueID(String str) {
        this.m_UniqueID = str;
        reset();
    }

    public String getUniqueID() {
        return this.m_UniqueID;
    }

    public String uniqueIDTipText() {
        return "The name of the attribute (string/numeric) used for uniquely identifying rows among the datasets.";
    }

    public void setKeepOnlySingleUniqueID(boolean z) {
        this.m_KeepOnlySingleUniqueID = z;
        reset();
    }

    public boolean getKeepOnlySingleUniqueID() {
        return this.m_KeepOnlySingleUniqueID;
    }

    public String keepOnlySingleUniqueIDTipText() {
        return "If enabled, only a single instance of the unique ID attribute is kept.";
    }

    public void setStrict(boolean z) {
        this.m_Strict = z;
        reset();
    }

    public boolean getStrict() {
        return this.m_Strict;
    }

    public String strictTipText() {
        return "If enabled, ensures that IDs in unique ID column are truly unique.";
    }

    public String getQuickInfo() {
        String quickInfoHelper = QuickInfoHelper.toString(this, EquiDistance.PREFIX, this.m_Prefix, "prefix: ");
        if (quickInfoHelper == null) {
            quickInfoHelper = "";
        }
        String quickInfoHelper2 = QuickInfoHelper.toString(this, "prefixSeparator", this.m_PrefixSeparator, ", separator: ");
        if (quickInfoHelper2 != null) {
            quickInfoHelper = quickInfoHelper + quickInfoHelper2;
        }
        String quickInfoHelper3 = QuickInfoHelper.toString(this, "excludedAttributes", this.m_ExcludedAttributes, ", excluded: ");
        if (quickInfoHelper3 != null) {
            quickInfoHelper = quickInfoHelper + quickInfoHelper3;
        }
        String quickInfoHelper4 = QuickInfoHelper.toString(this, "uniqueID", this.m_UniqueID, ", unique: ");
        if (quickInfoHelper4 != null) {
            quickInfoHelper = quickInfoHelper + quickInfoHelper4;
        }
        if (quickInfoHelper.startsWith(", ")) {
            quickInfoHelper = quickInfoHelper.substring(2);
        }
        ArrayList arrayList = new ArrayList();
        QuickInfoHelper.add(arrayList, QuickInfoHelper.toString(this, "addIndex", this.m_AddIndex, "index"));
        QuickInfoHelper.add(arrayList, QuickInfoHelper.toString(this, "usePrefix", this.m_UsePrefix, EquiDistance.PREFIX));
        QuickInfoHelper.add(arrayList, QuickInfoHelper.toString(this, "invertMatchingSense", this.m_InvertMatchingSense, RemoveNonWordCharTokens.INVERT));
        QuickInfoHelper.add(arrayList, QuickInfoHelper.toString(this, "remove", this.m_Remove, "remove"));
        QuickInfoHelper.add(arrayList, QuickInfoHelper.toString(this, "keepOnlySingleUniqueID", this.m_KeepOnlySingleUniqueID, "single unique ID"));
        QuickInfoHelper.add(arrayList, QuickInfoHelper.toString(this, "strict", this.m_Strict, "strict"));
        return quickInfoHelper + QuickInfoHelper.flatten(arrayList);
    }

    public Class[] accepts() {
        return new Class[]{String[].class, File[].class, Instance[].class, Instances[].class};
    }

    public Class[] generates() {
        return new Class[]{Instances.class};
    }

    protected Instances excludeAttributes(Instances instances) {
        Instances instances2;
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < instances.numAttributes(); i++) {
            if (instances.attribute(i).name().matches(this.m_ExcludedAttributes)) {
                if (sb.length() > 0) {
                    sb.append(",");
                }
                sb.append(i + 1);
            }
        }
        try {
            Remove remove = new Remove();
            remove.setAttributeIndices(sb.toString());
            remove.setInvertSelection(this.m_InvertMatchingSense);
            remove.setInputFormat(instances);
            instances2 = Filter.useFilter(instances, remove);
        } catch (Exception e) {
            instances2 = instances;
            handleException("Error filtering data:", e);
        }
        return instances2;
    }

    protected String createPrefix(Instances instances, int i) {
        String relationName = this.m_Prefix.equals(CrossValidationFoldGenerator.PLACEHOLDER_ORIGINAL) ? instances.relationName() : this.m_Prefix;
        if (this.m_AddIndex) {
            relationName = relationName + ((relationName.isEmpty() || relationName.endsWith(this.m_PrefixSeparator)) ? "" : this.m_PrefixSeparator) + (i + 1);
        }
        return relationName + this.m_PrefixSeparator;
    }

    protected Instances prefixAttributes(Instances instances, int i) {
        String createPrefix = createPrefix(instances, i);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
            arrayList.add(instances.attribute(i2).copy(createPrefix + instances.attribute(i2).name()));
        }
        Instances instances2 = new Instances(instances.relationName(), arrayList, instances.numInstances());
        instances2.setClassIndex(instances.classIndex());
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            instances2.add((Instance) instances.instance(i3).copy());
        }
        return instances2;
    }

    protected Instances prepareData(Instances instances, int i) {
        Instances instances2 = instances;
        if (this.m_KeepOnlySingleUniqueID && !this.m_UniqueID.isEmpty() && instances.attribute(this.m_UniqueID) != null && i > 0) {
            this.m_UniqueIDAtts.add(createPrefix(instances, i) + this.m_UniqueID);
        }
        if (this.m_ExcludedAttributes.length() > 0) {
            instances2 = excludeAttributes(instances2);
        }
        if (this.m_UsePrefix) {
            instances2 = prefixAttributes(instances, i);
        }
        return instances2;
    }

    protected void updateIDs(int i, Instances instances, HashSet hashSet) {
        Attribute attribute = instances.attribute(this.m_UniqueID);
        if (attribute == null) {
            throw new IllegalStateException("Attribute '" + this.m_UniqueID + "' not found in relation '" + instances.relationName() + "' (#" + (i + 1) + ")!");
        }
        if (this.m_AttType == -1) {
            if (attribute.type() != 0 && attribute.type() != 2) {
                throw new IllegalStateException("Attribute '" + this.m_UniqueID + "' must be either NUMERIC or STRING (#" + (i + 1) + ")!");
            }
            this.m_AttType = attribute.type();
        } else if (this.m_AttType != attribute.type()) {
            throw new IllegalStateException("Attribute '" + this.m_UniqueID + "' must have same attribute type in all the datasets (#" + (i + 1) + ")!");
        }
        boolean z = this.m_AttType == 0;
        HashSet hashSet2 = new HashSet();
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            Object valueOf = z ? Double.valueOf(instances.instance(i2).value(attribute)) : instances.instance(i2).stringValue(attribute);
            if (this.m_Strict && hashSet2.contains(valueOf)) {
                throw new IllegalStateException("ID '" + valueOf + "' is not unique in dataset #" + (i + 1) + "!");
            }
            hashSet2.add(valueOf);
        }
        hashSet.addAll(hashSet2);
    }

    protected Instances merge(Instances[] instancesArr, Instances[] instancesArr2, HashSet hashSet) {
        if (isLoggingEnabled()) {
            getLogger().info("Creating merged header...");
        }
        ArrayList arrayList = new ArrayList();
        String str = "";
        int[] iArr = new int[instancesArr2.length];
        for (int i = 0; i < instancesArr2.length; i++) {
            iArr[i] = arrayList.size();
            for (int i2 = 0; i2 < instancesArr2[i].numAttributes(); i2++) {
                arrayList.add((Attribute) instancesArr2[i].attribute(i2).copy());
            }
            if (i > 0) {
                str = str + "_";
            }
            str = str + instancesArr2[i].relationName();
        }
        Instances instances = new Instances(str, arrayList, hashSet.size());
        if (isLoggingEnabled()) {
            getLogger().info("Filling with missing values...");
        }
        for (int i3 = 0; i3 < hashSet.size(); i3++) {
            if (isStopped()) {
                return null;
            }
            if (isLoggingEnabled() && (i3 + 1) % 1000 == 0) {
                getLogger().info("" + (i3 + 1));
            }
            instances.add(new DenseInstance(instances.numAttributes()));
        }
        if (isLoggingEnabled()) {
            getLogger().info("Sorting indices...");
        }
        ArrayList arrayList2 = new ArrayList(hashSet);
        Collections.sort(arrayList2);
        HashMap hashMap = new HashMap();
        for (int i4 = 0; i4 < instancesArr2.length; i4++) {
            if (isStopped()) {
                return null;
            }
            if (isLoggingEnabled()) {
                getLogger().info("Adding file #" + (i4 + 1));
            }
            Attribute attribute = instancesArr[i4].attribute(this.m_UniqueID);
            for (int i5 = 0; i5 < instancesArr2[i4].numInstances(); i5++) {
                if (isLoggingEnabled() && (i5 + 1) % 1000 == 0) {
                    getLogger().info("" + (i5 + 1));
                }
                int binarySearch = this.m_AttType == 0 ? Collections.binarySearch(arrayList2, Double.valueOf(instancesArr2[i4].instance(i5).value(attribute))) : Collections.binarySearch(arrayList2, instancesArr2[i4].instance(i5).stringValue(attribute));
                if (binarySearch < 0) {
                    throw new IllegalStateException("Failed to determine index for row #" + (i5 + 1) + " of dataset #" + (i4 + 1) + "!");
                }
                if (!hashMap.containsKey(Integer.valueOf(binarySearch))) {
                    hashMap.put(Integer.valueOf(binarySearch), 0);
                }
                hashMap.put(Integer.valueOf(binarySearch), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(binarySearch))).intValue() + 1));
                double[] doubleArray = instances.instance(binarySearch).toDoubleArray();
                for (int i6 = 0; i6 < instancesArr2[i4].numAttributes(); i6++) {
                    if (!instancesArr2[i4].instance(i5).isMissing(i6)) {
                        switch (instancesArr2[i4].attribute(i6).type()) {
                            case 0:
                            case 1:
                            case 3:
                                doubleArray[iArr[i4] + i6] = instancesArr2[i4].instance(i5).value(i6);
                                break;
                            case 2:
                                doubleArray[iArr[i4] + i6] = instances.attribute(iArr[i4] + i6).addStringValue(instancesArr2[i4].instance(i5).stringValue(i6));
                                break;
                            case 4:
                                doubleArray[iArr[i4] + i6] = instances.attribute(iArr[i4] + i6).addRelation(instancesArr2[i4].instance(i5).relationalValue(i6));
                                break;
                            default:
                                throw new IllegalStateException("Unhandled attribute type: " + instancesArr2[i4].attribute(i6).type());
                        }
                    }
                }
                instances.set(binarySearch, new DenseInstance(1.0d, doubleArray));
            }
        }
        if (getRemove()) {
            HashSet hashSet2 = new HashSet();
            for (Integer num : hashMap.keySet()) {
                if (((Integer) hashMap.get(num)).intValue() != instancesArr2.length) {
                    hashSet2.add(instances.get(num.intValue()));
                }
            }
            instances.removeAll(hashSet2);
        }
        return instances;
    }

    protected String doExecute() {
        String str = null;
        File[] fileArr = null;
        Instances[] instancesArr = null;
        if (this.m_InputToken.getPayload() instanceof String[]) {
            String[] strArr = (String[]) this.m_InputToken.getPayload();
            fileArr = new File[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                fileArr[i] = new PlaceholderFile(strArr[i]);
            }
        } else if (this.m_InputToken.getPayload() instanceof File[]) {
            fileArr = (File[]) this.m_InputToken.getPayload();
        } else if (this.m_InputToken.getPayload() instanceof Instance[]) {
            Instance[] instanceArr = (Instance[]) this.m_InputToken.getPayload();
            instancesArr = new Instances[instanceArr.length];
            for (int i2 = 0; i2 < instanceArr.length; i2++) {
                instancesArr[i2] = new Instances(instanceArr[i2].dataset(), 1);
                instancesArr[i2].add((Instance) instanceArr[i2].copy());
            }
        } else {
            if (!(this.m_InputToken.getPayload() instanceof Instances[])) {
                throw new IllegalStateException("Unhandled input type: " + this.m_InputToken.getPayload().getClass());
            }
            instancesArr = (Instances[]) this.m_InputToken.getPayload();
        }
        try {
            Instances instances = null;
            if (this.m_UniqueID.length() != 0) {
                this.m_AttType = -1;
                int i3 = 0;
                this.m_UniqueIDAtts = new ArrayList();
                if (fileArr != null) {
                    instancesArr = new Instances[fileArr.length];
                    for (int i4 = 0; i4 < fileArr.length && !isStopped(); i4++) {
                        if (isLoggingEnabled()) {
                            getLogger().info("Loading file #" + (i4 + 1) + ": " + fileArr[i4]);
                        }
                        instancesArr[i4] = ConverterUtils.DataSource.read(fileArr[i4].getAbsolutePath());
                        i3 = Math.max(i3, instancesArr[i4].numInstances());
                    }
                } else if (instancesArr != null) {
                    for (Instances instances2 : instancesArr) {
                        i3 = Math.max(i3, instances2.numInstances());
                    }
                }
                Instances[] instancesArr2 = new Instances[instancesArr.length];
                HashSet hashSet = new HashSet(i3);
                for (int i5 = 0; i5 < instancesArr.length && !isStopped(); i5++) {
                    if (isLoggingEnabled()) {
                        getLogger().info("Updating IDs #" + (i5 + 1));
                    }
                    updateIDs(i5, instancesArr[i5], hashSet);
                    if (isLoggingEnabled()) {
                        getLogger().info("Preparing dataset #" + (i5 + 1));
                    }
                    instancesArr2[i5] = prepareData(instancesArr[i5], i5);
                }
                instances = merge(instancesArr, instancesArr2, hashSet);
                if (this.m_KeepOnlySingleUniqueID) {
                    TIntArrayList tIntArrayList = new TIntArrayList();
                    Iterator<String> it = this.m_UniqueIDAtts.iterator();
                    while (it.hasNext()) {
                        tIntArrayList.add(instances.attribute(it.next()).index());
                    }
                    if (tIntArrayList.size() > 0) {
                        if (isLoggingEnabled()) {
                            getLogger().info("Removing duplicate unique ID attributes: " + this.m_UniqueIDAtts);
                        }
                        Remove remove = new Remove();
                        remove.setAttributeIndicesArray(tIntArrayList.toArray());
                        remove.setInputFormat(instances);
                        instances = Filter.useFilter(instances, remove);
                    }
                }
            } else if (fileArr != null) {
                Instances[] instancesArr3 = new Instances[1];
                for (int i6 = 0; i6 < fileArr.length && !isStopped(); i6++) {
                    instancesArr3[0] = ConverterUtils.DataSource.read(fileArr[i6].getAbsolutePath());
                    instancesArr3[0] = prepareData(instancesArr3[0], i6);
                    if (i6 == 0) {
                        instances = instancesArr3[0];
                    } else {
                        if (isLoggingEnabled()) {
                            getLogger().info("Merging with file #" + (i6 + 1) + ": " + fileArr[i6]);
                        }
                        instances = Instances.mergeInstances(instances, instancesArr3[0]);
                    }
                }
            } else if (instancesArr != null) {
                Instances[] instancesArr4 = new Instances[1];
                for (int i7 = 0; i7 < instancesArr.length && !isStopped(); i7++) {
                    instancesArr4[0] = prepareData(instancesArr[i7], i7);
                    if (i7 == 0) {
                        instances = instancesArr4[0];
                    } else {
                        if (isLoggingEnabled()) {
                            getLogger().info("Merging with dataset #" + (i7 + 1) + ": " + instancesArr[i7].relationName());
                        }
                        instances = Instances.mergeInstances(instances, instancesArr4[0]);
                    }
                }
            }
            if (!isStopped()) {
                this.m_OutputToken = new Token(instances);
                updateProvenance(this.m_OutputToken);
            }
        } catch (Exception e) {
            str = handleException("Failed to merge: ", e);
        }
        return str;
    }

    public void updateProvenance(ProvenanceContainer provenanceContainer) {
        if (Provenance.getSingleton().isEnabled()) {
            provenanceContainer.addProvenance(new ProvenanceInformation(ActorType.DATAGENERATOR, this.m_InputToken.getPayload().getClass(), this, this.m_OutputToken.getPayload().getClass()));
        }
    }
}
