Медиана / квантили в пределах группы PySpark

Я хотел бы рассчитать групповые квантили на фрейме данных Spark (используя PySpark). Либо приблизительный, либо точный результат. Я предпочитаю решение, которое я могу использовать в контексте groupBy / agg, чтобы я мог смешивать его с другими агрегатными функциями PySpark. Если это невозможно по какой-то причине, другой подход также был бы прекрасным.

этот вопрос это связано, но не указывает, как использовать approxQuantile в совокупности функция.

у меня также есть доступ к percentile_approx Hive UDF, но я не знаю, как использовать его в качестве агрегатной функции.

для конкретности предположим, что у меня есть следующий фрейм данных:

from pyspark import SparkContext
import pyspark.sql.functions as f

sc = SparkContext()    

df = sc.parallelize([
    ['A', 1],
    ['A', 2],
    ['A', 3],
    ['B', 4],
    ['B', 5],
    ['B', 6],
]).toDF(('grp', 'val'))

df_grp = df.groupBy('grp').agg(f.magic_percentile('val', 0.5).alias('med_val'))
df_grp.show()

ожидаемый результат:

+----+-------+
| grp|med_val|
+----+-------+
|   A|      2|
|   B|      5|
+----+-------+

3 ответов


Я думаю, вам это больше не нужно. Но оставлю его здесь для будущих поколений (то есть меня на следующей неделе, когда я забуду).

from pyspark.sql import Window
import pyspark.sql.functions as F

grp_window = Window.partitionBy('grp')
magic_percentile = F.expr('percentile_approx(val, 0.5)')

df.withColumn('med_val', magic_percentile.over(grp_window))

или для решения именно вашего вопроса, это тоже работает:

df.groupBy('gpr').agg(magic_percentile.alias('med_val'))

и в качестве бонуса вы можете передать массив процентилей:

quantiles = F.expr('percentile_approx(val, array(0.25, 0.5, 0.75))')

и вы получите список.


Так как у вас есть доступ к percentile_approx, одним из простых решений было бы использовать его в :

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

df.registerTempTable("df")
df2 = sqlContext.sql("select grp, percentile_approx(val, 0.5) as med_val from df group by grp")

к сожалению, и насколько мне известно, кажется, что это невозможно сделать с помощью" чистых " команд PySpark (решение Shaido предоставляет обходной путь с SQL), и причина очень элементарна: в отличие от других агрегатных функций, таких как mean, approxQuantile не возвращает Column тип, но список.

давайте посмотрим быстрый пример с вашими данными образца:

spark.version
# u'2.2.0'

import pyspark.sql.functions as func
from pyspark.sql import DataFrameStatFunctions as statFunc

# aggregate with mean works OK:
df_grp_mean = df.groupBy('grp').agg(func.mean(df['val']).alias('mean_val'))
df_grp_mean.show()
# +---+--------+ 
# |grp|mean_val|
# +---+--------+
# |  B|     5.0|
# |  A|     2.0|
# +---+--------+

# try aggregating by median:
df_grp_med = df.groupBy('grp').agg(statFunc(df).approxQuantile('val', [0.5], 0.1))
# AssertionError: all exprs should be Column

# mean aggregation is a Column, but median is a list:

type(func.mean(df['val']))
# pyspark.sql.column.Column

type(statFunc(df).approxQuantile('val', [0.5], 0.1))
# list

Я сомневаюсь, что оконный подход сделает любая разница, поскольку, как я уже сказал, основная причина очень элементарна.

см. также мой ответ здесь для получения дополнительной информации.