开发者

Java开发Spark应用程序自定义PipeLineStage详解

目录
  • 引言
  • 背景知识介绍
  • 定义一个Transformer
    • 1. 场景介绍
    • 2. 代码实现
      • 2.1 定义并封装成员变量
      • 2.2 实现抽象方法
    • 3. Pipeline的存储文件
    • 小结

      引言

      在Spark中使用Pipeline进行数据建模是一种非常高效的手段。作为Pipeline中基本数据加工处理单元——PipelineStage,Spark提供了用户自定义的抽象子类Transformer和Estimator。

      关于自定义PipelineStage的详细方法,大部分的资料和介绍都是基于Scala的。少数基于Java的介绍都极不完整,有些可能还存在一定的误导。所以接下来我们将系统的介绍用Java开发Spark如何自定义PipelineStage。

      本文使用环境:Spark-2.3.0,Java 8。

      背景知识介绍

      在spark中构建一条Pipeline需要串联多个PipelineStage,每个PipelineStage单独处理一个数据加工环节,如数据清洗、特征提取、特征选择、预估等。PipelineStage按是否有训练训练方法分为Transformer和Estimator两个抽象子类。其中Estimator可以进行训练,有fit抽象方法要实现。

      用于分类回归等任务的Predictor都继承于Estimator;而Transformer无需训练,没有fit方法,一般的数据转换器如VectorAssembler、StopWordsRemover等都是Transformer的子类。值得注意的是,所有由Estimator训练得到的Model类也都是Transformer的子类。

      自定义PipelineStage需要继承Transformer或Estimator并实现他们的方法。除此之外,我们自定义的PipelineStage要能同其他官方定义的PipelineStage一样按照统一的读写流程进行存储和加载。PipelineStage读写基于Param对象,PipelineStage中的成员变量需要用Param类进行封装,然后用PipelineStage类中已实现的Params接口方法对封装后的成员变量进行统一访问和处理。由于PipelineStage具有以上特性,我们自定义PipelineStage编程至少需要以下几个步骤:

      • 继承Transformer或Estimator抽象类;
      • 定义由Param封装的成员变量,并通过调用由PipelineStage类实现的Params接口方法,定义对成员变量进行操作的方法;
      • 实现Transformer或Estimator中的fit、transform等核心抽象方法;
      • 定义或实现读写方法用于存储和加载对象实例。

      下面我们将自定义一个Transformer,并对其中的一些细节与要点进行详述。

      定义一个Transformer

      1. 场景介绍

      本例定义一个名为SeqAssembler的Transformer,用于提取用户最近n次(n>=0,包括本单)下单的序列特征。 输入Dataset包括以下字段:user_id, buy_rn, feat1, feat2, feat3。经过SeqAssembler后输出:user_id, buy_rn, feat1, feat2, feat3, features。 其中features为数组类型,shape (3n, ) :

      输入:

      Java开发Spark应用程序自定义PipeLineStage详解

      输出:

      Java开发Spark应用程序自定义PipeLineStage详解

      2. 代码实现

      2.1 定义并封装成员变量

      SeqAssembler中要定义如下成员变量:

      private String idCol;
      private String rnCol;
      private String[] featCols;
      private String outputCol;
      private Integer limitRn;
      

      使用org.apache.spark.ml.param.Param对成员变量进行封装,String[] 用StringArrayParam封装, Integer成员变量采用String类型的Param进行封装,方便保存的时候进行json化,封装后的成员变量如下:

      private Param<String> idCol;
      private Param<String> rnCol;
      private StringArrayParam featCols;
      private Param<String> outputCol;
      private Param<String> limitRn; //Integer成员变量需要用String类型Param封装,由于保存时要调用JsonEncoder方法,JsonEncoder仅支持String、数组等类型的数据。
      

      此外我们还需要定义一个名为uid成员变量,用于识别SeqAssembler对象,并定义至少两个构造器,需要注意的细节如下:

      • uid不用声明成静态的,同一Spark进程下初始化多个SeqAssembler对象,每个SeqAssembler对象都要有自己的uid,不用全局唯一。
      • uid的初始化需要在Param成员变量初始化之前,有了uid之后才能进行Param成员变量的初始化。
      • 需要至少定义两个构造器,其中一个是无参构造器,另一个是需要传入唯一参数String uid的有参构造器,有参构造器用于load过程中构造SeqAssembler对象。各成员变量需要在构造器中完成初始化。

      至此,SeqAssembler类中定义内容如下:

      public class SeqAssembler extends Transformer {
          private String uid;
          private Param<String> idCol;
          private Param<String> rnCol;
          private StringArrayParam featCols;
          private Param<String> outputCol;
          private Param<String> limitRn;
          /**
           * 定义一个辅助Param初始化的方法,在构造器中对各Param成员变量进行初始化
           */
          public void initParam(){
              idCol = new Param<String>(this,"idCol","Column name for id");
              rnCol = new Param<String>(this,"rnCol","Column name for sequential rn");
              featCols = new StringArrayParam(this,"featCols","Column names of features");
              outputCol = new Param<String>(this,"outputCol","Column name of output");
              limitRn = new Param<String>(this,"limitRn","Column name of limitRn");
          }
          public SeqAssembler() {
              uid = Identifiable$.MODULE$.randomUID("SeqAssembler"); //uid初始化在Param类型成员变量前
              initParam();
          }
          public SeqAssembler(String value){
              uid = value; //uid初始化在Param类型成员变量前
              initParam();
          }
          @Override
          public Dataset<Row> transform(Dataset<?> dataset) {
              return null;
          }
          @Override
          public StructType transformSchema(StructType schema) {
              return null;
          }
          @Override
          public Transformer copy(ParamMap extra) {
              return null;
          }
          @Override
          public String uid() {
              return null;
          }
      }
      

      接着定义get、set方法,调用PipelineStage类中已实现的Params接口下的$()和set()方法,方便对Param封装后的成员变量进行赋值取值操作。

         /**
           * 定义对Param成员变量进行操作的get/set方法, 通过调用PipelineStage类中已实现的Params的$()、set()方法对
           * Param成员变量进行操作。
           * $()、set()对Param进行操作前会调用shouldOwn(),验证被操作的Param成员变量是否已经被维护到params数组中
           */
          public String getIdCol() {
              return this.$(idCol);
          }
          public SeqAssembler setIdCol(String value) {
              return (SeqAssembler) this.<String>set(idCol,value);
          }
          public String getRnCol() {
              return this.$(rnCol);
          }
          public SeqAssembler setRnCol(String value) {
              return (SeqAssembler) this.<String>set(rnCol,value);
          }
          public String[] getFeatCols() {
              return this.$(featCols);
          }
          public SeqAssembler setFeatCols(String[] value) {
              return (SeqAssembler) this.<String[]>set(featCols,value);
          }
          public String getOutputCol() {
              return this.$(outputCol);
          }
          public SeqAssembler setOutputCol(String value) {
              return (SeqAssembler) this.<String>set(outputCol,value);
          }
          public Integer getLimitRn() {
              return Integer.parseInt(this.$(limitRn));
          }
          public SeqAssembler setLimitRn(Integer value) {
              return (SeqAssembler) this.<String>set(limitRn,value.toString());
          }
      

      此外,我们还需要为每个Param定义一个public方法,因为Params接口会延迟加载并生成一个名为params数组。延迟加载时通过反射扫描一遍public方法, 将作为返回值的Param成员变量维护进params数组中。

      Params源码中通过反射延迟加载params数组:

      Java开发Spark应用程序自定义PipeLineStage详解

      DefaultParamsReader的load方法中通过params数组反射构造对象:

      Java开发Spark应用程序自定义PipeLineStage详解

      如果Param封装的字段缺乏作用域pubic、无参、返回类型为对应Param的方法,在load过程中通过反射构造出的对象会出现成员变量缺失,用读取的metadata装配时会出错。 因此我们需要为每个Param定义如下方法:

          /**
           * 需要为每个Param定义一个public方法, 因为Params会延迟加载并生成一个Param[] params数组,
           * params的生成方式是通过反射扫描一遍public方法, 将作为返回值的Param成员变量维护进params数组中。
           *
           * org.apache.spark.ml.param.shared下的所有接口都有一个以Param类型为返回值的方法,也是为了方便子类
           * 通过实现org.apache.spark.ml.param.shared接口,达到将Param成员变量维护进params数组的目的。
           */
          public Param<String> idCol(){
              return idCol;
          }
          public Param<String> rnCol(){
      js        return rnCol;
          }
          public StringArrayParam featCols(){
              return featCols;
          }
          public Param<String> outputCol(){
              return outputCol;
          }
          pubhttp://www.devze.comlic Param<String> limitRn(){
              return limitRn;
          }
      

      如果研究spark ml的源码不难发现,官方的各个Transformer子类都实现org.apache.spark.ml.param.shared包下HasInputCols、HasOutputCol等接口,这些接口下都有一个满足以上3要素(public、无参、Param类型返回)的方法,用途与我们上面为每个Param定义的方法类似。

      2.2 实现抽象方法

      接下来,我们需要实现从Transformer类中继承来各个抽象方法,包括transform、transformSchema、copy、uid。

      transform方法中包含的是整个数据处理的逻辑,该方法定义的原则是不改变原数据的字段与条数,只在原数据基础上新增字段。下面实现的transform方法用于本例中最近几次下单 特征的提取。

      @Override
      public Dataset<Row> transform(Dataset<?> dataset) {
          Dataset<Row> df = dataset.toDF();
          String idColName = getIdCol();
          String rnColName = getRnCol();
          String[] featCols = getFeatCols();
          String outputCol = getOutputCol();
          Integer limitRnValue = getLimitRn();
          // 获取原始数据中rn字段下最大值
          Integer maxRN = (Integer) df.groupBy().max(rnColName).first().get(0);
          // 限制设置的limitRN不得大于maxRn。
          if(limitRnValue>maxRN){
              throw new ValueException(String.format( "the value of limitRn %d is larger than max value of rnCol %d, choose a smaller limitRn instead",
                      limitRnValue,maxRN));
          }
      //        定义一个备用的Dataset df_c
          Dataset<Row> df_c = df.select(idColName,rnColName);
          df_c = df_c.withColumnRenamed(idColName,idColName+"_c")
                  .withColumnRenamed(rnColName, rnColName+"_c");
      //        将df与df_c进行连接,连接条件df.idCol==df_c.idCol && df.rnCol<=df_c.rnCol
          Column joinExpr = df.col(idColName).equalTo(df_c.col(idColName+"_c")).and(df.col(rnColName).leq(df_c.col(rnColName+"_c")));
          Dataset<Row> joinedDf = df_c.join(df,joinExpr,"left");
      //        打上一列rnCol_p = df_c.rnCol - df.rnCol 最近购买次序列,当前次的值0
          String pivotRnColName = rnColName+"_p";
          joinedDf = joinedDf.withColumn(pivotRnColName,joinedDf.col(rnColName+"_c").minus(joinedDf.col(rnColName)));
      //        表格透视前的准备工作,定义一些map和array,用于记录表格透视计算规则和透视后的列名
          Map<String, String> featAggMap = new HashMap<>();
          Integer featNums = featCols.length;
          String[] pivotColNames = new String[maxRN*featNums-1];
          String firstPivotColName = "0_min"+"("+featCols[0]+")";
          int n = 0;
          for(int i=0; i<maxRN; i++){
              for(String feat:featCols){
                  if(i==0){
                      featAggMap.put(feat,"min");
                  }
                  if(n>0){
                      pivotColNames[n - 1] = String.valueOf(i) + "_min" + "(" + feat + ")";
                  }
                  n++;
              }
          }
      //        对表格进行透视、特征字段合并、得到outputCol
          Dataset<Row> transformed = joinedDf.groupBy(joinedDf.col(idColName+"_c"), joinedDf.col(rnColName+"_c")).pivot(pivotRnColName).agg(featAggMap);
          transformed = transformed.withColumn(outputCol, functions.array(firstPivotColName,pivotColNames));
      //        将outputCol连到原df上,保证经过transform后的df只在原数据基础上新增一列
          Column joinExprT = df.col(idColName).equalTo(transformed.col(idColName+"_c")).and(df.col(rnColName).equalTo(transformed.col(rnColName+"_c")));
          df = df.join(transformed.select(idColName+"_c",rnColName+"_c",outputCol),joinExprT,"left").drop(idColName+"_c",rnColName+"_c");
          return df;
      }
      

      实现transformSchema方法,通常在其中定义输入数据类型判断的逻辑,并返回一个与transform方法输出的Dataset相对应的schema:

          @Override
          /**
           * transformSchema中定义输入数据类型判断的逻辑,并返回一个与transform方法输出的Dataset相对应的schema
           */
          public StructType transformSchema(StructType schema) {
              HashSet<String> featColSet = new HashSet<String>(Arrays.asList(getFeatCols()));
              StructField[] fields = schema.fields();
              for(StructField field:fields){
                  if(featColSet.contains(field.name())){
                      if(!field.dataType().sameType(DoubleType)&&!field.dataType().sameType(IntegerType)){
                          throw new TypeConstraintException(String.format("featCol DataType need DoubleType or IntegerType, " +
                                  "but column %s is a %s." ,field.name(),field.dataType().typeName()));
                      }
                  }
              }
              StructType addedSchema = schema.add(getOutputCol(), new VectorUDT(), true);
              return addedSchema;
          }
      

      实现uid与copy方法:

          @Override
          public Transformer copy(ParamMap extra) {
              return this.<SeqAssembler>defaultCopy(extra);
          }
          @Override
          public String uid() {
              return uid;
          }
      

      最后,我们需要实现和定义读写方法,其中用于写的两个方法write()、save()通过实现 DefaultParamsWritable接口来实现;用于读的两个方法read()、load()直接自定义实现,需要声明为静态方法。

          /**
           * 调用DefaultParamsWriter和DefaultParamsReader实现write()/save(), read()/load()方法.
           */
          @Override
          public MLWriter write() {
              MLWriter defaultParamsWriter = new DefaultParamsWriter(this);
              return defaultParamsWriter;
          }
          @Override
          public void save(String path) throws IOException {
              write().saveImpl(path);
          }
          public static MLReader read() {
              MLReader defaultParamsReader = new DefaultParamsReader();
              return defaultParamsReader;
          }
          public static SeqAssembler load(String path) {
              return (SeqAssembler) read().load(path);
          }
      

      最终完整的SeqAssembler类如下:

      import jdk.nashorn.internal.runtime.regexp.joni.exception.ValueException;
      import org.apache.spark.ml.Transformer;
      import org.apache.spark.ml.linalg.VectorUDT;
      import org.apache.spark.ml.param.Param;
      import org.apache.spark.ml.param.ParamMap;
      import org.apache.spark.ml.param.StringArrayParam;
      import org.apache.spark.ml.util.*;
      import org.apache.spark.sql.Column;
      import org.apache.spark.sql.Dataset;
      import org.apache.spark.sql.Row;
      import org.apache.spark.sql.functions;
      import static org.apache.spark.sql.types.DataTypes.DoubleType;
      import static org.apache.spark.sql.types.DataTypes.IntegerType;
      import org.apache.spark.sql.types.StructField;
      import org.apache.spark.sql.types.StructType;
      import scala.actors.threadpool.Arrays;
      import javax.XML.bind.TypeConstraintException;
      import java.io.IOException;
      import java.util.HashMap;
      import java.util.HashSet;
      import java.util.Map;
      /**
       * @author wangjiahui
       * @create 2021-03-12-21:00
       */
      public class SeqAssembler extends Transformer implements DefaultParamsWritable {
          private String uid;
          private Param<String> idCol;
          private Param<String> rnCol;
          private StringArrayParam featCols;
          private Param<String> outputCol;
          private Param<String> limitRn;
          /**
           * 定义一个辅助Param初始化的方法,在构造器中对各Param成员变量进行初始化
           */
          public void initParam(){
              idCol = new Param<String>(this,"idCol","Column name for id");
              rnCol = new Param<String>(this,"rnCol","Column name for sequential rn");
              featCols = new StringArrayParam(this,"featCols","Column names of features");
              outputCol = new Param<String>(this,"outputCol","Column name of output");
              limitRn = new Param<String>(this,"limitRn","Column name of limitRn");
          }
          public SeqAssembler() {
              uid = Identifiable$.MODULE$.randomUID("SeqAssembler"); //uid初始化在Param类型成员变量前
              initParam();
          }
          public SeqAssembler(String value){
              uid = value; //uid初始化在Param类型成员变量前
              initParam();
          }
          /**
           * 需要为每个Param定义一个public方法, 因为Params会延迟加载并生成一个Param[] params数组,
           * params的生成方式是通过反射扫描一遍public方法, 将作为返回值的Param成员变量维护进params数组中。
           *
           * org.apache.spark.ml.param.shared下的所有接口都有一个以Param类型为返回值的方法,也是为了方便子类
           * 通过实现org.apache.spark.ml.param.shared接口,达到将Param成员变量维护进params数组的目的。
           */
          public Param<String> idCol(){
              return idCol;
          }
          public Param<String> rnCol(){
              return rnCol;
          }
          public StringArrayParam featCols(){
              return featCols;
          }
          public Param<String> outputCol(){
              return outputCol;
          }
          public Param<String> limitRn(){
              return limitRn;
          }
          /**
           * 定义对Param成员变量进行操作的get/set方法, 通过调用PipelineStage类中已实现的Params的$()、set()方法对
           * Param成员变量进行操作。
           * $()、set()对Param进行操作前会调用shouldOwn(),验证被操作的Param成员变量是否已经被维护到params数组中
           */
          public String getIdCol() {
              return this.$(idCol);
          }
          public SeqAssembler setIdCol(String value) {
              return (SeqAssembler) this.<String>set(idCol,value);
          }
          public String getRnCol() {
              return this.$(rnCol);
          }
          public SeqAssembler setRnCol(String value) {
              return (SeqAssembler) this.<String>set(rnCol,value);
          }
          public String[] getFeatCols() {
              return this.$(featCols);
          }
          public SeqAssembler setFeatCols(String[] value) {
              return (SeqAssembler) this.<String[]>set(featCols,value);
          }
          public String getOutputCol() {
              return this.$(outputCol);
          }
          public SeqAssembler setOutputCol(String value) {
              return (SeqAssembler) this.<String>set(outputCol,value);
          }
          public Integer getLimitRn() {
              return Integer.parseInt(this.$(limitRn));
          }
          public SeqAssembler setLimitRn(Integer value) {
              return (SeqAssembler) this.<String>set(limitRn,value.toString());
          }
          @Override
          public Dataset<Row> transform(Dataset<?> dataset) {
              Dataset<Row> df = dataset.toDF();
              transformSchema(dataset.schema());
              String idColName = getIdCol();
              String rnColName = getRnCol();
              String[] featCols = getFeatCols();
              String outputCol = getOutputCol();
              Integer limitRnValue = getLimitRn();
              // 获取原始数据中rn字段下最大值
              Integer maxRN = (Integer) df.groupBy().max(rnColName).first().get(0);
              // 限制设置的limitRN不得大于maxRn。
              if(limitRnValue>maxRN){
                  throw new ValueException(String.format( "the value of limitRn %d is larger than max value of rnCol %d, choose a smaller limitRn instead",
                          limitRnValue,maxRN));
              }
      //        定义一个备用的Dataset df_c
              Dataset<Row> df_c = df.select(idColName,rnColName);
              df_c = df_c.withColumnRenamed(idColName,idColName+"_c")
                      .withColumnRenamed(rnColName, rnColName+"_c");
      //        将df与df_c进行连接,连接条件df.idCol==df_c.idCol && df.rnCol<=df_c.rnCol
              Column joinExpr = df.col(idColName).equalTo(df_c.col(idColName+"_c")).and(df.col(rnColName).leq(df_c.col(rnColName+"_c")));
              Dataset<Row> joinedDf = df_c.join(df,joinExpr,"left");
      //        打上一列rnCol_p = df_c.rnCol - df.rnCol 最近购买次序列,当前次的值0
              String pivotRnColName = rnColName+"_p";
              joinedDf = joinedDf.withColumn(pivotRnColName,joinedDf.col(rnColName+"_c").minus(joinedDf.col(rnColName)));
      //        表格透视前的准备工作,定义一些map和array,用于记录表格透视计算规则和透视后的列名
              Map<String, String> featAggMap = new HashMap<>();
              Integer featNums = featCols.length;
              String[] pivotColNames = new String[maxRN*featNums-1];
              String firstPivotColName = "0_min"+"("+featCols[0]+")";
              int n = 0;
              for(int i=0; i<maxRN; i++){
                  for(String feat:featCols){
                      if(i==0){
                          featAggMap.put(feat,"min");
                      }
                      if(n>0){
                          pivotColNames[n - 1] = String.valueOf(i) + "_min" + "(" + feat + ")";
                      }
                      n++;
                  }
              }
      //        对表格进行透视、特征字段合并、得到outputCol
              Dataset<Row> transformed = joinedDf.groupBy(joinedDf.col(idColName+"_c"), joinedDf.col(rnColName+"_c")).pivot(pivotRnColName).agg(featAggMap);
              transformed = transformed.withColumn(outputCol, functions.array(firstPivotColName,pivotColNames));
      //        将outputCol连到原df上,保证经过transform后的df只在原数据基础上新增一列
              Column joinExprT = df.col(idColName).equalTo(transformed.col(idColName+"_c")).and(df.col(rnColName).equalTo(transformed.col(rnColName+"_c")));
              df = df.join(transformed.select(idColName+"_c",rnColName+"_c",outputCol),joinExprT,"left").drop(idColName+"_c",rnColName+"_c");
              return df;
          }
          @Override
          /**
           * transformSchema中定义输入数据类型判断的逻辑,并返回一个与transform方法输出的Dataset相对应的schema
           */
          public StructType transformSchema(StructType schema) {
              HashSet<String> featColSet = new HashSet<String>(Arrays.asList(getFeatCols()));
              StructField[] fields = schema.fields();
              for(StructField field:fields){
                  if(featColSet.contains(field.name())){
                      if(!field.dataType().sameType(DoubleType)&&!field.dataType().sameType(IntegerType)){
                          throw new TypeConstraintException(String.format("featCol DataType need DoubleType or IntegerType, " +
                                  "but column %s is a %s.",field.name(),field.dataType().typeName()));
                      }
                  }
              }
              StructType addedSchema = schema.add(getOutputCol(), new VectorUDT(), true);
              return addedSchema;
          }
          @Override
          public Transformer copy(ParamMap extra) {
              return this.<SeqAssembler>defaultCopy(extra);
          }
          @Override
          public String uid() {
              return uid;
          }
          /**
           * 调用DefaultParamsWriter和DefaultParamsReader实现write()/save(), read()/load()方法.
           */
          @Override
          public MLWriter write() {
              MLWriter defaultParamsWriter = new DefaultParamsWriter(this);
              return defaultParamsWriter;
          }
          @Override
          public void save(String path) throws IOException {
              write().saveImpl(path);
          }
          public static MLReader read() {
              MLReader defaultParamsReader = new DefaultParamsReader();
              return defaultParamsReader;
          }
          public static SeqAssembler load(String path) {
              return (SeqAssembler) read().load(path);
          }
      }
      

      单元测试代码如下:

      import org.apache.spark.ml.Pipeline;
      import org.apache.jsspark.ml.PipelineStage;
      import org.apache.spark.ml.Transformer;
      import org.apache.spark.sql.Dataset;
      import org.apache.spark.sql.Row;
      import org.apache.spark.sql.RowFactory;
      import org.apache.spark.sql.SparkSession;
      import org.apache.spark.sql.types.DataTypes;
      import org.apache.spark.sql.types.StructField;
      import org.apache.spark.sql.types.StructType;
      import org.junit.Test;
      import java.io.IOException;
      import java.util.ArrayList;
      import java.util.List;
      import static org.apache.spark.sql.types.DataTypes.*;
      import static org.apache.spark.sql.types.DataTypes.IntegerType;
      /**
       * @author wangjiahui
       * @create 2023-01-25-20:51
       */
      public class TestClient {
          @Test
          public void testSeqAssembler(){
              // 配置自己的SparkSession
               SparkSession spark = LocalSparkSession.getSpark();
              // 定义一个测试用的DataSet
              List<Row> rows = new ArrayList<>();
              Row row1 = RowFactory.create("a",1, 2.1, 1, 1);
              Row row2 = RowFactory.create("b",1, 2.0, 3, 2);
              Row row3 = RowFactory.create("b",2, 2.3, 4, 1);
              Row row4 = RowFactory.create("c",1, 3.1, 3, 3);
              Row row5 = RowFactory.create("c",2, 1.5, 3, 7);
              Row row6 = RowFactory.create("c",3, 4.2, 4, 2);
         js     rows.add(row1);
              rows.add(row2);
              rows.add(row3);
              rows.add(row4);
              rows.add(row5);
              rows.add(row6);
              List<StructField> fields = new ArrayList<StructField>();
              StructField col1 = DataTypes.createStructField("user_id", StringType, true);
              StructField col2 = DataTypes.createStructField("buy_rn", IntegerType, true);
              StructField col3 = DataTypes.createStructField("feat_1", DoubleType, true);
              StructField col4 = DataTypes.createStructField("feat_2", IntegerType, true);
              StructField col5= DataTypes.createStructField("feat_3", IntegerType, true);
              fields.add(col1);
              fields.add(col2);
              fields.add(col3);
              fields.add(col4);
              fields.add(col5);
              StructType schema = DataTypes.createStructType(fields);
              Dataset dfr = spark.createDataFrame(rows,schema);
              Dataset<Row> df = dfr.toDF();
              df = df.persist();
              System.out.println("in:");
              df.show();
              df.printSchema();
              // 定义两个seqAssembler
              String[] featCols = new String[] {"feat_1", "feat_2", "feat_3"};
              SeqAssembler seqAssembler1 = new SeqAssembler()
                      .setIdCol("user_id")
                      .setRnCol("buy_rn")
                      .setLimitRn(3)
                      .setFeatCols(featCols)
                      .setOutputCol("features");
              SeqAssembler seqAssembler2 = new SeqAssembler()
                      .setIdCol("user_id")
                      .setRnCol("buy_rn")
                      .setLimitRn(2)
                      .setFeatCols(featCols)
                      .setOutputCol("features");
      //        Dataset<Row> transformed = seqAssembler.transform(df);
              // 定义pipeline
              List<PipelineStage> pipelineStages = new ArrayList<>();
              pipelineStages.add(seqAssembler1);
              pipelineStages.add(seqAssembler2);
              Pipeline pipeline = new Pipeline();
              pipeline.setStages(pipelineStages.toArray(new PipelineStage[pipelineStages.size()]));
              // 写入
              try {
                  pipeline.write().overwrite().save("oss://<自己的路径>");
              } catch (IOException e) {
                  e.printStackTrace();
              }
              // 读取
              Pipeline loadedPipeline = Pipeline.load("oss://<自己的路径>");
              Transformer seqAssemblerLoad = (Transformer) loadedPipeline.getStages()[0];
              // 使用
              Dataset<Row> transformed = seqAssemblerLoad.transform(df);
              System.out.println(开发者_JS培训"out: ");
              transformed.show(false);
              transformed.printSchema();
              spark.close();
          }
      }
      

      3. Pipeline的存储文件

      在oss/hdfs上找到上面单元测试中pipeline的存储路径,并将存储文件夹下载到本地,pipeline存储文件夹中包含metadata, stages两个目录,metadata中存放的是pipeline的信息,包括pipeline的uid、对应stage的uid等。pipeline metadata文件如下:

      Java开发Spark应用程序自定义PipeLineStage详解

      stages目录中存放的是我们定义的两个SeqAssembler的metadata,SeqAssembler的metadata中的文件内容与pipeline的metadata中的文件内容类似,记录了SeqAssembler相关信息与Param数据:

      Java开发Spark应用程序自定义PipeLineStage详解

      小结

      在这篇文章中我们介绍了使用java开发spark如何自定义PipelineStage,并用一个SeqAssembler的例子对自定义PipelineStage中的一些注意事项进行了说明。相信这篇文章对不少java的spark开发者有一定的帮助。

      以上就是Java开发Spark应用程序自定义PipeLineStage详解的详细内容,更多关于Java Spark自定义PipeLineStage的资料请关注我们其它相关文章!

      0

      上一篇:

      下一篇:

      精彩评论

      暂无评论...
      验证码 换一张
      取 消

      最新开发

      开发排行榜