Как pyspark функция mapPartitions работать?

поэтому я пытаюсь изучить Spark с помощью Python (Pyspark). Я хочу знать, как функция mapPartitions работа. Это то, что он принимает и что дает. Я не смог найти ни одного подходящего примера из интернета. Допустим, у меня есть объект RDD, содержащий списки, такие как ниже.

[ [1, 2, 3], [3, 2, 4], [5, 2, 7] ] 

и я хочу удалить элемент 2 из всех списков, как бы я этого достиг, используя mapPartitions.

3 ответов


mapPartition следует рассматривать как операцию карты над разделами, а не над элементами раздела. Его вход-это набор текущих разделов, его выход будет другим набором разделов.

функция, которую вы передаете map, должна принимать отдельный элемент вашего RDD

функция, которую вы передаете mapPartition, должна принимать итерацию вашего типа RDD и возвращать и итерацию другого или того же типа.

в вашем случае вы, вероятно, просто хотите сделать что-то вроде

def filterOut2(line):
    return [x for x in line if x != 2]

filtered_lists = data.map(filterOut2)

Если вы хотите использовать mapPartition, это будет

def filterOut2FromPartion(list_of_lists):
  final_iterator = []
  for sub_list in list_of_lists:
    final_iterator.append( [x for x in sub_list if x != 2])
  return iter(final_iterator)

filtered_lists = data.mapPartition(filterOut2FromPartion)

проще использовать mapPartitions с функцией генератора, используя yield синтаксис:

def filter_out_2(partition):
    for element in partition:
        if element != 2:
            yield element

filtered_lists = data.mapPartitions(filter_out_2)

нужен последний Iter

def filter_out_2(partition):
for element in partition:
    sec_iterator = []
    for i in element:
        if i!= 2:
            sec_iterator.append(i)
    yield sec_iterator

filtered_lists = data.mapPartitions(filter_out_2)
for i in filtered_lists.collect(): print(i)