{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "colab": { "name": "Custom_training_demo.ipynb", "provenance": [], "collapsed_sections": [] } }, "cells": [ { "cell_type": "code", "metadata": { "scrolled": true, "id": "_G5CWvypQMXr" }, "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "%matplotlib inline \n", "import matplotlib.pyplot as plt\n", "tf.random.set_seed(1)\n", "np.random.seed(1)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "eikJH7uSQMXu" }, "source": [ "## Create data" ] }, { "cell_type": "code", "metadata": { "scrolled": true, "id": "BeW6_bSVQMXx" }, "source": [ "def my_func(x):\n", " val = tf.exp(-200*(x-0.25)**2) - tf.exp(-200*(x-0.75)**2)\n", " return val\n", "\n", "N = 100\n", "x_train = tf.reshape(tf.linspace(0,1,N),(-1,1)) # Want each row to be a sample (since input_dim=1)\n", "y_train = my_func(x_train)\n", "\n", "print(x_train.shape,x_train.dtype)\n", "print(y_train.shape,y_train.dtype)\n", "\n", "plt.figure()\n", "plt.plot(x_train,y_train,'-o')\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Zt8VCsvnQMXz" }, "source": [ "## Define basic network architecture" ] }, { "cell_type": "code", "metadata": { "scrolled": true, "id": "fe7ZzCh4QMXz" }, "source": [ "def MLP(Input_Dim=1,Output_Dim=1,Width=15,Depth=8):\n", "\n", " Reg_Func = keras.regularizers.l2\n", " Reg_Param = 1e-5\n", " Act_Func = tf.math.tanh\n", "\n", "\n", " assert Depth > 1, 'Depth of generator must be greater than 1'\n", " \n", " model = tf.keras.Sequential()\n", " \n", " model.add(keras.layers.Dense(Width, input_shape=(Input_Dim,), activation=Act_Func,\n", " kernel_initializer='RandomNormal', bias_initializer='RandomNormal',\n", " kernel_regularizer=Reg_Func(Reg_Param)))\n", "\n", " # Adding remaining hidden layers\n", " if(Depth > 2):\n", " for l in range(Depth - 2):\n", " #model.add(keras.layers.BatchNormalization())\n", " model.add(keras.layers.Dense(Width, activation=Act_Func,\n", " kernel_initializer='RandomNormal', bias_initializer='RandomNormal',\n", " kernel_regularizer=Reg_Func(Reg_Param)))\n", " \n", " # Adding output layer\n", " model.add(keras.layers.Dense(Output_Dim, activation=None,\n", " kernel_initializer='RandomNormal', bias_initializer='RandomNormal',\n", " kernel_regularizer=Reg_Func(Reg_Param)))\n", "\n", " return model" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "_LERV1EKQMX0" }, "source": [ "## Define loss function" ] }, { "cell_type": "code", "metadata": { "scrolled": true, "id": "01Tl5E_uQMX0" }, "source": [ "mse = keras.losses.MeanSquaredError() # Inbuilt loss\n", "\n", "# Define custom loss function\n", "def my_mse(real,gen):\n", " real = tf.cast(real, gen.dtype) # make sure both tensors have the same type\n", " loss = tf.reduce_mean(tf.math.squared_difference(real, gen))\n", "\n", " return loss\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "f_Gc_-7Xa7S0" }, "source": [ "y1 = tf.constant([1.0,2.0,3.0])\n", "y2 = tf.constant([3.0,-4.0,-1.0])\n", "\n", "print(mse(y1,y2))\n", "print(my_mse(y1,y2))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "3H2hrl6hQMX1" }, "source": [ "## Have a look at details here\n", "### https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch" ] }, { "cell_type": "code", "metadata": { "id": "1cualRc2RNJ6" }, "source": [ "tf.keras.backend.clear_session()\n", "model = MLP(1,1,15,8)\n", "optimizer = keras.optimizers.Adam(learning_rate=1e-3) \n", "max_epoch = 3000\n", "\n", "for epoch in range(1,max_epoch+1):\n", " \n", " with tf.GradientTape() as tape:\n", " \n", " gen_out = model(x_train)\n", " loss_val = mse(y_train, gen_out) # Using in-built mse\n", "\n", " # Adding regularization losses\n", " loss_val += sum(model.losses)\n", " \n", " \n", " grads = tape.gradient(loss_val, model.trainable_variables)\n", " \n", " optimizer.apply_gradients(zip(grads, model.trainable_variables)) # zip used to create an iterator over the tuples\n", "\n", " if epoch % 100 == 0 or epoch==max_epoch:\n", " print(f\"Epoch: {epoch}, loss: {loss_val:.2e}\") " ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "scrolled": true, "id": "y5tyN2oPQMX3" }, "source": [ "plt.figure(figsize=(10,10))\n", "plt.plot(x_train,model.predict(x_train),'-',label='NN',linewidth=4)\n", "plt.plot(x_train,y_train,'--',label='True',linewidth=4)\n", "plt.legend(fontsize=15)\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "scrolled": true, "id": "GRpQP8GuQMX4" }, "source": [ "tf.keras.backend.clear_session()\n", "model = MLP(1,1,15,8)\n", "optimizer = keras.optimizers.Adam(learning_rate=1e-3) \n", "max_epoch = 3000\n", "\n", "for epoch in range(1,max_epoch+1):\n", " \n", " with tf.GradientTape() as tape:\n", " \n", " gen_out = model(x_train)\n", " loss_val = my_mse(y_train, gen_out) # Using custom mse\n", "\n", " # Adding regularization losses\n", " loss_val += sum(model.losses)\n", " \n", " grads = tape.gradient(loss_val, model.trainable_variables)\n", " \n", " optimizer.apply_gradients(zip(grads, model.trainable_variables)) # zip used to create an iterator over the tuples\n", "\n", " if epoch % 100 == 0 or epoch==max_epoch:\n", " print(f\"Epoch: {epoch}, loss: {loss_val:.2e}\") " ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "B6S8GHU5fs1B" }, "source": [ "plt.figure(figsize=(10,10))\n", "plt.plot(x_train,model.predict(x_train),'-',label='NN',linewidth=4)\n", "plt.plot(x_train,y_train,'--',label='True',linewidth=4)\n", "plt.legend(fontsize=15)\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "rI_V9lF3f0Rv" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }