package cn.com.pconline.adclick.pipeline;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;

/* loaded from: input_file:cn/com/pconline/adclick/pipeline/StringOneHotEncoder.class */
public class StringOneHotEncoder implements RowPipelineStage, Serializable {
    private static final long serialVersionUID = -419504858945251891L;
    private String inputCol;
    private Map<String, Integer> labelToIndex = new HashMap();

    /* loaded from: input_file:cn/com/pconline/adclick/pipeline/StringOneHotEncoder$IndexerUDF.class */
    public class IndexerUDF implements UDF1<String, Integer> {
        private static final long serialVersionUID = 1;

        public IndexerUDF() {
        }

        public Integer call(String str) throws Exception {
            return StringOneHotEncoder.this.labelToIndex.containsKey(str) ? (Integer) StringOneHotEncoder.this.labelToIndex.get(str) : Integer.valueOf(StringOneHotEncoder.this.labelToIndex.size());
        }
    }

    public StringOneHotEncoder(String str) {
        this.inputCol = str;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public void fit(Dataset<Row> dataset) {
        ArrayList arrayList = new ArrayList(dataset.na().drop(new String[]{this.inputCol}).select(new Column[]{dataset.col(this.inputCol).cast(DataTypes.StringType)}).javaRDD().map(new Function<Row, String>() { // from class: cn.com.pconline.adclick.pipeline.StringOneHotEncoder.1
            private static final long serialVersionUID = 7604432209006340751L;

            public String call(Row row) throws Exception {
                return row.getString(0);
            }
        }).countByValue().entrySet());
        Collections.sort(arrayList, new Comparator<Map.Entry<String, Long>>() { // from class: cn.com.pconline.adclick.pipeline.StringOneHotEncoder.2
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Long> entry, Map.Entry<String, Long> entry2) {
                return entry.getValue().compareTo(entry2.getValue());
            }
        });
        int size = arrayList.size();
        this.labelToIndex = new HashMap();
        for (int i = 0; i < size; i++) {
            this.labelToIndex.put((String) ((Map.Entry) arrayList.get(i)).getKey(), Integer.valueOf(i));
        }
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public Vector transform(Row row) {
        String string = row.getString(row.fieldIndex(this.inputCol));
        int size = this.labelToIndex.size();
        int i = size;
        if (this.labelToIndex.containsKey(string)) {
            i = this.labelToIndex.get(string).intValue();
        }
        return i < size ? Vectors.sparse(size, new int[]{i}, new double[]{1.0d}) : Vectors.sparse(size, new int[0], new double[0]);
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public int resultLength() {
        return this.labelToIndex.size();
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public List<Set<String>> inputsForOutputs() {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        for (String str : this.labelToIndex.keySet()) {
            hashMap.put(this.labelToIndex.get(str), str);
        }
        for (int i = 0; i < this.labelToIndex.size(); i++) {
            HashSet hashSet = new HashSet();
            hashSet.add(String.valueOf(this.inputCol) + "_" + ((String) hashMap.get(Integer.valueOf(i))));
            arrayList.add(hashSet);
        }
        return arrayList;
    }
}
