如何在Spark RDD(Java)中通过索引获取元素

我知道方法rdd.first(),它给了我RDD中的第一个元素。

还有方法rdd.take(num)这给了我第一个“num”元素。

但是没有可能通过索引获得元素吗?

谢谢。

这应该可以通过首先索引RDD来实现。 转换zipWithIndex提供了一个稳定的索引,按原始顺序对每个元素进行编号。

给定: rdd = (a,b,c)

 val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2)) 

要按索引查找元素,此表单无用。 首先,我们需要使用索引作为关键:

 val indexKey = withIndex.map{case (k,v) => (v,k)} //((0,a),(1,b),(2,c)) 

现在,可以使用PairRDD中的lookup操作按键查找元素:

 val b = indexKey.lookup(1) // Array(b) 

如果您希望在同一个RDD上经常使用lookup ,我建议缓存indexKey RDD以提高性能。

如何使用Java API执行此操作是一个留给读者的练习。

我试过这个类来索引一个项目。 首先,当您构造new IndexedFetcher(rdd, itemClass) ,它会计算RDD的每个分区中的元素数。 然后,当您调用indexedFetcher.get(n) ,它仅在包含该索引的分区上运行作业。

请注意,我需要使用Java 1.7而不是1.8来编译它; 从Spark 1.1.0开始,com.esotericsoftware.reflectasm中捆绑的org.objectweb.asm无法读取Java 1.8类(当您尝试runJob Java 1.8函数时抛出IllegalStateException)。

 import java.io.Serializable; import org.apache.spark.SparkContext; import org.apache.spark.TaskContext; import org.apache.spark.rdd.RDD; import scala.reflect.ClassTag; public static class IndexedFetcher implements Serializable { private static final long serialVersionUID = 1L; public final RDD rdd; public Integer[] elementsPerPartitions; private Class clazz; public IndexedFetcher(RDD rdd, Class clazz){ this.rdd = rdd; this.clazz = clazz; SparkContext context = this.rdd.context(); ClassTag intClassTag = scala.reflect.ClassTag$.MODULE$.apply(Integer.class); elementsPerPartitions = (Integer[]) context.runJob(rdd, IndexedFetcher.countFunction(), intClassTag); } public static class IteratorCountFunction extends scala.runtime.AbstractFunction2, Integer> implements Serializable { private static final long serialVersionUID = 1L; @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator iterator) { int count = 0; while (iterator.hasNext()) { count++; iterator.next(); } return count; } } static  scala.Function2, Integer> countFunction() { scala.Function2, Integer> function = new IteratorCountFunction(); return function; } public E get(long index) { long remaining = index; long totalCount = 0; for (int partition = 0; partition < elementsPerPartitions.length; partition++) { if (remaining < elementsPerPartitions[partition]) { return getWithinPartition(partition, remaining); } remaining -= elementsPerPartitions[partition]; totalCount += elementsPerPartitions[partition]; } throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount)); } public static class FetchWithinPartitionFunction extends scala.runtime.AbstractFunction2, E> implements Serializable { private static final long serialVersionUID = 1L; private final long indexWithinPartition; public FetchWithinPartitionFunction(long indexWithinPartition) { this.indexWithinPartition = indexWithinPartition; } @Override public E apply(TaskContext taskContext, scala.collection.Iterator iterator) { int count = 0; while (iterator.hasNext()) { E element = iterator.next(); if (count == indexWithinPartition) return element; count++; } throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count)); } } public E getWithinPartition(int partition, long indexWithinPartition) { System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition); SparkContext context = rdd.context(); scala.Function2, E> function = new FetchWithinPartitionFunction(indexWithinPartition); scala.collection.Seq partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition}); ClassTag classTag = scala.reflect.ClassTag$.MODULE$.apply(this.clazz); E[] result = (E[]) context.runJob(rdd, function, partitions, true, classTag); return result[0]; } } 

我也坚持了一段时间,所以扩展Maasg的答案,但回答通过索引查找Java的一系列值(你需要在顶部定义4个变量):

 DataFrame df; SQLContext sqlContext; Long start; Long end; JavaPairRDD indexedRDD = df.toJavaRDD().zipWithIndex(); JavaRDD filteredRDD = indexedRDD.filter((Tuple2 v1) -> v1._2 >= start && v1._2 < end); DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema()); 

请记住,当您运行此代码时,您的群集将需要具有Java 8(正在使用lambda表达式)。

另外,zipWithIndex可能很贵!