gpu_accelerate.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import numpy as np
  2. import matplotlib
  3. from matplotlib import pyplot as plt
  4. # Default parameters for plots
  5. matplotlib.rcParams['font.size'] = 20
  6. matplotlib.rcParams['figure.titlesize'] = 20
  7. matplotlib.rcParams['figure.figsize'] = [9, 7]
  8. matplotlib.rcParams['font.family'] = ['STKaiti']
  9. matplotlib.rcParams['axes.unicode_minus']=False
  10. import tensorflow as tf
  11. import timeit
  12. cpu_data = []
  13. gpu_data = []
  14. for n in range(9):
  15. n = 10**n
  16. # 创建在CPU上运算的2个矩阵
  17. with tf.device('/cpu:0'):
  18. cpu_a = tf.random.normal([1, n])
  19. cpu_b = tf.random.normal([n, 1])
  20. print(cpu_a.device, cpu_b.device)
  21. # 创建使用GPU运算的2个矩阵
  22. with tf.device('/gpu:0'):
  23. gpu_a = tf.random.normal([1, n])
  24. gpu_b = tf.random.normal([n, 1])
  25. print(gpu_a.device, gpu_b.device)
  26. def cpu_run():
  27. with tf.device('/cpu:0'):
  28. c = tf.matmul(cpu_a, cpu_b)
  29. return c
  30. def gpu_run():
  31. with tf.device('/gpu:0'):
  32. c = tf.matmul(gpu_a, gpu_b)
  33. return c
  34. # 第一次计算需要热身,避免将初始化阶段时间结算在内
  35. cpu_time = timeit.timeit(cpu_run, number=10)
  36. gpu_time = timeit.timeit(gpu_run, number=10)
  37. print('warmup:', cpu_time, gpu_time)
  38. # 正式计算10次,取平均时间
  39. cpu_time = timeit.timeit(cpu_run, number=10)
  40. gpu_time = timeit.timeit(gpu_run, number=10)
  41. print('run time:', cpu_time, gpu_time)
  42. cpu_data.append(cpu_time/10)
  43. gpu_data.append(gpu_time/10)
  44. del cpu_a,cpu_b,gpu_a,gpu_b
  45. x = [10**i for i in range(9)]
  46. cpu_data = [1000*i for i in cpu_data]
  47. gpu_data = [1000*i for i in gpu_data]
  48. plt.plot(x, cpu_data, 'C1')
  49. plt.plot(x, cpu_data, color='C1', marker='s', label='CPU')
  50. plt.plot(x, gpu_data,'C0')
  51. plt.plot(x, gpu_data, color='C0', marker='^', label='GPU')
  52. plt.gca().set_xscale('log')
  53. plt.gca().set_yscale('log')
  54. plt.ylim([0,100])
  55. plt.xlabel('矩阵大小n:(1xn)@(nx1)')
  56. plt.ylabel('运算时间(ms)')
  57. plt.legend()
  58. plt.savefig('gpu-time.svg')