package cn.com.pconline.adclick.pipeline;

import cn.com.pconline.adclick.mixexperts.Formula;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:cn/com/pconline/adclick/pipeline/CumulativePipeline.class */
public class CumulativePipeline implements RowPipelineStage, Serializable {
    private static final long serialVersionUID = 1046306309859685303L;
    private RowPipelineStage[] stages;
    private int resultLength = 0;

    public CumulativePipeline(RowPipelineStage[] rowPipelineStageArr) {
        this.stages = rowPipelineStageArr;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public void fit(Dataset<Row> dataset) throws Exception {
        for (int i = 0; i < this.stages.length; i++) {
            this.stages[i].fit(dataset);
        }
        this.resultLength = this.stages[0].resultLength();
        for (int i2 = 1; i2 < this.stages.length; i2++) {
            if (this.resultLength != this.stages[i2].resultLength()) {
                throw new Exception("cumulative pipeline different length of stages result");
            }
        }
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public Vector transform(Row row) throws Exception {
        SparseVector[] sparseVectorArr = new SparseVector[this.stages.length];
        double[] dArr = new double[this.stages.length];
        for (int i = 0; i < this.stages.length; i++) {
            sparseVectorArr[i] = this.stages[i].transform(row).toSparse();
            dArr[i] = 1.0d;
        }
        return Formula.plus(sparseVectorArr, dArr);
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public int resultLength() {
        return this.resultLength;
    }

    @Override // cn.com.pconline.adclick.pipeline.RowPipelineStage
    public List<Set<String>> inputsForOutputs() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.resultLength; i++) {
            arrayList.add(new HashSet());
        }
        for (int i2 = 0; i2 < this.stages.length; i2++) {
            List<Set<String>> inputsForOutputs = this.stages[i2].inputsForOutputs();
            for (int i3 = 0; i3 < this.resultLength; i3++) {
                ((Set) arrayList.get(i3)).addAll(inputsForOutputs.get(i3));
            }
        }
        return arrayList;
    }
}
