package cn.com.pconline.adclick.mixexperts;

import cn.com.pconline.adclick.Evaluation;
import cn.com.pconline.adclick.Trainer;
import cn.com.pconline.adclick.Utils;
import cn.com.pconline.adclick.pipeline.CategoryChoose;
import cn.com.pconline.adclick.pipeline.FlowPipeline;
import cn.com.pconline.adclick.pipeline.FlowPipelineStage;
import cn.com.pconline.adclick.pipeline.MergePipeline;
import cn.com.pconline.adclick.pipeline.RowNormalizer;
import cn.com.pconline.adclick.pipeline.RowPipelineStage;
import cn.com.pconline.adclick.pipeline.RowStandardScaler;
import cn.com.pconline.adclick.pipeline.StringOneHotEncoder;
import cn.com.pconline.adclick.pipeline.TransformToVector;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.codehaus.jackson.map.ObjectMapper;
import scala.Tuple2;

/* loaded from: input_file:cn/com/pconline/adclick/mixexperts/MixExpertsTrainer.class */
public class MixExpertsTrainer extends Trainer {
    private static final long serialVersionUID = -2626943573106173812L;
    private static final transient Logger LOG = Logger.getLogger(MixExpertsTrainer.class);

    public MixExpertsTrainer() {
    }

    public MixExpertsTrainer(String str) throws IOException {
        this(Utils.loadProperties(str));
    }

    public MixExpertsTrainer(Properties properties) {
        super(properties);
        LOG.debug("properties----" + properties.toString());
    }

    @Override // cn.com.pconline.adclick.Trainer
    public void train(JavaRDD<LabeledPoint> javaRDD) throws Exception {
        Properties properties = getProperties();
        trainModels(javaRDD, Integer.valueOf(properties.getProperty("num_areas", "10")).intValue(), Integer.valueOf(properties.getProperty("num_iter", "50")).intValue(), Double.valueOf(properties.getProperty("alpha", "1.0")).doubleValue(), Double.valueOf(properties.getProperty("beta", "1.0")).doubleValue(), Integer.valueOf(properties.getProperty("num_vec2save", "20")).intValue(), Double.valueOf(properties.getProperty("alpha4Linesearch", "2.0")).doubleValue(), Double.valueOf(properties.getProperty("gamma4Linesearch", "0.5")).doubleValue(), Double.valueOf(properties.getProperty("c4Linesearch", "1.0")).doubleValue(), Integer.valueOf(properties.getProperty("iternum4Linesearch", "5")).intValue());
    }

    @Override // cn.com.pconline.adclick.Trainer
    public Evaluation evaluate(JavaRDD<LabeledPoint> javaRDD) throws Exception {
        MixExpertsModel mixExpertsModel = (MixExpertsModel) Utils.loadModelFromHDFS(getProperties().getProperty("moe_train_model"), javaRDD.context().hadoopConfiguration());
        SparseVector[] theta = mixExpertsModel.getTheta();
        for (int i = 0; i < theta.length; i++) {
            LOG.debug("final theta[" + i + "]--- " + theta[i].toString());
        }
        mixExpertsModel.clearThreshold();
        return new Evaluation(javaRDD, mixExpertsModel);
    }

