package moetest;

import cn.com.pconline.adclick.mixexperts.Formula;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:moetest/FormulaTest.class */
public class FormulaTest {
    @Test
    public void testPiFunSparseVectorArraySparseVectorArray() {
        Assert.assertArrayEquals(new SparseVector[]{new SparseVector(5, new int[]{1, 3}, new double[]{3.0d, 0.0d}), new SparseVector(5, new int[]{1, 3}, new double[]{-3.0d, 0.0d})}, Formula.piFun(new SparseVector[]{new SparseVector(5, new int[]{1, 2, 3}, new double[]{3.0d, -3.0d, -3.0d}), new SparseVector(5, new int[]{1, 2, 3}, new double[]{-3.0d, 3.0d, 3.0d})}, new SparseVector[]{new SparseVector(5, new int[]{0, 1, 3}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(5, new int[]{0, 1, 3}, new double[]{-1.0d, -1.0d, -1.0d})}));
    }

    @Test
    public void testPiFunSparseVectorArraySparseVectorArraySparseVectorArray() {
        Assert.assertArrayEquals(new SparseVector[]{new SparseVector(6, new int[]{1, 3, 5}, new double[]{3.0d, 0.0d, 3.0d}), new SparseVector(6, new int[]{1, 3}, new double[]{-3.0d, 0.0d})}, Formula.piFun(new SparseVector[]{new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{3.0d, -3.0d, -3.0d, 3.0d}), new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{-3.0d, 3.0d, 3.0d, -3.0d})}, new SparseVector[]{new SparseVector(6, new int[]{0, 1, 3}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(6, new int[]{0, 1, 3}, new double[]{-1.0d, -1.0d, -1.0d})}, new SparseVector[]{new SparseVector(6, new int[]{5}, new double[]{1.0d}), new SparseVector(6, new int[]{5}, new double[]{1.0d})}));
    }

    @Test
    public void testUpdateProducts() throws Exception {
        SparseVector[] sparseVectorArr = {new SparseVector(8, new int[1], new double[]{0.0d}), new SparseVector(8, new int[1], new double[]{1.0d}), new SparseVector(8, new int[]{0, 1}, new double[]{1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 2}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 2, 3}, new double[]{1.0d, 1.0d, 1.0d, 1.0d})};
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 1; i < sparseVectorArr.length; i++) {
            SparseVector[] sparseVectorArr2 = {sparseVectorArr[i], sparseVectorArr[0]};
            arrayList.add(sparseVectorArr2);
            arrayList2.add(sparseVectorArr2);
            if (arrayList.size() > 2) {
                arrayList.remove(0);
                arrayList2.remove(0);
            }
            Formula.updateProducts(arrayList3, 2, i, sparseVectorArr2, sparseVectorArr2, new SparseVector[]{Formula.multiply(sparseVectorArr[i], 2.0d), sparseVectorArr[0]}, arrayList, arrayList2);
            int i2 = i < 2 ? i : 2;
            Assert.assertEquals("wrong size", (2 * i2) + 1, arrayList3.size());
            for (int i3 = 0; i3 < arrayList3.size() - 1; i3++) {
                List list = (List) arrayList3.get(i3);
                Assert.assertEquals("wrong size", (2 * i2) + 1, list.size());
                for (int i4 = i3; i4 < list.size() - 1; i4++) {
                    Assert.assertEquals("k=" + i + ",i=" + i3 + ",j=" + i4, Math.min(i3 % i2, i4 % i2) + Math.max(i - i2, 0) + 1, ((Double) list.get(i4)).doubleValue(), 0.001d);
                    Assert.assertEquals("asymmetric", list.get(i4), ((List) arrayList3.get(i4)).get(i3));
                }
                Assert.assertEquals("k=" + i + ",i=" + i3 + ",j=" + (list.size() - 1), 2 * ((i3 % i2) + Math.max(i - i2, 0) + 1), ((Double) list.get(list.size() - 1)).doubleValue(), 0.001d);
                Assert.assertEquals("asymmetric", list.get(list.size() - 1), ((List) arrayList3.get(list.size() - 1)).get(i3));
            }
            Assert.assertEquals("k=" + i + ",i=" + (arrayList3.size() - 1) + ",j=" + (arrayList3.size() - 1), 4 * i, ((Double) ((List) arrayList3.get(arrayList3.size() - 1)).get(arrayList3.size() - 1)).doubleValue(), 0.001d);
        }
    }

    @Test
    public void testVlbfgs() throws Exception {
        SparseVector[] sparseVectorArr = {new SparseVector(8, new int[1], new double[]{0.0d}), new SparseVector(8, new int[1], new double[]{1.0d}), new SparseVector(8, new int[]{0, 1}, new double[]{1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 2}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 2, 3}, new double[]{1.0d, 1.0d, 1.0d, 1.0d})};
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 1; i < sparseVectorArr.length; i++) {
            SparseVector[] sparseVectorArr2 = {sparseVectorArr[i], sparseVectorArr[0]};
            arrayList.add(sparseVectorArr2);
            arrayList2.add(sparseVectorArr2);
            if (arrayList.size() > 2) {
                arrayList.remove(0);
                arrayList2.remove(0);
            }
            SparseVector[] sparseVectorArr3 = {Formula.multiply(sparseVectorArr[i], 2.0d), sparseVectorArr[0]};
            Formula.updateProducts(arrayList3, 2, i, sparseVectorArr2, sparseVectorArr2, sparseVectorArr3, arrayList, arrayList2);
            ArrayList arrayList4 = new ArrayList();
            arrayList4.addAll(arrayList);
            arrayList4.addAll(arrayList2);
            arrayList4.add(sparseVectorArr3);
            Assert.assertArrayEquals(Formula.lbfgs(sparseVectorArr3, arrayList, arrayList2), Formula.vlbfgs(arrayList4, arrayList3, i, 2));
        }
    }

    @Test
    public void testDirection() throws Exception {
        Assert.assertArrayEquals(new Vector[]{Vectors.dense(0.0d, new double[]{-1.0833333333333333d, 0.0d, -1.7666666666666666d, 0.0d, -2.7666666666666666d, 0.0d, 0.0d}), Vectors.dense(0.0d, new double[]{(Math.sqrt(11.0d) - 1.0d) / Math.sqrt(11.0d), 0.0d, (3.0d - (3.0d * Math.sqrt(11.0d))) / Math.sqrt(11.0d), (Math.sqrt(11.0d) - 1.0d) / Math.sqrt(11.0d), 0.0d, 0.0d, 0.0d}), Vectors.dense(0.0d, new double[]{(-2.5d) - (2.0d / Math.sqrt(6.0d)), 1.0d, 0.0d, 0.0d, (-1.0d) - (1.0d / Math.sqrt(6.0d)), (-1.0d) - (1.0d / Math.sqrt(6.0d)), 0.0d}), Vectors.dense(0.0d, new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d})}, Formula.direction(new SparseVector[]{new SparseVector(8, new int[]{1, 3, 5}, new double[]{-0.25d, 0.1d, 1.1d}), new SparseVector(8, new int[]{1, 3, 4}, new double[]{-2.0d, 4.0d, -2.0d}), new SparseVector(8, new int[]{1, 2, 7}, new double[]{1.5d, -2.0d, 1.0d}), new SparseVector(8, new int[]{1, 2, 3}, new double[]{-1.0d, 1.0d, 2.0d})}, new SparseVector[]{new SparseVector(8, new int[]{1, 3, 5}, new double[]{0.5d, 1.0d, 1.0d}), new SparseVector(8, new int[1], new double[]{0.0d}), new SparseVector(8, new int[]{1, 5, 6}, new double[]{2.0d, 1.0d, 1.0d}), new SparseVector(8, new int[1], new double[]{0.0d})}, 1.0d, 1.0d));
    }

    @Test
    public void testLoss_derivative() throws Exception {
        SparseVector sparseVector = new SparseVector(8, new int[]{0, 1, 4, 6, 7}, new double[]{500.0d, 1.0d, 500.0d, -500.0d, -500.0d});
        SparseVector[] sparseVectorArr = {new SparseVector(8, new int[]{0, 4, 5}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 4}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{5, 6, 7}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{1, 6, 7}, new double[]{1.0d, 1.0d, 1.0d})};
        SparseVector[] sparseVectorArr2 = {Vectors.zeros(8).toSparse(), Vectors.zeros(8).toSparse(), Vectors.zeros(8).toSparse(), Vectors.zeros(8).toSparse()};
        SparseVector[] loss_derivative = Formula.loss_derivative(sparseVector, 0.0d, sparseVectorArr, 2);
        for (int i = 0; i < sparseVectorArr2.length; i++) {
            for (int i2 = 0; i2 < sparseVectorArr2[i].size(); i2++) {
                Assert.assertEquals(sparseVectorArr2[i].apply(i2), loss_derivative[i].apply(i2), 1.0E-4d);
            }
        }
        SparseVector[] sparseVectorArr3 = {Formula.multiply(sparseVector, (1.0d / (1.0d + Math.exp(2.0d))) - 0.2689414213699951d), Formula.multiply(sparseVector, (Math.exp(2.0d) / (1.0d + Math.exp(2.0d))) - 0.7310585786300049d), Formula.multiply(sparseVector, 1.0d / (1.0d + Math.exp(2.0d))), Formula.multiply(sparseVector, Math.exp(2.0d) / (1.0d + Math.exp(2.0d)))};
        SparseVector[] loss_derivative2 = Formula.loss_derivative(sparseVector, 1.0d, sparseVectorArr, 2);
        for (int i3 = 0; i3 < sparseVectorArr3.length; i3++) {
            for (int i4 = 0; i4 < sparseVectorArr3[i3].size(); i4++) {
                Assert.assertEquals("i:" + i3 + ";j:" + i4, sparseVectorArr3[i3].apply(i4), loss_derivative2[i3].apply(i4), 1.0E-4d);
            }
        }
    }

    @Test
    public void testP1() throws Exception {
        SparseVector sparseVector = new SparseVector(8, new int[]{0, 1, 4, 6, 7}, new double[]{500.0d, 1.0d, 500.0d, -500.0d, -500.0d});
        SparseVector[] sparseVectorArr = {new SparseVector(8, new int[]{0, 4, 5}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{0, 1, 4}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{5, 6, 7}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(8, new int[]{1, 6, 7}, new double[]{1.0d, 1.0d, 1.0d})};
        Assert.assertEquals(0.0d, Formula.p1(sparseVector, sparseVectorArr, 2), 0.001d);
        Assert.assertEquals(1.0d, Formula.p1(new SparseVector(8, new int[]{0, 1, 4, 6, 7}, new double[]{500.0d, 1.0d, 500.0d, 500.0d, 500.0d}), sparseVectorArr, 2), 0.001d);
    }

    @Test
    public void testDotSparseVectorArraySparseVectorArray() throws Exception {
        SparseVector sparseVector = new SparseVector(6, new int[]{1, 3, 5}, new double[]{0.5d, 1.5d, 1.5d});
        SparseVector sparseVector2 = new SparseVector(6, new int[]{1, 2, 3}, new double[]{-2.0d, -1.0d, 1.0d});
        SparseVector sparseVector3 = new SparseVector(6, new int[]{1, 3, 5}, new double[]{0.5d, 1.5d, 1.5d});
        Assert.assertEquals(Formula.dot(new SparseVector(12, new int[]{1, 3, 5, 7, 9, 11}, new double[]{0.5d, 1.5d, 1.5d, 0.5d, 1.5d, 1.5d}), new SparseVector(12, new int[]{1, 2, 3, 7, 8, 9}, new double[]{-2.0d, -1.0d, 1.0d, -2.0d, -1.0d, 1.0d})), Formula.dot(new SparseVector[]{sparseVector, sparseVector3}, new SparseVector[]{sparseVector2, new SparseVector(6, new int[]{1, 2, 3}, new double[]{-2.0d, -1.0d, 1.0d})}), 1.0E-4d);
    }

    @Test
    public void testDotSparseVectorSparseVector() throws Exception {
        Assert.assertEquals(0.5d, Formula.dot(new SparseVector(6, new int[]{1, 3, 5}, new double[]{0.5d, 1.5d, 1.5d}), new SparseVector(6, new int[]{1, 2, 3}, new double[]{-2.0d, -1.0d, 1.0d})), 1.0E-4d);
    }

    @Test
    public void testMultiply() {
        Assert.assertEquals(new SparseVector(6, new int[]{1, 3, 5}, new double[]{1.0d, 2.0d, 3.0d}), Formula.multiply(new SparseVector(6, new int[]{1, 3, 5}, new double[]{0.5d, 1.0d, 1.5d}), 2.0d));
    }

    @Test
    public void testPlusSparseVectorArrayDoubleArray() throws Exception {
        Assert.assertEquals(new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{-1.0d, -3.0d, -1.0d, 2.0d}), Formula.plus(new SparseVector[]{new SparseVector(6, new int[]{1, 3, 5}, new double[]{1.0d, 1.0d, 1.0d}), new SparseVector(6, new int[]{1, 2, 3}, new double[]{-1.0d, -1.0d, -1.0d})}, new double[]{2.0d, 3.0d}));
    }

    @Test
    public void testPlusListOfSparseVectorDoubleArray() throws Exception {
        SparseVector sparseVector = new SparseVector(6, new int[]{1, 3, 5}, new double[]{1.0d, 1.0d, 1.0d});
        SparseVector sparseVector2 = new SparseVector(6, new int[]{1, 2, 3}, new double[]{-1.0d, -1.0d, -1.0d});
        SparseVector sparseVector3 = new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{-1.0d, -3.0d, -1.0d, 2.0d});
        SparseVector sparseVector4 = new SparseVector(6, new int[]{1, 3, 5}, new double[]{1.0d, 1.0d, 1.0d});
        SparseVector sparseVector5 = new SparseVector(6, new int[]{1, 2, 3}, new double[]{-1.0d, -1.0d, -1.0d});
        SparseVector sparseVector6 = new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{-1.0d, -3.0d, -1.0d, 2.0d});
        ArrayList arrayList = new ArrayList();
        arrayList.add(new SparseVector[]{sparseVector, sparseVector4});
        arrayList.add(new SparseVector[]{sparseVector2, sparseVector5});
        Assert.assertArrayEquals(new SparseVector[]{sparseVector3, sparseVector6}, Formula.plus(arrayList, new double[]{2.0d, 3.0d}));
    }

    @Test
    public void testPlusSparseVectorDoubleSparseVectorDouble() throws Exception {
        Assert.assertEquals(new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{-1.0d, -3.0d, -1.0d, 2.0d}), Formula.plus(new SparseVector(6, new int[]{1, 3, 5}, new double[]{1.0d, 1.0d, 1.0d}), 2.0d, new SparseVector(6, new int[]{1, 2, 3}, new double[]{-1.0d, -1.0d, -1.0d}), 3.0d));
    }

    @Test
    public void testNorm2() {
        Assert.assertEquals(2.0d, Formula.norm2(new SparseVector(6, new int[]{1, 2, 3, 5}, new double[]{1.0d, -1.0d, -1.0d, 1.0d})), 1.0E-4d);
    }
}
