Парный график с тепловыми картами (возможно, логарифмическими)?

Как создать парный график в Python, как показано ниже: enter image description here но с тепловые карты вместо точек (или вместо графика "шестнадцатеричного Бина")? Наличие возможности вместо отображения логарифмических отсчетов тепловой карты было бы дополнительным бонусом. (Гистограммы по диагонали были бы в полном порядке.)

под "тепловой картой" я имею в виду 2D-гистограмму отсчетов, отображаемую как Сиборн или Википедии тепло карты:

enter image description here

использование панд, seaborn или matplotlib было бы здорово (возможно, сюжет.ly).

Я пробовал наивные вариации следующего, но безрезультатно:

pairplot = sns.PairGrid(data)  # sns = seaborn
pairplot.map_offdiag(sns.kdeplot)  # Off-diagnoal heat map wanted instead!
pairplot.map_diag(plt.hist)  # plt = matplotlib.pyplot

(выше используется оценка плотности ядра, которую я не хочу; шестнадцатеричная сетка также может быть получена с пандами, но я ищу вместо этого "квадратную" 2D-гистограмму и Matplotlib hist2d() не работает).

2 ответов


ключом к вашему ответу является функция matplotlib plt.hist2d, который рассчитывает графики в прямоугольных ячейках, используя цветовую шкалу ("тепловая карта"). Его API почти совместим с PairGrid, но не совсем, потому что он не знает, как справиться с color= kwarg. Это легко решить, написав функцию тонкой оболочки. Также, если вы хотите, чтобы цветовая карта логарифмически отображала подсчеты, это легко сделать с помощью matplotlib LogNorm:

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
sns.set_style("white")
iris = sns.load_dataset("iris")    

g = sns.PairGrid(iris)
g.map_diag(plt.hist, bins=20)

def pairgrid_heatmap(x, y, **kws):
    cmap = sns.light_palette(kws.pop("color"), as_cmap=True)
    plt.hist2d(x, y, cmap=cmap, cmin=1, **kws)

g.map_offdiag(pairgrid_heatmap, bins=20, norm=LogNorm())

enter image description here


Приготовление:

%matplotlib inline #for jupyter notebook

import matplotlib.pyplot as plt
import seaborn as sns
iris = sns.load_dataset("iris")

ответ:

g = sns.PairGrid(iris)
g = g.map_upper(plt.scatter,marker='+')
g = g.map_lower(sns.kdeplot, cmap="hot",shade=True)
g = g.map_diag(sns.kdeplot, shade=True)
sns.plt.show()

enter image description here

ответ:

g = sns.PairGrid(iris)
g = g.map_upper(plt.scatter)
g = g.map_lower(sns.kdeplot, cmap="hot",shade=True)
g = g.map_diag(plt.hist)
sns.plt.show()

enter image description here