package cn.com.pconline.adclick.mixexperts;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

/* loaded from: input_file:cn/com/pconline/adclick/mixexperts/Formula.class */
public class Formula {
    private static final transient Logger LOG = Logger.getLogger(MixExpertsTrainer.class);

    public static SparseVector[] piFun(SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2) {
        SparseVector[] sparseVectorArr3 = new SparseVector[sparseVectorArr.length];
        for (int i = 0; i < sparseVectorArr3.length; i++) {
            int[] indices = sparseVectorArr[i].indices();
            double[] values = sparseVectorArr[i].values();
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < values.length; i2++) {
                if (Math.abs(Math.signum(values[i2]) - Math.signum(sparseVectorArr2[i].apply(indices[i2]))) < 0.001d) {
                    arrayList.add(Integer.valueOf(i2));
                }
            }
            int[] iArr = new int[arrayList.size()];
            double[] dArr = new double[arrayList.size()];
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                iArr[i3] = indices[((Integer) arrayList.get(i3)).intValue()];
                dArr[i3] = values[((Integer) arrayList.get(i3)).intValue()];
            }
            sparseVectorArr3[i] = new SparseVector(sparseVectorArr[i].size(), iArr, dArr);
        }
        return sparseVectorArr3;
    }

    public static SparseVector[] piFun(SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2, SparseVector[] sparseVectorArr3) {
        SparseVector[] sparseVectorArr4 = new SparseVector[sparseVectorArr.length];
        for (int i = 0; i < sparseVectorArr4.length; i++) {
            int[] indices = sparseVectorArr[i].indices();
            double[] values = sparseVectorArr[i].values();
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < values.length; i2++) {
                if (Math.abs(sparseVectorArr2[i].apply(indices[i2])) < 0.001d) {
                    if (Math.abs(Math.signum(values[i2]) - Math.signum(sparseVectorArr3[i].apply(indices[i2]))) < 0.001d) {
                        arrayList.add(Integer.valueOf(i2));
                    }
                } else if (Math.abs(Math.signum(values[i2]) - Math.signum(sparseVectorArr2[i].apply(indices[i2]))) < 0.001d) {
                    arrayList.add(Integer.valueOf(i2));
                }
            }
            int[] iArr = new int[arrayList.size()];
            double[] dArr = new double[arrayList.size()];
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                iArr[i3] = indices[((Integer) arrayList.get(i3)).intValue()];
                dArr[i3] = values[((Integer) arrayList.get(i3)).intValue()];
            }
            sparseVectorArr4[i] = new SparseVector(sparseVectorArr[i].size(), iArr, dArr);
        }
        return sparseVectorArr4;
    }

    public static SparseVector[] backtrackingLineSearch(JavaRDD<LabeledPoint> javaRDD, SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2, SparseVector[] sparseVectorArr3, double d, double d2, int i, double d3, double d4, double d5, long j) throws Exception {
        if (d4 >= 1.0d || d4 <= 0.0d) {
            throw new Exception("backtrackingLineSearch gamma must be in (0,1)");
        }
        if (d3 <= 0.0d) {
            throw new Exception("backtrackingLineSearch initAlpha must be greater than 0");
        }
        double directionalDerivative = d5 * directionalDerivative(sparseVectorArr2, sparseVectorArr, sparseVectorArr3, d, d2);
        SparseVector[][] sparseVectorArr4 = new SparseVector[i + 1][sparseVectorArr.length];
        sparseVectorArr4[0] = sparseVectorArr;
        double d6 = d3;
        for (int i2 = 1; i2 <= i; i2++) {
            sparseVectorArr4[i2] = new SparseVector[sparseVectorArr.length];
            for (int i3 = 0; i3 < sparseVectorArr.length; i3++) {
                sparseVectorArr4[i2][i3] = plus(sparseVectorArr[i3], 1.0d, sparseVectorArr2[i3], d6);
            }
            d6 *= d4;
        }
        Double[] logloss = logloss(javaRDD, sparseVectorArr4, j);
        System.out.println("backtracking linesearch loglosslist--- " + Arrays.toString(logloss));
        System.out.println("backtracking linesearch last loss-- " + logloss[0]);
        double targetFun = targetFun(sparseVectorArr, logloss[0].doubleValue(), d, d2);
        double d7 = d3;
        for (int i4 = 1; i4 <= i; i4++) {
            if (targetFun(sparseVectorArr4[i4], logloss[i4].doubleValue(), d, d2) <= targetFun + (d7 * directionalDerivative)) {
                LOG.debug("logoss  [" + i4 + "]----" + logloss[i4]);
                LOG.debug("backtracking linesearch satisfied loss [" + i4 + "]--- " + logloss[i4]);
                return sparseVectorArr4[i4];
            }
            d7 *= d4;
        }
        LOG.debug("logoss [" + i + "]----" + logloss[i]);
        return sparseVectorArr4[i];
    }

    public static double targetFun(SparseVector[] sparseVectorArr, double d, double d2, double d3) {
        return d + (d2 * norm21(sparseVectorArr)) + (d3 * norm1(sparseVectorArr));
    }

    public static double directionalDerivative(SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2, SparseVector[] sparseVectorArr3, double d, double d2) throws Exception {
        double dot = dot(sparseVectorArr, sparseVectorArr3);
        if (dot < -1.0d || dot == 0.0d || Double.isNaN(dot)) {
            System.out.println("directionalDerivative dot:" + dot);
            System.out.println("directionalDerivative p:" + sparseVectorArr);
        }
        int size = sparseVectorArr[0].size();
        for (int i = 0; i < size; i++) {
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            for (int i2 = 0; i2 < sparseVectorArr2.length; i2++) {
                double apply = sparseVectorArr2[i2].apply(i);
                double apply2 = sparseVectorArr[i2].apply(i);
                d3 += Math.pow(apply, 2.0d);
                d4 += Math.pow(apply2, 2.0d);
                d5 = apply * sparseVectorArr[i2].apply(i);
                double signum = Math.abs(apply) > 0.001d ? d2 * Math.signum(apply) * apply2 : d2 * Math.abs(apply2);
                dot += signum;
                if (signum < -1.0d || Double.isNaN(signum)) {
                    System.out.println("directionalDerivative p_" + i + "_" + i2 + ":" + apply2 + ";theta_ij:" + apply);
                }
            }
            double sqrt = Math.abs(d3) > 0.001d ? (d * d5) / Math.sqrt(d3) : d * Math.sqrt(d4);
            dot += sqrt;
            if (sqrt < -1.0d || Double.isNaN(sqrt)) {
                System.out.println("directionalDerivative theta_p:" + d5 + ";theta_norm:" + d3 + ";p_norm:" + d4);
            }
        }
        return dot;
    }

    public static SparseVector[] targetFun_derivative(SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2, double d, double d2) throws Exception {
        SparseVector[] sparseVectorArr3 = new SparseVector[sparseVectorArr.length];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            double[] dArr = new double[sparseVectorArr[i].size()];
            for (int i2 = 0; i2 < sparseVectorArr[i].size(); i2++) {
                if (Math.abs(sparseVectorArr[i].apply(i2)) > 0.001d) {
                    dArr[i2] = sparseVectorArr2[i].apply(i2) + (2.0d * d * sparseVectorArr[i].apply(i2)) + (d2 * Math.signum(sparseVectorArr[i].apply(i2)));
                } else {
                    double apply = (-sparseVectorArr2[i].apply(i2)) - ((2.0d * d) * sparseVectorArr[i].apply(i2));
                    dArr[i2] = (-Math.max(Math.abs(apply) - d2, 0.0d)) * Math.signum(apply);
                }
            }
            sparseVectorArr3[i] = Vectors.dense(dArr).toSparse();
        }
        return sparseVectorArr3;
    }

    public static void updateProducts(List<List<Double>> list, int i, int i2, SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2, SparseVector[] sparseVectorArr3, List<SparseVector[]> list2, List<SparseVector[]> list3) throws Exception {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i3 = 0; i3 < list2.size(); i3++) {
            arrayList.add(Double.valueOf(dot(sparseVectorArr, list2.get(i3))));
            arrayList2.add(Double.valueOf(dot(sparseVectorArr2, list2.get(i3))));
            arrayList3.add(Double.valueOf(dot(sparseVectorArr3, list2.get(i3))));
        }
        for (int i4 = 0; i4 < list3.size(); i4++) {
            arrayList.add(Double.valueOf(dot(sparseVectorArr, list3.get(i4))));
            arrayList2.add(Double.valueOf(dot(sparseVectorArr2, list3.get(i4))));
            arrayList3.add(Double.valueOf(dot(sparseVectorArr3, list3.get(i4))));
        }
        arrayList.add(Double.valueOf(dot(sparseVectorArr, sparseVectorArr3)));
        arrayList2.add(Double.valueOf(dot(sparseVectorArr2, sparseVectorArr3)));
        arrayList3.add(Double.valueOf(dot(sparseVectorArr3, sparseVectorArr3)));
        if (i2 > i) {
            list.remove(i * 2);
            list.remove(i);
            list.remove(0);
            list.add(i - 1, arrayList);
            list.add((i * 2) - 1, arrayList2);
            list.add(i * 2, arrayList3);
            for (int i5 = 0; i5 < i - 1; i5++) {
                List<Double> list4 = list.get(i5);
                list4.remove(i * 2);
                list4.remove(i);
                list4.remove(0);
                list4.add(i - 1, arrayList.get(i5));
                list4.add((i * 2) - 1, arrayList2.get(i5));
                list4.add(i * 2, arrayList3.get(i5));
                List<Double> list5 = list.get(i + i5);
                list5.remove(i * 2);
                list5.remove(i);
                list5.remove(0);
                list5.add(i - 1, arrayList.get(i + i5));
                list5.add((i * 2) - 1, arrayList2.get(i + i5));
                list5.add(i * 2, arrayList3.get(i + i5));
            }
        } else {
            if (i2 > 1) {
                list.remove((i2 - 1) * 2);
            }
            list.add(i2 - 1, arrayList);
            list.add((i2 * 2) - 1, arrayList2);
            list.add(i2 * 2, arrayList3);
            for (int i6 = 0; i6 < i2 - 1; i6++) {
                List<Double> list6 = list.get(i6);
                list6.remove((i2 - 1) * 2);
                list6.add(i2 - 1, arrayList.get(i6));
                list6.add((i2 * 2) - 1, arrayList2.get(i6));
                list6.add(i2 * 2, arrayList3.get(i6));
                List<Double> list7 = list.get(i2 + i6);
                list7.remove((i2 - 1) * 2);
                list7.add(i2 - 1, arrayList.get(i2 + i6));
                list7.add((i2 * 2) - 1, arrayList2.get(i2 + i6));
                list7.add(i2 * 2, arrayList3.get(i2 + i6));
            }
        }
        ArrayList arrayList4 = new ArrayList();
        Iterator<List<Double>> it = list.iterator();
        while (it.hasNext()) {
            arrayList4.add(Integer.valueOf(it.next().size()));
        }
        LOG.debug("products row size----" + arrayList4.toString());
    }

    public static SparseVector[] vlbfgs(List<SparseVector[]> list, List<List<Double>> list2, int i, int i2) throws Exception {
        int min = Math.min(i, i2);
        double[] dArr = new double[list.size()];
        dArr[list.size() - 1] = 1.0d;
        double[] dArr2 = new double[min];
        for (int i3 = min - 1; i3 >= 0; i3--) {
            List<Double> list3 = list2.get(i3);
            double d = 0.0d;
            for (int i4 = 0; i4 < dArr.length; i4++) {
                try {
                    d += dArr[i4] * list3.get(i4).doubleValue();
                } catch (IndexOutOfBoundsException e) {
                    LOG.error("vlbfgs IndexOutOfBoundsException i:" + i3 + ";j:" + i4 + ";products:" + list2.size() + ";products_row:" + list3.size());
                    throw e;
                }
            }
            dArr2[i3] = d / list3.get(min + i3).doubleValue();
            dArr[min + i3] = dArr[min + i3] - dArr2[i3];
        }
        double doubleValue = list2.get(2 * min).get(min).doubleValue() / list2.get(2 * min).get(2 * min).doubleValue();
        for (int i5 = 0; i5 < dArr.length; i5++) {
            dArr[i5] = doubleValue * dArr[i5];
        }
        for (int i6 = 0; i6 < min; i6++) {
            List<Double> list4 = list2.get(min + i6);
            double d2 = 0.0d;
            for (int i7 = 0; i7 < dArr.length; i7++) {
                d2 += dArr[i7] * list4.get(i7).doubleValue();
            }
            dArr[i6] = (dArr[i6] + dArr2[i6]) - (d2 / list4.get(i6).doubleValue());
        }
        return plus(list, dArr);
    }

    public static SparseVector[] lbfgs(SparseVector[] sparseVectorArr, List<SparseVector[]> list, List<SparseVector[]> list2) throws Exception {
        double[] dArr = new double[list.size()];
        SparseVector[] sparseVectorArr2 = sparseVectorArr;
        for (int size = list.size() - 1; size >= 0; size--) {
            dArr[size] = dot(list.get(size), sparseVectorArr2) / dot(list2.get(size), list.get(size));
            sparseVectorArr2 = plus(sparseVectorArr2, 1.0d, list2.get(size), dArr[size]);
        }
        SparseVector[] plus = plus(sparseVectorArr2, dot(list.get(list.size() - 1), list2.get(list2.size() - 1)) / dot(list2.get(list2.size() - 1), list2.get(list2.size() - 1)), sparseVectorArr2, 0.0d);
        for (int i = 0; i < list.size(); i++) {
            plus = plus(plus, 1.0d, list.get(i), dArr[i] - (dot(list2.get(i), plus) / dot(list2.get(i), list.get(i))));
        }
        return plus;
    }

    public static SparseVector[] direction(SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2, double d, double d2) throws Exception {
        SparseVector[] sparseVectorArr3 = new SparseVector[sparseVectorArr2.length];
        for (int i = 0; i < sparseVectorArr2.length; i++) {
            double[] dArr = new double[sparseVectorArr2[i].size()];
            for (int i2 = 0; i2 < sparseVectorArr2[i].size(); i2++) {
                double d3 = 0.0d;
                for (SparseVector sparseVector : sparseVectorArr2) {
                    d3 += Math.pow(sparseVector.apply(i2), 2.0d);
                }
                double sqrt = Math.sqrt(d3);
                if (Math.abs(sparseVectorArr2[i].apply(i2)) > 0.001d) {
                    dArr[i2] = ((-sparseVectorArr[i].apply(i2)) - ((d * sparseVectorArr2[i].apply(i2)) / sqrt)) - (d2 * Math.signum(sparseVectorArr2[i].apply(i2)));
                    if (Double.isInfinite(dArr[i2])) {
                        System.out.println("direction i:" + i + ";j:" + i2 + " thetai!=0 loss_derivative:" + sparseVectorArr[i].apply(i2) + ";thetai:" + sparseVectorArr2[i].apply(i2) + ";theta_norm:" + sqrt);
                    }
                } else if (Math.abs(sqrt) > 0.001d) {
                    double apply = (-sparseVectorArr[i].apply(i2)) - ((d * sparseVectorArr2[i].apply(i2)) / sqrt);
                    dArr[i2] = Math.max(Math.abs(apply) - d2, 0.0d) * Math.signum(apply);
                    if (Double.isInfinite(dArr[i2])) {
                        System.out.println("direction i:" + i + ";j:" + i2 + " theta_normi!=0 loss_derivative:" + sparseVectorArr[i].apply(i2) + ";thetai:" + sparseVectorArr2[i].apply(i2) + ";theta_norm:" + sqrt);
                    }
                } else {
                    double d4 = 0.0d;
                    for (int i3 = 0; i3 < sparseVectorArr2.length; i3++) {
                        d4 += Math.pow(Math.max(0.0d, Math.abs(sparseVectorArr[i3].apply(i2)) - d2), 2.0d);
                    }
                    double sqrt2 = Math.sqrt(d4);
                    dArr[i2] = Math.max(0.0d, 1.0d - (d / sqrt2)) * Math.max(0.0d, Math.abs(sparseVectorArr[i].apply(i2)) - d2) * Math.signum(-sparseVectorArr[i].apply(i2));
                    if (Double.isInfinite(dArr[i2])) {
                        System.out.println("direction i:" + i + ";j:" + i2 + " theta_normi==0 loss_derivative:" + sparseVectorArr[i].apply(i2) + ";v_norm:" + sqrt2);
                    }
                }
            }
            sparseVectorArr3[i] = Vectors.dense(dArr).toSparse();
        }
        return sparseVectorArr3;
    }

    public static SparseVector[] loss_derivative(JavaRDD<LabeledPoint> javaRDD, final SparseVector[] sparseVectorArr, long j) throws Exception {
        if (sparseVectorArr.length % 2 == 1) {
            throw new Exception("loss_derivative The number of theta's rows is non-even");
        }
        final int length = sparseVectorArr.length / 2;
        SparseVector[] sparseVectorArr2 = (SparseVector[]) javaRDD.map(new Function<LabeledPoint, SparseVector[]>() { // from class: cn.com.pconline.adclick.mixexperts.Formula.1
            private static final long serialVersionUID = 973245730135457112L;

            public SparseVector[] call(LabeledPoint labeledPoint) throws Exception {
                return Formula.loss_derivative(labeledPoint.features(), labeledPoint.label(), sparseVectorArr, length);
            }
        }).reduce(new Function2<SparseVector[], SparseVector[], SparseVector[]>() { // from class: cn.com.pconline.adclick.mixexperts.Formula.2
            private static final long serialVersionUID = 1023698359150881832L;

            public SparseVector[] call(SparseVector[] sparseVectorArr3, SparseVector[] sparseVectorArr4) throws Exception {
                SparseVector[] sparseVectorArr5 = new SparseVector[sparseVectorArr3.length];
                for (int i = 0; i < sparseVectorArr3.length; i++) {
                    sparseVectorArr5[i] = Formula.plus(sparseVectorArr3[i], 1.0d, sparseVectorArr4[i], 1.0d);
                }
                return sparseVectorArr5;
            }
        });
        SparseVector[] sparseVectorArr3 = new SparseVector[sparseVectorArr2.length];
        for (int i = 0; i < sparseVectorArr2.length; i++) {
            sparseVectorArr3[i] = multiply(sparseVectorArr2[i], (-1.0d) / j);
        }
        return sparseVectorArr3;
    }

    public static SparseVector[] loss_derivative(SparseVector sparseVector, double d, SparseVector[] sparseVectorArr, int i) throws Exception {
        double d2 = Math.abs(d) < 0.001d ? -1.0d : 1.0d;
        double[] dArr = new double[i];
        double[] dArr2 = new double[i];
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[i];
        double d3 = -1.7976931348623157E308d;
        double d4 = -1.7976931348623157E308d;
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = dot(sparseVector, sparseVectorArr[i2]);
            dArr2[i2] = d2 * dot(sparseVector, sparseVectorArr[i + i2]);
            if (dArr[i2] > d3) {
                d3 = dArr[i2];
            }
            if (dArr2[i2] > 0.0d) {
                dArr4[i2] = dArr[i2];
                dArr3[i2] = -dArr2[i2];
            } else {
                dArr4[i2] = dArr[i2] + dArr2[i2];
                dArr3[i2] = dArr2[i2];
            }
            if (dArr4[i2] > d4) {
                d4 = dArr4[i2];
            }
        }
        double[] dArr5 = new double[i];
        double[] dArr6 = new double[i];
        double d5 = 0.0d;
        double d6 = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            dArr5[i3] = Math.exp(dArr[i3] - d3);
            d5 += dArr5[i3];
            dArr6[i3] = Math.exp(dArr4[i3] - d4) / (1.0d + Math.exp(dArr3[i3]));
            d6 += dArr6[i3];
        }
        SparseVector[] sparseVectorArr2 = new SparseVector[2 * i];
        double[] dArr7 = new double[2 * i];
        for (int i4 = 0; i4 < i; i4++) {
            dArr7[i4] = (dArr6[i4] / d6) - (dArr5[i4] / d5);
            dArr7[i + i4] = d2 * (dArr6[i4] / d6) * (1.0d / (1.0d + Math.exp(dArr2[i4])));
            sparseVectorArr2[i4] = multiply(sparseVector, dArr7[i4]);
            sparseVectorArr2[i + i4] = multiply(sparseVector, dArr7[i + i4]);
        }
        return sparseVectorArr2;
    }

    public static double p1(SparseVector sparseVector, SparseVector[] sparseVectorArr, int i) throws Exception {
        double[] dArr = new double[i];
        double d = -1.7976931348623157E308d;
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = dot(sparseVector, sparseVectorArr[i2]);
            if (dArr[i2] > d) {
                d = dArr[i2];
            }
        }
        double[] dArr2 = new double[i];
        double d2 = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            dArr2[i3] = Math.exp(dArr[i3] - d);
            d2 += dArr2[i3];
        }
        double d3 = 0.0d;
        for (int i4 = 0; i4 < i; i4++) {
            d3 += (dArr2[i4] / d2) / (Math.exp(-dot(sparseVector, sparseVectorArr[i + i4])) + 1.0d);
        }
        return d3;
    }

    public static Double[] logloss(JavaRDD<LabeledPoint> javaRDD, final SparseVector[][] sparseVectorArr, long j) throws Exception {
        final int length = sparseVectorArr[0].length / 2;
        Double[] dArr = (Double[]) javaRDD.map(new Function<LabeledPoint, Double[]>() { // from class: cn.com.pconline.adclick.mixexperts.Formula.3
            private static final long serialVersionUID = 7641692403240750543L;

            public Double[] call(LabeledPoint labeledPoint) throws Exception {
                Double[] dArr2 = new Double[sparseVectorArr.length];
                SparseVector features = labeledPoint.features();
                double label = labeledPoint.label();
                for (int i = 0; i < sparseVectorArr.length; i++) {
                    dArr2[i] = Double.valueOf(Formula.logloss(features, label, sparseVectorArr[i], length));
                }
                return dArr2;
            }
        }).reduce(new Function2<Double[], Double[], Double[]>() { // from class: cn.com.pconline.adclick.mixexperts.Formula.4
            private static final long serialVersionUID = 6048287195955865496L;

            public Double[] call(Double[] dArr2, Double[] dArr3) throws Exception {
                Double[] dArr4 = new Double[dArr2.length];
                for (int i = 0; i < dArr2.length; i++) {
                    dArr4[i] = Double.valueOf(dArr2[i].doubleValue() + dArr3[i].doubleValue());
                }
                return dArr4;
            }
        });
        System.out.println("logloss reduce finish " + new Date().toString());
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = Double.valueOf(dArr[i2].doubleValue() / j);
        }
        return dArr;
    }

    public static double logloss(SparseVector sparseVector, double d, SparseVector[] sparseVectorArr, int i) throws Exception {
        double d2 = Math.abs(d) < 0.001d ? 1.0d : -1.0d;
        double[] dArr = new double[i];
        double[] dArr2 = new double[i];
        double[] dArr3 = new double[i];
        double d3 = -1.7976931348623157E308d;
        double d4 = -1.7976931348623157E308d;
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = dot(sparseVector, sparseVectorArr[i2]);
            dArr2[i2] = d2 * dot(sparseVector, sparseVectorArr[i + i2]);
            if (dArr2[i2] > 0.0d) {
                dArr3[i2] = dArr[i2] - dArr2[i2];
                dArr2[i2] = -dArr2[i2];
            } else {
                dArr3[i2] = dArr[i2];
            }
            if (dArr[i2] > d3) {
                d3 = dArr[i2];
            }
            if (dArr3[i2] > d4) {
                d4 = dArr3[i2];
            }
        }
        double d5 = 0.0d;
        double d6 = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d5 += Math.exp(dArr3[i3] - d4) / (1.0d + Math.exp(dArr2[i3]));
            d6 += Math.exp(dArr[i3] - d3);
        }
        return -(((Math.log(d5) - Math.log(d6)) + d4) - d3);
    }

    public static double dot(SparseVector[] sparseVectorArr, SparseVector[] sparseVectorArr2) throws Exception {
        if (sparseVectorArr.length != sparseVectorArr2.length) {
            throw new Exception("dot The length of a is not equal to the size of b");
        }
        double d = 0.0d;
        for (int i = 0; i < sparseVectorArr.length; i++) {
            d += dot(sparseVectorArr[i], sparseVectorArr2[i]);
        }
        return d;
    }

    public static double dot(SparseVector sparseVector, SparseVector sparseVector2) throws Exception {
        if (sparseVector.size() != sparseVector2.size()) {
            throw new Exception("dot The size of a(" + sparseVector.size() + ") is not equal to the size of b(" + sparseVector2.size() + ")");
        }
        double d = 0.0d;
        int[] indices = sparseVector.indices();
        double[] values = sparseVector.values();
        for (int i = 0; i < indices.length; i++) {
            d += values[i] * sparseVector2.apply(indices[i]);
        }
        return d;
    }

    public static double dot(SparseVector sparseVector, SparseVector sparseVector2, int i) throws Exception {
        if (sparseVector.size() < i + sparseVector2.size()) {
            throw new Exception("dot The size of longVec(" + sparseVector.size() + ") is less than the size of shortVec(" + sparseVector2.size() + ") + begin(" + i + ")");
        }
        double d = 0.0d;
        int[] indices = sparseVector2.indices();
        double[] values = sparseVector2.values();
        for (int i2 = 0; i2 < indices.length; i2++) {
            d += values[i2] * sparseVector.apply(i + indices[i2]);
        }
        return d;
    }

    public static SparseVector multiply(SparseVector sparseVector, double d) {
        int size = sparseVector.size();
        int[] indices = sparseVector.indices();
        double[] values = sparseVector.values();
        double[] dArr = new double[values.length];
        for (int i = 0; i < indices.length; i++) {
            dArr[i] = values[i] * d;
        }
        return new SparseVector(size, indices, dArr);
    }

    public static SparseVector multiply(SparseVector sparseVector, double d, double d2) {
        int size = sparseVector.size();
        int[] indices = sparseVector.indices();
        double[] values = sparseVector.values();
        double[] dArr = new double[values.length];
        for (int i = 0; i < indices.length; i++) {
            dArr[i] = (values[i] * d) + d2;
        }
        return new SparseVector(size, indices, dArr);
    }

    public static SparseVector plus(SparseVector[] sparseVectorArr, double[] dArr) throws Exception {
        if (sparseVectorArr == null || sparseVectorArr.length == 0) {
            throw new Exception("plus empty vecs");
        }
        if (dArr.length != sparseVectorArr.length) {
            throw new Exception("plus the length of vecs not equal to the length of coefs");
        }
        int size = sparseVectorArr[0].size();
        for (int i = 1; i < sparseVectorArr.length; i++) {
            if (sparseVectorArr[i].size() != size) {
                throw new Exception("plus vecs different size");
            }
        }
        TreeMap treeMap = new TreeMap();
        for (int i2 = 0; i2 < sparseVectorArr.length; i2++) {
            int[] indices = sparseVectorArr[i2].indices();
            double[] values = sparseVectorArr[i2].values();
            for (int i3 = 0; i3 < values.length; i3++) {
                if (treeMap.containsKey(Integer.valueOf(indices[i3]))) {
                    treeMap.put(Integer.valueOf(indices[i3]), Double.valueOf(((Double) treeMap.get(Integer.valueOf(indices[i3]))).doubleValue() + (dArr[i2] * values[i3])));
                } else {
                    treeMap.put(Integer.valueOf(indices[i3]), Double.valueOf(dArr[i2] * values[i3]));
                }
            }
        }
        int size2 = treeMap.size();
        int[] iArr = new int[size2];
        double[] dArr2 = new double[size2];
        int i4 = 0;
        for (Map.Entry entry : treeMap.entrySet()) {
            iArr[i4] = ((Integer) entry.getKey()).intValue();
            dArr2[i4] = ((Double) entry.getValue()).doubleValue();
            i4++;
        }
        return new SparseVector(size, iArr, dArr2);
    }

    public static SparseVector[] plus(List<SparseVector[]> list, double[] dArr) throws Exception {
        SparseVector[] sparseVectorArr = new SparseVector[list.get(0).length];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            SparseVector[] sparseVectorArr2 = new SparseVector[dArr.length];
            for (int i2 = 0; i2 < list.size(); i2++) {
                sparseVectorArr2[i2] = list.get(i2)[i];
            }
            sparseVectorArr[i] = plus(sparseVectorArr2, dArr);
        }
        return sparseVectorArr;
    }

    public static SparseVector plus(SparseVector sparseVector, double d, SparseVector sparseVector2, double d2) throws Exception {
        return plus(new SparseVector[]{sparseVector, sparseVector2}, new double[]{d, d2});
    }

    public static SparseVector[] plus(SparseVector[] sparseVectorArr, double d, SparseVector[] sparseVectorArr2, double d2) throws Exception {
        ArrayList arrayList = new ArrayList();
        arrayList.add(sparseVectorArr);
        arrayList.add(sparseVectorArr2);
        return plus(arrayList, new double[]{d, d2});
    }

    public static SparseVector sign(SparseVector sparseVector) {
        int[] indices = sparseVector.indices();
        double[] values = sparseVector.values();
        double[] dArr = new double[values.length];
        for (int i = 0; i < values.length; i++) {
            dArr[i] = Math.signum(values[i]);
        }
        return new SparseVector(sparseVector.size(), indices, dArr);
    }

    public static double norm1(SparseVector[] sparseVectorArr) {
        double d = 0.0d;
        for (SparseVector sparseVector : sparseVectorArr) {
            d += norm1(sparseVector);
        }
        return d;
    }

    public static double norm1(SparseVector sparseVector) {
        double d = 0.0d;
        for (double d2 : sparseVector.values()) {
            d += Math.abs(d2);
        }
        return d;
    }

    public static double norm21(SparseVector[] sparseVectorArr) {
        int size = sparseVectorArr[0].size();
        double d = 0.0d;
        for (int i = 0; i < size; i++) {
            double d2 = 0.0d;
            for (SparseVector sparseVector : sparseVectorArr) {
                d2 += Math.pow(sparseVector.apply(i), 2.0d);
            }
            d += Math.sqrt(d2);
        }
        return d;
    }

    public static SparseVector norm2(SparseVector[] sparseVectorArr) {
        int size = sparseVectorArr[0].size();
        double[] dArr = new double[size];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            for (int i2 : sparseVectorArr[i].indices()) {
                dArr[i2] = dArr[i2] + Math.pow(sparseVectorArr[i].apply(i2), 2.0d);
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            dArr[i3] = Math.sqrt(dArr[i3]);
        }
        return Vectors.dense(dArr).toSparse();
    }

    public static double norm2(SparseVector sparseVector) {
        double d = 0.0d;
        for (double d2 : sparseVector.values()) {
            d += Math.pow(d2, 2.0d);
        }
        return Math.sqrt(d);
    }

    public static double sumOfSquare(SparseVector[] sparseVectorArr) {
        double d = 0.0d;
        for (SparseVector sparseVector : sparseVectorArr) {
            d += sumOfSquare(sparseVector);
        }
        return d;
    }

    public static double sumOfSquare(SparseVector sparseVector) {
        double d = 0.0d;
        for (double d2 : sparseVector.values()) {
            d += Math.pow(d2, 2.0d);
        }
        return d;
    }

    public static void main(String[] strArr) throws Exception {
    }
}
