package cn.com.pconline.adclick.feature;

import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.BinaryAttribute;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.UDFRegistration;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;

/* loaded from: input_file:cn/com/pconline/adclick/feature/OneHotEncoder4Vec.class */
public class OneHotEncoder4Vec extends OneHotEncoder {
    private static final long serialVersionUID = 1;
    private int size = -1;

    /* loaded from: input_file:cn/com/pconline/adclick/feature/OneHotEncoder4Vec$OneHotEncoderUDF.class */
    public class OneHotEncoderUDF implements UDF1<Double, Vector> {
        private static final long serialVersionUID = 1;

        public OneHotEncoderUDF() {
        }

        public Vector call(Double d) throws Exception {
            return d.doubleValue() < ((double) OneHotEncoder4Vec.this.size) ? Vectors.sparse(OneHotEncoder4Vec.this.size, new int[]{d.intValue()}, new double[]{1.0d}) : Vectors.sparse(OneHotEncoder4Vec.this.size, new int[0], new double[0]);
        }
    }

    public Vector transform(int i) {
        return i < this.size ? Vectors.sparse(this.size, new int[]{i}, new double[]{1.0d}) : Vectors.sparse(this.size, new int[0], new double[0]);
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        String inputCol = super.getInputCol();
        String outputCol = super.getOutputCol();
        boolean dropLast = super.getDropLast();
        AttributeGroup fromStructField = AttributeGroup.fromStructField(super.transformSchema(dataset.schema()).apply(outputCol));
        if (fromStructField.size() < 0) {
            if (this.size < 0) {
                int i = (int) ((Row) dataset.groupBy(new Column[0]).agg(functions.max(dataset.col(inputCol).cast(DataTypes.DoubleType)), new Column[0]).first()).getDouble(0);
                if (!dropLast) {
                    i++;
                }
                this.size = i;
            }
            Attribute[] attributeArr = new Attribute[this.size];
            for (int i2 = 0; i2 < this.size; i2++) {
                attributeArr[i2] = BinaryAttribute.defaultAttr().withName(String.valueOf(i2));
            }
            fromStructField = new AttributeGroup(outputCol, attributeArr);
        }
        Metadata metadata = fromStructField.toMetadata();
        UDFRegistration udf = dataset.sparkSession().udf();
        getClass();
        udf.register("ohencoder", new OneHotEncoderUDF(), new VectorUDT());
        return dataset.select(new Column[]{dataset.col("*"), functions.callUDF("ohencoder", new Column[]{dataset.col(inputCol).cast(DataTypes.DoubleType)}).as(getOutputCol(), metadata)});
    }

    public int size() {
        return this.size;
    }
}
