はじめに
データ分析や機械学習等を行う際にたくさんあるデータの関係性を調べたいことがあるかと思います。
一つ一つデータを選んで散布図を描いて、どのデータとどのデータにどういう相関があるか調べるというのは非常に大変な作業ですが、
seabornというライブラリを使えば、データ間の関係性や、データ間の相関係数など一瞬で一覧形式で可視化できます。
こんな図がすぐかけます。
では実際にやっていきましょう。
データ間の関係を可視化する
今回は以下2つを紹介します。
- 全データ同士の散布図を一覧形式で作成
- 全データ間の相関係数をヒートマップ形式で描画
データの準備
とりあえずグラフ化用の適当なデータを用意します。
今回は、機械学習用ライブラリのscikit-learnで用意されているボストンの住宅価格のデータセットを使用します。
以下コードでpandasのデータフレームに格納できます。
import numpy as np import pandas as pd from sklearn.datasets import load_boston data = load_boston() df = pd.DataFrame(data=data.data, columns=data.feature_names) df['TARGET'] = pd.DataFrame(data=data.target) print(df.head()) # 初めの5レコードを表示
CRIM | ZN | INDUS | CHAS | NOX | RM | AGE | DIS | RAD | TAX | PTRATIO | B | LSTAT | TARGET | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.00632 | 18.0 | 2.31 | 0.0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1.0 | 296.0 | 15.3 | 396.90 | 4.98 | 24.0 |
1 | 0.02731 | 0.0 | 7.07 | 0.0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2.0 | 242.0 | 17.8 | 396.90 | 9.14 | 21.6 |
2 | 0.02729 | 0.0 | 7.07 | 0.0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2.0 | 242.0 | 17.8 | 392.83 | 4.03 | 34.7 |
3 | 0.03237 | 0.0 | 2.18 | 0.0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3.0 | 222.0 | 18.7 | 394.63 | 2.94 | 33.4 |
4 | 0.06905 | 0.0 | 2.18 | 0.0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3.0 | 222.0 | 18.7 | 396.90 | 5.33 | 36.2 |
各行が1つ1つの住宅のデータで、targetが目的変数となる住宅価格のデータです。
それ以外の13の説明変数はこちらのサイトにすごく詳しく説明されていましたので参照ください。
機械学習用のデータセットなので、本来はこれらの13変数を用いて住宅価格(TARGET)を予想するというものです。
住宅価格に関係が深そうな以下3つの説明変数と目的変数であるTARGETを抜き出して可視化対象のデータとします。
- RM: 住居の平均部屋数
- RAD: 高速道路へのアクセスのしやすさ
- LSTAT: 低所得者人口の割合
抜き出すためのコードは以下です。
df2 = df.loc[:,['RM', 'RAD', 'LSTAT', 'TARGET']] print(df2.head())
RM | RAD | LSTAT | TARGET | |
---|---|---|---|---|
0 | 6.575 | 1.0 | 4.98 | 24.0 |
1 | 6.421 | 2.0 | 9.14 | 21.6 |
2 | 7.185 | 2.0 | 4.03 | 34.7 |
3 | 6.998 | 3.0 | 2.94 | 33.4 |
4 | 7.147 | 3.0 | 5.33 | 36.2 |
全データ間の散布図を一覧表示
seabornのpairplotというものを使います。
pairplotの引数に散布図を一覧で書きたいデータフレームを入れるだけです。
以下数行でグラフが描けます。
import matplotlib.pyplot as plt import seaborn as sns sns.pairplot(df2) plt.show()
結果は以下。
各列と各行の変数間の散布図が表示されます。同じ変数のところはその変数のヒストグラムとなっています。
このグラフよりわかるのは、ざっと以下のことかと思います。
- RMとTARGETの間には正の相関がありそう
- LSTATとTARGETの間には負の相関がありそう
一覧表示にすることで有効そうな変数を選択できるのがいいですね。
散布図全コード
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.datasets import load_boston data = load_boston() df = pd.DataFrame(data=data.data, columns=data.feature_names) df['TARGET'] = pd.DataFrame(data=data.target) df2 = df.loc[:,['RM', 'RAD', 'LSTAT', 'TARGET']] sns.pairplot(df2) plt.show()
相関係数をヒートマップで描画
続いて同様のデータで相関係数を求めて、ヒートマップ表示してみましょう。
相関係数の導出
pandasのデータフレームのメソッドにcorr()というものがありこれで変数間の相関係数行列が一発で算出できます。
corr = df2.corr()
print(corr)
RM | RAD | LSTAT | TARGET | |
---|---|---|---|---|
RM | 1.000000 | -0.209847 | -0.613808 | 0.695360 |
RAD | -0.209847 | 1.000000 | 0.488676 | -0.381626 |
LSTAT | -0.613808 | 0.488676 | 1.000000 | -0.737663 |
TARGET | 0.695360 | -0.381626 | -0.737663 | 1.000000 |
相関係数は絶対値が1に近いほど相関が強く、0に近いほど相関がありません。
ヒートマップの描画
ヒートマップにしてみましょう。
以前の記事で、imshowを使った行列の可視化を紹介しましたが、今回はseabornのheatmapというモジュールを使ってみます。
seabornのheatmapはimshowとほぼ同じですが、グラフの調整や描画がseabornのほうがかなり楽です。
sns.heatmap(corr) plt.show()
heatmapの引数にデータフレームを指定しただけなのにカラーバーまで表示されています。
ついでに相関係数は-1~1までの値なので、カラーバーの最小値、最大値を-1~1に指定して描画してみます。
sns.heatmap(corr, vmin=-1, vmax=1) plt.show()
色で表している値を数字でも表示する
annotプロパティを指定することで値(annotation)をグラフ上に表示できます。
imshowでおんなじことをやろうと思うと別でplt.text()とかで値を1つ1つプロットする必要があってなかなかにめんどくさいです。
fmtプロパティで表示する文字列のフォーマットを指定できます。
- 整数は'd'
- 少数1桁は'.1f'
- 少数2桁は'.2f'
等です。
sns.heatmap(corr, vmin=-1, vmax=1, annot=True, fmt='.2f') plt.show()
なんか値の色もいい感じで背景色にまぎれないように勝手に調整してくれていて助かります。
annotationのフォントサイズを変更する
annotationの調整はannot_kwを指定することで可能です。
以下リンクのプロパティが指定可能です。
matplotlib.axes.Axes.text — Matplotlib 3.8.2 documentation
sns.heatmap(corr, vmin=-1, vmax=1, annot=True, fmt='.2f', annot_kws={'fontsize': 20}) plt.show()
相関係数可視化全コード
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.datasets import load_boston data = load_boston() df = pd.DataFrame(data=data.data, columns=data.feature_names) df['TARGET'] = pd.DataFrame(data=data.target) df2 = df.loc[:,['RM', 'RAD', 'LSTAT', 'TARGET']] corr = df2.corr() sns.heatmap(corr, vmin=-1, vmax=1, annot=True, fmt='.2f', annot_kws={'fontsize': 20}) plt.show()
おわりに
seabornを使えば簡単にデータの可視化ができることを学びました。
ただ、seabornはデータが多くなれば非常に重いです。
そして前半で紹介したpairplotは変数が多くなれば訳が分からなるという欠点はあります。。。
書籍紹介
pandasを使ってゴリゴリデータ分析できるようになる本です。
こちらは実際に使いそうなデータ分析の手法の流れを1から実戦形式で学べる本です。100問の問題形式で一通りの流れに沿って問題を解いていけば機械学習にも入門できていました。
飽きずに続けられたし、pandas力が上がりましたね。