package cn.com.pconline.adclick.pipeline;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;

/* loaded from: input_file:cn/com/pconline/adclick/pipeline/RowStringIndexer.class */
public class RowStringIndexer implements RowPipelineStage, Serializable {
    private static final long serialVersionUID = -419504858945251891L;
    private String inputCol;
    private Map<String, Integer> labelToIndex = new HashMap();
    private int max_index = 0;

    /* loaded from: input_file:cn/com/pconline/adclick/pipeline/RowStringIndexer$IndexerUDF.class */
    public class IndexerUDF implements UDF1<String, Integer> {
        private static final long serialVersionUID = 1;

        public IndexerUDF() {
        }

        public Integer call(String str) throws Exception {
            return RowStringIndexer.this.labelToIndex.containsKey(str) ? (Integer) RowStringIndexer.this.labelToIndex.get(str) : Integer.valueOf(RowStringIndexer.this.labelToIndex.size());
        }
    }

    public RowStringIndexer(String str) {
        this.inputCol = str;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public void fit(Dataset<Row> dataset) {
        ArrayList arrayList = new ArrayList(dataset.na().drop(new String[]{this.inputCol}).select(new Column[]{dataset.col(this.inputCol).cast(DataTypes.StringType)}).javaRDD().map(new Function<Row, String>() { // from class: cn.com.pconline.adclick.pipeline.RowStringIndexer.1
            private static final long serialVersionUID = 7604432209006340751L;

            public String call(Row row) throws Exception {
                return row.getString(0);
            }
        }).countByValue().entrySet());
        Collections.sort(arrayList, new Comparator<Map.Entry<String, Long>>() { // from class: cn.com.pconline.adclick.pipeline.RowStringIndexer.2
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Long> entry, Map.Entry<String, Long> entry2) {
                return entry.getValue().compareTo(entry2.getValue());
            }
        });
        this.max_index = arrayList.size();
        this.labelToIndex = new HashMap();
        for (int i = 0; i < this.max_index; i++) {
            this.labelToIndex.put((String) ((Map.Entry) arrayList.get(i)).getKey(), Integer.valueOf(i));
        }
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public Vector transform(Row row) {
        return Vectors.dense(transform(row.getString(row.fieldIndex(this.inputCol))), new double[0]);
    }

    public int transform(String str) {
        return this.labelToIndex.containsKey(str) ? this.labelToIndex.get(str).intValue() : this.max_index;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public int resultLength() {
        return 1;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public List<Set<String>> inputsForOutputs() {
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        hashSet.add(this.inputCol);
        arrayList.add(hashSet);
        return arrayList;
    }
}
