package cn.com.pconline.adclick.algorithm;

import cn.com.pconline.adclick.mixexperts.Formula;
import java.io.Serializable;
import java.util.Date;
import java.util.Random;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.collection.Iterator;

/* loaded from: input_file:cn/com/pconline/adclick/algorithm/FactorisationMachine.class */
public class FactorisationMachine implements Serializable {
    private static final long serialVersionUID = 770678417280843324L;
    private int k;
    private int iternum;
    private double learnrate;
    private double lambda;
    private long seed;

    public FactorisationMachine() {
        this.k = 10;
        this.iternum = 20;
        this.learnrate = 0.8d;
        this.lambda = 0.5d;
        this.seed = 42L;
    }

    public FactorisationMachine(int i, int i2, double d, double d2) {
        this(i, i2, d, d2, new Date().getTime());
    }

    public FactorisationMachine(int i, int i2, double d, double d2, long j) {
        this.k = 10;
        this.iternum = 20;
        this.learnrate = 0.8d;
        this.lambda = 0.5d;
        this.seed = 42L;
        this.k = i;
        this.iternum = i2;
        this.learnrate = d;
        this.lambda = d2;
        this.seed = j;
    }

    public FactorisationMachineModel fit(JavaRDD<LabeledPoint> javaRDD) throws Exception {
        SparseVector[] initThetas = initThetas(this.k, ((LabeledPoint) javaRDD.first()).features().size());
        for (int i = 0; i < this.iternum; i++) {
            SparseVector[] lossDerivative = lossDerivative(javaRDD, initThetas);
            for (int i2 = 0; i2 < initThetas.length; i2++) {
                initThetas[i2] = Formula.plus(initThetas[i2], 1.0d, lossDerivative[i2], -this.learnrate);
            }
        }
        return new FactorisationMachineModel(initThetas);
    }

    public SparseVector[] lossDerivative(JavaRDD<LabeledPoint> javaRDD, final SparseVector[] sparseVectorArr) throws Exception {
        return Formula.plus((SparseVector[]) javaRDD.map(new Function<LabeledPoint, SparseVector[]>() { // from class: cn.com.pconline.adclick.algorithm.FactorisationMachine.1
            private static final long serialVersionUID = 7514026040151400631L;

            public SparseVector[] call(LabeledPoint labeledPoint) throws Exception {
                double label = labeledPoint.label();
                SparseVector sparse = labeledPoint.features().toSparse();
                SparseVector[] sparseVectorArr2 = new SparseVector[sparseVectorArr.length];
                double f = FactorisationMachine.f(sparse, sparseVectorArr);
                double exp = f > 0.0d ? 1.0d / (1.0d + Math.exp(-f)) : Math.exp(f) / (1.0d + Math.exp(f));
                sparseVectorArr2[0] = new SparseVector(sparse.size(), new int[1], new double[]{label - exp});
                sparseVectorArr2[1] = Formula.multiply(sparse, label - exp);
                for (int i = 2; i < sparseVectorArr2.length; i++) {
                    int[] indices = sparse.indices();
                    double[] dArr = new double[indices.length];
                    for (int i2 = 0; i2 < indices.length; i2++) {
                        dArr[i2] = sparseVectorArr[i].apply(indices[i2]) * sparse.apply(indices[i2]) * sparse.apply(indices[i2]);
                    }
                    sparseVectorArr2[i] = Formula.multiply(Formula.plus(sparse, Formula.dot(sparseVectorArr[i], sparse), new SparseVector(sparse.size(), indices, dArr), -1.0d), label - exp);
                }
                return sparseVectorArr2;
            }
        }).reduce(new Function2<SparseVector[], SparseVector[], SparseVector[]>() { // from class: cn.com.pconline.adclick.algorithm.FactorisationMachine.2
            private static final long serialVersionUID = -9187017667781255225L;

            public SparseVector[] call(SparseVector[] sparseVectorArr2, SparseVector[] sparseVectorArr3) throws Exception {
                SparseVector[] sparseVectorArr4 = new SparseVector[sparseVectorArr2.length];
                for (int i = 0; i < sparseVectorArr4.length; i++) {
                    sparseVectorArr4[i] = Formula.plus(sparseVectorArr2[i], 1.0d, sparseVectorArr3[i], 1.0d);
                }
                return sparseVectorArr4;
            }
        }), (-1.0d) / javaRDD.count(), sparseVectorArr, this.lambda);
    }

    public static double f(SparseVector sparseVector, SparseVector[] sparseVectorArr) throws Exception {
        double apply = sparseVectorArr[0].apply(0) + Formula.dot(sparseVector, sparseVectorArr[1]);
        for (int i = 2; i < sparseVectorArr.length; i++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i2 = 0; i2 < sparseVectorArr[0].size(); i2++) {
                d += sparseVectorArr[i].apply(i2) * sparseVector.apply(i2);
                d2 += Math.pow(sparseVectorArr[i].apply(i2), 2.0d) * Math.pow(sparseVector.apply(i2), 2.0d);
            }
            apply += 0.5d * (Math.pow(d, 2.0d) - d2);
        }
        return apply;
    }

    public SparseVector[] initThetas(int i, int i2) {
        SparseVector[] sparseVectorArr = new SparseVector[i + 2];
        Random random = new Random(this.seed);
        sparseVectorArr[0] = new SparseVector(i2, new int[1], new double[]{random.nextGaussian()});
        Iterator rowIter = Matrices.sprand(i + 1, i2, 0.5d, random).rowIter();
        int i3 = 1;
        while (rowIter.hasNext()) {
            sparseVectorArr[i3] = Formula.multiply(((Vector) rowIter.next()).toSparse(), 2.0d, -1.0d);
            i3++;
        }
        return sparseVectorArr;
    }

    public int getK() {
        return this.k;
    }

    public void setK(int i) {
        this.k = i;
    }

    public int getIternum() {
        return this.iternum;
    }

    public void setIternum(int i) {
        this.iternum = i;
    }

    public double getLearnrate() {
        return this.learnrate;
    }

    public void setLearnrate(double d) {
        this.learnrate = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public long getSeed() {
        return this.seed;
    }

    public void setSeed(long j) {
        this.seed = j;
    }
}
