Tensorflow training with batch size of (1, None, features), but model expects extra dimension

I've made an autoencoder like below, to accept variable-length inputs. It works for a single sample if I do model.fit(np.expand_dims(x, axis = 0) but this won't work when passing in an entire dataset. What's the simplest approach in this case?

import numpy as np
import tensorflow.python.keras.backend as K
from tensorflow.python.keras.layers import Input, LSTM, Lambda
from tensorflow.python.keras.models import Model


def repeat(x):
    step_matrix = K.ones_like(x[0][:, :, :1])
    latent_matrix = K.expand_dims(x[1], axis = 1)
    return K.batch_dot(step_matrix, latent_matrix)

timesteps = None
features = 2
latent_dim = 10

inputs = Input(shape = (timesteps, features))
encoded = LSTM(latent_dim, name = "encoded")(inputs)
decoded = Lambda(repeat)([inputs, encoded])
outputs = LSTM(features, return_sequences = True)(decoded)
autoenc = Model(inputs = inputs, outputs = outputs)
autoenc.compile(optimizer = "adam", loss = "mse")
encoder = Model(
    inputs = autoenc.input, outputs = autoenc.get_layer("encoded").output
)

x1 = np.ones((20, 2))
x2 = np.ones((30, 2))
x3 = np.ones((40, 2))
X_train = np.array((x1, x2, x3))

autoenc.fit(x = X_train, y = X_train, epochs = 10, batch_size = 1)

Topic keras tensorflow

Category Data Science


I managed to solve my problem with a generator, which expands the dimensions for each single batch to return shape (1, None, 2).

class SingleBatchGenerator:
    def __init__(self, X):
        self.X = X

    def __call__(self):
        for i in range(len(self.X)):
            xi = np.expand_dims(self.X[i], axis=0)
            yield xi, xi

X = [np.ones((np.random.randint(1, 100), 2)) for _ in range(100)]
gen = SingleBatchGenerator(X)

ds = tf.data.Dataset.from_generator(
    generator = gen,
    output_types=(tf.float64, tf.float64),
    output_shapes=((1, None, 2), (1, None, 2)),
)

autoenc.fit(ds.repeat(), steps_per_epoch=len(X), epochs=500)

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.