morikomorou’s blog

自分が学んだことなどの備忘録的なやつ

【python】seabornで変数間の関係性をお手軽に可視化

はじめに

データ分析や機械学習等を行う際にたくさんあるデータの関係性を調べたいことがあるかと思います。

一つ一つデータを選んで散布図を描いて、どのデータとどのデータにどういう相関があるか調べるというのは非常に大変な作業ですが、
seabornというライブラリを使えば、データ間の関係性や、データ間の相関係数など一瞬で一覧形式で可視化できます。

こんな図がすぐかけます。

では実際にやっていきましょう。




データ間の関係を可視化する

今回は以下2つを紹介します。

  • 全データ同士の散布図を一覧形式で作成
  • 全データ間の相関係数をヒートマップ形式で描画

データの準備

とりあえずグラフ化用の適当なデータを用意します。
今回は、機械学習用ライブラリのscikit-learnで用意されているボストンの住宅価格のデータセットを使用します。
以下コードでpandasのデータフレームに格納できます。

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston

data = load_boston()
df = pd.DataFrame(data=data.data, columns=data.feature_names)
df['TARGET'] = pd.DataFrame(data=data.target)
print(df.head())  # 初めの5レコードを表示

CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT TARGET
0 0.00632 18.0 2.31 0.0 0.538 6.575 65.2 4.0900 1.0 296.0 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0.0 0.469 6.421 78.9 4.9671 2.0 242.0 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0.0 0.469 7.185 61.1 4.9671 2.0 242.0 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0.0 0.458 6.998 45.8 6.0622 3.0 222.0 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0.0 0.458 7.147 54.2 6.0622 3.0 222.0 18.7 396.90 5.33 36.2

各行が1つ1つの住宅のデータで、targetが目的変数となる住宅価格のデータです。
それ以外の13の説明変数はこちらのサイトにすごく詳しく説明されていましたので参照ください。

機械学習用のデータセットなので、本来はこれらの13変数を用いて住宅価格(TARGET)を予想するというものです。

住宅価格に関係が深そうな以下3つの説明変数と目的変数であるTARGETを抜き出して可視化対象のデータとします。

  • RM: 住居の平均部屋数
  • RAD: 高速道路へのアクセスのしやすさ
  • LSTAT: 低所得者人口の割合

抜き出すためのコードは以下です。

df2 = df.loc[:,['RM', 'RAD', 'LSTAT', 'TARGET']]
print(df2.head())

RM RAD LSTAT TARGET
0 6.575 1.0 4.98 24.0
1 6.421 2.0 9.14 21.6
2 7.185 2.0 4.03 34.7
3 6.998 3.0 2.94 33.4
4 7.147 3.0 5.33 36.2

全データ間の散布図を一覧表示

seabornのpairplotというものを使います。
pairplotの引数に散布図を一覧で書きたいデータフレームを入れるだけです。
以下数行でグラフが描けます。

import matplotlib.pyplot as plt
import seaborn as sns

sns.pairplot(df2)
plt.show()

結果は以下。

各列と各行の変数間の散布図が表示されます。同じ変数のところはその変数のヒストグラムとなっています。
このグラフよりわかるのは、ざっと以下のことかと思います。

  • RMとTARGETの間には正の相関がありそう
  • LSTATとTARGETの間には負の相関がありそう

一覧表示にすることで有効そうな変数を選択できるのがいいですね。

散布図全コード
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_boston

data = load_boston()
df = pd.DataFrame(data=data.data, columns=data.feature_names)
df['TARGET'] = pd.DataFrame(data=data.target)
df2 = df.loc[:,['RM', 'RAD', 'LSTAT', 'TARGET']]

sns.pairplot(df2)
plt.show()


相関係数をヒートマップで描画

続いて同様のデータで相関係数を求めて、ヒートマップ表示してみましょう。

相関係数の導出

pandasのデータフレームのメソッドにcorr()というものがありこれで変数間の相関係数行列が一発で算出できます。

corr = df2.corr()
print(corr)

RM RAD LSTAT TARGET
RM 1.000000 -0.209847 -0.613808 0.695360
RAD -0.209847 1.000000 0.488676 -0.381626
LSTAT -0.613808 0.488676 1.000000 -0.737663
TARGET 0.695360 -0.381626 -0.737663 1.000000

相関係数は絶対値が1に近いほど相関が強く、0に近いほど相関がありません。

ヒートマップの描画

ヒートマップにしてみましょう。
以前の記事で、imshowを使った行列の可視化を紹介しましたが、今回はseabornのheatmapというモジュールを使ってみます。

seabornのheatmapはimshowとほぼ同じですが、グラフの調整や描画がseabornのほうがかなり楽です。

sns.heatmap(corr)
plt.show()

heatmapの引数にデータフレームを指定しただけなのにカラーバーまで表示されています。
ついでに相関係数は-1~1までの値なので、カラーバーの最小値、最大値を-1~1に指定して描画してみます。

sns.heatmap(corr, vmin=-1, vmax=1)
plt.show()


色で表している値を数字でも表示する

annotプロパティを指定することで値(annotation)をグラフ上に表示できます。
imshowでおんなじことをやろうと思うと別でplt.text()とかで値を1つ1つプロットする必要があってなかなかにめんどくさいです。

fmtプロパティで表示する文字列のフォーマットを指定できます。

  • 整数は'd'
  • 少数1桁は'.1f'
  • 少数2桁は'.2f'

等です。

sns.heatmap(corr, vmin=-1, vmax=1, annot=True, fmt='.2f')
plt.show()

なんか値の色もいい感じで背景色にまぎれないように勝手に調整してくれていて助かります。

annotationのフォントサイズを変更する

annotationの調整はannot_kwを指定することで可能です。
以下リンクのプロパティが指定可能です。
matplotlib.axes.Axes.text — Matplotlib 3.8.2 documentation

sns.heatmap(corr, vmin=-1, vmax=1, annot=True, fmt='.2f', annot_kws={'fontsize': 20})
plt.show()


相関係数可視化全コード
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_boston

data = load_boston()
df = pd.DataFrame(data=data.data, columns=data.feature_names)
df['TARGET'] = pd.DataFrame(data=data.target)
df2 = df.loc[:,['RM', 'RAD', 'LSTAT', 'TARGET']]
corr = df2.corr()

sns.heatmap(corr, vmin=-1, vmax=1, annot=True, fmt='.2f', annot_kws={'fontsize': 20})
plt.show()

おわりに

seabornを使えば簡単にデータの可視化ができることを学びました。
ただ、seabornはデータが多くなれば非常に重いです。
そして前半で紹介したpairplotは変数が多くなれば訳が分からなるという欠点はあります。。。

書籍紹介

pandasを使ってゴリゴリデータ分析できるようになる本です。


こちらは実際に使いそうなデータ分析の手法の流れを1から実戦形式で学べる本です。100問の問題形式で一通りの流れに沿って問題を解いていけば機械学習にも入門できていました。
飽きずに続けられたし、pandas力が上がりましたね。