| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import multiprocessing
- import tensorflow as tf
- def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
- # @tf.function
- def _map_fn(img):
- img = tf.image.resize(img, [resize, resize])
- # img = tf.image.random_crop(img,[resize, resize])
- # img = tf.image.random_flip_left_right(img)
- # img = tf.image.random_flip_up_down(img)
- img = tf.clip_by_value(img, 0, 255)
- img = img / 127.5 - 1 #-1~1
- return img
- dataset = disk_image_batch_dataset(img_paths,
- batch_size,
- drop_remainder=drop_remainder,
- map_fn=_map_fn,
- shuffle=shuffle,
- repeat=repeat)
- img_shape = (resize, resize, 3)
- len_dataset = len(img_paths) // batch_size
- return dataset, img_shape, len_dataset
- def batch_dataset(dataset,
- batch_size,
- drop_remainder=True,
- n_prefetch_batch=1,
- filter_fn=None,
- map_fn=None,
- n_map_threads=None,
- filter_after_map=False,
- shuffle=True,
- shuffle_buffer_size=None,
- repeat=None):
- # set defaults
- if n_map_threads is None:
- n_map_threads = multiprocessing.cpu_count()
- if shuffle and shuffle_buffer_size is None:
- shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048
- # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
- if shuffle:
- dataset = dataset.shuffle(shuffle_buffer_size)
- if not filter_after_map:
- if filter_fn:
- dataset = dataset.filter(filter_fn)
- if map_fn:
- dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
- else: # [*] this is slower
- if map_fn:
- dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
- if filter_fn:
- dataset = dataset.filter(filter_fn)
- dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
- dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
- return dataset
- def memory_data_batch_dataset(memory_data,
- batch_size,
- drop_remainder=True,
- n_prefetch_batch=1,
- filter_fn=None,
- map_fn=None,
- n_map_threads=None,
- filter_after_map=False,
- shuffle=True,
- shuffle_buffer_size=None,
- repeat=None):
- """Batch dataset of memory data.
- Parameters
- ----------
- memory_data : nested structure of tensors/ndarrays/lists
- """
- dataset = tf.data.Dataset.from_tensor_slices(memory_data)
- dataset = batch_dataset(dataset,
- batch_size,
- drop_remainder=drop_remainder,
- n_prefetch_batch=n_prefetch_batch,
- filter_fn=filter_fn,
- map_fn=map_fn,
- n_map_threads=n_map_threads,
- filter_after_map=filter_after_map,
- shuffle=shuffle,
- shuffle_buffer_size=shuffle_buffer_size,
- repeat=repeat)
- return dataset
- def disk_image_batch_dataset(img_paths,
- batch_size,
- labels=None,
- drop_remainder=True,
- n_prefetch_batch=1,
- filter_fn=None,
- map_fn=None,
- n_map_threads=None,
- filter_after_map=False,
- shuffle=True,
- shuffle_buffer_size=None,
- repeat=None):
- """Batch dataset of disk image for PNG and JPEG.
- Parameters
- ----------
- img_paths : 1d-tensor/ndarray/list of str
- labels : nested structure of tensors/ndarrays/lists
- """
- if labels is None:
- memory_data = img_paths
- else:
- memory_data = (img_paths, labels)
- def parse_fn(path, *label):
- img = tf.io.read_file(path)
- img = tf.image.decode_jpeg(img, channels=3) # fix channels to 3
- return (img,) + label
- if map_fn: # fuse `map_fn` and `parse_fn`
- def map_fn_(*args):
- return map_fn(*parse_fn(*args))
- else:
- map_fn_ = parse_fn
- dataset = memory_data_batch_dataset(memory_data,
- batch_size,
- drop_remainder=drop_remainder,
- n_prefetch_batch=n_prefetch_batch,
- filter_fn=filter_fn,
- map_fn=map_fn_,
- n_map_threads=n_map_threads,
- filter_after_map=filter_after_map,
- shuffle=shuffle,
- shuffle_buffer_size=shuffle_buffer_size,
- repeat=repeat)
- return dataset
|