package cn.com.pconline.adclick.pipeline;

import cn.com.pconline.adclick.Utils;
import cn.com.pconline.adclick.algorithm.FactorisationMachine;
import cn.com.pconline.adclick.algorithm.FactorisationMachineModel;
import cn.com.pconline.adclick.mixexperts.Formula;
import java.io.Serializable;
import java.util.List;
import java.util.Set;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Tuple2;

/* loaded from: input_file:cn/com/pconline/adclick/pipeline/FMDimensionReducer.class */
public class FMDimensionReducer implements CombinePipelineStage, Serializable {
    private static final long serialVersionUID = 5961748574979089575L;
    private int k = 10;
    private int iternum = 20;
    private double learnrate = 0.8d;
    private double lambda = 0.5d;
    private String[] inputCols;
    private String labelCol;
    private int numCols;
    private FactorisationMachineModel fmModel;

    public FMDimensionReducer(String[] strArr, String str) {
        this.numCols = 1;
        this.inputCols = strArr;
        this.labelCol = str;
        this.numCols = strArr.length;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public void fit(Dataset<Row> dataset) throws Exception {
        fit(dataset.select(this.labelCol, this.inputCols).javaRDD().map(new Function<Row, LabeledPoint>() { // from class: cn.com.pconline.adclick.pipeline.FMDimensionReducer.1
            private static final long serialVersionUID = -606732813951337213L;

            public LabeledPoint call(Row row) throws Exception {
                Vector[] vectorArr = new Vector[FMDimensionReducer.this.inputCols.length];
                for (int i = 0; i < vectorArr.length; i++) {
                    vectorArr[i] = (Vector) row.get(row.fieldIndex(FMDimensionReducer.this.inputCols[i]));
                }
                return new LabeledPoint(row.getDouble(row.fieldIndex(FMDimensionReducer.this.labelCol)), Utils.vectorConcat(vectorArr));
            }
        }));
        this.fmModel.getThetas();
    }

    public void fit(JavaRDD<LabeledPoint> javaRDD) throws Exception {
        this.fmModel = new FactorisationMachine(this.k, this.iternum, this.learnrate, this.lambda).fit(javaRDD);
    }

    @Override // cn.com.pconline.adclick.pipeline.CombinePipelineStage
    public void fit(JavaPairRDD<Double, Vector[]> javaPairRDD) throws Exception {
        fit(javaPairRDD.map(new Function<Tuple2<Double, Vector[]>, LabeledPoint>() { // from class: cn.com.pconline.adclick.pipeline.FMDimensionReducer.2
            private static final long serialVersionUID = -1310034077100218343L;

            public LabeledPoint call(Tuple2<Double, Vector[]> tuple2) throws Exception {
                return new LabeledPoint(((Double) tuple2._1).doubleValue(), Utils.vectorConcat((Vector[]) tuple2._2));
            }
        }));
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public Vector transform(Row row) throws Exception {
        Vector[] vectorArr = new Vector[this.inputCols.length];
        for (int i = 0; i < vectorArr.length; i++) {
            Object obj = row.get(row.fieldIndex(this.inputCols[i]));
            if (!(obj instanceof Vector)) {
                throw new Exception("FMDimensionReducer input error");
            }
            vectorArr[i] = (Vector) obj;
        }
        return transform(vectorArr);
    }

    public Vector transform(Vector vector) throws Exception {
        SparseVector[] thetas = this.fmModel.getThetas();
        int k = this.fmModel.getK();
        double[] dArr = new double[k];
        for (int i = 0; i < k + 1; i++) {
            dArr[i] = Formula.dot(vector.toSparse(), thetas[i + 1]);
        }
        return Vectors.dense(dArr).toSparse();
    }

    @Override // cn.com.pconline.adclick.pipeline.CombinePipelineStage
    public Vector transform(Vector[] vectorArr) throws Exception {
        SparseVector[] thetas = this.fmModel.getThetas();
        int k = this.fmModel.getK();
        int i = 0;
        double[] dArr = new double[this.numCols * (k + 1)];
        for (int i2 = 0; i2 < vectorArr.length; i2++) {
            for (int i3 = 1; i3 < k + 2; i3++) {
                dArr[((i2 * (k + 1)) + i3) - 1] = Formula.dot(thetas[i3], vectorArr[i2].toSparse(), i);
            }
            i += vectorArr[i2].size();
        }
        return Vectors.dense(dArr).toSparse();
    }

    @Override // cn.com.pconline.adclick.pipeline.CombinePipelineStage
    public String[] getInputCols() {
        return this.inputCols;
    }

    public void setInputCols(String[] strArr) {
        this.inputCols = strArr;
    }

    public int getK() {
        return this.k;
    }

    public void setK(int i) {
        this.k = i;
    }

    public int getIternum() {
        return this.iternum;
    }

    public void setIternum(int i) {
        this.iternum = i;
    }

    public double getLearnrate() {
        return this.learnrate;
    }

    public void setLearnrate(double d) {
        this.learnrate = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public int resultLength() {
        return (this.k + 1) * this.numCols;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public List<Set<String>> inputsForOutputs() {
        return null;
    }
}
