morikomorou’s blog

自分が学んだことなどの備忘録的なやつ

【python】imshowによる2次元データのヒートマップ作成方法

はじめに

2次元データをその値に応じた色で視覚的に表現するヒートマップの作成方法についてまとめます。
以前matplotlibのimshowを使ってマンデルブロ集合を描いた際に少し手間取ったのでメモがてら。
mori-memo.hateblo.jp




行列データのヒートマップ描画方法

matplotlibのimshowを使用してヒートマップをかいてみます。
2次元データ(2次元リストでも可)をimshowの引数に与えるだけです。
やってみましょう。

import numpy as np
import matplotlib.pyplot as plt

data = np.arange(0, 100).reshape(10, 10)
print(data)

# output:
# [[ 0  1  2  3  4  5  6  7  8  9]
#  [10 11 12 13 14 15 16 17 18 19]
#  [20 21 22 23 24 25 26 27 28 29]
#  [30 31 32 33 34 35 36 37 38 39]
#  [40 41 42 43 44 45 46 47 48 49]
#  [50 51 52 53 54 55 56 57 58 59]
#  [60 61 62 63 64 65 66 67 68 69]
#  [70 71 72 73 74 75 76 77 78 79]
#  [80 81 82 83 84 85 86 87 88 89]
#  [90 91 92 93 94 95 96 97 98 99]]

fig, ax = plt.subplots()
im = ax.imshow(data)
plt.show()


プロットした2次元データは0から99までの数字を順番に10×10の行列にしただけのものです。

カラーバーを表示

色と数値の対応がわかるようにカラーバーを入れてみましょう。

import numpy as np
import matplotlib.pyplot as plt

data = np.arange(0, 100).reshape(10, 10)
fig, ax = plt.subplots()
im = ax.imshow(data)
plt.colorbar(im)
plt.show()

カラーバーを表示したらよく分かりますが、2次元データのインデックス[0, 0]のデータ(値0)がヒートマップの左上に描画され、インデックス[9, 9]のデータ(値99)が右下に描画されております。
y軸をよく見ると上が0で、下が9なのが気になります。
行列をそのまま可視化する際にはそのままの方が見やすいですが、x, yの位置に対応した2次元データのヒートマップを作成したい際にはyを上下逆に表示する必要があります。

y軸を反転する

y軸の反転はimshowのパラメータとしてorigin='lower'を指定します。
defaultではorigin='upper'となっています。

import numpy as np
import matplotlib.pyplot as plt

data = np.arange(0, 100).reshape(10, 10)
fig, ax = plt.subplots()
im = ax.imshow(data, origin='lower')
plt.colorbar(im)
plt.show()


X,Y座標に対応する2次元データの可視化

原点からのマンハッタン距離の可視化をしてみます。
原点から点 (x_{i}, y_{j})までのマンハッタン距離 z_{ij}は以下の定義で表されます。

z_{ij} = |x_{i}| + |y_{j}|

格子点データの作成

まずは格子状に点のx、y座標のデータを作成します。
格子点の作成にはnumpyモジュールのmeshgridが便利です。
簡単に使用例を示します。
x = 0, 1, 2
y = 0, 1, 2
の格子点を作成するコードです。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 2, 3)
y = np.linspace(0, 2, 3)
X, Y = np.meshgrid(x, y)  # 格子点のX座標、Y座標をまとめた行列をそれぞれ返す
print(X)
# [[0. 1. 2.]
#  [0. 1. 2.]
#  [0. 1. 2.]]
print(Y)
# [[0. 0. 0.]
#  [1. 1. 1.]
#  [2. 2. 2.]]
fig, ax = plt.subplots()
ax.scatter(X, Y)
ax.axis('equal')
plt.show()


点(0,0), (0,1), ~ 点(2,2)まで格子点がちゃんとプロットされています。

格子点における値の可視化

マンハッタン距離の可視化を実装します。
origin='lower'も忘れず定義します。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 10, 11)
y = np.linspace(0, 10, 11)
X, Y = np.meshgrid(x, y)
z = np.abs(X) + np.abs(Y)
fig, ax = plt.subplots()
im = ax.imshow(z, origin='lower')
plt.colorbar(im)
plt.show()


注意点

先ほどのコードにおいて、imshowの引数にはマンハッタン距離zのみしか指定しておりません。
つまり、グラフのx軸、y軸の値はマンハッタン距離を求める際に使用したx,y座標の値ではなく、ただのインデックス番号であることに注意が必要です。
先ほどのコードでx = np.linspace(0, 20, 11), y = np.linspace(0, 20, 11)に変えても全く同じグラフが得られますので確認してみてください。
正しいx,y座標の値をx軸、y軸に設定したい場合は下記のようにextentプロパティに(格子点のx座標の最小値, 最大値, y座標の最小値, 最大値)を指定します。
以下にx = np.linspace(0, 20, 11), y = np.linspace(0, 20, 11)とした際に正しいx, y座標を設定するコードを記述します。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 20, 11)
y = np.linspace(0, 20, 11)
X, Y = np.meshgrid(x, y)
z = np.abs(X) + np.abs(Y)
fig, ax = plt.subplots()
im = ax.imshow(z, origin='lower', extent=(0, 20, 0, 20))
plt.colorbar(im)
plt.show()


x軸,y軸の値が格子点の座標と合致し、0~20になりました。

カラーバーの最大値,最小値を自動更新する方法

ヒートマップでアニメーション等動的なグラフを作ることがあるかもしれないので、カラーバーの最大値、最小値を更新する方法についても説明します。
まず、以下のコードを見てください。
先ほどと同様のデータをプロットした後に、データを2倍の値で入れ替える処理をしています。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 20, 11)
y = np.linspace(0, 20, 11)
X, Y = np.meshgrid(x, y)
z = np.abs(X) + np.abs(Y)
fig, ax = plt.subplots()
im = ax.imshow(z, origin='lower', extent=(0, 20, 0, 20))
im.set_data(z * 2)   # データを2倍の値に差し替え
plt.colorbar(im)
plt.show()

その際の出力は下記の通りになります。

明らかにカラーバーの範囲がおかしいです。
データを2倍にしたのでカラーバーも2倍にしたいです。

その際に使えるのが、autoscale()メソッドです。
これを差し替えた後に記述するといい感じでカラーバーの値を再調整してくれます。
コードは以下です。

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 20, 11)
y = np.linspace(0, 20, 11)
X, Y = np.meshgrid(x, y)
z = np.abs(X) + np.abs(Y)
fig, ax = plt.subplots()
im = ax.imshow(z, origin='lower', extent=(0, 20, 0, 20))
im.set_data(z * 2)   # データを2倍の値で入れ替え
plt.colorbar(im)
im.autoscale()  # カラーバーの最大値最小値を自動更新
plt.show()


うまく更新できました。