从Apache Spark SQL中的用户定义聚合函数(UDAF)返回多个数组

我正在尝试使用Apache Spark SQL在Java中创建用户定义的聚合函数(UDAF),该函数在完成时返回多个数组。 我在网上搜索过,找不到任何关于如何做到这一点的例子或建议。

我能够返回单个数组,但无法弄清楚如何在evaluate()方法中以正确的格式获取数据以返回多个数组。

UDAF确实有效,因为我可以在evaluate()方法中打印出数组,我无法弄清楚如何将这些数组返回到调用代码(下面显示以供参考)。

UserDefinedAggregateFunction customUDAF = new CustomUDAF(); DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data"); 

我在下面包含了整个自定义UDAF类,但关键方法是dataType()和evaluate方法(),它们首先显示。

任何帮助或建议将不胜感激。 谢谢。

 public class CustomUDAF extends UserDefinedAggregateFunction { @Override public DataType dataType() { // TODO: Is this the correct way to return 2 arrays? return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false)) .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false)); } @Override public Object evaluate(Row buffer) { // Data conversion List longList = new ArrayList(buffer.getList(0)); List dataList = new ArrayList(buffer.getList(1)); // Processing of data (omitted) // TODO: How to get data into format needed to return 2 arrays? return dataList; } @Override public StructType inputSchema() { return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType); } @Override public StructType bufferSchema() { return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false)) .add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false)); } @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, new ArrayList()); buffer.update(1, new ArrayList()); } @Override public void update(MutableAggregationBuffer buffer, Row row) { ArrayList longList = new ArrayList(buffer.getList(0)); longList.add(row.getLong(0)); ArrayList dataList = new ArrayList(buffer.getList(1)); dataList.add(row.getDouble(1)); buffer.update(0, longList); buffer.update(1, dataList); } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { ArrayList longList = new ArrayList(buffer1.getList(0)); longList.addAll(buffer2.getList(0)); ArrayList dataList = new ArrayList(buffer1.getList(1)); dataList.addAll(buffer2.getList(1)); buffer1.update(0, longList); buffer1.update(1, dataList); } @Override public boolean deterministic() { return true; } } 

更新 :基于zero323的答案,我能够使用以下命令返回两个数组:

 return new Tuple2(longArray, dataArray); 

从中获取数据有点困难,但涉及将DataFrame解构为Java列表,然后将其构建回DataFrame。

据我所知,返回一个元组应该就够了。 在斯卡拉:

 import org.apache.spark.sql.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.functions.udf import org.apache.spark.sql.{Row, Column} object DummyUDAF extends UserDefinedAggregateFunction { def inputSchema = new StructType().add("x", StringType) def bufferSchema = new StructType() .add("buff", ArrayType(LongType)) .add("buff2", ArrayType(DoubleType)) def dataType = new StructType() .add("xs", ArrayType(LongType)) .add("ys", ArrayType(DoubleType)) def deterministic = true def initialize(buffer: MutableAggregationBuffer) = {} def update(buffer: MutableAggregationBuffer, input: Row) = {} def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {} def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0)) } val df = sc.parallelize(Seq(("a", 1), ("b", 2))).toDF("k", "v") df.select(DummyUDAF($"k")).show(1, false) // +---------------------------------------------------+ // |(DummyUDAF$(k),mode=Complete,isDistinct=false) | // +---------------------------------------------------+ // |[WrappedArray(1, 2, 3),WrappedArray(1.0, 2.0, 3.0)]| // +---------------------------------------------------+