a3c_tf_cartpole.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import matplotlib
  2. from matplotlib import pyplot as plt
  3. matplotlib.rcParams['font.size'] = 18
  4. matplotlib.rcParams['figure.titlesize'] = 18
  5. matplotlib.rcParams['figure.figsize'] = [9, 7]
  6. matplotlib.rcParams['font.family'] = ['KaiTi']
  7. matplotlib.rcParams['axes.unicode_minus']=False
  8. plt.figure()
  9. import os
  10. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  11. import threading
  12. import gym
  13. import multiprocessing
  14. import numpy as np
  15. from queue import Queue
  16. import matplotlib.pyplot as plt
  17. import tensorflow as tf
  18. from tensorflow import keras
  19. from tensorflow.keras import layers,optimizers,losses
  20. tf.random.set_seed(1231)
  21. np.random.seed(1231)
  22. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  23. assert tf.__version__.startswith('2.')
  24. class ActorCritic(keras.Model):
  25. # Actor-Critic模型
  26. def __init__(self, state_size, action_size):
  27. super(ActorCritic, self).__init__()
  28. self.state_size = state_size # 状态向量长度
  29. self.action_size = action_size # 动作数量
  30. # 策略网络Actor
  31. self.dense1 = layers.Dense(128, activation='relu')
  32. self.policy_logits = layers.Dense(action_size)
  33. # V网络Critic
  34. self.dense2 = layers.Dense(128, activation='relu')
  35. self.values = layers.Dense(1)
  36. def call(self, inputs):
  37. # 获得策略分布Pi(a|s)
  38. x = self.dense1(inputs)
  39. logits = self.policy_logits(x)
  40. # 获得v(s)
  41. v = self.dense2(inputs)
  42. values = self.values(v)
  43. return logits, values
  44. def record(episode,
  45. episode_reward,
  46. worker_idx,
  47. global_ep_reward,
  48. result_queue,
  49. total_loss,
  50. num_steps):
  51. # 统计工具函数
  52. if global_ep_reward == 0:
  53. global_ep_reward = episode_reward
  54. else:
  55. global_ep_reward = global_ep_reward * 0.99 + episode_reward * 0.01
  56. print(
  57. f"{episode} | "
  58. f"Average Reward: {int(global_ep_reward)} | "
  59. f"Episode Reward: {int(episode_reward)} | "
  60. f"Loss: {int(total_loss / float(num_steps) * 1000) / 1000} | "
  61. f"Steps: {num_steps} | "
  62. f"Worker: {worker_idx}"
  63. )
  64. result_queue.put(global_ep_reward) # 保存回报,传给主线程
  65. return global_ep_reward
  66. class Memory:
  67. def __init__(self):
  68. self.states = []
  69. self.actions = []
  70. self.rewards = []
  71. def store(self, state, action, reward):
  72. self.states.append(state)
  73. self.actions.append(action)
  74. self.rewards.append(reward)
  75. def clear(self):
  76. self.states = []
  77. self.actions = []
  78. self.rewards = []
  79. class Agent:
  80. # 智能体,包含了中央参数网络server
  81. def __init__(self):
  82. # server优化器,client不需要,直接从server拉取参数
  83. self.opt = optimizers.Adam(1e-3)
  84. # 中央模型,类似于参数服务器
  85. self.server = ActorCritic(4, 2) # 状态向量,动作数量
  86. self.server(tf.random.normal((2, 4)))
  87. def train(self):
  88. res_queue = Queue() # 共享队列
  89. # 创建各个交互环境
  90. workers = [Worker(self.server, self.opt, res_queue, i)
  91. for i in range(multiprocessing.cpu_count())]
  92. for i, worker in enumerate(workers):
  93. print("Starting worker {}".format(i))
  94. worker.start()
  95. # 统计并绘制总回报曲线
  96. returns = []
  97. while True:
  98. reward = res_queue.get()
  99. if reward is not None:
  100. returns.append(reward)
  101. else: # 结束标志
  102. break
  103. [w.join() for w in workers] # 等待线程退出
  104. print(returns)
  105. plt.figure()
  106. plt.plot(np.arange(len(returns)), returns)
  107. # plt.plot(np.arange(len(moving_average_rewards)), np.array(moving_average_rewards), 's')
  108. plt.xlabel('回合数')
  109. plt.ylabel('总回报')
  110. plt.savefig('a3c-tf-cartpole.svg')
  111. class Worker(threading.Thread):
  112. def __init__(self, server, opt, result_queue, idx):
  113. super(Worker, self).__init__()
  114. self.result_queue = result_queue # 共享队列
  115. self.server = server # 中央模型
  116. self.opt = opt # 中央优化器
  117. self.client = ActorCritic(4, 2) # 线程私有网络
  118. self.worker_idx = idx # 线程id
  119. self.env = gym.make('CartPole-v1').unwrapped
  120. self.ep_loss = 0.0
  121. def run(self):
  122. mem = Memory() # 每个worker自己维护一个memory
  123. for epi_counter in range(500): # 未达到最大回合数
  124. current_state = self.env.reset() # 复位client游戏状态
  125. mem.clear()
  126. ep_reward = 0.
  127. ep_steps = 0
  128. done = False
  129. while not done:
  130. # 获得Pi(a|s),未经softmax
  131. logits, _ = self.client(tf.constant(current_state[None, :],
  132. dtype=tf.float32))
  133. probs = tf.nn.softmax(logits)
  134. # 随机采样动作
  135. action = np.random.choice(2, p=probs.numpy()[0])
  136. new_state, reward, done, _ = self.env.step(action) # 交互
  137. ep_reward += reward # 累加奖励
  138. mem.store(current_state, action, reward) # 记录
  139. ep_steps += 1 # 计算回合步数
  140. current_state = new_state # 刷新状态
  141. if ep_steps >= 500 or done: # 最长步数500
  142. # 计算当前client上的误差
  143. with tf.GradientTape() as tape:
  144. total_loss = self.compute_loss(done, new_state, mem)
  145. # 计算误差
  146. grads = tape.gradient(total_loss, self.client.trainable_weights)
  147. # 梯度提交到server,在server上更新梯度
  148. self.opt.apply_gradients(zip(grads,
  149. self.server.trainable_weights))
  150. # 从server拉取最新的梯度
  151. self.client.set_weights(self.server.get_weights())
  152. mem.clear() # 清空Memory
  153. # 统计此回合回报
  154. self.result_queue.put(ep_reward)
  155. print(self.worker_idx, ep_reward)
  156. break
  157. self.result_queue.put(None) # 结束线程
  158. def compute_loss(self,
  159. done,
  160. new_state,
  161. memory,
  162. gamma=0.99):
  163. if done:
  164. reward_sum = 0. # 终止状态的v(终止)=0
  165. else:
  166. reward_sum = self.client(tf.constant(new_state[None, :],
  167. dtype=tf.float32))[-1].numpy()[0]
  168. # 统计折扣回报
  169. discounted_rewards = []
  170. for reward in memory.rewards[::-1]: # reverse buffer r
  171. reward_sum = reward + gamma * reward_sum
  172. discounted_rewards.append(reward_sum)
  173. discounted_rewards.reverse()
  174. # 获取状态的Pi(a|s)和v(s)
  175. logits, values = self.client(tf.constant(np.vstack(memory.states),
  176. dtype=tf.float32))
  177. # 计算advantage = R() - v(s)
  178. advantage = tf.constant(np.array(discounted_rewards)[:, None],
  179. dtype=tf.float32) - values
  180. # Critic网络损失
  181. value_loss = advantage ** 2
  182. # 策略损失
  183. policy = tf.nn.softmax(logits)
  184. policy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
  185. labels=memory.actions, logits=logits)
  186. # 计算策略网络损失时,并不会计算V网络
  187. policy_loss = policy_loss * tf.stop_gradient(advantage)
  188. # Entropy Bonus
  189. entropy = tf.nn.softmax_cross_entropy_with_logits(labels=policy,
  190. logits=logits)
  191. policy_loss = policy_loss - 0.01 * entropy
  192. # 聚合各个误差
  193. total_loss = tf.reduce_mean((0.5 * value_loss + policy_loss))
  194. return total_loss
  195. if __name__ == '__main__':
  196. master = Agent()
  197. master.train()