{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 正式的project文件" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Init Plugin\n", "Init Graph Optimizer\n", "Init Kernel\n" ] } ], "source": [ "from tensorflow import keras\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "# from icecream.icecream import ic\n", "plt.style.use('ggplot')\n", "# plt.rcParams[\"figure.figsize\"] = (12,6)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model construction\n", "\n", "## What should the input and output be?\n", "Input: [x, t]\n", "x as horizontal wave position, t as time\n", "\n", "Output: [u]\n", "u as vertical wave position\n", "\n", "[x, t]->y w.r.t the wave eqn, where x is the position.\n", "\n", "## IC and BC\n", "IC:\n", "u(0,x) = f(x)\n", "du/dt(0,x) = g(x)\n", "\n", "BC:\n", "u(0,t) = a(t)\n", "u(L,t) = b(t)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAs70lEQVR4nO3de1xU573v8c8zw2UEEXEmqCBqUBHwCpJ4SWJjJEhzaWyabS+Je7fZNs32ZPuyF0/Mbpqk7Um228Q25Rw9yalu7b1Nd9qm7U4aSm4mURN1wAuIgBqjUaLcRAGFmfWcP0aJBHQGmJk1l9/79eIlw3rWmu/D4I/FM2s9j9Jaa4QQQoQ9i9kBhBBC+IcUdCGEiBBS0IUQIkJIQRdCiAghBV0IISKEFHQhhIgQMWY++YkTJwa0n8PhoKGhwc9pQpv0OTpIn6PDYPqclpZ2xW1yhi6EEBFCCroQQkQIKehCCBEhpKALIUSEkIIuhBARwuerXAzDYPXq1YwYMYLVq1f32Ka1ZvPmzZSXlxMfH8/y5cvJzMz0e1ghhBBX5vMZ+ssvv0x6enqf28rLy6mvr6ekpIQHHniAjRs3+i2gEEII3/hU0BsbG3E6nSxcuLDP7bt27WL+/PkopcjKyqKtrY3m5ma/BhUi0LTW6A8PYZT+CX1gD9rtNjuSEP3i05DLli1buO++++jo6Ohze1NTEw6Ho/ux3W6nqamJlJSUHu3KysooKysDYM2aNT326VfomJgB7xuupM+Bo893cO73W7jw7mu4P/bc7KYBNWw48dffROIX78fqGBnwHCCvc7QIVJ+9FvTdu3eTnJxMZmYmlZWVfbbpa40MpVSvrxUWFlJYWNj9eKB3SsmdZdEhGH3WHe0YJT+AQ9UwZSZq0d2o3Dz4oBbt3EbH1lI6yt/D8q0folJHBzQLyOscLQJ1p6jXgn7w4EF27dpFeXk5nZ2ddHR0UFJSwooVK7rb2O32HuEaGxt7nZ0LEWp02zmMnzwBHx5Cff07WK678ZON9mtQs+ahjx7C+PFjGE8/guVb/ws1eoxpeYXwxusY+le+8hWee+451q9fz8qVK5k6dWqPYg5QUFDA1q1b0VpTU1NDQkKCFHQR0nR7G8a678Kxw1gefLhnMb+MGjcBy3eeBLcb4+lH0CePBzmpEL4b8HXopaWllJaWApCXl0dqaiorVqzg+eefZ9myZX4LKEQg6Be3wPGjWP7Hd1Ez51y1rRozHsv//HfQGmPzs2hD3iwVoalfsy1OmTKFKVOmAFBUVNT9daWUFHERNnTNfvTWV1FFi1FTZ/m0jxo1BvWlr6M3rkO/8TJq4Z0BTilE/8mdoiKq6K5OjF+sB3sq6nNf6de+6vr5MCUP/cdfoBtPByihEAMnBV1EFf3y76H+Iyz3LUfF2/q1r1IKy33LPUMvv/q/fV7dJYSZpKCLqKHrj6NfeRE1+zOoqfkDOoZyjETddS/s2wXO7X5OKMTgSEEXUUO/8iJYLagl/zyo46iFd8LoDIy//lbO0kVIkYIuooJuOo1+703UTYtQw4YP6ljKakUV3w3HP4D9Tr/kE8IfpKCLqKD//hIA6ta7/HI8df18GOHA+Nt/+eV4QviDFHQR8fS5Vs9litfPR9lT/XJMFROLunUx1FSi6w745ZhCDJYUdBHx9Ov/DZ0XUIu+4NfjqpuKIDEJ428v+vW4QgyUFHQR0fSF8+jX/wozrkelj/XrsVW8DXXL7bDnffRHH/r12EIMhBR0EdH0jjeh7SyW4rsDcnx1yx0QF4d+/S8BOb4Q/SEFXUQ0/W4ZpI2FCTkBOb4aOgyVfwN659voCxcC8hxC+EoKuohY+uQxOFKDumFhn/Pz+4u6YSF0tKPL5UYjYS4p6CJi6XfLwGJBzbk5sE+UNRXsqehtrwX2eYTwQgq6iEja7faMn0+/DjUssHPzK4sFdUMhVO9FN3wc0OcS4mqkoIvItN8JZ5qx3ND3wub+pubdAoDe9npQnk+IvnidD72zs5PHH38cl8uF2+1mzpw5LFmypEebyspK1q5dS2qq56aN2bNnc8899wQmsRA+MLaVQVIyTC0IyvMpeypkT0dvew19xxdRFjlXEsHntaDHxsby+OOPY7PZcLlcPPbYY8ycOZOsrKwe7XJycli9enXAggrhK322FfbsRN1yOyqmX2u4DIqatxC96UdQsx+ypwfteYW4xOtphFIKm80zb7Tb7cbtdgf0igEhBkvvegfcLtS84Ay3XKLy5sKQBPSON4L6vEJc4tPpi2EYPPzww9TX17No0SImTZrUq01NTQ2rVq0iJSWFpUuXkpGR0atNWVkZZWVlAKxZswaHwzGw0DExA943XEmffde0932MMeNxzAzOcMvlzlx3IxecO7APHz6gvw7kdY4Ogeqz0v2Y0LmtrY1nnnmGr33ta4wd+8lt1O3t7VgsFmw2G06nky1btlBSUuL1eCdOnBhQaIfDQUNDw4D2DVfSZ9/os2cwvv1PqNv/Actd9wYo2VWev3wHxoansHzz+6jcvH7vL69zdBhMn9PS0q64rV/v3CQmJpKbm0tFRUWPryckJHQPy+Tn5+N2u2ltbe1/UiEGSZfvAG2g8ueZE2BKHsTb0LvlJiMRfF4LemtrK21tbYDnipd9+/aRnp7eo01LS0v3yi11dXUYhkFSUlIA4gpxdXr3NkgdDWPGm/L8Ki4eNa0AXb4dbbhNySCil9dBvubmZtavX49hGGitmTt3LrNmzaK0tBSAoqIiduzYQWlpKVarlbi4OFauXClvnIqg021n4eBe1K2Lzf35y58Hu96B2gMweap5OUTU8VrQx40bx9q1a3t9vaioqPvz4uJiiouL/ZtMiH7SFe+D242aZdJwy0Vq2ix0bBzauQ0lBV0Ekdz9ICKG3v0u2FNh3ERTcyjbEJiSj3ZuQxuGqVlEdJGCLiKCbm+DAxWo/LkhMdynZs2DliY4UmN2FBFFpKCLiKD37QKXy7yrWz5FTb8OrDGevxqECBIp6CIy7N3pmbslc7LZSQBQCYkweRp67y6zo4goIgVdhD3tdqP3O1FTZ4XUpFhqegF8/BH61MBuoBOiv0Lnp1+IgTp8ENrPeQpoCFHTPHnkLF0EixR0Efb0vp1gtcIAbrUPJJU6Gkale8b3hQgCKegi7Om9u2BirmfcOsSoaQVQsx99vsPsKCIKSEEXYU03noaPjnYPb4QaNa0AXC44sMfsKCIKSEEXYe3ScEaojZ93m5QLtiEy7CKCQgq6CGt63y5wjIRRY8yO0icVEwu5eeh9u+jHTNVCDIgUdBG2dOcFqN6DmlYQEneHXomaXuC5a/TYYbOjiAgnBV2Er4P7obMzdIdbLlJTZwFy+aIIPCnoImzpSifExcHkaWZHuSqVnALjJnryChFAUtBF2NKV5ZA1FRUbZ3YUr9SUPDh80DOJmBABIgVdhCXddBrqj6NyZpodxScqdyYYBtTsMzuKiGBeF7jo7Ozk8ccfx+Vy4Xa7mTNnDkuWLOnRRmvN5s2bKS8vJz4+nuXLl5OZmRmw0ELoqgrg4plvOMjM9qw1WlmBmjnH7DQiQnkt6LGxsTz++OPYbDZcLhePPfYYM2fOJCsrq7tNeXk59fX1lJSUUFtby8aNG3nqqacCGlxEuaoKSB4BaWPNTuITFRsLWVPRVeVmRxERzOuQi1IKm80GgNvtxu1297pEbNeuXcyfPx+lFFlZWbS1tdHc3ByYxCLqacNAH6hA5c4I6csVP03lzoRTJ9Gn682OIiKU1zN0AMMwePjhh6mvr2fRokVMmjSpx/ampiYcDkf3Y7vdTlNTEykpKT3alZWVUVZWBsCaNWt67NOv0DExA943XEmfP9F1qJqmc2dJmj2fIWH0PXHdcAuNv9tI4rE6EnL6XmtUXufoEKg++1TQLRYLTz/9NG1tbTzzzDN8+OGHjB37yZ+6fd0B19eZU2FhIYWFhd2PGxoaBpIZh8Mx4H3DlfT5E8a7bwBwLiOTtjD6nmhbIqQ4OPveO7Tn39hnG3mdo8Ng+pyWlnbFbf26yiUxMZHc3FwqKip6fN1ut/cI19jY2OvsXAh/0VUVMOZa1LDw+hlTSqFyZ0D1HrThNjuOiEBeC3prayttbZ5rZzs7O9m3bx/p6ek92hQUFLB161a01tTU1JCQkCAFXQSEvnAe6g6gpsw0O8rA5OZBext8UGd2EhGBvA65NDc3s379egzDQGvN3LlzmTVrFqWlpQAUFRWRl5eH0+lkxYoVxMXFsXz58oAHF1GqZj+4XagQW8zCVypnJlopdFU5KkTWPxWRw2tBHzduHGvXru319aKiou7PlVIsW7bMv8mE6IOuqoDYOM+0tGFIJQ2DjExPP+74ktlxRISRO0VFWNEH9sDEnLC43f9KVM50OFzjGT4Swo+koIuwoVtbPKsTZU83O8qgqOwZ4HZBbZXZUUSEkYIuwoY+6JkHJdwLOpNywRqDrpZl6YR/SUEX4ePAHhiSAOMmmp1kUFS8DTKz0NUyUZfwLynoImzo6r2e6XKtVrOjDJrKngEfHkK3nTU7ioggUtBFWNCNp+B0ffgPt1yksqeD1p5Vl4TwEynoIizo6r0AqJwZJifxk8wsiIuXcXThV1LQRXg4sAeSksNmulxvVEwsZE2RcXThV1LQRcjTWqOr96Kyp4fVdLneqOzpcPIYuqXR7CgiQkhBF6Gv/jicaYZIGW65SGV7+iNn6cJfpKCLkKcPeMaZI+UN0W4Z4yFhKMg4uvATKegi5OnqvWBPRV0zyuwofqUsVsieJmfowm+koIuQpg0DaipR2dPMjhIQavI0aDwly9IJv5CCLkLb8Q+g7SxMjrDhlovUxX5dmtZAiMGQgi5CWvf8LZMj8wydtAzP5ZhS0IUfSEEXIU1X74XUNNSIyFxEWCmFyp6Ort7b59q8QvSH1wUuGhoaWL9+PS0tLSilKCws5LbbbuvRprKykrVr15KamgrA7NmzueeeewKTWEQN7XZBbSXqupvMjhJYk6fBzrfh4xNwzTVmpxFhzGtBt1qtLF26lMzMTDo6Oli9ejXTp09nzJgxPdrl5OSwevXqgAUV0cd1uAY62j0FL4KpydPQXBxemhpZ19qL4PI65JKSkkJmZiYAQ4YMIT09naampoAHE6JzvxOI4PHzS0amwfARMo4uBs3rGfrlTp06xZEjR5g4sfd81DU1NaxatYqUlBSWLl1KRkZGrzZlZWWUlZUBsGbNGhyOgY2LxsTEDHjfcBWNfW7ZX451zHgcEyaZHSXgzkwvoHPPTqxWa9S9ztH4sx2oPivt4zsx58+f5/HHH+fuu+9m9uzZPba1t7djsViw2Ww4nU62bNlCSUmJ12OeOHFiQKEdDgcNDQ0D2jdcRVuftcuF/ua9MHcBlq88aHacgDPe+Tv6Z/8b+09+SUvCMLPjBFW0/WzD4PqclpZ2xW0+XeXicrlYt24dN910U69iDpCQkIDNZgMgPz8ft9tNa2vrgMIKAcAHtejzHd3XaUe6S8NKl4aZhBgIrwVda81zzz1Heno6d9xxR59tWlpaui+5qqurwzAMkpKS/JtURJXuG22yppobJEjUNaPAnkrnPinoYuC8jqEfPHiQrVu3MnbsWFatWgXAl7/85e4/F4qKitixYwelpaVYrVbi4uJYuXJlRE1zKoJPH9xHzPiJ6KToGX5Qk6fRuXcnyjBQFrlFRPSf14KenZ3NCy+8cNU2xcXFFBcX+y2UiG66qwvqDhC3aDEXzA4TTNnT0dteQ310FDKuNTuNCENyGiBCz+GD0NVJ7LRZZicJqkvj6JeW2xOiv6Sgi5CjD+4FZSFuykyzowSVGuHAOnqMTNQlBkwKugg5+uA+GJuJJTH63liPm5oPNZVow212FBGGpKCLkKI7L8Dhg5F/d+gVxE2bBR1t8OFhs6OIMCQFXYSWQ9XgckXsghbexE7NB2R+dDEwUtBFSNHVe8FigUm5ZkcxhTXFDqMz5I1RMSBS0EVI0Qf3wfhJKFuC2VFMoyZPg9oqtMtldhQRZqSgi5Chz3fAB7VRO35+icqeDhfOw9E6s6OIMCMFXYSOuipwu6N2/LzbxekOZNhF9JcUdBEydPU+sMbAhOgcP79EJQ2DMePljVHRb1LQRcjQ1Xvh2ixUfLzZUUynsqdD3QHPNAhC+EgKuggJuu0cfHgIlRMd0+V6o7KnQ1cnHK42O4oII1LQRWio2Q9ao7JlTU3AM45uscg4uugXKegiJOgDeyAuHjKzzI4SEtSQBBg30fN9EcJHUtBFSNDVe2FSLiom1uwoIUNlT7+4clO72VFEmJCCLkynzzTDyWOeAia6qZwZ4HZD7QGzo4gw4XWBi4aGBtavX09LSwtKKQoLC7ntttt6tNFas3nzZsrLy4mPj2f58uVkZmYGLLSILJfGiVWOjJ/3MCEbYmLR1XtQUTY3vBgYrwXdarWydOlSMjMz6ejoYPXq1UyfPp0xY8Z0tykvL6e+vp6SkhJqa2vZuHEjTz31VECDiwhSvRcSEmWVnk9RcfEwIVveGBU+8zrkkpKS0n22PWTIENLT02lqaurRZteuXcyfPx+lFFlZWbS1tdHc3ByYxCLi6AN7YPI0lMVqdpSQo7Knw7Ej6HOtZkcRYcDrGfrlTp06xZEjR5g4cWKPrzc1NeFwOLof2+12mpqaSElJ6dGurKyMsrIyANasWdNjn36FjokZ8L7hKlL77P74BA2Np0i6+z4SPtW/SO3z1Xy6z51z5tP80q9IOnkU29wFJiYLHHmd/XhcXxueP3+edevW8dWvfpWEhJ4z4Wmte7VXSvX6WmFhIYWFhd2PGxoa+pO1m8PhGPC+4SpS+2xsexOAtowJtH+qf5Ha56v5dJ/18Gsgfgit77/DuUmROceNvM79k5aWdsVtPl3l4nK5WLduHTfddBOzZ8/utd1ut/cI19jY2OvsXIg+Ve+F5BQYNcZ72yikYmIgawr6gIyjC++8FnStNc899xzp6enccccdfbYpKChg69ataK2pqakhISFBCrrwShsG+sAeVM6MPv+iEx4qZwZ8/BG66bTZUUSI8zrkcvDgQbZu3crYsWNZtWoVAF/+8pe7z8iLiorIy8vD6XSyYsUK4uLiWL58eWBTi8hw/AM4ewZyZpqdJKSp3JloQFdVoG681ew4IoR5LejZ2dm88MILV22jlGLZsmV+CyWigz5QAYDKlevPryptLCSPgKoKkIIurkLuFBWm0VUVkDYWNdxudpSQppRC5cxAH9iDNgyz44gQJgVdmEJ3dUJtFSp3ptlRwkPuTDjXCsePmJ1EhDAp6MIctVXQ1SkF3UeXpkXQVRXmBhEhTQq6MIWuqvAsN3dx/UxxdWr4CEgfJwVdXJUUdGEKfaACJmSj4m1mRwkbKmcm1FahOy+YHUWEKCnoIuj02TPw4WEZbuknlTsTXF1QV2V2FBGipKCLoLu0Co8U9H7KmgIxMTLsIq5ICroIvqoKz3S54yaYnSSsqHgbTMiRgi6uSAq6CCqtNbqyHLJnyHS5A6ByZ3qm0z0j01OL3qSgi+D66Ci0NKKm5pudJCxd+r7pSqfJSUQokoIugkrv3w2AmipLqg1IRqZndsr9UtBFb1LQRVDp/U4YMx6VIrf7D4RSCjUlH11ZjjbcZscRIUYKuggafb4d6g6gpshwy6BMzYf2c3Ck1uwkIsRIQRfBc2AvuF2ygv0gqdyZoCzdw1dCXCIFXQSN3u8E2xCYkG12lLCmEpMgM8vz/RTiMl7nQ9+wYQNOp5Pk5GTWrVvXa3tlZSVr164lNTUVgNmzZ3PPPff4P6kIa1przxll9gxUTKzZccKempqP/vNv0GfPoJKSzY4jQoTXgn7zzTdTXFzM+vXrr9gmJyeH1atX+zWYiDD1x6HpNOr2fzA7SURQU2ehX/o1urIcNedms+OIEOF1yCU3N5ehQ4cGI4uIYHrfxcsVp8j4uV+MnQBJySDj6OIyXs/QfVFTU8OqVatISUlh6dKlZGRk+OOwIoLo/bthdAbKfo3ZUSKCslhQU/LQ+51owy133QrADwX92muvZcOGDdhsNpxOJ08//TQlJSV9ti0rK6OsrAyANWvW4HA4BvScMTExA943XIVzn422c5yurSThzi+S1I8+hHOfB6o/fT5/wy2c2fEmyU2niMueFuBkgSOvsx+PO9gDJCQkdH+en5/Ppk2baG1tZdiwYb3aFhYWUlhY2P24oaFhQM/pcDgGvG+4Cuc+GzvfAZeL81nTuNCPPoRznweqP33WYyeB1UrLW6VYHKMDnCxw5HXun7S0tCtuG/Rliy0tLWitAairq8MwDJKSkgZ7WBFJ9rwHQ4dB5mSzk0QUlZAIWVPRe943O4oIEV7P0J999lmqqqo4e/YsDz74IEuWLMHlcgFQVFTEjh07KC0txWq1EhcXx8qVK1FKBTy4CA/a7Ubv242acb2M8waAmnE9+rc/RZ86gUq98pmbiA5eC/rKlSuvur24uJji4mJ/5RGRpu4AtJ9Dzbze7CQRSU2/zlPQ9+xE3XqX2XGEyeROURFQes97EBMDuXlmR4lI6ppRnsWjZdhFIAVdBJDWGl3xnufuUNsQs+NELDVjNtRWotvOmh1FmEwKugic+uNwuh414zqzk0Q0NeM6MIzum7dE9JKCLgJGV3iGAdR0GT8PqPGTYNhwkGGXqCcFXQSM3vMejM1EjYium0aCTVksnqtdKp3ori6z4wgTSUEXAaGbG+FQNSpvrtlRooLKmwMd7XCgwuwowkRS0EVAaOc2AFTBDSYniRI5M2BIInrXu2YnESaSgi4CQu9+F9LHoUaNMTtKVFAxsaiZ16P3vId2ybBLtJKCLvxOtzR51g6dJWfnwaRm3QjtbZ6l/kRUkoIu/E47t4HWqFnzzI4SXXJnwpAE9O53zE4iTCIFXfid3r3NM/d52lizo0QVFRvrudql/D30xfmWRHSRgi78Sp9phtpKeTPUJGrWDdB+Dqpl2CUaSUEXfqXLt18cbpGCboopeWAb0n2VkYguUtCFX+ld78KoMSDDLaZQsXGo6dejy7fLsEsUkoIu/EY3NUDNftR1N8qc+CZS190I585CVbnZUUSQSUEXfqPff8sz3DLnZrOjRLep+TA0Cb39DbOTiCDzusDFhg0bcDqdJCcns27dul7btdZs3ryZ8vJy4uPjWb58OZmZmQEJK0KX1tpTQCZky8o5JlMxsajrbkK//Xd0e5tnqToRFbyeod98883827/92xW3l5eXU19fT0lJCQ888AAbN270a0ARJo4dhhMfytl5iFBzFoCry3PHrogaXgt6bm4uQ4cOveL2Xbt2MX/+fJRSZGVl0dbWRnNzs19DitCnt78J1hhUwY1mRxEA12ZBahp6x5tmJxFB5HXIxZumpiYcjk+mR7Xb7TQ1NZGSktKrbVlZGWVlZQCsWbOmx379ERMTM+B9w1Uo91m7XTTsepu4ghsYPt5/w22h3OdA8WefzxXeTtuvf0qK0YU1dbRfjhkI8jr78biDPYDWutfXrnSFQ2FhIYWFhd2PGxoaBvScDodjwPuGq1Dus96/G6Olia78eX7NGMp9DhR/9llPux74KY2v/BHL7Uv8csxAkNe5f9LSrvwe1aCvcrHb7T2CNTY29nl2LiKX3v4mJAyFaQVmRxGXUY6RMCkXveONPk+8ROQZdEEvKChg69ataK2pqakhISFBCnoU0W3n0BXbUdffhIqNNTuO+BQ19xao/wgOHTA7iggCr0Muzz77LFVVVZw9e5YHH3yQJUuW4Lp4B1pRURF5eXk4nU5WrFhBXFwcy5cvD3hoETr09tehsxM1v9jsKKIP6rqb0C9sQr/1KmpirtlxRIB5LegrV6686nalFMuWLfNXHhFGtNbot/4GmZNRGdeaHUf0QdmGoOYsQL/zd/QX/xk1dJjZkUQAyZ2iYuBq9kP9cdRn5Ow8lKnPLPJck77tdbOjiACTgi4GTL/1N0hIlGvPQ5wacy1MyEa/9Td5czTCSUEXA6Jbm9HO7ah5C1Fx8WbHEV6oz3wWTp2QedIjnBR0MSD63dfA7ZI3Q8OEmjUPEpM8f1WJiCUFXfSbdrs9hWHyNNToMWbHET5QcfGoebegK3agWxrNjiMCRAq66Dft3A6Np7DccrvZUUQ/qAW3g6HRr/3V7CgiQKSgi37RWqNf/QOMTIeZs82OI/pBXTMKNWse+q1X0B3tZscRASAFXfRP9V44WocqWoyyWM1OI/pJLfo8dLSjt75qdhQRAFLQRb8Yr/4Bhg1HzV1gdhQxAGr8JMieji57Ce3qMjuO8DMp6MJn+tgRqCxHLbwTFRtndhwxQJZFd0NLE/q9rWZHEX4mBV34TL/6B4i3ea5pFuFrSh6MGY9+9Q9owzA7jfAjKejCJ7r+I/TOt1E3LUIlXnkFKxH6lFKo4i/AyWNQvt3sOMKPpKALn+iXfgWxcajP3m12FOEH6robYXQGxp9+iXa7zY4j/EQKuvBKHz2E3vUOqvBzqGEy130kUBYrlsX3Qf1HnimQRUSQgi68Mv70C0hMQhV93uwowp/y5sC1Wei//Abd1Wl2GuEHUtDFVemD+2G/E/XZe1AJiWbHEX6klMLy+aXQ1IB+8xWz4wg/8GmR6IqKCjZv3oxhGCxcuJDFixf32F5ZWcnatWtJTU0FYPbs2dxzzz1+DyuCS2uN8cefw3A7asFtZscRAaByZkDODPTLv0ffeCtqSILZkcQgeC3ohmGwadMmHn30Uex2O4888ggFBQWMGdNzUqacnBxWr14dsKAi+PSON+FQNeofH5IpciOY5Qv/hPHkt9F//jXqi7L6WDjzOuRSV1fHqFGjGDlyJDExMcybN4+dO3cGI5swkW47h/79f3qWl7uh0Ow4IoDUuImo+YvQr/0V/eFhs+OIQfB6ht7U1ITdbu9+bLfbqa2t7dWupqaGVatWkZKSwtKlS8nIyOjVpqysjLKyMgDWrFmDw+EYWOiYmAHvG66C3efW/9pMR9tZRnz/J8ReHEoLNnmdg8dYtpLGivew/u6npPz78yhL8N5ek9fZj8f11qCvJauUUj0eX3vttWzYsAGbzYbT6eTpp5+mpKSk136FhYUUFn5yttfQ0DCQzDgcjgHvG66C2Wd9pAaj9E+ohXdyJmkEmPS9ltc5uPQXvkrXf/6Y03/6NZYgLlwir3P/pKWlXXGb11/DdrudxsZPJsRvbGwkJaXntcgJCQnYbDYA8vPzcbvdtLa2DiisMJd2uTB+uQGSU1B3fcXsOCKI1JybYfI09Is/R7c2mx1HDIDXgj5hwgROnjzJqVOncLlcbNu2jYKCgh5tWlpaus/k6+rqMAyDpKSkwCQWAaX/8lv48DCWLz+AsskVD9FEKYXl3geh8wLG5hJZUDoMeR1ysVqt3H///Tz55JMYhsGCBQvIyMigtLQUgKKiInbs2EFpaSlWq5W4uDhWrlzZa1hGhD59cB/6ld+jbihE5c8zO44wgRqdgVpyP/rXz6Nf+wuq8HNmRxL94NN16Pn5+eTn5/f4WlFRUffnxcXFFBfLYsHhTJ9rxdj4I0hNQ33p62bHESZSN9+GrixHv7gFnTUVNTbT7EjCR3KnqPDcQPTz/wNnz2D5+ndQtiFmRxImUkph+acVkDgM46fPoC+cNzuS8JEUdIH+799B+Q7U3UtR4yaYHUeEAJU0DMs/fxM+/ghj049k3vQwIQU9yhk73kS/9GvU3AWoWxebHUeEEJUzA7XkfijfgX7xZ2bHET7waQxdRCZdU4n+WQlMnua5vV/eyBafohZ+Dk7Vo0v/iHHNKCw3y2pVoUwKepTSx49gbHgKHCOx/MtqVEys2ZFECFJKwReXoRs+Rv/mefSw4aj8uWbHElcgQy5RSH9Qi/HMoxAXj+VfH0Mlyj0D4sqU1YrlgVUwfhLG8/+BsfNtsyOJK5CCHmV03QGMH30PbEOwrHoKlTra7EgiDCjbECzf/D5MzEH/dB3GttfMjiT6IAU9iujyHRjPPg5Jw7H8z39HXTPK7EgijChbApYVT0D2NPTmn2C88qLcTRpipKBHAW24Mf74S8+Y+agxnjPzEdeYHUuEIRUfj+Vfv4cquBH9h59hPP8f6PPtZscSF8mbohFOtzZjbC6B/bs9t/Tf+yAqNs7sWCKMqdg4eGAVjJ+IfvHnGCeOYfnGw6j0sWZHi3pS0COU1hq97XX0C5ug8zzq3n9BfaZYLk0UfqGUQi26Gz12Asb/exrjhytRt92D+uw/oGLliimzSEGPQPr4BxgvbIIDe2BiLpZ/fAg1eoz3HYXoJ5UzA8sP1qN/uxH9l9+id72L5Utfh5wZcvJgAinoEUTXf4T+y2/QO98G2xDP8Mr84qCuPiOij0pKRn392+g5n8H41XMYP34MsqZiueteVNYUs+NFFSnoYU4bBlRVYLz1N9jzPsTGoorvRhV9HjV0mNnxRBRR0wqw/HADemsp+uUXMJ5+BCZko27+LGrWDfLeTRBIQQ9DWmv48DDauR29cyucroekZNSiz6Nu/RxqWIr3gwgRACo2DrXwDvSNt6K3/g395svoTT9G/24j6rqbUHlzIWsqymo1O2pEkoIeJtwNpzB2bIWa/eiqCmg8BRaLZx6Wxfeh8ubKm1EiZKj4eNStd6EX3gnVe9FbX0W/W4Z+42VITELlzvT87GZNRV+2CL0YHJ8KekVFBZs3b8YwDBYuXMjixYt7bNdas3nzZsrLy4mPj2f58uVkZsqk+AOh287BqRPoj0/AyePoY4fh2GEaWpo8DYYkQtYU1B1fRM2YjUqSYRURupTFArkzUbkz0RcuQKUTXb4dfWAv7HwbDZwemoROH4/KyIT0sajUNEgd7VnXVt5Y7RevBd0wDDZt2sSjjz6K3W7nkUceoaCggDFjPrlqory8nPr6ekpKSqitrWXjxo089dRTAQ0eirThBpcLXF3Q1QVdnZ6PC+fhwgU434HuaIOOdmg/B+da4ewZdGsLNDdCSyOc7/jkgBYLjM5AZc8gMXc67WnjIWM8yiJ/rorwo+LjIX8uKn+uZ9jw1El0zX5s9cfpqK1Cv/UKdHXSfe9pbBwMHwEpdlTyCBg6zPORmARDElAJCWBLgHgbxMV7PmLjPB8xsWC1gtUaVb8UvBb0uro6Ro0axciRIwGYN28eO3fu7FHQd+3axfz581FKkZWVRVtbG83NzaSk+H8sV+930vDiFtxut5eGV7olWff+VGvPg0v76D4+v/RhuC/+a4A2wO32fM3tvspzXoFtiOcHdNhwSBuLmpIHKQ7UyNGQmgbXjOp+IynR4aCjoaF/xxciRCmlYGQaamQawxwOOhsa0G63Zyjx1En0qRPQeBqaG9AtjeijhzwnQO3nuo/h8/82awxYLWCxXvxQoC4+Vng+v/SvJ5zn49LnXP64Ry+u1DmvkdoWLYYbbvW1Bz7zWtCbmpqwXzbGZbfbqa2t7dXG4XD0aNPU1NSroJeVlVFWVgbAmjVreuzjq85Rozk/bgKGL8XzSt/YPl4U9ekX8fIXUimwWDwtL/1AWKyeN3YsVs+2mBiIifV8LSbWM4YYG4eKi0fZhqDibShbAioxEZUwFEviUFRcvM/9jomJGdD3K5xJn6NDjz6PHAm5067YVrtc6LazGO1t6PZz6PY29IXzn3x0dUHnBXRXJ9rtgi6X51/DALfL80vDMC5+uD2/FLT2nJxpLv7b84Tuk/lq+jgZ7BXQt18zsfZrSAzA6+y1oPc1+c6n/4TxpQ1AYWEhhYWF3Y8bBnLG6RiN4zs/HNi+ocKtofUscNbnXRwOR3j3eQCkz9FhQH2OtUGyDZLD85df3CBe57S0tCtu83rHid1up7GxsftxY2NjrzNvu93eI1xfbYQQQgSW14I+YcIETp48yalTp3C5XGzbto2CgoIebQoKCti6dStaa2pqakhISJCCLoQQQeZ1yMVqtXL//ffz5JNPYhgGCxYsICMjg9LSUgCKiorIy8vD6XSyYsUK4uLiWL58ecCDCyGE6Mmn69Dz8/PJz8/v8bWioqLuz5VSLFu2zL/JhBBC9IvM2iSEEBFCCroQQkQIKehCCBEhpKALIUSEUFqW7RZCiIgQlmfoq1evNjtC0Emfo4P0OToEqs9hWdCFEEL0JgVdCCEiRFgW9Msn+IoW0ufoIH2ODoHqs7wpKoQQESIsz9CFEEL0JgVdCCEihE+Tc5klGhen9tbnt99+m5deegkAm83GsmXLGD9+fPCD+pG3Pl9SV1fHd7/7Xb75zW8yZ86c4Ib0M1/6XFlZyZYtnuUWk5KS+P73vx/8oH7krc/t7e2UlJTQ2NiI2+3mzjvvZMGCBeaE9YMNGzbgdDpJTk5m3bp1vbYHpH7pEOV2u/VDDz2k6+vrdVdXl/7Od76jjx071qPN7t279ZNPPqkNw9AHDx7UjzzyiElp/cOXPldXV+uzZ89qrbV2Op1R0edL7Z544gn91FNP6e3bt5uQ1H986fO5c+f0ypUr9enTp7XWWre0tJgR1W986fOLL76of/GLX2ittT5z5oz+6le/qru6usyI6xeVlZX60KFD+lvf+laf2wNRv0J2yOXyxaljYmK6F6e+3JUWpw5XvvR58uTJDB06FIBJkyb1WE0qHPnSZ4BXXnmF2bNnM2zYMBNS+pcvfX7nnXeYPXt291qbycnJZkT1G1/6rJTi/PnzaK05f/48Q4cOxWIJ2RLlVW5ubvf/1b4Eon6F7Herr8Wpm5qaerXpa3HqcOVLny/3+uuvk5eXF4xoAePr6/z+++/3mIM/nPnS55MnT3Lu3DmeeOIJHn74Yd56661gx/QrX/pcXFzMRx99xDe+8Q2+/e1v87WvfS2sC7o3gahfITuGrv24OHW46E9/9u/fzxtvvMEPfvCDQMcKKF/6vGXLFu69996I+c/tS5/dbjdHjhzhe9/7Hp2dnTz66KNMmjTpqgsEhzJf+rxnzx7GjRvHY489xscff8wPf/hDsrOzSUhICFbMoApE/QrZgh6Ni1P70meAo0eP8vzzz/PII4+QlJQUzIh+50ufDx06xE9+8hMAWltbKS8vx2KxcP311wc1q7/4+rOdlJSEzWbDZrORk5PD0aNHw7ag+9LnN954g8WLF6OUYtSoUaSmpnLixAkmTpwY7LhBEYj6FbKnPNG4OLUvfW5oaOCZZ57hoYceCtv/3Jfzpc/r16/v/pgzZw7Lli0L22IOvv9sV1dX43a7uXDhAnV1daSnp5uUePB86bPD4WDfvn0AtLS0cOLECVJTU82IGxSBqF8hfaeo0+nkZz/7Wffi1HfffXePxam11mzatIk9e/Z0L049YcIEk1MPjrc+P/fcc7z33nvdY29Wq5U1a9aYGXnQvPX5cuvXr2fWrFlhf9miL33+85//zBtvvIHFYuGWW27h9ttvNzPyoHnrc1NTExs2bOh+Y/Cuu+5i/vz5ZkYelGeffZaqqirOnj1LcnIyS5YsweVyAYGrXyFd0IUQQvguZIdchBBC9I8UdCGEiBBS0IUQIkJIQRdCiAghBV0IISKEFHQhhIgQUtCFECJC/H/pLFrB78RiZAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def init_fn(x, mu, sigma) -> np.array:\n", " return np.exp(-(x-mu)**2 / (2*sigma**2)) / (np.sqrt(2*np.pi) * sigma)\n", " \n", "L = 1 # length\n", "T = 4 # time\n", "c = 0.94\n", "dx = 0.01\n", "dt = 0.1\n", "sigma2 = (c*dt/dx) ** 2\n", "\n", "x = np.arange(0, L+dx, dx)\n", "t = np.arange(0, T+dt, dt)\n", "Nx = x.size\n", "Nt = t.size\n", "ic = init_fn(x, 0.5, 0.1)\n", "plt.plot(x, ic)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_model(input_dim, output_dim, width, depth, regularization_param=1e-7,\n", " act_func='tanh', output_act_func=False, random_seed=None)->keras.Sequential:\n", " \"\"\"This will build a fully connected model for you\n", "\n", " Parameters\n", " ----------\n", " input_dim : int\n", " [description]\n", " output_dim : int\n", " [description]\n", " width : int\n", " [description]\n", " depth : int\n", " [description]\n", " regularization_param : float\n", " [description], by default 1e-3\n", " act_func : str, optional\n", " [description], by default 'tanh'\n", " output_act_func : bool, optional\n", " Should the output layer use activation function?, by default False\n", " random_seed : int, optional\n", " [description], by default None\n", "\n", " Returns\n", " -------\n", " keras.Sequential\n", " An un-compiled model\n", " \"\"\"\n", " \n", " initializer = keras.initializers.RandomNormal(seed=random_seed)\n", " regularizer = keras.regularizers.l2(regularization_param)\n", "\n", " if act_func == 'sin':\n", " act_func = tf.math.sin\n", "\n", " all_layers = [keras.layers.InputLayer(input_shape=[input_dim])]\n", " all_layers += [keras.layers.Dense(width, activation=act_func,\n", " kernel_initializer=initializer,\n", " bias_initializer=initializer,\n", " kernel_regularizer=regularizer,\n", " bias_regularizer=regularizer) for _ in range(depth-1)]\n", " all_layers.append(keras.layers.Dense(\n", " output_dim, kernel_initializer=initializer,\n", " activation=act_func if output_act_func else None,\n", " bias_initializer=initializer,\n", " kernel_regularizer=regularizer,\n", " bias_regularizer=regularizer))\n", " m = keras.Sequential(all_layers)\n", " return m" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training progress" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def fit(model, x_input, c, epoch) -> list:\n", " def int_loss(d2u_di2, c):\n", " \"\"\"Calculate the interior loss\n", "\n", " Parameters\n", " ----------\n", " d2u_di2 : [type]\n", " [description]\n", " c : [type]\n", " [description]\n", "\n", " Returns\n", " -------\n", " tf.Tensor\n", " The shape is [101, 101, 1]\n", " \"\"\"\n", " # d2u_dx2, d2u_dt2 = zip(*d2u_di2)\n", " d2u_dx2 = np.zeros([len(d2u_di2), len(d2u_di2)])\n", " d2u_dt2 = np.zeros([len(d2u_di2), len(d2u_di2)])\n", " for i in range(1,len(d2u_di2)):\n", " for j in range(1, len(d2u_di2[0])):\n", " d2u_dx2[i, j] = d2u_di2[i, j, 0]\n", " d2u_dt2[i, j] = d2u_di2[i, j, 1]\n", " # d2u_dx2 is [101, 101, 1]\n", " return tf.reduce_mean(\n", " (np.array(d2u_dx2)-c**2*np.array(d2u_dt2))**2)\n", "\n", " def bc_loss(u):\n", " lambda_b = 10\n", " # for all t, get x=0 -> u[:, 0, :]\n", " # for all t, get x=L -> u[:, -1, :]\n", " return tf.cast(lambda_b*tf.reduce_mean(u[:, 0, :]**2+u[:, -1, :]**2), tf.float64)\n", "\n", " def ic_loss(u, du_di):\n", " # _, du_dt = zip(*du_di)\n", " du_dt = np.zeros([len(du_di), len(du_di)])\n", " for i in range(len(du_di)):\n", " for j in range(len(du_di[0])):\n", " du_dt[i, j] = du_di[i, j, 1]\n", " # for all x, get t=0 -> u[0,:,:]\n", " lambda_c = 10\n", " # shape of ic is [101, 1]\n", " gg1 = tf.reduce_mean((u[:, 0, :]-ic)**2).numpy()\n", " gg2 = tf.reduce_mean(du_dt**2).numpy()\n", " # print(np.shape(gg1), np.shape(gg2))\n", " # print(type(gg1), type(gg2))\n", " # print(gg1, gg2)\n", " gg3 = gg1+gg2\n", " return lambda_c*gg3\n", "\n", " optimizer = keras.optimizers.Adam(learning_rate=1e-3)\n", " interior_losses = []\n", " boundary_losses = []\n", " total_losses = []\n", " for ep in range(1, epoch+1):\n", " # the first tier tape for computing the gradient of the whole network\n", " with tf.GradientTape() as tape_0:\n", " with tf.GradientTape() as tape_1:\n", " tape_1.watch(x_input)\n", " with tf.GradientTape() as tape_2:\n", " tape_2.watch(x_input)\n", " # u will be [101, 101, 1]\n", " u = model(x_input)\n", " # this will return both du_dx and du_dt\n", " # the shape of du_di is [101, 101, 2]\n", " du_di = tape_2.gradient(u, x_input)\n", " d2u_di2 = tape_1.gradient(du_di, x_input)\n", " # del tape_2, tape_1\n", " interior_loss = int_loss(d2u_di2, c)\n", " boundary_loss = bc_loss(u)\n", " initial_loss = ic_loss(u, du_di)\n", " # print(interior_loss, boundary_loss,initial_loss)\n", " total_loss = interior_loss+boundary_loss+initial_loss\n", " # the final gradient of the whole network\n", " grads = tape_0.gradient(total_loss, model.trainable_variables)\n", " # keep records of losses\n", " interior_losses.append(interior_loss)\n", " boundary_losses.append(boundary_loss)\n", " total_losses.append(total_loss)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", " print(f\"{ep}, int {interior_loss.numpy()}, bc {boundary_loss.numpy()}, ic {initial_loss}, tot {total_loss.numpy()}\")\n", " return interior_losses, boundary_losses, total_losses, u\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1, int 1.0925161022858864e-14, bc 0.2745721936225891, ic 30.387790204924773, tot 30.662362398547373\n", "2, int 1.0425323086650898e-15, bc 0.25982049107551575, ic 30.317552108141697, tot 30.577372599217213\n", "3, int 2.998793326065171e-14, bc 0.2455635517835617, ic 30.247950606523858, tot 30.49351415830745\n", "4, int 1.1168545877536837e-13, bc 0.23177467286586761, ic 30.178911795865865, tot 30.410686468731843\n", "5, int 2.6302296552445623e-13, bc 0.21841737627983093, ic 30.1102926341293, tot 30.328710010409395\n", "6, int 5.017590842786328e-13, bc 0.20544452965259552, ic 30.04185709938902, tot 30.247301629042116\n", "7, int 8.447436634261483e-13, bc 0.1928120106458664, ic 29.97336917330525, tot 30.166181183951963\n", "8, int 1.3064454897892904e-12, bc 0.18048225343227386, ic 29.904588074925563, tot 30.085070328359144\n", "9, int 1.8970529968461057e-12, bc 0.1684231162071228, ic 29.835287336345996, tot 30.003710452555016\n", "10, int 2.618932613397383e-12, bc 0.15660615265369415, ic 29.76522857916755, tot 29.921834731823864\n", "11, int 3.462365040670978e-12, bc 0.14500649273395538, ic 29.694149598048078, tot 29.839156090785497\n", "12, int 4.402145353352658e-12, bc 0.13360337913036346, ic 29.621783440945283, tot 29.755386820080048\n", "13, int 5.39637004181858e-12, bc 0.12238159030675888, ic 29.547860801670236, tot 29.67024239198239\n", "14, int 6.390350456603594e-12, bc 0.11133287847042084, ic 29.47208142003877, tot 29.58341429851558\n", "15, int 7.33403879210923e-12, bc 0.10045754909515381, ic 29.39416893173228, tot 29.494626480834768\n", "16, int 8.229124202746513e-12, bc 0.0897660180926323, ic 29.313813665771686, tot 29.403579683872547\n", "17, int 9.232551012125239e-12, bc 0.07928042113780975, ic 29.23073227234994, tot 29.310012693496983\n", "18, int 1.0862906953469061e-11, bc 0.06903643906116486, ic 29.14463914205688, tot 29.21367558112891\n", "19, int 1.4387386326630836e-11, bc 0.05908545106649399, ic 29.05525359514111, tot 29.114339046221993\n", "20, int 2.2515006320403657e-11, bc 0.049496740102767944, ic 28.962335688810725, tot 29.011832428936007\n", "21, int 4.0589291507346476e-11, bc 0.040359675884246826, ic 28.865671963557645, tot 28.90603163948248\n", "22, int 7.856026657249723e-11, bc 0.03178578242659569, ic 28.765123181330004, tot 28.79690896383516\n", "23, int 1.5409693928380488e-10, bc 0.02391001582145691, ic 28.660636294467455, tot 28.684546310443007\n", "24, int 2.971969214027387e-10, bc 0.01689028926193714, ic 28.552315992621157, tot 28.56920628218029\n", "25, int 5.563325195301126e-10, bc 0.010903570801019669, ic 28.440503335222772, tot 28.451406906580125\n", "26, int 1.004982236060879e-09, bc 0.006136139389127493, ic 28.325851856456108, tot 28.331987996850216\n", "27, int 1.744177503530629e-09, bc 0.002763548633083701, ic 28.209539293427493, tot 28.212302843804753\n", "28, int 2.889627049569021e-09, bc 0.000914800213649869, ic 28.09345260605957, tot 28.09436740916285\n", "29, int 4.521587247701469e-09, bc 0.0006149948458187282, ic 27.98045813315668, tot 27.981073132524088\n", "30, int 6.578004647688498e-09, bc 0.0017095066141337156, ic 27.874633024895505, tot 27.876342538087645\n", "31, int 8.737388059222395e-09, bc 0.0038013323210179806, ic 27.78103964345009, tot 27.784840984508495\n", "32, int 1.0461114881738913e-08, bc 0.006275682710111141, ic 27.704866947359477, tot 27.711142640530703\n", "33, int 1.1273917592876457e-08, bc 0.008463799953460693, ic 27.649873083433604, tot 27.658336894660984\n", "34, int 1.1026649608858988e-08, bc 0.009863297455012798, ic 27.61745003705598, tot 27.627313345537644\n", "35, int 9.908197730031042e-09, bc 0.010263372212648392, ic 27.606579552153413, tot 27.61684293427426\n", "36, int 8.281266115079021e-09, bc 0.009727941825985909, ic 27.61462638403028, tot 27.624354334137532\n", "37, int 6.509736953496178e-09, bc 0.008497134782373905, ic 27.638109183308167, tot 27.646606324600278\n", "38, int 4.8597212313068066e-09, bc 0.00687796575948596, ic 27.6733174059901, tot 27.68019537660931\n", "39, int 3.4754633717859063e-09, bc 0.0051612239331007, ic 27.71665060455907, tot 27.721811831967635\n", "40, int 2.4003594780260856e-09, bc 0.0035735382698476315, ic 27.764865780451274, tot 27.76843932112148\n", "41, int 1.6136719020041445e-09, bc 0.0022598428186029196, ic 27.815160896572277, tot 27.81742074100455\n", "42, int 1.0644167574267563e-09, bc 0.0012870440259575844, ic 27.865242056338456, tot 27.86652910142883\n", "43, int 6.948933523154055e-10, bc 0.0006596053717657924, ic 27.9133121360081, tot 27.913971742074757\n", "44, int 4.534484478671892e-10, bc 0.0003393774386495352, ic 27.958021253798133, tot 27.95836063169023\n", "45, int 2.992906022219073e-10, bc 0.0002644923806656152, ic 27.99844336957247, tot 27.998707862252427\n", "46, int 2.026964704610395e-10, bc 0.0003645956458058208, ic 28.0339812375234, tot 28.0343458333719\n", "47, int 1.4317590832938255e-10, bc 0.0005715289153158665, ic 28.064304622516993, tot 28.064876151575483\n", "48, int 1.0714786436183431e-10, bc 0.0008258171728812158, ic 28.089312267818375, tot 28.090138085098403\n", "49, int 8.585763870821587e-11, bc 0.0010798560688272119, ic 28.109034194915637, tot 28.110114051070322\n", "50, int 7.376995782035425e-11, bc 0.001298799877986312, ic 28.1236317155913, tot 28.124930515543056\n", "51, int 6.743017997688762e-11, bc 0.0014599947025999427, ic 28.133342586170155, tot 28.134802580940185\n", "52, int 6.469970553373786e-11, bc 0.001551621942780912, ic 28.138466682784294, tot 28.140018304791774\n", "53, int 6.426044203738229e-11, bc 0.0015709666768088937, ic 28.139344520685242, tot 28.140915487426312\n", "54, int 6.530114374073529e-11, bc 0.0015225517563521862, ic 28.136350079075605, tot 28.13787263089726\n", "55, int 6.732175587429724e-11, bc 0.0014163090381771326, ic 28.129883627534884, tot 28.13129993664038\n", "56, int 7.001273174599276e-11, bc 0.001265829661861062, ic 28.120359786545897, tot 28.121625616277772\n", "57, int 7.31809075778764e-11, bc 0.0010867334203794599, ic 28.108228968898846, tot 28.109315702392408\n", "58, int 7.67043793233296e-11, bc 0.0008951699128374457, ic 28.09393921868596, tot 28.0948343886755\n", "59, int 8.050496826636033e-11, bc 0.0007064783130772412, ic 28.0779719622449, tot 28.07867844063848\n", "60, int 8.453184446681916e-11, bc 0.000534027989488095, ic 28.06079669912556, tot 28.061330727199582\n", "61, int 8.875133709507503e-11, bc 0.0003882816236000508, ic 28.04292106309283, tot 28.04330934480518\n", "62, int 9.314163172665594e-11, bc 0.0002761427022051066, ic 28.024819292876995, tot 28.02509543567234\n", "63, int 9.768968884974923e-11, bc 0.00020063250849489123, ic 28.006982300070234, tot 28.007182932676418\n", "64, int 1.023899165112593e-10, bc 0.00016093318117782474, ic 27.989874757746712, tot 27.99003569103028\n", "65, int 1.0724379286623381e-10, bc 0.0001528141729068011, ic 27.973918419301448, tot 27.974071233581597\n" ] } ], "source": [ "model = get_model(2, 1, 20, 5, act_func='sin')\n", "dx = 0.01\n", "dt = 0.01\n", "time_span = 1\n", "x_end = 1\n", "\n", "x_list = np.arange(0, 1, dx)\n", "t_list = np.arange(0, 1, dt)\n", "x_input = [[np.asarray([x, t]) for x in np.arange(0, x_end+dx, dx)]\n", " for t in np.arange(0, time_span+dt, dt)]\n", "# np.shape(x_input[0])\n", "# ic(list(x_input))\n", "epoch = 9999\n", "\n", "x_input = tf.constant(x_input)\n", "ic = init_fn(x_input[0, :, 0], 0.5, 0.1).reshape([101, 1])\n", "plt.plot(x_input[0,:,0], ic)\n", "# print(ic)\n", "# np.shape(ic)\n", "\n", "losses = fit(model, x_input, c, epoch)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1, int 2.470602543412797e-13,bc 5.600417352091824e-10, ic 27.93001890403208, tot 27.93001890459237\n", "2, int 2.4425813147832335e-13,bc 2.392385738403391e-07, ic 27.93218612892408, tot 27.9321863681629\n", "3, int 2.324050076901599e-13,bc 1.5172992107181926e-07, ic 27.928302290243934, tot 27.928302441974086\n", "4, int 2.3046771450012425e-13,bc 2.0206579165460425e-07, ic 27.92803764561826, tot 27.928037847684283\n", "5, int 2.352963053842054e-13,bc 7.176923499940813e-09, ic 27.9296636603352, tot 27.92966366751236\n", "6, int 2.4025543507088934e-13,bc 8.822872388236647e-08, ic 27.93133735876812, tot 27.931337446997084\n", "7, int 2.403019532868571e-13,bc 1.4195906317127083e-07, ic 27.93168783407957, tot 27.931687976038873\n", "8, int 2.35714537272552e-13,bc 3.558021433036629e-08, ic 27.93085336904624, tot 27.93085340462669\n", "9, int 2.294745692627271e-13,bc 9.842421988537353e-09, ic 27.92959690312545, tot 27.9295969129681\n", "10, int 2.248442241127396e-13,bc 8.243357996207124e-08, ic 27.9287552855344, tot 27.928755367968204\n", "11, int 2.236033020597819e-13,bc 7.675150470731751e-08, ic 27.928805353433805, tot 27.928805430185534\n", "12, int 2.2510484645523988e-13,bc 1.2215240374757741e-08, ic 27.929549219402105, tot 27.92954923161757\n", "13, int 2.273677984759475e-13,bc 1.1239209563029817e-08, ic 27.930483820234567, tot 27.930483831474003\n", "14, int 2.281862736653623e-13,bc 5.556395876737952e-08, ic 27.931065561567703, tot 27.93106561713189\n", "15, int 2.2641405233353569e-13,bc 4.8521204831786235e-08, ic 27.930996420176008, tot 27.93099646869744\n", "16, int 2.2255037975056165e-13,bc 7.699388682169683e-09, ic 27.930400373721262, tot 27.930400381420874\n", "17, int 2.1811716663835533e-13,bc 7.372367605285035e-09, ic 27.92965650774635, tot 27.929656515118936\n", "18, int 2.1479524961267865e-13,bc 3.6240564327272295e-08, ic 27.92919159151023, tot 27.929191627751006\n", "19, int 2.1349828579447626e-13,bc 3.274632476291117e-08, ic 27.9292297384786, tot 27.92922977122514\n", "20, int 2.1392109126316796e-13,bc 5.694131388622736e-09, ic 27.9297041914489, tot 27.929704197143245\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/var/folders/6h/mmvp5df90fb3fsqtc5d9mxt40000gn/T/ipykernel_41481/3826703814.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/var/folders/6h/mmvp5df90fb3fsqtc5d9mxt40000gn/T/ipykernel_41481/2776831954.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(model, x_input, c, epoch)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0md2u_di2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtape_1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgradient\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdu_di\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mtape_2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtape_1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0minterior_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md2u_di2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0mboundary_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbc_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0minitial_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mic_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdu_di\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/var/folders/6h/mmvp5df90fb3fsqtc5d9mxt40000gn/T/ipykernel_41481/2776831954.py\u001b[0m in \u001b[0;36mint_loss\u001b[0;34m(d2u_di2, c)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md2u_di2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md2u_di2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0md2u_dx2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md2u_di2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0md2u_dt2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md2u_di2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# d2u_dx2 is [101, 101, 1]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniforge3/envs/tf/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;34m\"\"\"Call target, and fall back on dispatchers if there is a TypeError.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniforge3/envs/tf/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py\u001b[0m in \u001b[0;36m_slice_helper\u001b[0;34m(tensor, slice_spec, var)\u001b[0m\n\u001b[1;32m 1024\u001b[0m skip_on_eager=False) as name:\n\u001b[1;32m 1025\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbegin\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1026\u001b[0;31m packed_begin, packed_end, packed_strides = (stack(begin), stack(end),\n\u001b[0m\u001b[1;32m 1027\u001b[0m stack(strides))\n\u001b[1;32m 1028\u001b[0m if (packed_begin.dtype == dtypes.int64 or\n", "\u001b[0;32m~/miniforge3/envs/tf/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;34m\"\"\"Call target, and fall back on dispatchers if there is a TypeError.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 206\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 207\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;31m# Note: convert_to_eager_tensor currently raises a ValueError, not a\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniforge3/envs/tf/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py\u001b[0m in \u001b[0;36mstack\u001b[0;34m(values, axis, name)\u001b[0m\n\u001b[1;32m 1410\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1411\u001b[0m \u001b[0;31m# If the input is a constant list, it can be converted to a constant op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1412\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_to_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1413\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mTypeError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1414\u001b[0m \u001b[0;32mpass\u001b[0m \u001b[0;31m# Input list contains non-constant tensors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniforge3/envs/tf/lib/python3.8/site-packages/tensorflow/python/profiler/trace.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mtrace_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "losses = fit(model, x_input, c, epoch)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-0.05, 1.05, 2.490356564521786e-07, 9.623225778341294e-05)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUj0lEQVR4nO3df2xT573H8c+JPciyhMyzIVGArSLQdkiZBDU/FmlkWaxs2k9rkyqutkkUoa1CCC3dUAkjKyvKGq1kYUggmJqbTpMmVdqkdH9sUmSBQtdsq7uQUlq1kA6x0SQLtvOjYeVHfM79495azU1CbB8nwXneLwkpx+c5z/P9OvaHw5OEWI7jOAIALHkFi10AAGBhEPgAYAgCHwAMQeADgCEIfAAwBIEPAIbwLnYBcxkYGMjqukAgoFgsluNq7m/0vPSZ1q9Ez5mqqKiY9Rx3+ABgCAIfAAxB4AOAIQh8ADAEgQ8AhiDwAcAQBD4AGILABwBDEPgAYAgCHwAMQeADgCEIfAAwBIEPAIbIyf+W2dfXp46ODtm2rbq6OoXD4Snno9GoXnjhBVmWJY/Ho127dunhhx/OxdIAgDS5DnzbttXe3q7Dhw/L7/ersbFRwWBQa9asSY2pqqpSMBiUZVm6du2a2tradPz4cbdLAwAy4HpLp7+/X+Xl5SorK5PX61V1dbWi0eiUMYWFhbIsS5J0+/bt1McAgIXj+g4/kUjI7/enjv1+v65cuTJt3CuvvKLf/va3GhsbU2Nj46zzRSIRRSIRSVJLS4sCgUBWdXm93qyvzVf0vPSZ1q9Ezzmd1+0EjuNMe2ymO/itW7dq69atevPNN/XCCy+oqalpxvlCoZBCoVDqONvf+sJvyTGDaT2b1q9Ez5ma19945ff7FY/HU8fxeFw+n2/W8Rs3btTQ0JDGx8fdLg0AyIDrwK+srNTg4KCGh4c1OTmpnp4eBYPBKWOGhoZS/xL4xz/+ocnJSZWUlLhdGgCQAddbOh6PR7t371Zzc7Ns21Ztba3Wrl2rrq4uSVJ9fb3++te/6vz58/J4PFq2bJkaGhr4wi0ALDDLmWkT/j4yMDCQ1XXs+5nBtJ5N61ei50zN6x4+ACA/EPgAYAgCHwAMQeADgCEIfAAwBIEPAIYg8AHAEAQ+ABiCwAcAQxD4AGAIAh8ADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYgsAHAEMQ+ABgCAIfAAxB4AOAIQh8ADCENxeT9PX1qaOjQ7Ztq66uTuFweMr5l156SS+++KIkqbCwUHv27NEDDzyQi6UBAGlyfYdv27ba29t16NAhtbW16eWXX9b169enjFm1apWOHDmiY8eO6Vvf+pZ+9atfuV0WAJAh14Hf39+v8vJylZWVyev1qrq6WtFodMqYhx56SMXFxZKkDRs2KB6Pu10WAJAh11s6iURCfr8/dez3+3XlypVZx589e1abNm2a9XwkElEkEpEktbS0KBAIZFWX1+vN+tp8Rc9Ln2n9SvSc03ndTuA4zrTHLMuaceylS5d07tw5Pf3007POFwqFFAqFUsexWCyrugKBQNbX5it6XvpM61ei50xVVFTMes71lo7f75+yRROPx+Xz+aaNu3btms6cOaMDBw6opKTE7bIAgAy5DvzKykoNDg5qeHhYk5OT6unpUTAYnDImFovp2LFj2rdv3z3/9gEAzB/XWzoej0e7d+9Wc3OzbNtWbW2t1q5dq66uLklSfX29fve732liYkLPPfdc6pqWlha3SwMAMmA5M23C30cGBgayuo59PzOY1rNp/Ur0nKl53cMHAOQHAh8ADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYgsAHAEMQ+ABgCAIfAAxB4AOAIQh8ADAEgQ8AhiDwAcAQBD4AGILABwBDEPgAYAgCHwAMQeADgCEIfAAwBIEPAIbw5mKSvr4+dXR0yLZt1dXVKRwOTzn/7rvv6tSpU7p69ap27typr3/967lYFgCQAdeBb9u22tvbdfjwYfn9fjU2NioYDGrNmjWpMcXFxXrssccUjUbdLgcAyJLrLZ3+/n6Vl5errKxMXq9X1dXV04K9tLRU69evl8fjcbscACBLrgM/kUjI7/enjv1+vxKJhNtpAQA55npLx3GcaY9ZlpX1fJFIRJFIRJLU0tKiQCCQ1Txerzfra/MVPS99pvUr0XNO53U7gd/vVzweTx3H43H5fL6s5wuFQgqFQqnjWCyW1TyBQCDra/MVPS99pvUr0XOmKioqZj3nekunsrJSg4ODGh4e1uTkpHp6ehQMBt1OCwDIMdd3+B6PR7t371Zzc7Ns21Ztba3Wrl2rrq4uSVJ9fb1GR0d18OBBvf/++7IsS3/84x/1i1/8QkVFRa4bAACkJyffh79582Zt3rx5ymP19fWpjz/+8Y/r9OnTuVgKAJAlftIWAAxB4AOAIQh8ADAEgQ8AhiDwAcAQBD4AGILABwBDEPgAYAgCHwAMQeADgCEIfAAwBIEPAIYg8AHAEAQ+ABiCwAcAQxD4AGAIAh8ADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYwpuLSfr6+tTR0SHbtlVXV6dwODzlvOM46ujo0IULF7R8+XLt3btX69aty8XSAIA0ub7Dt21b7e3tOnTokNra2vTyyy/r+vXrU8ZcuHBBQ0NDOnHihL73ve/pueeec7ssACBDru/w+/v7VV5errKyMklSdXW1otGo1qxZkxrz6quvaseOHbIsSw8++KBu3rypkZER+Xw+t8vPyP7vNo16CmTfvpPdBI6T24IWguNodPly2bdvL3YlGXHk7rkeXbZcyTsL0PN98pIYXb5MyYxe1/dJ4Zn60HtwwT7H2Zinp3fsE5+Q/uvxnM/rOvATiYT8fn/q2O/368qVK9PGBAKBKWMSicSMgR+JRBSJRCRJLS0tU65LV3zoupJ37qjA1WfDcnHt4kgWFKggH/+ysrJ/rpOSPLmr5L6XtCx5MvwcWy6e30X1f3UnLUve+/l1PQ/Pr23fzSr75uI68J0ZPhH//wWWzpgPhEIhhUKh1HEsFsu8qEOtCgQC2V2bx/z0vORl0+99HJVpMfG97HPRc0VFxaznXO/h+/1+xePx1HE8Hp925+73+6cUP9MYAMD8ch34lZWVGhwc1PDwsCYnJ9XT06NgMDhlTDAY1Pnz5+U4ji5fvqyioiICHwAWmOstHY/Ho927d6u5uVm2bau2tlZr165VV1eXJKm+vl6bNm1Sb2+v9u/fr2XLlmnv3r2uCwcAZCYn34e/efNmbd68ecpj9fX1qY8ty9KePXtysRQAIEv8pC0AGILABwBDEPgAYAgCHwAMQeADgCEIfAAwBIEPAIYg8AHAEAQ+ABiCwAcAQxD4AGAIAh8ADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYgsAHAEMQ+ABgCAIfAAxB4AOAIbxuLp6YmFBbW5tu3LihlStXqqGhQcXFxdPGnTp1Sr29vSotLVVra6ubJQEAWXJ1h9/Z2amqqiqdOHFCVVVV6uzsnHHc5z//eR06dMjNUgAAl1wFfjQaVU1NjSSppqZG0Wh0xnEbN26c8c4fALBwXAX+2NiYfD6fJMnn82l8fDwnRQEAcm/OPfyjR49qdHR02uM7d+6cj3oUiUQUiUQkSS0tLQoEAlnN4/V6s742X9Hz0mdavxI953TeuQY0NTXNeq60tFQjIyPy+XwaGRnRihUrXBcUCoUUCoVSx7FYLKt5AoFA1tfmK3pe+kzrV6LnTFVUVMx6ztWWTjAYVHd3tySpu7tbW7ZscTMdAGAeuQr8cDisixcvav/+/bp48aLC4bAkKZFI6JlnnkmNO378uA4fPqyBgQE9/vjjOnv2rKuiAQCZsxzHcRa7iHsZGBjI6jr+GWgG03o2rV+JnjM1b1s6AID8QeADgCEIfAAwBIEPAIYg8AHAEAQ+ABiCwAcAQxD4AGAIAh8ADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYgsAHAEMQ+ABgCAIfAAxB4AOAIQh8ADAEgQ8AhiDwAcAQBD4AGMLr5uKJiQm1tbXpxo0bWrlypRoaGlRcXDxlTCwW08mTJzU6OirLshQKhfTlL3/ZVdEAgMy5CvzOzk5VVVUpHA6rs7NTnZ2d+s53vjNljMfj0Xe/+12tW7dO77//vg4ePKjPfOYzWrNmjavCAQCZcbWlE41GVVNTI0mqqalRNBqdNsbn82ndunWSpI9+9KNavXq1EomEm2UBAFlwdYc/NjYmn88n6X+DfXx8/J7jh4eHdfXqVa1fv37WMZFIRJFIRJLU0tKiQCCQVW1erzfra/MVPS99pvUr0XNO551rwNGjRzU6Ojrt8Z07d2a00K1bt9Ta2qpdu3apqKho1nGhUEihUCh1HIvFMlrnA4FAIOtr8xU9L32m9SvRc6YqKipmPTdn4Dc1Nc16rrS0VCMjI/L5fBoZGdGKFStmHDc5OanW1lZ97nOf07Zt29IoGQCQa6728IPBoLq7uyVJ3d3d2rJly7QxjuPo9OnTWr16tb761a+6WQ4A4IKrwA+Hw7p48aL279+vixcvKhwOS5ISiYSeeeYZSdLbb7+t8+fP69KlSzpw4IAOHDig3t5e14UDADJjOY7jLHYR9zIwMJDVdez7mcG0nk3rV6LnTN1rD5+ftAUAQxD4AGAIAh8ADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYgsAHAEMQ+ABgCAIfAAxB4AOAIQh8ADAEgQ8AhiDwAcAQBD4AGILABwBDEPgAYAgCHwAMQeADgCEIfAAwhNfNxRMTE2pra9ONGze0cuVKNTQ0qLi4eMqYO3fu6KmnntLk5KSSyaS2b9+uRx991FXRAIDMubrD7+zsVFVVlU6cOKGqqip1dnZOG/ORj3xETz31lJ599ln9/Oc/V19fny5fvuxmWQBAFlwFfjQaVU1NjSSppqZG0Wh02hjLslRYWChJSiaTSiaTsizLzbIAgCy42tIZGxuTz+eTJPl8Po2Pj884zrZtPfnkkxoaGtIXv/hFbdiwYdY5I5GIIpGIJKmlpUWBQCCr2rxeb9bX5it6XvpM61ei55zOO9eAo0ePanR0dNrjO3fuTHuRgoICPfvss7p586aOHTumf/7zn/rkJz8549hQKKRQKJQ6jsViaa/zYYFAIOtr8xU9L32m9SvRc6YqKipmPTdn4Dc1Nc16rrS0VCMjI/L5fBoZGdGKFSvuOdfHPvYxbdy4UX19fbMGPgBgfrjaww8Gg+ru7pYkdXd3a8uWLdPGjI+P6+bNm5L+9zt2Xn/9da1evdrNsgCALLjaww+Hw2pra9PZs2cVCAT0xBNPSJISiYTOnDmjxsZGjYyM6OTJk7JtW47j6LOf/aweeeSRnBQPAEif5TiOs9hF3MvAwEBW17HvZwbTejatX4meM3WvPXx+0hYADEHgA4AhCHwAMASBDwCGIPABwBAEPgAYgsAHAEMQ+ABgCAIfAAxB4AOAIQh8ADAEgQ8Ahrjv//M0AEBuLNk7/IMHDy52CQuOnpc+0/qV6DmXlmzgAwCmIvABwBBLNvA//IvQTUHPS59p/Ur0nEt80RYADLFk7/ABAFMR+ABgCO9iF+BGX1+fOjo6ZNu26urqFA6Hp5x3HEcdHR26cOGCli9frr1792rdunWLU2yOzNXzSy+9pBdffFGSVFhYqD179uiBBx5Y+EJzaK6eP9Df368f//jHamho0Pbt2xe2yBxLp+c33nhDzz//vJLJpEpKSvTTn/504QvNobl6/s9//qMTJ04oHo8rmUzqa1/7mmpraxen2Bw4deqUent7VVpaqtbW1mnn5yW/nDyVTCadffv2OUNDQ87du3edH/3oR86//vWvKWP+/ve/O83NzY5t287bb7/tNDY2LlK1uZFOz2+99Zbz3nvvOY7jOL29vUb0/MG4I0eOOD/72c+cv/zlL4tQae6k0/PExITzgx/8wLlx44bjOI4zOjq6GKXmTDo9//73v3d+85vfOI7jOGNjY86uXbucu3fvLka5OfHGG28477zzjvPEE0/MeH4+8itvt3T6+/tVXl6usrIyeb1eVVdXKxqNThnz6quvaseOHbIsSw8++KBu3rypkZGRRarYvXR6fuihh1RcXCxJ2rBhg+Lx+GKUmjPp9CxJf/rTn7Rt2zatWLFiEarMrXR6/vOf/6xt27YpEAhIkkpLSxej1JxJp2fLsnTr1i05jqNbt26puLhYBQV5G2HauHFj6r06k/nIr7x9thKJhPx+f+rY7/crkUhMG/PBG2K2MfkknZ4/7OzZs9q0adNClDZv0v08v/LKK6qvr1/o8uZFOj0PDg5qYmJCR44c0ZNPPqnu7u6FLjOn0un5S1/6kt599119//vf1w9/+EM99thjeR34c5mP/MrbPXxnhu8mtSwr4zH5JJN+Ll26pHPnzunpp5+e77LmVTo9P//88/r2t7+9ZN786fScTCZ19epVNTU16c6dOzp8+LA2bNigioqKhSozp9Lp+bXXXtOnPvUp/eQnP9G///1vHT16VA8//LCKiooWqswFNR/5lbeB7/f7p2xXxONx+Xy+aWNisdg9x+STdHqWpGvXrunMmTNqbGxUSUnJQpaYc+n0/M477+iXv/ylJGl8fFwXLlxQQUGBtm7duqC15kq6r+2SkhIVFhaqsLBQn/70p3Xt2rW8Dfx0ej537pzC4bAsy1J5eblWrVqlgYEBrV+/fqHLXRDzkV95e0tUWVmpwcFBDQ8Pa3JyUj09PQoGg1PGBINBnT9/Xo7j6PLlyyoqKsrrwE+n51gspmPHjmnfvn15++b/sHR6PnnyZOrP9u3btWfPnrwNeyn91/Zbb72lZDKp27dvq7+/X6tXr16kit1Lp+dAIKDXX39dkjQ6OqqBgQGtWrVqMcpdEPORX3n9k7a9vb369a9/Ldu2VVtbq29+85vq6uqSJNXX18txHLW3t+u1117TsmXLtHfvXlVWVi5y1e7M1fPp06f1t7/9LbX35/F41NLSspgluzZXzx928uRJPfLII3n/bZnp9PyHP/xB586dU0FBgb7whS/oK1/5ymKW7NpcPScSCZ06dSr1hctvfOMb2rFjx2KW7Mrx48f15ptv6r333lNpaakeffRRTU5OSpq//MrrwAcApC9vt3QAAJkh8AHAEAQ+ABiCwAcAQxD4AGAIAh8ADEHgA4Ah/gfE5MhU2hdYUAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_predict = model(x_input[0,:,:])\n", "plt.plot(x_input[0,:,0], test_predict)\n", "plt.axis('equal')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_mat = np.asarray(list([[i, 3] for _ in range(10)] for i in range(10)))\n", "first = np.zeros([10,10])\n", "second = np.zeros([10,10])\n", "for i in range(len(test_mat)):\n", " for j in range(len(test_mat[0])):\n", " first[i, j] = test_mat[i, j, 0]\n", " second[i, j] = test_mat[i, j, 1]\n", "# first = np.asarray(first).reshape(10, 10)\n", "# second = np.asarray(second).reshape(10, 10)\n", "second" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_old(model, x_input, c, epoch, verbose=True):\n", "\n", " optimizer = keras.optimizers.Adam(learning_rate=1e-3)\n", " interior_losses = []\n", " boundary_losses = []\n", " total_losses = []\n", " for _ in range(1, epoch+1):\n", " # the first tier tape for computing the gradient of the whole network\n", " with tf.GradientTape() as tape_0:\n", " with tf.GradientTape(persistent=True) as tape_1:\n", " tape_1.watch(x_input)\n", " with tf.GradientTape(persistent=True) as tape_2:\n", " tape_2.watch(x_input)\n", " u = model(x_input)\n", " du_di = tape_2.gradient(u, x_input)\n", " d2u_di2 = tape_1.gradient(du_di, x_input)\n", " del tape_2\n", " del tape_1\n", " d2u_dx2, d2u_dt2 = zip(*d2u_di2)\n", " interior_loss = tf.reduce_mean(\n", " (np.array(d2u_dx2)-c**2*np.array(d2u_dt2))**2)\n", " boundary_loss = 10*(u[0]**2+(u[-1]-1)**2)\n", " loss = interior_loss + boundary_loss + \\\n", " tf.math.reduce_sum(model.losses)\n", "\n", " interior_losses.append(interior_loss)\n", " boundary_losses.append(boundary_loss)\n", " total_losses.append(loss)\n", " grads = tape_0.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", " return interior_losses, boundary_losses, total_losses, u" ] } ], "metadata": { "interpreter": { "hash": "83eb3a94852d7da6265f5c02f9c348bfbb051e60fbf28829dfd0a8aff7ca6bfd" }, "kernelspec": { "display_name": "Python 3.8.12 64-bit ('tf': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }