package moetest;

import cn.com.pconline.adclick.mixexperts.Formula;
import java.util.ArrayList;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.SparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

/* loaded from: input_file:moetest/FormulaSparkTest.class */
public class FormulaSparkTest {
    JavaSparkContext jsc;

    @Before
    public void setUp() {
        this.jsc = new JavaSparkContext(SparkSession.builder().appName("MoeFormulaTest").enableHiveSupport().getOrCreate().sparkContext());
    }

    @After
    public void tearDown() {
        this.jsc.close();
    }

    @Ignore
    public void testBacktrackingLineSearch() throws Exception {
    }

    @Test
    public void testLoss_derivative() throws Exception {
        LabeledPoint labeledPoint = new LabeledPoint(1.0d, new SparseVector(8, new int[]{0, 4, 6, 7}, new double[]{2.0d, 5.0d, 0.75d, 0.25d}));
        LabeledPoint labeledPoint2 = new LabeledPoint(0.0d, new SparseVector(8, new int[]{1, 3, 4, 6}, new double[]{4.0d, 3.0d, 0.8d, 0.5d}));
        ArrayList arrayList = new ArrayList();
        arrayList.add(labeledPoint);
        arrayList.add(labeledPoint);
        arrayList.add(labeledPoint2);
        arrayList.add(labeledPoint2);
        JavaRDD parallelize = this.jsc.parallelize(arrayList);
        SparseVector[] sparseVectorArr = {new SparseVector(8, new int[]{1, 3, 5}, new double[]{0.5d, 1.0d, 1.0d}), new SparseVector(8, new int[]{1, 3, 4, 7}, new double[]{-2.0d, -1.0d, 2.0d, 1.0d}), new SparseVector(8, new int[]{1, 5, 6}, new double[]{2.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 6}, new double[]{1.0d, 2.0d, 3.0d})};
        SparseVector[] loss_derivative = Formula.loss_derivative(parallelize, sparseVectorArr, 4L);
        SparseVector[] plus = Formula.plus(Formula.loss_derivative(labeledPoint.features().toSparse(), labeledPoint.label(), sparseVectorArr, 2), -0.5d, Formula.loss_derivative(labeledPoint2.features().toSparse(), labeledPoint2.label(), sparseVectorArr, 2), -0.5d);
        for (int i = 0; i < 4; i++) {
            for (int i2 = 0; i2 < plus[i].size(); i2++) {
                Assert.assertEquals(plus[i].apply(i2), loss_derivative[i].apply(i2), 1.0E-4d);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v51, types: [org.apache.spark.mllib.linalg.SparseVector[], org.apache.spark.mllib.linalg.SparseVector[][]] */
    @Test
    public void testLogloss() throws Exception {
        SparseVector sparseVector = new SparseVector(8, new int[]{0, 4, 6, 7}, new double[]{2.0d, 5.0d, 0.75d, 0.25d});
        LabeledPoint labeledPoint = new LabeledPoint(1.0d, sparseVector);
        SparseVector sparseVector2 = new SparseVector(8, new int[]{1, 3, 4, 6}, new double[]{4.0d, 3.0d, 0.8d, 0.5d});
        LabeledPoint labeledPoint2 = new LabeledPoint(0.0d, sparseVector2);
        ArrayList arrayList = new ArrayList();
        arrayList.add(labeledPoint);
        arrayList.add(labeledPoint);
        arrayList.add(labeledPoint2);
        arrayList.add(labeledPoint2);
        JavaRDD parallelize = this.jsc.parallelize(arrayList);
        long count = parallelize.count();
        SparseVector sparseVector3 = new SparseVector(8, new int[]{1, 3, 5}, new double[]{0.5d, 1.0d, 1.0d});
        SparseVector sparseVector4 = new SparseVector(8, new int[]{1, 3, 4, 7}, new double[]{-2.0d, -1.0d, 2.0d, 1.0d});
        SparseVector sparseVector5 = new SparseVector(8, new int[]{1, 5, 6}, new double[]{2.0d, 1.0d, 1.0d});
        SparseVector sparseVector6 = new SparseVector(8, new int[]{0, 1, 6}, new double[]{1.0d, 2.0d, 3.0d});
        SparseVector[] sparseVectorArr = {sparseVector3, sparseVector4, sparseVector5, sparseVector6};
        SparseVector[] sparseVectorArr2 = {sparseVector5, sparseVector6, sparseVector4, sparseVector3};
        Double[] logloss = Formula.logloss(parallelize, new SparseVector[]{sparseVectorArr, sparseVectorArr2}, count);
        Double[] dArr = {Double.valueOf((((-2.0d) * Math.log(Formula.p1(sparseVector, sparseVectorArr, 2))) - (2.0d * Math.log(1.0d - Formula.p1(sparseVector2, sparseVectorArr, 2)))) / 4.0d), Double.valueOf((((-2.0d) * Math.log(Formula.p1(sparseVector, sparseVectorArr2, 2))) - (2.0d * Math.log(1.0d - Formula.p1(sparseVector2, sparseVectorArr2, 2)))) / 4.0d)};
        for (int i = 0; i < dArr.length; i++) {
            Assert.assertEquals(dArr[i].doubleValue(), logloss[i].doubleValue(), 0.001d);
        }
    }
}
