package cn.com.pconline.adclick.pipeline;

import cn.com.pconline.adclick.Constant;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Tuple2;

/* loaded from: input_file:cn/com/pconline/adclick/pipeline/CombinePipeline.class */
public class CombinePipeline implements RowPipelineStage, Serializable {
    private static final long serialVersionUID = -231539511978366928L;
    private CombinePipelineStage last;
    private RowPipelineStage[] stages;
    private String labelCol;

    public CombinePipeline(CombinePipelineStage combinePipelineStage, RowPipelineStage[] rowPipelineStageArr, String str) {
        this.last = combinePipelineStage;
        this.stages = rowPipelineStageArr;
        this.labelCol = str;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public void fit(Dataset<Row> dataset) throws Exception {
        for (int i = 0; i < this.stages.length; i++) {
            this.stages[i].fit(dataset);
        }
        this.last.fit(dataset.javaRDD().mapToPair(new PairFunction<Row, Double, Vector[]>() { // from class: cn.com.pconline.adclick.pipeline.CombinePipeline.1
            private static final long serialVersionUID = 2213885750622494519L;

            public Tuple2<Double, Vector[]> call(Row row) throws Exception {
                Vector[] vectorArr = new Vector[CombinePipeline.this.stages.length];
                for (int i2 = 0; i2 < CombinePipeline.this.stages.length; i2++) {
                    vectorArr[i2] = CombinePipeline.this.stages[i2].transform(row);
                }
                return new Tuple2<>(Double.valueOf(row.getDouble(row.fieldIndex(CombinePipeline.this.labelCol))), vectorArr);
            }
        }));
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public Vector transform(Row row) throws Exception {
        Vector[] vectorArr = new Vector[this.stages.length];
        for (int i = 0; i < this.stages.length; i++) {
            vectorArr[i] = this.stages[i].transform(row);
        }
        return this.last.transform(vectorArr);
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public int resultLength() {
        return this.last.resultLength();
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public List<Set<String>> inputsForOutputs() {
        List<Set<String>> inputsForOutputs = this.last.inputsForOutputs();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.stages.length; i++) {
            arrayList.add(this.stages[i].inputsForOutputs());
        }
        HashMap hashMap = new HashMap();
        String[] inputCols = this.last.getInputCols();
        for (int i2 = 0; i2 < inputCols.length; i2++) {
            hashMap.put(inputCols[i2], Integer.valueOf(i2));
        }
        for (int i3 = 0; i3 < inputsForOutputs.size(); i3++) {
            HashSet hashSet = new HashSet();
            for (String str : inputsForOutputs.get(i3)) {
                hashSet.addAll((Collection) ((List) arrayList.get(((Integer) hashMap.get(str.substring(0, str.lastIndexOf(Constant.INPUTCOL_SEP)))).intValue())).get(Integer.valueOf(str.substring(str.lastIndexOf(Constant.INPUTCOL_SEP) + 1, str.length())).intValue()));
            }
            inputsForOutputs.set(i3, hashSet);
        }
        return inputsForOutputs;
    }
}
