dataset.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import multiprocessing
  2. import tensorflow as tf
  3. def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
  4. # @tf.function
  5. def _map_fn(img):
  6. img = tf.image.resize(img, [resize, resize])
  7. # img = tf.image.random_crop(img,[resize, resize])
  8. # img = tf.image.random_flip_left_right(img)
  9. # img = tf.image.random_flip_up_down(img)
  10. img = tf.clip_by_value(img, 0, 255)
  11. img = img / 127.5 - 1 #-1~1
  12. return img
  13. dataset = disk_image_batch_dataset(img_paths,
  14. batch_size,
  15. drop_remainder=drop_remainder,
  16. map_fn=_map_fn,
  17. shuffle=shuffle,
  18. repeat=repeat)
  19. img_shape = (resize, resize, 3)
  20. len_dataset = len(img_paths) // batch_size
  21. return dataset, img_shape, len_dataset
  22. def batch_dataset(dataset,
  23. batch_size,
  24. drop_remainder=True,
  25. n_prefetch_batch=1,
  26. filter_fn=None,
  27. map_fn=None,
  28. n_map_threads=None,
  29. filter_after_map=False,
  30. shuffle=True,
  31. shuffle_buffer_size=None,
  32. repeat=None):
  33. # set defaults
  34. if n_map_threads is None:
  35. n_map_threads = multiprocessing.cpu_count()
  36. if shuffle and shuffle_buffer_size is None:
  37. shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048
  38. # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
  39. if shuffle:
  40. dataset = dataset.shuffle(shuffle_buffer_size)
  41. if not filter_after_map:
  42. if filter_fn:
  43. dataset = dataset.filter(filter_fn)
  44. if map_fn:
  45. dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
  46. else: # [*] this is slower
  47. if map_fn:
  48. dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
  49. if filter_fn:
  50. dataset = dataset.filter(filter_fn)
  51. dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
  52. dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
  53. return dataset
  54. def memory_data_batch_dataset(memory_data,
  55. batch_size,
  56. drop_remainder=True,
  57. n_prefetch_batch=1,
  58. filter_fn=None,
  59. map_fn=None,
  60. n_map_threads=None,
  61. filter_after_map=False,
  62. shuffle=True,
  63. shuffle_buffer_size=None,
  64. repeat=None):
  65. """Batch dataset of memory data.
  66. Parameters
  67. ----------
  68. memory_data : nested structure of tensors/ndarrays/lists
  69. """
  70. dataset = tf.data.Dataset.from_tensor_slices(memory_data)
  71. dataset = batch_dataset(dataset,
  72. batch_size,
  73. drop_remainder=drop_remainder,
  74. n_prefetch_batch=n_prefetch_batch,
  75. filter_fn=filter_fn,
  76. map_fn=map_fn,
  77. n_map_threads=n_map_threads,
  78. filter_after_map=filter_after_map,
  79. shuffle=shuffle,
  80. shuffle_buffer_size=shuffle_buffer_size,
  81. repeat=repeat)
  82. return dataset
  83. def disk_image_batch_dataset(img_paths,
  84. batch_size,
  85. labels=None,
  86. drop_remainder=True,
  87. n_prefetch_batch=1,
  88. filter_fn=None,
  89. map_fn=None,
  90. n_map_threads=None,
  91. filter_after_map=False,
  92. shuffle=True,
  93. shuffle_buffer_size=None,
  94. repeat=None):
  95. """Batch dataset of disk image for PNG and JPEG.
  96. Parameters
  97. ----------
  98. img_paths : 1d-tensor/ndarray/list of str
  99. labels : nested structure of tensors/ndarrays/lists
  100. """
  101. if labels is None:
  102. memory_data = img_paths
  103. else:
  104. memory_data = (img_paths, labels)
  105. def parse_fn(path, *label):
  106. img = tf.io.read_file(path)
  107. img = tf.image.decode_jpeg(img, channels=3) # fix channels to 3
  108. return (img,) + label
  109. if map_fn: # fuse `map_fn` and `parse_fn`
  110. def map_fn_(*args):
  111. return map_fn(*parse_fn(*args))
  112. else:
  113. map_fn_ = parse_fn
  114. dataset = memory_data_batch_dataset(memory_data,
  115. batch_size,
  116. drop_remainder=drop_remainder,
  117. n_prefetch_batch=n_prefetch_batch,
  118. filter_fn=filter_fn,
  119. map_fn=map_fn_,
  120. n_map_threads=n_map_threads,
  121. filter_after_map=filter_after_map,
  122. shuffle=shuffle,
  123. shuffle_buffer_size=shuffle_buffer_size,
  124. repeat=repeat)
  125. return dataset