pokemon.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import os, glob
  2. import random, csv
  3. import tensorflow as tf
  4. def load_csv(root, filename, name2label):
  5. # 从csv文件返回images,labels列表
  6. # root:数据集根目录,filename:csv文件名, name2label:类别名编码表
  7. if not os.path.exists(os.path.join(root, filename)):
  8. # 如果csv文件不存在,则创建
  9. images = []
  10. for name in name2label.keys(): # 遍历所有子目录,获得所有的图片
  11. # 只考虑后缀为png,jpg,jpeg的图片:'pokemon\\mewtwo\\00001.png
  12. images += glob.glob(os.path.join(root, name, '*.png'))
  13. images += glob.glob(os.path.join(root, name, '*.jpg'))
  14. images += glob.glob(os.path.join(root, name, '*.jpeg'))
  15. # 打印数据集信息:1167, 'pokemon\\bulbasaur\\00000000.png'
  16. print(len(images), images)
  17. random.shuffle(images) # 随机打散顺序
  18. # 创建csv文件,并存储图片路径及其label信息
  19. with open(os.path.join(root, filename), mode='w', newline='') as f:
  20. writer = csv.writer(f)
  21. for img in images: # 'pokemon\\bulbasaur\\00000000.png'
  22. name = img.split(os.sep)[-2]
  23. label = name2label[name]
  24. # 'pokemon\\bulbasaur\\00000000.png', 0
  25. writer.writerow([img, label])
  26. print('written into csv file:', filename)
  27. # 此时已经有csv文件,直接读取
  28. images, labels = [], []
  29. with open(os.path.join(root, filename)) as f:
  30. reader = csv.reader(f)
  31. for row in reader:
  32. # 'pokemon\\bulbasaur\\00000000.png', 0
  33. img, label = row
  34. label = int(label)
  35. images.append(img)
  36. labels.append(label)
  37. # 返回图片路径list和标签list
  38. return images, labels
  39. def load_pokemon(root, mode='train'):
  40. # 创建数字编码表
  41. name2label = {} # "sq...":0
  42. # 遍历根目录下的子文件夹,并排序,保证映射关系固定
  43. for name in sorted(os.listdir(os.path.join(root))):
  44. # 跳过非文件夹
  45. if not os.path.isdir(os.path.join(root, name)):
  46. continue
  47. # 给每个类别编码一个数字
  48. name2label[name] = len(name2label.keys())
  49. # 读取Label信息
  50. # [file1,file2,], [3,1]
  51. images, labels = load_csv(root, 'images.csv', name2label)
  52. if mode == 'train': # 60%
  53. images = images[:int(0.6 * len(images))]
  54. labels = labels[:int(0.6 * len(labels))]
  55. elif mode == 'val': # 20% = 60%->80%
  56. images = images[int(0.6 * len(images)):int(0.8 * len(images))]
  57. labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
  58. else: # 20% = 80%->100%
  59. images = images[int(0.8 * len(images)):]
  60. labels = labels[int(0.8 * len(labels)):]
  61. return images, labels, name2label
  62. # 这里的mean和std根据真实的数据计算获得,比如ImageNet
  63. img_mean = tf.constant([0.485, 0.456, 0.406])
  64. img_std = tf.constant([0.229, 0.224, 0.225])
  65. def normalize(x, mean=img_mean, std=img_std):
  66. # 标准化
  67. # x: [224, 224, 3]
  68. # mean: [224, 224, 3], std: [3]
  69. x = (x - mean)/std
  70. return x
  71. def denormalize(x, mean=img_mean, std=img_std):
  72. # 标准化的逆过程
  73. x = x * std + mean
  74. return x
  75. def preprocess(x,y):
  76. # x: 图片的路径List,y:图片的数字编码List
  77. x = tf.io.read_file(x) # 根据路径读取图片
  78. x = tf.image.decode_jpeg(x, channels=3) # 图片解码
  79. x = tf.image.resize(x, [244, 244]) # 图片缩放
  80. # 数据增强
  81. # x = tf.image.random_flip_up_down(x)
  82. x= tf.image.random_flip_left_right(x) # 左右镜像
  83. x = tf.image.random_crop(x, [224, 224, 3]) # 随机裁剪
  84. # 转换成张量
  85. # x: [0,255]=> 0~1
  86. x = tf.cast(x, dtype=tf.float32) / 255.
  87. # 0~1 => D(0,1)
  88. x = normalize(x) # 标准化
  89. y = tf.convert_to_tensor(y) # 转换成张量
  90. return x, y
  91. def main():
  92. import time
  93. # 加载pokemon数据集,指定加载训练集
  94. images, labels, table = load_pokemon('pokemon', 'train')
  95. print('images:', len(images), images)
  96. print('labels:', len(labels), labels)
  97. print('table:', table)
  98. # images: string path
  99. # labels: number
  100. db = tf.data.Dataset.from_tensor_slices((images, labels))
  101. db = db.shuffle(1000).map(preprocess).batch(32)
  102. # 创建TensorBoard对象
  103. writter = tf.summary.create_file_writer('logs')
  104. for step, (x,y) in enumerate(db):
  105. # x: [32, 224, 224, 3]
  106. # y: [32]
  107. with writter.as_default():
  108. x = denormalize(x) # 反向normalize,方便可视化
  109. # 写入图片数据
  110. tf.summary.image('img',x,step=step,max_outputs=9)
  111. time.sleep(5)
  112. if __name__ == '__main__':
  113. main()