morikomorou’s blog

自分が学んだことなどの備忘録的なやつ

【python】多数の説明変数間の相関をインタラクティブに可視化する方法

はじめに

前回の記事で、seabornのheatmapやpairplotを使用して複数のデータ間の相関係数をヒートマップで表示したり、散布図で一覧表示したりする方法について触れました。

前回は4つのデータ間の関係性だけに絞って可視化しましたが、説明変数が多くなるとかなり図がごちゃごちゃしてしまうし、描画も遅いし、見づらいしで使い勝手はあまりよくありません。
一行で実装できるのですごく手軽ではあるんですが残念です。

見づらくてつらいつらいの図:

そこで、グラフをインタラクティブに操作できるようにして見やすくしようと思います。


どういうものを作るかというと、図の左側に変数間の相関係数行列をヒートマップで表示します。
ヒートマップをマウスでクリックすると、クリックした箇所の変数間の関係を散布図で図の右側に表示するというものです。

matplotlibとseabornだけで作れるので、サクッとデータ分析する際に便利かなと思います。
解説していきます。



コード解説

今までさんざんやってきたmatplotlibのマウスイベント処理を使えば容易に実装できます。

データの準備

とりあえずグラフ化用のデータは前回と同じものを用意します。
機械学習用ライブラリのscikit-learnで用意されているボストンの住宅価格のデータセットです。

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston

data = load_boston()
df = pd.DataFrame(data=data.data, columns=data.feature_names)
df['TARGET'] = pd.DataFrame(data=data.target)
print(df.head())

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT TARGET
0 0.00632 18.0 2.31 0.0 0.538 6.575 65.2 4.0900 1.0 296.0 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0.0 0.469 6.421 78.9 4.9671 2.0 242.0 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0.0 0.469 7.185 61.1 4.9671 2.0 242.0 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0.0 0.458 6.998 45.8 6.0622 3.0 222.0 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0.0 0.458 7.147 54.2 6.0622 3.0 222.0 18.7 396.90 5.33 36.2

今回は変数を絞らず14変数全部使って可視化しますのでこのまま使います。

前回同様、相関係数も求めておきます。

corr = df.corr()
print(corr)

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT TARGET
CRIM 1.000000 -0.200469 0.406583 -0.055892 0.420972 -0.219247 0.352734 -0.379670 0.625505 0.582764 0.289946 -0.385064 0.455621 -0.388305
ZN -0.200469 1.000000 -0.533828 -0.042697 -0.516604 0.311991 -0.569537 0.664408 -0.311948 -0.314563 -0.391679 0.175520 -0.412995 0.360445
INDUS 0.406583 -0.533828 1.000000 0.062938 0.763651 -0.391676 0.644779 -0.708027 0.595129 0.720760 0.383248 -0.356977 0.603800 -0.483725
CHAS -0.055892 -0.042697 0.062938 1.000000 0.091203 0.091251 0.086518 -0.099176 -0.007368 -0.035587 -0.121515 0.048788 -0.053929 0.175260
NOX 0.420972 -0.516604 0.763651 0.091203 1.000000 -0.302188 0.731470 -0.769230 0.611441 0.668023 0.188933 -0.380051 0.590879 -0.427321
RM -0.219247 0.311991 -0.391676 0.091251 -0.302188 1.000000 -0.240265 0.205246 -0.209847 -0.292048 -0.355501 0.128069 -0.613808 0.695360
AGE 0.352734 -0.569537 0.644779 0.086518 0.731470 -0.240265 1.000000 -0.747881 0.456022 0.506456 0.261515 -0.273534 0.602339 -0.376955
DIS -0.379670 0.664408 -0.708027 -0.099176 -0.769230 0.205246 -0.747881 1.000000 -0.494588 -0.534432 -0.232471 0.291512 -0.496996 0.249929
RAD 0.625505 -0.311948 0.595129 -0.007368 0.611441 -0.209847 0.456022 -0.494588 1.000000 0.910228 0.464741 -0.444413 0.488676 -0.381626
TAX 0.582764 -0.314563 0.720760 -0.035587 0.668023 -0.292048 0.506456 -0.534432 0.910228 1.000000 0.460853 -0.441808 0.543993 -0.468536
PTRATIO 0.289946 -0.391679 0.383248 -0.121515 0.188933 -0.355501 0.261515 -0.232471 0.464741 0.460853 1.000000 -0.177383 0.374044 -0.507787
B -0.385064 0.175520 -0.356977 0.048788 -0.380051 0.128069 -0.273534 0.291512 -0.444413 -0.441808 -0.177383 1.000000 -0.366087 0.333461
LSTAT 0.455621 -0.412995 0.603800 -0.053929 0.590879 -0.613808 0.602339 -0.496996 0.488676 0.543993 0.374044 -0.366087 1.000000 -0.737663
TARGET -0.388305 0.360445 -0.483725 0.175260 -0.427321 0.695360 -0.376955 0.249929 -0.381626 -0.468536 -0.507787 0.333461 -0.737663 1.000000

