acc_topk.py 982 B

12345678910111213141516171819202122232425262728293031323334353637
  1. import tensorflow as tf
  2. import os
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  4. tf.random.set_seed(2467)
  5. def accuracy(output, target, topk=(1,)):
  6. maxk = max(topk)
  7. batch_size = target.shape[0]
  8. pred = tf.math.top_k(output, maxk).indices
  9. pred = tf.transpose(pred, perm=[1, 0])
  10. target_ = tf.broadcast_to(target, pred.shape)
  11. # [10, b]
  12. correct = tf.equal(pred, target_)
  13. res = []
  14. for k in topk:
  15. correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
  16. correct_k = tf.reduce_sum(correct_k)
  17. acc = float(correct_k* (100.0 / batch_size) )
  18. res.append(acc)
  19. return res
  20. output = tf.random.normal([10, 6])
  21. output = tf.math.softmax(output, axis=1)
  22. target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
  23. print('prob:', output.numpy())
  24. pred = tf.argmax(output, axis=1)
  25. print('pred:', pred.numpy())
  26. print('label:', target.numpy())
  27. acc = accuracy(output, target, topk=(1,2,3,4,5,6))
  28. print('top-1-6 acc:', acc)