import matplotlib.pyplot as plt
import numpy as np


def thomas(first, last, a, b, c, d):
    """トーマス法を用いて3項方程式
    ( a(i)x(i-1) + b(i)x(i) + c(i)x(i+1) = d(i) ) を解く．
    計算進行に伴って配列bにb，g，配列dにd，s，xを保持することで，
    係数aと係数cが同じ場合のthomas()で使用するスペースを節約．

        引数:
            first (array): 方程式をfirstからlastについて解く．
            last (array): 方程式をfirstからlastについて解く．
            a (array): 差分方程式のU[i-1]（i=2~N）の係数を記憶する配列．
            b (array): 初期にb, 計算進行に伴ってgを記憶する配列．
            c (array): 差分方程式のU[i+1]（i=2~N）の係数を記憶する配列．
            d (array): 初期にd, 計算進行に伴ってsさらにxを記憶する配列．
    """
    start = first + 1
    for i in range(start, last):  # 3項方程式をトーマス法を用いて解く
        p = c[i] / b[i - 1]  # 右辺のbは本文ではg
        b[i] = b[i] - p * a[i - 1]  # 左辺のbは本文ではg
        d[i] = d[i] - p * d[i - 1]  # 左辺と右辺第二項のdは本文ではs
    d[last - 1] = d[last - 1] / b[last - 1]
    for i in range(last - start, 0, -1):
        d[i] = (d[i] - a[i] * d[i + 1]) / b[i]  # 左辺と右辺第2項のdは本文ではx,
                                                # 右辺のbは本文ではg


mesh_max_x = int(input("X方向の格子数を入力してください. (e.g. 20): ")) + 1
mesh_max_y = int(input("Y方向の格子数を入力してください. (e.g. 20): ")) + 1
time_max = int(input("計算打ち切りのタイムステップ数を入力してください."\
    "(e.g. 100): "))
delta_t = float(input("時間間隔Δtを入力してください. (e.g. 0.02): "))
delta_x = 1 / (mesh_max_x - 1)
delta_y = 1 / (mesh_max_y - 1)
rx = 0.5 * delta_t / (delta_x ** 2)
ry = 0.5 * delta_t / (delta_y ** 2)
A = np.zeros(mesh_max_x)  # 差分方程式のU[i-1, j] 又はU[i, j-1] （i=2~N）の
                          # 係数を記憶する配列
B = np.zeros(mesh_max_x)  # 差分方程式のU[i, j] （i=2~N）の係数を記憶する配列
C = np.zeros(mesh_max_x)  # 差分方程式のU[i+1, j] 又はU[i, j+1] （i=2~N）の
                          # 係数を記憶する配列
D = np.zeros(mesh_max_x)  # 差分方程式の右辺の値を記憶する配列，
                          # thomas()を呼んだあとでは近似解が記憶される
U = np.zeros((mesh_max_x, mesh_max_y))  # 現時点での解を記憶する配列，0で初期化
UU = np.zeros((mesh_max_x, mesh_max_y))  # 1つ前の時間ステップでの解を
                                         # 記憶する配列，0で初期化

for t in range(0, time_max):  # 2次元熱伝導方程式をADI法を用いて解く

    for i in range(0, mesh_max_x):  # 境界条件（上と下）
        U[i, 0] = 1.0
        U[i, mesh_max_y - 1] = 0.0
        UU[i, 0] = 1.0
        UU[i, mesh_max_y - 1] = 0.0
    for j in range(0, mesh_max_y):  # 境界条件（左と右）
        U[0, j] = 0.5
        U[mesh_max_x - 1, j] = 0.0
        UU[0, j] = 0.5
        UU[mesh_max_x - 1, j] = 0.0

    for j in range(1, mesh_max_y - 1):  # x方向を陰的にして
                                        # Un+1/2 に対する方程式を解く
        for i in range(1, mesh_max_x - 1):  # 3項方程式の係数を計算
            A[i] = -rx
            B[i] = 2 * rx + 1
            C[i] = -rx
            D[i] = U[i, j] + ry * (U[i, j + 1] - 2 * U[i, j] + U[i, j - 1])

        D[1] = D[1] + rx * U[0, j]
        D[mesh_max_x - 2] = D[mesh_max_x - 2] + rx * U[mesh_max_x - 1, j]
        thomas(1, mesh_max_x - 1, A, B, C, D)

        for i in range(1, mesh_max_x - 1):
            UU[i, j] = D[i]

    for i in range(1, mesh_max_x - 1):  # y方向を陰的にして
                                        # Un+1/2 に対する方程式を解く
        for j in range(1, mesh_max_y - 1):  # 3項方程式の係数を計算
            A[j] = -ry
            B[j] = 2 * ry + 1
            C[j] = -ry
            D[j] = UU[i, j] + rx * (UU[i + 1, j] - 2 * UU[i, j] + UU[i - 1, j])

        D[1] = D[1] + ry * UU[i, 0]
        D[mesh_max_y - 2] = D[mesh_max_y - 2] + ry * UU[i, mesh_max_y - 1]
        thomas(1, mesh_max_y - 1, A, B, C, D)

        for j in range(1, mesh_max_y - 1):
            U[i, j] = D[j]

g = plt.subplot()
x_pos, y_pos = np.meshgrid(np.linspace(0, mesh_max_x - 1, mesh_max_x),
                   np.linspace(0, mesh_max_y - 1, mesh_max_y))  # グラフ描画用の
                                                                # 格子点を生成
g.contourf(y_pos, x_pos, U, alpha=0.5)
g.set_aspect('equal')
plt.show()