相関係数のヒートマップを表示

先ほど求めた相関係数をヒートマップにしてプロットします。
前回と同様seabornのheatmapを使用します。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_boston

data = load_boston()
df = pd.DataFrame(data=data.data, columns=data.feature_names)
df['TARGET'] = pd.DataFrame(data=data.target)
cols = df.columns
rows = cols
corr = df.corr()

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
sns.heatmap(corr, ax=ax[0], vmin=-1, vmax=1, annot=True, fmt="3.2f", annot_kws={"size": 5})
ax[0].set_aspect('equal')
plt.show()

今回は図の左側に描画したいので、1×2のサブプロットにしてます。
seabornのheatmapで描画するaxesを指定するのはaxプロパティを使用します。




ヒートマップにマウスクリック時の処理を紐づける

ボタンクリック時の処理なのでfigureに'button_press_event'を紐づけます。

マウスイベントについては下記を参照ください。

fig.canvas.mpl_connect('button_press_event', onclick)

今回は図の左側のグラフだけクリック操作を有効にしたいです。
そのためにはevent.inaxesという変数を利用します。
これはクリックされた点がどのaxesにいたかを返します。

今回はax[0]とax[1]の2つだけなのでevent.inaxesがax[1]なら何もしないという形でonclick関数を定義します。

以下に全コードを示します。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_boston

data = load_boston()
df = pd.DataFrame(data=data.data, columns=data.feature_names)
df['TARGET'] = pd.DataFrame(data=data.target)
cols = df.columns
rows = cols
corr = df.corr()

def onclick(event):
    x = event.xdata
    y = event.ydata
    axes = event.inaxes
    # axis外クリック時は無視
    if x == None or y == None:
        return
    # ヒートマップ以外をクリック時は無視
    if axes != ax[0]:
        return

    # x、yの小数点切り捨ての値がインデックス番号と一致する
    col = cols[int(x)]
    row = rows[int(y)]

    # 散布図の範囲指定用に描くデータの最大値-最小値を計算しておく
    xwid = df[col].max() - df[col].min()
    ywid = df[row].max() - df[row].min()

    # 右の散布図の更新
    ax[1].set_title('Correlation coefficient ' + col + ' - ' + row + ': ' \
                    + str(round(corr.loc[row, col], 3)))
    ax[1].set_xlabel(col)
    ax[1].set_ylabel(row)
    ln.set_data(df[col], df[row])
    # 散布図の範囲はmax,min値より少し大きめにする
    ax[1].set_xlim(df[col].min() - xwid * 0.1, df[col].max() + xwid * 0.1)
    ax[1].set_ylim(df[row].min() - ywid * 0.1, df[row].max() + ywid * 0.1)
    plt.draw()

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
sns.heatmap(corr, ax=ax[0], vmin=-1, vmax=1, annot=True, fmt="3.2f", annot_kws={"size": 5})
ln, = ax[1].plot([np.nan], [np.nan], 'o')  # 散布図プロット
ax[0].set_aspect('equal')
fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()

散布図のプロットは、scatterでやろうと思ったんですが、エラー出て心折れました。
plotでもちゃんと書けてるので許してください。
クリックした箇所のインデックスの取得ですが、event.xdata、event.ydataの整数部分がインデックス番号になりますので、それを使っています。
heatmapのX軸Y軸は行列のインデックスを表してますので整数部分がインデックス番号として得られるわけです。
実行結果は下記のとおりです。


pick_eventではだめでした

なぜ'pick_event'を使わなかったかですが、pick_eventでヒートマップの要素を取得する際は、取得したい四角形の左下頂点をクリックしないと要素が取得できません。
ユーザーの気分的に四角形の中であればどこでも選択できるようにしておきたかったので、pick_eventではなく、button_press_eventで実装しました。

おわりに

今までのmatplotlibのイベント処理では用途不明のものばかり作ってきましたが、今回初めて使えそうなものじゃないでしょうか??笑

書籍紹介

pandasを使ってゴリゴリデータ分析できるようになる本です。


こちらは実際に使いそうなデータ分析の手法の流れを1から実戦形式で学べる本です。100問の問題形式で一通りの流れに沿って問題を解いていけば機械学習にも入門できていました。
飽きずに続けられたし、pandas力が上がりましたね。