    public void trainModels(JavaRDD<LabeledPoint> javaRDD, int i, int i2, double d, double d2, int i3, double d3, double d4, double d5, int i4) throws Exception {
        javaRDD.cache();
        int size = ((LabeledPoint) javaRDD.first()).features().size();
        long count = javaRDD.count();
        double d6 = d / ((i * 2) * size);
        double d7 = d2 / ((i * 2) * size);
        SparseVector[] initTheta = initTheta(i * 2, size);
        SparseVector[] loss_derivative = Formula.loss_derivative(javaRDD, initTheta, count);
        SparseVector[] direction = Formula.direction(loss_derivative, initTheta, d6, d7);
        SparseVector[] piFun = Formula.piFun(Formula.backtrackingLineSearch(javaRDD, initTheta, direction, loss_derivative, d6, d7, i4, d3, d4, d5, count), initTheta, direction);
        SparseVector[] sparseVectorArr = initTheta;
        SparseVector[] sparseVectorArr2 = direction;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i5 = 0; i5 < piFun.length; i5++) {
            System.out.println("k=0 theta[" + i5 + "]----" + sparseVectorArr[i5].toString());
        }
        for (int i6 = 1; i6 < i2; i6++) {
            LOG.debug("loop " + i6 + " begin!");
            SparseVector[] loss_derivative2 = Formula.loss_derivative(javaRDD, piFun, count);
            LOG.debug("loss_derivative end!");
            SparseVector[] sparseVectorArr3 = sparseVectorArr2;
            sparseVectorArr2 = Formula.direction(loss_derivative2, piFun, d6, d7);
            LOG.debug("direction end!");
            SparseVector[] sparseVectorArr4 = new SparseVector[piFun.length];
            for (int i7 = 0; i7 < piFun.length; i7++) {
                sparseVectorArr4[i7] = Formula.plus(piFun[i7], 1.0d, sparseVectorArr[i7], -1.0d);
            }
            arrayList.add(sparseVectorArr4);
            SparseVector[] sparseVectorArr5 = new SparseVector[sparseVectorArr2.length];
            for (int i8 = 0; i8 < sparseVectorArr2.length; i8++) {
                sparseVectorArr5[i8] = Formula.plus(sparseVectorArr2[i8], -1.0d, sparseVectorArr3[i8], 1.0d);
            }
            arrayList2.add(sparseVectorArr5);
            LOG.debug("update s and y end!");
            if (arrayList.size() > i3) {
                arrayList.remove(0);
                arrayList2.remove(0);
            }
            Formula.updateProducts(arrayList3, i3, i6, sparseVectorArr4, sparseVectorArr5, sparseVectorArr2, arrayList, arrayList2);
            LOG.debug("update products end!");
            SparseVector[] sparseVectorArr6 = sparseVectorArr2;
            if (Formula.dot(sparseVectorArr4, sparseVectorArr5) > 0.0d) {
                ArrayList arrayList4 = new ArrayList();
                arrayList4.addAll(arrayList);
                arrayList4.addAll(arrayList2);
                arrayList4.add(sparseVectorArr2);
                sparseVectorArr6 = Formula.piFun(Formula.vlbfgs(arrayList4, arrayList3, i6, i3), sparseVectorArr2);
            }
            LOG.debug("calculate p end!");
            SparseVector[] backtrackingLineSearch = Formula.backtrackingLineSearch(javaRDD, piFun, sparseVectorArr6, loss_derivative2, d6, d7, i4, d3, d4, d5, count);
            LOG.debug("backtrackingLineSearch end!");
            sparseVectorArr = piFun;
            piFun = Formula.piFun(backtrackingLineSearch, sparseVectorArr, sparseVectorArr6);
        }
        javaRDD.unpersist();
        Utils.saveModelToHDFS(new MixExpertsModel(piFun, i), javaRDD.context().hadoopConfiguration(), getProperties().getProperty("moe_train_model"));
    }

    @Override // cn.com.pconline.adclick.Trainer
    public void preprocessTrain(Dataset<Row> dataset) throws Exception {
        LOG.debug("preprocess train begin");
        FlowPipeline flowPipeline = new FlowPipeline(new TransformToVector(new String[]{"ad_price", "content_price", "ad_shownum", "ad_clicknum", "location_shownum", "location_clicknum"}), new FlowPipelineStage[]{new RowStandardScaler("scale")}, "click");
        String str = "select " + getProperties().getProperty("labels_column") + " from " + getProperties().getProperty("labels_table") + " where ds = " + getProperties().getProperty("traindate") + " limit " + getProperties().getProperty("labelnum", "1000");
        MergePipeline mergePipeline = new MergePipeline(new RowPipelineStage[]{flowPipeline, new StringOneHotEncoder("hour"), new StringOneHotEncoder("province_id"), new StringOneHotEncoder("content_category_id"), new StringOneHotEncoder("content_brand_id"), new StringOneHotEncoder("ad_category_id"), new StringOneHotEncoder("ad_brand_id"), new FlowPipeline(new CategoryChoose("category_uv", str), new FlowPipelineStage[]{new RowNormalizer("category_uv_norm")}, "click"), new FlowPipeline(new CategoryChoose("interests", str), new FlowPipelineStage[]{new RowNormalizer("interests_norm")}, "click")});
        mergePipeline.fit(dataset);
        Configuration hadoopConfiguration = dataset.sparkSession().sparkContext().hadoopConfiguration();
        Path path = new Path(getProperties().getProperty("preprocess_train_model"));
        ObjectOutputStream objectOutputStream = null;
        FileSystem fileSystem = null;
        try {
            try {
                fileSystem = FileSystem.get(hadoopConfiguration);
                objectOutputStream = new ObjectOutputStream(fileSystem.create(path));
                objectOutputStream.writeObject(mergePipeline);
                if (objectOutputStream != null) {
                    try {
                        objectOutputStream.close();
                    } catch (Exception e) {
                        throw e;
                    }
                }
                if (fileSystem != null) {
                    fileSystem.close();
                }
                LOG.debug("preprocess train end");
            } catch (Exception e2) {
                throw e2;
            }
        } catch (Throwable th) {
            if (objectOutputStream != null) {
                try {
                    objectOutputStream.close();
                } catch (Exception e3) {
                    throw e3;
                }
            }
            if (fileSystem != null) {
                fileSystem.close();
            }
            throw th;
        }
    }

    @Override // cn.com.pconline.adclick.Trainer
    public JavaRDD<LabeledPoint> preprocess(Dataset<Row> dataset) throws Exception {
        final MergePipeline mergePipeline = (MergePipeline) Utils.loadModelFromHDFS(getProperties().getProperty("preprocess_train_model"), dataset.sparkSession().sparkContext().hadoopConfiguration());
        JavaPairRDD mapToPair = dataset.javaRDD().mapToPair(new PairFunction<Row, Double, Vector>() { // from class: cn.com.pconline.adclick.mixexperts.MixExpertsTrainer.1
            private static final long serialVersionUID = 7888067592621820511L;

            public Tuple2<Double, Vector> call(Row row) throws Exception {
                return new Tuple2<>(Double.valueOf(row.getDouble(row.fieldIndex("click"))), mergePipeline.transform(row));
            }
        });
        if (getProperties().containsKey("sample_rate")) {
            Map map = (Map) new ObjectMapper().readValue(getProperties().getProperty("sample_rate"), Map.class);
            Map countByKey = mapToPair.countByKey();
            double doubleValue = Double.valueOf(getProperties().getProperty("maxsample")).doubleValue();
            for (Double d : countByKey.keySet()) {
                long longValue = ((Long) countByKey.get(d)).longValue();
                Iterator it = map.keySet().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Object next = it.next();
                    if (Math.abs(d.doubleValue() - Double.valueOf((String) next).doubleValue()) < 0.001d) {
                        double doubleValue2 = longValue / ((Double) map.get(next)).doubleValue();
                        if (doubleValue2 < doubleValue) {
                            doubleValue = doubleValue2;
                        }
                    }
                }
            }
            HashMap hashMap = new HashMap();
            for (Map.Entry entry : countByKey.entrySet()) {
                double doubleValue3 = ((Double) entry.getKey()).doubleValue();
                long longValue2 = ((Long) entry.getValue()).longValue();
                double d2 = 0.0d;
                Iterator it2 = map.keySet().iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    Object next2 = it2.next();
                    if (Math.abs(doubleValue3 - Double.valueOf((String) next2).doubleValue()) < 0.001d) {
                        d2 = (((Double) map.get(next2)).doubleValue() * doubleValue) / longValue2;
                        break;
                    }
                }
                hashMap.put(Double.valueOf(doubleValue3), Double.valueOf(d2));
            }
            mapToPair = mapToPair.sampleByKey(false, hashMap);
            LOG.info("sample size---" + mapToPair.countByKey().toString());
        }
        return mapToPair.map(new Function<Tuple2<Double, Vector>, LabeledPoint>() { // from class: cn.com.pconline.adclick.mixexperts.MixExpertsTrainer.2
            private static final long serialVersionUID = 1;

            public LabeledPoint call(Tuple2<Double, Vector> tuple2) throws Exception {
                return new LabeledPoint(((Double) tuple2._1()).doubleValue(), (Vector) tuple2._2());
            }
        });
    }

    public JavaPairRDD<Double, Vector> features2Vector(Dataset<Row> dataset) {
        return dataset.javaRDD().mapToPair(new PairFunction<Row, Double, Vector>() { // from class: cn.com.pconline.adclick.mixexperts.MixExpertsTrainer.3
            private static final long serialVersionUID = 1;

            public Tuple2<Double, Vector> call(Row row) throws Exception {
                int size = row.size();
                int fieldIndex = row.fieldIndex("click");
                double d = 0.0d;
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < size; i++) {
                    Object apply = row.apply(i);
                    if (i == fieldIndex) {
                        d = ((Double) apply).doubleValue();
                    } else if (apply instanceof Double) {
                        arrayList.add((Double) apply);
                    } else if (apply instanceof Integer) {
                        arrayList.add(Double.valueOf(((Integer) apply).doubleValue()));
                    } else if (apply instanceof org.apache.spark.ml.linalg.Vector) {
                        Vector fromML = Vectors.fromML((org.apache.spark.ml.linalg.Vector) apply);
                        for (int i2 = 0; i2 < fromML.size(); i2++) {
                            arrayList.add(Double.valueOf(fromML.apply(i2)));
                        }
                    }
                }
                arrayList.add(Double.valueOf(1.0d));
                int size2 = arrayList.size();
                double[] dArr = new double[size2];
                for (int i3 = 0; i3 < size2; i3++) {
                    dArr[i3] = ((Double) arrayList.get(i3)).doubleValue();
                }
                return new Tuple2<>(Double.valueOf(d), Vectors.dense(dArr).toSparse());
            }
        });
    }

    public Dataset<Row> selectFeatures(Dataset<Row> dataset) {
        return dataset.select("click", new String[]{"scaled", "itemrs", "consumption_level", "interests_normalized", "category_uv_normalized", "hour_encoded", "province_id_encoded", "content_category_id_encoded", "content_brand_id_encoded", "ad_category_id_encoded", "ad_brand_id_encoded"});
    }

    private SparseVector[] initTheta(int i, int i2) {
        SparseVector[] sparseVectorArr = new SparseVector[i];
        scala.collection.Iterator rowIter = Matrices.sprand(i, i2, 1.0d, new Random(42L)).rowIter();
        int i3 = 0;
        while (rowIter.hasNext()) {
            sparseVectorArr[i3] = Formula.multiply(((Vector) rowIter.next()).toSparse(), 2.0d, -1.0d);
            i3++;
        }
        return sparseVectorArr;
    }

    public static void main(String[] strArr) throws Exception {
        new MixExpertsTrainer("/data/dev_test/lxz/adclick/config.properties").trainModels(null, 1, 10, 0.0d, 0.0d, 5, 1.0d, 0.5d, 0.5d, 5);
    }
}
