package featuretest;

import cn.com.pconline.adclick.pipeline.CombinePipeline;
import cn.com.pconline.adclick.pipeline.FMDimensionReducer;
import cn.com.pconline.adclick.pipeline.FlowPipeline;
import cn.com.pconline.adclick.pipeline.FlowPipelineStage;
import cn.com.pconline.adclick.pipeline.MergePipeline;
import cn.com.pconline.adclick.pipeline.RowNormalizer;
import cn.com.pconline.adclick.pipeline.RowPipelineStage;
import cn.com.pconline.adclick.pipeline.StringOneHotEncoder;
import cn.com.pconline.adclick.udf.List2VectorUDF;
import java.util.List;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:featuretest/FeatureSparkTest.class */
public class FeatureSparkTest {
    private SparkSession spark;

    @Before
    public void setUp() {
        this.spark = SparkSession.builder().appName("FeatureTest").enableHiveSupport().getOrCreate();
    }

    @After
    public void tearDown() {
        this.spark.stop();
    }

    @Test
    public void testStringOneHotEncoder() throws Exception {
        StringOneHotEncoder stringOneHotEncoder = new StringOneHotEncoder("name1");
        Dataset<Row> sql = this.spark.sql("select * from adclick_lxz.test_feature where type = 'train'");
        Dataset sql2 = this.spark.sql("select * from adclick_lxz.test_feature where type = 'test'");
        stringOneHotEncoder.fit(sql);
        for (Row row : sql2.collectAsList()) {
            Vector transform = stringOneHotEncoder.transform(row);
            List list = row.getList(row.fieldIndex("onehot1"));
            for (int i = 0; i < transform.size(); i++) {
                Assert.assertEquals(((Double) list.get(i)).doubleValue(), transform.apply(i), 0.001d);
            }
        }
    }

    @Test
    public void testMergePipeline() throws Exception {
        MergePipeline mergePipeline = new MergePipeline(new RowPipelineStage[]{new StringOneHotEncoder("name1"), new StringOneHotEncoder("name2")});
        Dataset<Row> sql = this.spark.sql("select * from adclick_lxz.test_feature where type = 'train'");
        Dataset sql2 = this.spark.sql("select * from adclick_lxz.test_feature where type = 'test'");
        mergePipeline.fit(sql);
        for (Row row : sql2.collectAsList()) {
            Vector transform = mergePipeline.transform(row);
            List list = row.getList(row.fieldIndex("merge"));
            for (int i = 0; i < transform.size(); i++) {
                Assert.assertEquals(((Double) list.get(i)).doubleValue(), transform.apply(i), 0.001d);
            }
        }
    }

    @Test
    public void testFMDimensionReducer() throws Exception {
        this.spark.udf().register("list2Vec", new List2VectorUDF(), new VectorUDT());
        FMDimensionReducer fMDimensionReducer = new FMDimensionReducer(new String[]{"onehot1", "onehot2"}, "click");
        fMDimensionReducer.setK(2);
        fMDimensionReducer.setLambda(0.0d);
        fMDimensionReducer.setIternum(1);
        fMDimensionReducer.setLearnrate(1.0d);
        Dataset<Row> sql = this.spark.sql("select list2Vec(onehot1) as onehot1, list2Vec(onehot2) as onehot2, click from adclick_lxz.test_feature where type = 'train'");
        Dataset sql2 = this.spark.sql("select list2Vec(onehot1) as onehot1, list2Vec(onehot2) as onehot2, fmreduce from adclick_lxz.test_feature where type = 'test'");
        fMDimensionReducer.fit(sql);
        List collectAsList = sql2.collectAsList();
        for (int i = 0; i < collectAsList.size(); i++) {
            Row row = (Row) collectAsList.get(i);
            Vector transform = fMDimensionReducer.transform(row);
            List list = row.getList(row.fieldIndex("fmreduce"));
            for (int i2 = 0; i2 < transform.size(); i2++) {
                Assert.assertEquals("row:" + row.get(row.fieldIndex("onehot1")).toString() + row.get(row.fieldIndex("onehot2")).toString(), ((Double) list.get(i2)).doubleValue(), transform.apply(i2), 0.001d);
            }
        }
    }

    @Test
    public void testFlowPipeline() throws Exception {
        FlowPipeline flowPipeline = new FlowPipeline(new MergePipeline(new RowPipelineStage[]{new StringOneHotEncoder("name1"), new StringOneHotEncoder("name2")}), new FlowPipelineStage[]{new RowNormalizer("name1")}, "click");
        Dataset<Row> sql = this.spark.sql("select * from adclick_lxz.test_feature where type = 'train'");
        Dataset sql2 = this.spark.sql("select * from adclick_lxz.test_feature where type = 'test'");
        flowPipeline.fit(sql);
        List collectAsList = sql2.collectAsList();
        for (int i = 0; i < collectAsList.size(); i++) {
            Row row = (Row) collectAsList.get(i);
            Vector transform = flowPipeline.transform(row);
            List list = row.getList(row.fieldIndex("flow"));
            for (int i2 = 0; i2 < transform.size(); i2++) {
                Assert.assertEquals("i:" + i + ";j:" + i2, ((Double) list.get(i2)).doubleValue(), transform.apply(i2), 0.001d);
            }
        }
    }

    @Test
    public void testCombinePipeline() throws Exception {
        StringOneHotEncoder stringOneHotEncoder = new StringOneHotEncoder("name1");
        StringOneHotEncoder stringOneHotEncoder2 = new StringOneHotEncoder("name2");
        FMDimensionReducer fMDimensionReducer = new FMDimensionReducer(new String[]{"name1", "name2"}, "click");
        fMDimensionReducer.setK(2);
        fMDimensionReducer.setLambda(0.0d);
        fMDimensionReducer.setIternum(1);
        fMDimensionReducer.setLearnrate(1.0d);
        CombinePipeline combinePipeline = new CombinePipeline(fMDimensionReducer, new RowPipelineStage[]{stringOneHotEncoder, stringOneHotEncoder2}, "click");
        Dataset<Row> sql = this.spark.sql("select * from adclick_lxz.test_feature where type = 'train'");
        Dataset sql2 = this.spark.sql("select * from adclick_lxz.test_feature where type = 'test'");
        combinePipeline.fit(sql);
        for (Row row : sql2.collectAsList()) {
            Vector transform = combinePipeline.transform(row);
            List list = row.getList(row.fieldIndex("fmreduce"));
            for (int i = 0; i < transform.size(); i++) {
                Assert.assertEquals(((Double) list.get(i)).doubleValue(), transform.apply(i), 0.001d);
            }
        }
    }
}
