package cn.com.pconline.adclick.pipeline;

import cn.com.pconline.adclick.Constant;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:cn/com/pconline/adclick/pipeline/VectorSimilarity.class */
public class VectorSimilarity implements CombinePipelineStage, Serializable {
    private static final long serialVersionUID = -2648068081004185147L;
    private String inputCol1;
    private String inputCol2;
    private int vectorLength = 0;

    public VectorSimilarity(String str, String str2) {
        this.inputCol1 = str;
        this.inputCol2 = str2;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public void fit(Dataset<Row> dataset) throws Exception {
        Row row = (Row) dataset.first();
        Object apply = row.apply(row.fieldIndex(this.inputCol1));
        if (!(apply instanceof Vector)) {
            throw new Exception("VectorSimilarity fit input error");
        }
        this.vectorLength = ((Vector) apply).size();
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public Vector transform(Row row) throws Exception {
        Object obj = row.get(row.fieldIndex(this.inputCol1));
        Object obj2 = row.get(row.fieldIndex(this.inputCol2));
        if ((obj instanceof Vector) && (obj2 instanceof Vector)) {
            return Vectors.dense(new double[]{eudistance((Vector) obj, (Vector) obj2)});
        }
        throw new Exception("VectorSimilarity input error");
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public int resultLength() {
        return 1;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public List<Set<String>> inputsForOutputs() {
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < this.vectorLength; i++) {
            hashSet.add(String.valueOf(this.inputCol1) + Constant.INPUTCOL_SEP + i);
            hashSet.add(String.valueOf(this.inputCol2) + Constant.INPUTCOL_SEP + i);
        }
        arrayList.add(hashSet);
        return arrayList;
    }

    @Override // cn.com.pconline.adclick.pipeline.CombinePipelineStage
    public void fit(JavaPairRDD<Double, Vector[]> javaPairRDD) throws Exception {
        this.vectorLength = ((Vector[]) javaPairRDD.first()._2())[0].size();
    }

    @Override // cn.com.pconline.adclick.pipeline.CombinePipelineStage
    public Vector transform(Vector[] vectorArr) throws Exception {
        return Vectors.dense(new double[]{cosine(vectorArr[0], vectorArr[1])});
    }

    private double eudistance(Vector vector, Vector vector2) {
        double d = 0.0d;
        for (int i = 0; i < vector.size(); i++) {
            d += Math.pow(vector.apply(i) - vector2.apply(i), 2.0d);
        }
        return Math.sqrt(d);
    }

    private double cosine(Vector vector, Vector vector2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < vector.size(); i++) {
            d += vector.apply(i) * vector2.apply(i);
            d2 += vector.apply(i) * vector.apply(i);
            d3 += vector2.apply(i) * vector2.apply(i);
        }
        return d / (d2 * d3);
    }

    @Override // cn.com.pconline.adclick.pipeline.CombinePipelineStage
    public String[] getInputCols() {
        return new String[]{this.inputCol1, this.inputCol2};
    }
}
