|
|
@@ -0,0 +1,290 @@
|
|
|
+{
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 0,
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.7.7"
|
|
|
+ },
|
|
|
+ "colab": {
|
|
|
+ "name": "Custom_training_demo.ipynb",
|
|
|
+ "provenance": [],
|
|
|
+ "collapsed_sections": []
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true,
|
|
|
+ "id": "_G5CWvypQMXr"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "import numpy as np\n",
|
|
|
+ "import tensorflow as tf\n",
|
|
|
+ "from tensorflow import keras\n",
|
|
|
+ "%matplotlib inline \n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "tf.random.set_seed(1)\n",
|
|
|
+ "np.random.seed(1)"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "eikJH7uSQMXu"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "## Create data"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true,
|
|
|
+ "id": "BeW6_bSVQMXx"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "def my_func(x):\n",
|
|
|
+ " val = tf.exp(-200*(x-0.25)**2) - tf.exp(-200*(x-0.75)**2)\n",
|
|
|
+ " return val\n",
|
|
|
+ "\n",
|
|
|
+ "N = 100\n",
|
|
|
+ "x_train = tf.reshape(tf.linspace(0,1,N),(-1,1)) # Want each row to be a sample (since input_dim=1)\n",
|
|
|
+ "y_train = my_func(x_train)\n",
|
|
|
+ "\n",
|
|
|
+ "print(x_train.shape,x_train.dtype)\n",
|
|
|
+ "print(y_train.shape,y_train.dtype)\n",
|
|
|
+ "\n",
|
|
|
+ "plt.figure()\n",
|
|
|
+ "plt.plot(x_train,y_train,'-o')\n",
|
|
|
+ "plt.show()"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "Zt8VCsvnQMXz"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "## Define basic network architecture"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true,
|
|
|
+ "id": "fe7ZzCh4QMXz"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "def MLP(Input_Dim=1,Output_Dim=1,Width=15,Depth=8):\n",
|
|
|
+ "\n",
|
|
|
+ " Reg_Func = keras.regularizers.l2\n",
|
|
|
+ " Reg_Param = 1e-5\n",
|
|
|
+ " Act_Func = tf.math.tanh\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " assert Depth > 1, 'Depth of generator must be greater than 1'\n",
|
|
|
+ " \n",
|
|
|
+ " model = tf.keras.Sequential()\n",
|
|
|
+ " \n",
|
|
|
+ " model.add(keras.layers.Dense(Width, input_shape=(Input_Dim,), activation=Act_Func,\n",
|
|
|
+ " kernel_initializer='RandomNormal', bias_initializer='RandomNormal',\n",
|
|
|
+ " kernel_regularizer=Reg_Func(Reg_Param)))\n",
|
|
|
+ "\n",
|
|
|
+ " # Adding remaining hidden layers\n",
|
|
|
+ " if(Depth > 2):\n",
|
|
|
+ " for l in range(Depth - 2):\n",
|
|
|
+ " #model.add(keras.layers.BatchNormalization())\n",
|
|
|
+ " model.add(keras.layers.Dense(Width, activation=Act_Func,\n",
|
|
|
+ " kernel_initializer='RandomNormal', bias_initializer='RandomNormal',\n",
|
|
|
+ " kernel_regularizer=Reg_Func(Reg_Param)))\n",
|
|
|
+ " \n",
|
|
|
+ " # Adding output layer\n",
|
|
|
+ " model.add(keras.layers.Dense(Output_Dim, activation=None,\n",
|
|
|
+ " kernel_initializer='RandomNormal', bias_initializer='RandomNormal',\n",
|
|
|
+ " kernel_regularizer=Reg_Func(Reg_Param)))\n",
|
|
|
+ "\n",
|
|
|
+ " return model"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "_LERV1EKQMX0"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "## Define loss function"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true,
|
|
|
+ "id": "01Tl5E_uQMX0"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "mse = keras.losses.MeanSquaredError() # Inbuilt loss\n",
|
|
|
+ "\n",
|
|
|
+ "# Define custom loss function\n",
|
|
|
+ "def my_mse(real,gen):\n",
|
|
|
+ " real = tf.cast(real, gen.dtype) # make sure both tensors have the same type\n",
|
|
|
+ " loss = tf.reduce_mean(tf.math.squared_difference(real, gen))\n",
|
|
|
+ "\n",
|
|
|
+ " return loss\n"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "f_Gc_-7Xa7S0"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "y1 = tf.constant([1.0,2.0,3.0])\n",
|
|
|
+ "y2 = tf.constant([3.0,-4.0,-1.0])\n",
|
|
|
+ "\n",
|
|
|
+ "print(mse(y1,y2))\n",
|
|
|
+ "print(my_mse(y1,y2))"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "3H2hrl6hQMX1"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "## Have a look at details here\n",
|
|
|
+ "### https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "1cualRc2RNJ6"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "tf.keras.backend.clear_session()\n",
|
|
|
+ "model = MLP(1,1,15,8)\n",
|
|
|
+ "optimizer = keras.optimizers.Adam(learning_rate=1e-3) \n",
|
|
|
+ "max_epoch = 3000\n",
|
|
|
+ "\n",
|
|
|
+ "for epoch in range(1,max_epoch+1):\n",
|
|
|
+ " \n",
|
|
|
+ " with tf.GradientTape() as tape:\n",
|
|
|
+ " \n",
|
|
|
+ " gen_out = model(x_train)\n",
|
|
|
+ " loss_val = mse(y_train, gen_out) # Using in-built mse\n",
|
|
|
+ "\n",
|
|
|
+ " # Adding regularization losses\n",
|
|
|
+ " loss_val += sum(model.losses)\n",
|
|
|
+ " \n",
|
|
|
+ " \n",
|
|
|
+ " grads = tape.gradient(loss_val, model.trainable_variables)\n",
|
|
|
+ " \n",
|
|
|
+ " optimizer.apply_gradients(zip(grads, model.trainable_variables)) # zip used to create an iterator over the tuples\n",
|
|
|
+ "\n",
|
|
|
+ " if epoch % 100 == 0 or epoch==max_epoch:\n",
|
|
|
+ " print(f\"Epoch: {epoch}, loss: {loss_val:.2e}\") "
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true,
|
|
|
+ "id": "y5tyN2oPQMX3"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "plt.figure(figsize=(10,10))\n",
|
|
|
+ "plt.plot(x_train,model.predict(x_train),'-',label='NN',linewidth=4)\n",
|
|
|
+ "plt.plot(x_train,y_train,'--',label='True',linewidth=4)\n",
|
|
|
+ "plt.legend(fontsize=15)\n",
|
|
|
+ "plt.show()"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true,
|
|
|
+ "id": "GRpQP8GuQMX4"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "tf.keras.backend.clear_session()\n",
|
|
|
+ "model = MLP(1,1,15,8)\n",
|
|
|
+ "optimizer = keras.optimizers.Adam(learning_rate=1e-3) \n",
|
|
|
+ "max_epoch = 3000\n",
|
|
|
+ "\n",
|
|
|
+ "for epoch in range(1,max_epoch+1):\n",
|
|
|
+ " \n",
|
|
|
+ " with tf.GradientTape() as tape:\n",
|
|
|
+ " \n",
|
|
|
+ " gen_out = model(x_train)\n",
|
|
|
+ " loss_val = my_mse(y_train, gen_out) # Using custom mse\n",
|
|
|
+ "\n",
|
|
|
+ " # Adding regularization losses\n",
|
|
|
+ " loss_val += sum(model.losses)\n",
|
|
|
+ " \n",
|
|
|
+ " grads = tape.gradient(loss_val, model.trainable_variables)\n",
|
|
|
+ " \n",
|
|
|
+ " optimizer.apply_gradients(zip(grads, model.trainable_variables)) # zip used to create an iterator over the tuples\n",
|
|
|
+ "\n",
|
|
|
+ " if epoch % 100 == 0 or epoch==max_epoch:\n",
|
|
|
+ " print(f\"Epoch: {epoch}, loss: {loss_val:.2e}\") "
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "B6S8GHU5fs1B"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "plt.figure(figsize=(10,10))\n",
|
|
|
+ "plt.plot(x_train,model.predict(x_train),'-',label='NN',linewidth=4)\n",
|
|
|
+ "plt.plot(x_train,y_train,'--',label='True',linewidth=4)\n",
|
|
|
+ "plt.legend(fontsize=15)\n",
|
|
|
+ "plt.show()"
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "rI_V9lF3f0Rv"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ ""
|
|
|
+ ],
|
|
|
+ "execution_count": null,
|
|
|
+ "outputs": []
|
|
|
+ }
|
|
|
+ ]
|
|
|
+}
|