Нажатие сортировки Radix (и python) до ее пределов

я был очень разочарован многими реализациями python radix сортировать там в интернете.

они последовательно используют радиус 10 и получают цифры чисел, которые они перебирают, деля на степень 10 или принимая log10 числа. Это невероятно неэффективно, так как log10 не является особенно быстрой операцией по сравнению с Бит-сдвигом, который почти в 100 раз быстрее!

гораздо более эффективная реализация использует radix 256 и сортирует число байт за байтом. Это позволяет выполнять все "получение байтов" с помощью смехотворно быстрых битовых операторов. К сожалению, кажется, что абсолютно никто не реализовал сортировку radix в python, которая использует битовые операторы вместо логарифмов.

Итак, я взял дело в свои руки и придумал этого зверя, который работает примерно с половиной скорости сортировки на небольших массивах и работает почти так же быстро на больших (например,len вокруг 10,000,000):

import itertools

def radix_sort(unsorted):
    "Fast implementation of radix sort for any size num."
    maximum, minimum = max(unsorted), min(unsorted)

    max_bits = maximum.bit_length()
    highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1

    min_bits = minimum.bit_length()
    lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1

    sorted_list = unsorted
    for offset in xrange(lowest_byte, highest_byte):
        sorted_list = radix_sort_offset(sorted_list, offset)

    return sorted_list

def radix_sort_offset(unsorted, offset):
    "Helper function for radix sort, sorts each offset."
    byte_check = (0xFF << offset*8)

    buckets = [[] for _ in xrange(256)]

    for num in unsorted:
        byte_at_offset = (num & byte_check) >> offset*8
        buckets[byte_at_offset].append(num)

    return list(itertools.chain.from_iterable(buckets))

эта версия сортировки radix работает, находя, какие байты она должна сортировать (если вы передадите ей только целые числа ниже 256, она будет сортировать только один байт и т. д.) затем сортировка каждого байта из LSB вверх, сбрасывая их в ведра, а затем просто связывая ведра вместе. Повторите это для каждого байта, который нужно отсортировать, и у вас есть хороший отсортированный массив за O(n) время.

однако это не так быстро, как могло бы быть, и я хотел бы сделать это быстрее прежде чем я напишу об этом как о лучшей разновидности radix, чем все остальные виды radix.

под управлением cProfile на это говорит мне, что много времени тратится на append метод для списков, что заставляет меня думать, что этот блок:

    for num in unsorted:
        byte_at_offset = (num & byte_check) >> offset*8
        buckets[byte_at_offset].append(num)

на radix_sort_offset ест много времени. Это также блок, который, если вы действительно посмотрите на него, делает 90% работы для всего сорта. Этот код выглядит так, как будто это может быть numpy - ized, что, я думаю, приведет к довольно повышение производительности. К сожалению, я не очень хорошо с numpyболее сложные функции, поэтому не смогли понять это. Помощь будет очень признательна.

в настоящее время я использую itertools.chain.from_iterable разогнуть buckets, но если у кого-то есть более быстрое предложение, я уверен, что это также поможет.

первоначально у меня был get_byte функция, которая вернула nTh байт числа, но встраивание кода дало мне огромное ускорение скорости, поэтому я сделал это.

любой также приветствуются другие замечания по реализации или способам выжать больше производительности. Я хочу услышать все, что у тебя есть.

3 ответов


вы уже поняли, что

for num in unsorted:
    byte_at_offset = (num & byte_check) >> offset*8
    buckets[byte_at_offset].append(num)

где большую часть времени идет - хорошо ;-)

есть два стандартных трюка для ускорения такого рода вещей, оба из которых связаны с перемещением инвариантов из циклов:

  1. вычислить "смещение*8" вне цикла. Сохраните его в локальной переменной. Сохраните умножение на итерацию.
  2. добавить bucketappender = [bucket.append for bucket in buckets] вне цикла. Сохраняет поиск метода на итерацию.

объединить их, и петля выглядит так:

for num in unsorted:
    bucketappender[(num & byte_check) >> ofs8](num)

сворачивание его в один оператор также сохраняет пару локальных опкодов vrbl store/fetch на итерацию.

но, на более высоком уровне, стандартный способ ускорить сортировку radix-использовать больший radix. Что магического около 256? Ничего, кроме того, что это удобно для бит-сдвига. Как и 512, 1024, 2048 ... это классический компромисс времени и пространства.

PS: Для очень длинных чисел,

(num >> offset*8) & 0xff

будет работать быстрее. Это потому что ваш num & byte_check занимает время, пропорциональное log(num) - он обычно должен создавать целое число размером примерно с num.


вы можете просто использовать одну из существующих реализаций C или C++, например например, integer_sort с импульс.Вроде или u4_sort С usort. Удивительно легко вызвать собственный код C или C++ из Python, см. как сортировать массив целых чисел быстрее, чем quicksort?

Я полностью понимаю ваше разочарование. Хотя прошло уже более 2 лет, numpy по-прежнему не имеет сортировки radix. Я дам знать разработчикам NumPy что они могут просто захватить одну из существующих реализаций; лицензирование не должно быть проблемой.


Это старый поток, но я наткнулся на это, когда искал radix сортировать массив положительных целых чисел. Я пытался увидеть, могу ли я сделать что-то лучше, чем уже злобно быстрый timsort (шляпы вам снова, Тим Питерс), который реализует встроенный python сортируется и сортируется! Либо я не понимаю некоторые аспекты вышеуказанного кода, либо, если я это делаю, код, представленный выше, имеет некоторые проблемы IMHO.

  1. он сортирует только байты, начиная с самого высокого байта самый маленький элемент и заканчивающийся самым высоким байтом самого большого элемента. Это может быть нормально в некоторых случаях специальных данных. Но в целом подход не позволяет дифференцировать элементы, которые отличаются из-за более низких битов. Например:

    arr=[65535,65534]
    radix_sort(arr)
    

    производит неправильный выход:

    [65535, 65534]
    
  2. диапазон, используемый для цикла над вспомогательной функцией, неверен. Я имею в виду, что если lowest_byte и highest_byte одинаковы, выполнение вспомогательной функции пропустить. Кстати, мне пришлось изменить xrange на диапазон в 2 местах.

  3. С изменениями для решения вышеуказанных 2 пунктов, я получил его для работы. Но это занимает 10-20 раз время сборки python сортируется или сортируется! Я знаю, что timsort очень эффективен и использует преимущества уже отсортированных запусков в данных. Но я пытался понять, могу ли я использовать предыдущие знания о том, что все мои данные являются целыми положительными числами, для некоторой пользы при сортировке. Почему radix делает так плохо по сравнению с timsort? Размеры массива, которые я использовал, находятся в порядке 80K элементов. Это потому, что реализация timsort в дополнение к своей алгоритмической эффективности имеет и другие эффективности, вытекающие из возможного использования библиотек низкого уровня? Или я что-то упускаю? Измененный код, который я использовал ниже:

    import itertools
    
    def radix_sort(unsorted):
        "Fast implementation of radix sort for any size num."
        maximum, minimum = max(unsorted), min(unsorted)
    
        max_bits = maximum.bit_length()
        highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
    
    #    min_bits = minimum.bit_length()
    #    lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
    
        sorted_list = unsorted
    #    xrange changed to range, lowest_byte deleted from the arguments
        for offset in range(highest_byte):
            sorted_list = radix_sort_offset(sorted_list, offset)
    
        return sorted_list
    
    def radix_sort_offset(unsorted, offset):
        "Helper function for radix sort, sorts each offset."
        byte_check = (0xFF << offset*8)
    
    #    xrange changed to range
        buckets = [[] for _ in range(256)]
    
        for num in unsorted:
            byte_at_offset = (num & byte_check) >> offset*8
            buckets[byte_at_offset].append(num)
    
        return list(itertools.chain.from_iterable(buckets))