How can I classify multiple images "simultaneously"?

I am dealing with a multiclass image classification problem with N classes. Particularly interesting is now that a single instance is NOT a single image as you would expect normally, but rather a set of multiple images that belong to each other.

I could theoretically built a model that looks at each image of an instance one after another and classifies them separately (and I actually already did), but I want to create a model that benefits from the interdependence between those images of a single instance.

The reason for that is that there are classes like ATop and ABottom and in my normal image classification set up, it is barely possible to distinguish between those two.

But: In each instance (set of images), there is for example always exactly one image of class ATop and one image of class ABottom. By looking at all images of a class at once, I hope to capture interdependencies between those images that result in a better score on these hard-to-distinguish classes.

My question is now: How can I implement this idea (network structure, data set structure, ...) (in pytorch)?

It might be reasonable to point out that instances can have between 3 to 10 images.


EXAMPLE:

Consider this instance which consists of 4 images:

instance = {image_A, image_B, image_C, image_A'}

I want a model that takes this instance as an input and maps it to

{A, B, C, A'}

Theoretically, I could split this instance into:

instance1 = {image_A}
instance2 = {image_B}
instance3 = {image_C}
instance4 = {image_A'} 

and map them separately to

{A}, {B}, {C}, {A'}

but then, I would lose interdependencies between these images.

A few sidenotes: As pointed out before, a single instance does not always have exactly 4 images. It can have anything between 3 to 10 images and there are also more than these 4 classes (certain classes can also appear more than once in a single instance). Also, the images themselves are not really a sequence. The ordering does not matter. So, the following would be considered identical from human perspective:

instance1 = {image_A, image_B, image_C, image_A'}
instance2 = {image_A, image_A', image_C, image_B}

Topic pytorch image-classification deep-learning

Category Data Science


Finally, I know how to solve this.

The most plain solution would be a RNN-like structure that just acts as a seq2seq model. We can use attention, bidirectional chains etc. to enhance our performance. However, this solution has the clear disadvantage that we insinuate that there is a sequential relationship between the images which is actually not really there.. as pointed out in the question, the ordering of the images in a sample does not matter at all.

As a better solution, we can use a transformer model without positional encoding. As we know, attention is all you need and we can waive the RNN-like structure. Transformer models usually still make use of the positional information in common application areas such as language modeling where the ordering does matter. But in our case, ordering doesn't matter. So, we just take the positional encoding (preprocessing step in the transformer model) away and we have a reasonable architecture where our interdependencies of arbitrary sets of images are captured.


This is sometimes called multi-output.

In PyTorch, it is possible to have multiple inputs and multiple outputs:

import torch.nn as nn

class NeuralNetwork(nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.linear1 = nn.Linear()
    self.linear2 = nn.Linear()

  def forward(self, x):
    output1 = self.linear1(x)
    output2 = self.linear2(x)
    return output1, output2

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.