Фильтровать по значению столбца, равному списку в spark

Я пытаюсь отфильтровать фрейм данных spark на основе того, равны ли значения в столбце списку. Я хотел бы сделать что-то вроде этого:

filtered_df = df.where(df.a == ['list','of' , 'stuff'])

здесь filtered_df содержит только строки, где значение filtered_df.a и ['list','of' , 'stuff'] и типа a is array (nullable = true).

2 ответов


обновление:

с текущими версиями вы можете использовать array литералов:

from pyspark.sql.functions import array, lit

df.where(df.a == array(*[lit(x) for x in ['list','of' , 'stuff']]))

оригинальный ответ:

ну, немного хакерский способ сделать это, который не требует пакетного задания Python, что-то вроде этого:

from pyspark.sql.functions import col, lit, size
from functools import reduce
from operator import and_

def array_equal(c, an_array):
    same_size = size(c) == len(an_array)  # Check if the same size
    # Check if all items equal
    same_items = reduce(
        and_, 
        (c.getItem(i) == an_array[i] for i in range(len(an_array)))
    )
    return and_(same_size, same_items)

быстрый тест:

df = sc.parallelize([
    (1, ['list','of' , 'stuff']),
    (2, ['foo', 'bar']),
    (3, ['foobar']),
    (4, ['list','of' , 'stuff', 'and', 'foo']),
    (5, ['a', 'list','of' , 'stuff']),
]).toDF(['id', 'a'])

df.where(array_equal(col('a'), ['list','of' , 'stuff'])).show()
## +---+-----------------+
## | id|                a|
## +---+-----------------+
## |  1|[list, of, stuff]|
## +---+-----------------+

вы можете создать udf. Например:

def test_in(x):
    return x == ['list','of' , 'stuff']

from pyspark.sql.functions import udf
f = udf(test_in, pyspark.sql.types.BooleanType())
filtered_df = df.where(f(df.a))