gradient_clip.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import datasets, layers, optimizers
  4. import os
  5. os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
  6. print(tf.__version__)
  7. (x, y), _ = datasets.mnist.load_data()
  8. x = tf.convert_to_tensor(x, dtype=tf.float32) / 50.
  9. y = tf.convert_to_tensor(y)
  10. y = tf.one_hot(y, depth=10)
  11. print('x:', x.shape, 'y:', y.shape)
  12. train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128).repeat(30)
  13. x,y = next(iter(train_db))
  14. print('sample:', x.shape, y.shape)
  15. # print(x[0], y[0])
  16. def main():
  17. # 784 => 512
  18. w1, b1 = tf.Variable(tf.random.truncated_normal([784, 512], stddev=0.1)), tf.Variable(tf.zeros([512]))
  19. # 512 => 256
  20. w2, b2 = tf.Variable(tf.random.truncated_normal([512, 256], stddev=0.1)), tf.Variable(tf.zeros([256]))
  21. # 256 => 10
  22. w3, b3 = tf.Variable(tf.random.truncated_normal([256, 10], stddev=0.1)), tf.Variable(tf.zeros([10]))
  23. optimizer = optimizers.SGD(lr=0.01)
  24. for step, (x,y) in enumerate(train_db):
  25. # [b, 28, 28] => [b, 784]
  26. x = tf.reshape(x, (-1, 784))
  27. with tf.GradientTape() as tape:
  28. # layer1.
  29. h1 = x @ w1 + b1
  30. h1 = tf.nn.relu(h1)
  31. # layer2
  32. h2 = h1 @ w2 + b2
  33. h2 = tf.nn.relu(h2)
  34. # output
  35. out = h2 @ w3 + b3
  36. # out = tf.nn.relu(out)
  37. # compute loss
  38. # [b, 10] - [b, 10]
  39. loss = tf.square(y-out)
  40. # [b, 10] => [b]
  41. loss = tf.reduce_mean(loss, axis=1)
  42. # [b] => scalar
  43. loss = tf.reduce_mean(loss)
  44. # compute gradient
  45. grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
  46. # print('==before==')
  47. # for g in grads:
  48. # print(tf.norm(g))
  49. grads, _ = tf.clip_by_global_norm(grads, 15)
  50. # print('==after==')
  51. # for g in grads:
  52. # print(tf.norm(g))
  53. # update w' = w - lr*grad
  54. optimizer.apply_gradients(zip(grads, [w1, b1, w2, b2, w3, b3]))
  55. if step % 100 == 0:
  56. print(step, 'loss:', float(loss))
  57. if __name__ == '__main__':
  58. main()