How do I predict with a Keras model on a dataset similar to how preprocessing.image_dataset_from_directory is formatted

I was wondering if anyone has written code where I can read from a directory of image (not having a subfolder inside to represent a class) and then running model.predict() on it. I do not want to setup a subfolder because it's usually named as a class since this folder will be unseen and unlabelled data. Here is my code attempt which does not work:

model = tf.keras.models.load_model('Classification_model')
data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Rescaling(1./255)
    ]
)

dataset = tf.data.Dataset.list_files(test/*.JPG, shuffle = False) # read in a bunch of jpegs. 

def decode_img(img):
  img = tf.image.decode_jpeg(img, channels=3) #color images
  img = tf.image.convert_image_dtype(img, tf.float32) 
   #convert unit8 tensor to floats in the [0,1]range
  return img 

def decode_jpeg_and_label(filename):
  bits = tf.io.read_file(filename)
  image = decode_img(bits)
  label = 1 # fake label 
  return image, label

dataset = dataset.map(decode_jpeg_and_label)

augmented_test_ds = dataset.map(
    lambda x, y: (data_augmentation(x, training=False), y))

probs = model.predict(augmented_test_ds, verbose = 1)

However, the error I get is:

ValueError: Input 0 of layer stem_conv is incompatible with the layer: : expected min_ndim=4, found ndim=3. Full shape received: [None, None, None]

which I assume means I am not formatting my dataset correctly for prediction. What should I do? Thank you!

Topic keras tensorflow python machine-learning

Category Data Science


Tensorflow/Keras expect training and testing dataset in batch format with shape of (32, 224, 224, 3) as 32 represents number of images so in case of single image prediction your batch shape should looks like (1, 224, 224, 3).

def decode_jpeg_and_label(filename):
  bits = tf.io.read_file(filename)
  image = decode_img(bits)
  image = image.reshape((1, 224, 224, 3)) # Added this section to reshape image single image
  label = 1 # fake label 
  return image, label

Once you add above set of code then your ndim=3 error will disappear

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.