5.3. Language – Vision#

In this notebook, we will learn how to train a multimodal network. We will explore a language–vision model which matches texts to images. In particular, we implement a simple CLIP (Contrastive Language-Image Pre-Training).

CLIP is a neural network trained on various (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similar to the zero-shot capabilities of GPT-2 and 3.

0. Preparation#

In order to run this notebook, we need to perform some preparations.

Packages#

Let’s start with all the necessary packages to implement this tutorial.

  • numpy is the main package for scientific computing with Python. It’s often imported with the np shortcut.

  • os provides a portable way of using operating system-dependent functionality, e.g., modifying files/folders.

  • pandas provides easy working routines with tabular data structures.

  • argparse is a module making it easy to write user-friendly command-line interfaces.

  • matplotlib is a library to plot graphs in Python.

  • torch is a deep learning framework that allows us to define networks, handle datasets, optimise a loss function, etc.

  • transformers provides pretrained models to perform tasks on different modalities such as text, vision, and audio.

import numpy as np
import os
import pandas as pd
import argparse
import itertools

import matplotlib.pyplot as plt
from PIL import Image as pil_image

import torch
from torch import nn
import torchvision

from transformers import DistilBertModel, DistilBertTokenizer

device#

Choosing CPU or GPU based on the availability of the hardware.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

arguments#

We use the argparse module to define a set of parameters that we use throughout this notebook:

  • The argparse is particularly useful when writing Python scripts, allowing you to run the same script with different parameters (e.g., for doing different experiments).

  • In notebooks using argparse is not necessarily beneficial, we could have hard-coded those values directly in variables, but here we use argparse for learning purposes.

parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=5, help="number of training epochs")
parser.add_argument("--batch_size", type=int, default=8, help="size of the batches")
parser.add_argument("--num_workers", type=int, default=0, help="number of CPU workers")

parser.add_argument("--head_lr", type=float, default=1e-3, help="Head learning rate")
parser.add_argument("--image_encoder_lr", type=float, default=1e-4, help="Image encoder learning rate")
parser.add_argument("--text_encoder_lr", type=float, default=1e-5, help="Text encoder learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-3, help="Weight decay")
parser.add_argument("--projection_dim", type=float, default=256, help="Size of the projection space")

parser.add_argument("--model_vision", type=str, default='resnet50', help="Image encoder")
parser.add_argument("--model_text", type=str, default='distilbert-base-uncased', help="Text encoder")

parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")
parser.add_argument("--log_frequency", type=int, default=100, help="interval log prints")
parser.add_argument("--out_dir", type=str, default="./out/clip_out/", help="the output directory")
parser.add_argument("--data_dir", type=str, default="./data/", help="the dataset directory")

def set_args(*args):
    # we can pass arguments to the parse_args function to change the default values. 
    opt = parser.parse_args([*args])
    # creating the output directory
    os.makedirs(opt.out_dir, exist_ok=True)
    return opt

1. Datasets#

We have created the TinyFlicker dataset which is a subset of Flickr8k dataset. The TinyFlicker contains:

  • 1000 images,

  • 5 captions per image.

Download and extract#

The download_and_extract_db function downloads and extracts the dataset if it’s not already stored in the provided directory. We use two functions already implemented in torchvision.datasets.utils:

  • download_url

  • extract_archive

def download_and_extract_db(data_dir):
    db_url = "https://dl.dropboxusercontent.com/s/zi46giyvvch0k8q/TinyFlicker.tar.gz"
    torchvision.datasets.utils.download_url(db_url, data_dir)

    data_dir = torchvision.datasets.utils.extract_archive(
        "%s/TinyFlicker.tar.gz" % data_dir, "%s/TinyFlicker/" % data_dir
    )
    return data_dir

Text Tokeniser#

We could have created our own text tokeniser similar to the text classification notebook. But in this tutorial, we use DistilBertTokenizer from the transformers package.

text_tokeniser = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

Dataset class#

The TinyFlicker class reads the dataset file and returns the pair of image-caption.

class TinyFlicker(torch.utils.data.Dataset):
    def __init__(self, root, data_frame, tokeniser, transforms=None, max_tokens=50):
        self.root = root
        self.captions = data_frame['caption'].values
        self.images = data_frame['image'].values
        self.tokeniser = tokeniser
        self.transforms = transforms
        self.max_tokens = max_tokens

    def __getitem__(self, idx):
        # loading the image
        img = pil_image.open("%s/imgs/%s" % (self.root, self.images[idx]))
        # performing the transformation functions
        if self.transforms:
            img = self.transforms(img)
        caption = self.captions[idx]

        # tokenising the text
        tout = self.tokeniser(caption, truncation=True, max_length=self.max_tokens)
        input_ids = torch.tensor(tout['input_ids'])
        attention_mask = torch.tensor(tout['attention_mask'])
        # all elements in each batch should have the same length, therefore we pad
        # the tensors into an identical length
        input_ids = nn.functional.pad(input_ids, (0, self.max_tokens - len(input_ids)), value=0)
        attention_mask = nn.functional.pad(attention_mask, (0, self.max_tokens - len(attention_mask)), value=0)

        item = {
            'image': img,
            'caption': caption,
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }
        return item

    def __len__(self):
        return len(self.captions)

Transform functions#

In this tutorial, we use the same set of transform functions for training and testing. Essentially, we only resize the input images and convert them to a normalised tensor. One could add data augmentation to the training images such as random cropping or horizontal flipping. We leave this as an exercise for interested readers.

def get_transforms(target_size, for_network=True):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    transform_funs = [
        torchvision.transforms.Resize(target_size),
        torchvision.transforms.CenterCrop(target_size)
    ]
    # if for_networks is False we don't convert it to Tensor and normalisation. This
    # makes it easier for visualisation purposes.
    if for_network:
        transform_funs.extend([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean, std)
        ])
    return torchvision.transforms.Compose(transform_funs)

Sample items#

Let’s create an instance of our dataset and explore its content. The first time you execute the the following cell, it downloads the dataset and extracts its content in the data_dir directory.

data_dir = './data/'
data_dir = download_and_extract_db(data_dir)
data_frame = pd.read_csv("%s/captions.csv" % data_dir)
sample_db = TinyFlicker(
    root=data_dir,
    data_frame=data_frame,
    tokeniser=text_tokeniser,
    transforms=get_transforms(224, for_network=False),
)
Using downloaded and verified file: ./data/TinyFlicker.tar.gz
print('The number of samples in the dataset: %d' % sample_db.__len__())
The number of samples in the dataset: 5000

Each sample is a dict of four elements:

  • image: the input image

  • caption: the raw image caption

  • input_ids: the processed caption by text_tokeniser

  • attention_mask: an array of 1s and 0s corresponding to actual text or padded zeros.

sample_item = sample_db.__getitem__(0)
print(sample_item.keys())
dict_keys(['image', 'caption', 'input_ids', 'attention_mask'])

If we print the input_ids each cell corresponds to index of that word (token) in the dictionary. Please check the text classification notebook for more information.

sample_item['input_ids']
tensor([ 101, 1037, 2775, 1999, 1037, 5061, 4377, 2003, 8218, 2039, 1037, 2275,
        1997, 5108, 1999, 2019, 4443, 2126, 1012,  102,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0])

The attention_mask tensor have the same length as input_ids:

  • Those with value 1 correspond to actual word (token)

  • Those with value 0 are padded element.

print('Size of input_ids:', sample_item['input_ids'].shape)
print('Size of attention_mask:', sample_item['attention_mask'].shape)
sample_item['attention_mask']
Size of input_ids: torch.Size([50])
Size of attention_mask: torch.Size([50])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0])

Let’s plot two captions for an identical image:

fig = plt.figure(figsize=(20, 6))
for i in range(2):
    sample_item = sample_db.__getitem__(i)
    ax = fig.add_subplot(1, 3, i+1)
    ax.imshow(sample_item['image'])
    ax.axis('off')
    ax.set_title(sample_item['caption'], wrap=True)
../_images/34ceedefd230a37c6b96dbc4c51dc9297c9ba2a5cd02d8994d009300fb54ef2f.png

Dataloaders#

We make standard 80-20% train/test splits.

def make_train_valid_dfs(data_dir, val_percent=0.2):
    dataframe = pd.read_csv("%s/captions.csv" % data_dir)
    dataframe['id'] = [id_ for id_ in range(dataframe.shape[0] // 5) for _ in range(5)]
    max_id = dataframe["id"].max() + 1
    image_ids = np.arange(0, max_id)
    valid_ids = np.random.choice(image_ids, size=int(val_percent * len(image_ids)), replace=False)
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

def build_loaders(args, data_frame, mode):
    transforms = get_transforms(args.img_size)
    dataset = TinyFlicker(
        root=args.data_dir,
        data_frame=data_frame,
        tokeniser=text_tokeniser,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

2. Network#

The CLIP Network consists of three main parts:

  • Image encoder that encodes the visual signal. In theory, it can be any architecture (pretrained network). The only requirement is to have the output as a flattened vector.

  • Text encoder that encodes the text data. Similar to the image encoder, the text encoder can be any architecture.

  • CLIPNet projects the image/text embeddings into a common embedding space and computes their similarity using matrix multiplication.

Vision#

We have hard-coded our ImageEncoder to the last layer of a ResNet50 pretrained on ImageNet. In our case, the output of ImageEncoder is a vector of size 2048.

Interested readers can change the code to support different image encoders.

# Encoder for processing images
class ImageEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        # Load a pre-trained ResNet-50 model (using ImageNet weights) and remove the final classification layer
        pretrained = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)
        self.model = nn.Sequential(*list(pretrained.children())[:-1])
        # Set the dimensionality of the output embeddings from the encoder
        self.embedding_dim = 2048

    def forward(self, x):
        # Pass the image through the ResNet-50 model
        x = self.model(x)
        # Flatten the output to a 2D tensor (batch_size, embedding_dim)
        return torch.flatten(x, start_dim=1)

Language#

We have hard-coded our TextEncoder to the last layer of a DistilBertModel which is a transformer-based model for natural language processing (NLP) tasks. We use its implementation from the transformers package. We use the transformers package instead of torchtext because it does not have many powerful pretrained text models. In our case, the output of TextEncoder is a vector of size 768.

Interested readers can change the code to support different text encoders.

# Encoder for processing text
class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        # Load a pre-trained DistilBERT model from the HuggingFace Transformers library
        self.model = DistilBertModel.from_pretrained(model_name)
        # Set the dimensionality of the output embeddings from the encoder
        self.embedding_dim = 768
        # Use the [CLS] token (index 0) as the representation of the whole input text
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        # Pass the text through the DistilBERT model
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        # Extract the embedding of the [CLS] token
        return last_hidden_state[:, self.target_token_idx, :]

CLIP Network#

The ProjectionHead brings the image/text embedding vectors from their corresponding dimensions (in our case 2048 and 768 respectively) to a common projection space (in our case 256).

The CLIPNet combines all the above-mentioned building blocks into one network:

  1. Encoding text.

  2. Encoding image.

  3. Projecting text features into the common space.

  4. Projecting image features into the common space.

  5. Computing the similarity by doing matrix multiplication (@ operation) between the projected image/text embeddings.

# A projection head to map encoder outputs to a shared embedding space
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super().__init__()
        # Linear layer to project the encoder's output to a lower-dimensional space
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()  # Activation function
        self.fc = nn.Linear(projection_dim, projection_dim)  # Fully connected layer
        self.layer_norm = nn.LayerNorm(projection_dim)  # Normalize embeddings

    def forward(self, x):
        # Project the input embeddings
        projected = self.projection(x)
        # Apply activation and a second linear layer
        x = self.gelu(projected)
        x = self.fc(x)
        # Add skip connection and normalize
        x = x + projected
        x = self.layer_norm(x)
        return x
# The main CLIP model combining image and text encoders
class CLIPNet(nn.Module):
    def __init__(self, vision_model_name, text_model_name, projection_dim):
        super().__init__()
        # Initialize the image and text encoders
        self.image_encoder = ImageEncoder(vision_model_name)
        self.text_encoder = TextEncoder(text_model_name)
        # Add projection heads to map encoder outputs to a common embedding space
        self.image_projection = ProjectionHead(self.image_encoder.embedding_dim, projection_dim)
        self.text_projection = ProjectionHead(self.text_encoder.embedding_dim, projection_dim)

    def forward(self, batch):
        # Process text inputs to obtain text features
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Process image inputs to obtain image features
        image_features = self.image_encoder(batch["image"])
        # Project features into the shared embedding space
        text_embeddings = self.text_projection(text_features)
        image_embeddings = self.image_projection(image_features)

        # Compute similarity scores between text and image embeddings
        logits = text_embeddings @ image_embeddings.T
        # Compute self-similarity matrices for images and texts
        image_similarity = image_embeddings @ image_embeddings.T
        text_similarity = text_embeddings @ text_embeddings.T

        # Generate targets by averaging image and text similarities
        targets = torch.softmax((image_similarity + text_similarity) / 2, dim=-1)

        # Compute cross-entropy loss for both directions
        text_loss = nn.functional.cross_entropy(logits, targets, reduction='none')
        image_loss = nn.functional.cross_entropy(logits.T, targets.T, reduction='none')

        # Average the two losses to ensure bidirectional alignment
        loss = (image_loss + text_loss) / 2.0  # shape: (batch_size)

        # Return the mean loss over the batch
        return loss.mean()

In the best theoretical scenario, text_embeddings and image_embedding matrices should be the same (or highly correlated) because they are describing similar things. If this happens, what would the logits matrix look like? Let’s see with a simple example!

The target (matched image-text pairs) becomes a diagonal matrix.

batch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
target = torch.softmax(out, dim=-1)
print(target)
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

3. Train/test routines#

The following routines are very similar (close to identical) to what we previously saw in the image classification problem.

def epoch_loop(model, dataloader, optimiser, log_frequency=100):
    # usually the code for train/test has a large overlap.
    is_train = False if optimiser is None else True

    # model should be in train/eval model accordingly
    model.train() if is_train else model.eval()
    
    losses = []
    with torch.set_grad_enabled(is_train):
        for batch_ind, batch in enumerate(dataloader):
            batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
            loss = model(batch)
            losses.extend([loss.item() for _ in range(batch["image"].size(0))])

            if batch_ind % log_frequency == 0 and batch_ind > 0:
                print(
                    '%s batches [%.5d/%.5d] \tloss=%.4f' % (
                        'training' if is_train else 'testing', batch_ind,
                        len(dataloader), np.mean(losses)
                    )
                )

            if is_train:
                optimiser.zero_grad()
                loss.backward()
                optimiser.step()
    return losses
def main(args):
    args.data_dir = download_and_extract_db(args.data_dir)
    train_df, valid_df = make_train_valid_dfs(args.data_dir)
    train_loader = build_loaders(args, train_df, mode="train")
    valid_loader = build_loaders(args, valid_df, mode="valid")

    model = CLIPNet(args.model_vision, args.model_text, args.projection_dim)
    model = model.to(device)
    params = [
        {"params": model.image_encoder.parameters(), "lr": args.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": args.text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters()
        ), "lr": args.head_lr, "weight_decay": args.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)

    logs = {'train': [], 'val': []}
    for epoch in range(args.epochs):
        print("Epoch: [%.3d/%.3d]" % (epoch, args.epochs))
        model.train()
        train_loss = epoch_loop(model, train_loader, optimizer, args.log_frequency)
        logs['train'].append(np.mean(train_loss))
        valid_loss = epoch_loop(model, valid_loader, None)
        logs['val'].append(np.mean(valid_loss))

        print('Train     loss=%.4f     Test     loss=%.4f' % 
              (np.mean(train_loss), np.mean(valid_loss)))
        # saving the checkpoint
        torch.save(model.state_dict(), "%s/checkpoint.pth.tar" % args.out_dir)
    return logs

Let’s train our network for ten epochs:

args = set_args("--epochs", "10", "--data_dir", "./data/")
print(args)
logs = main(args)
Namespace(epochs=10, batch_size=8, num_workers=0, head_lr=0.001, image_encoder_lr=0.0001, text_encoder_lr=1e-05, weight_decay=0.001, projection_dim=256, model_vision='resnet50', model_text='distilbert-base-uncased', img_size=224, log_frequency=100, out_dir='./out/clip_out/', data_dir='./data/')
Using downloaded and verified file: ./data/TinyFlicker.tar.gz
Epoch: [000/010]
training batches [00100/00625] 	loss=4.4483
training batches [00200/00625] 	loss=3.0828
training batches [00300/00625] 	loss=2.4970
training batches [00400/00625] 	loss=2.1279
training batches [00500/00625] 	loss=1.8743
training batches [00600/00625] 	loss=1.6705
testing batches [00100/00125] 	loss=1.7055
Train     loss=1.6304     Test     loss=1.6853
Epoch: [001/010]
training batches [00100/00625] 	loss=0.4974
training batches [00200/00625] 	loss=0.4904
training batches [00300/00625] 	loss=0.4870
training batches [00400/00625] 	loss=0.4528
training batches [00500/00625] 	loss=0.4569
training batches [00600/00625] 	loss=0.4458
testing batches [00100/00125] 	loss=1.5938
Train     loss=0.4388     Test     loss=1.5842
Epoch: [002/010]
training batches [00100/00625] 	loss=0.2457
training batches [00200/00625] 	loss=0.2380
training batches [00300/00625] 	loss=0.2512
training batches [00400/00625] 	loss=0.2662
training batches [00500/00625] 	loss=0.2629
training batches [00600/00625] 	loss=0.2625
testing batches [00100/00125] 	loss=1.6866
Train     loss=0.2675     Test     loss=1.6693
Epoch: [003/010]
training batches [00100/00625] 	loss=0.2246
training batches [00200/00625] 	loss=0.2256
training batches [00300/00625] 	loss=0.2126
training batches [00400/00625] 	loss=0.2097
training batches [00500/00625] 	loss=0.2076
training batches [00600/00625] 	loss=0.2052
testing batches [00100/00125] 	loss=1.6065
Train     loss=0.2036     Test     loss=1.6124
Epoch: [004/010]
training batches [00100/00625] 	loss=0.2084
training batches [00200/00625] 	loss=0.1841
training batches [00300/00625] 	loss=0.1906
training batches [00400/00625] 	loss=0.1936
training batches [00500/00625] 	loss=0.2045
training batches [00600/00625] 	loss=0.2002
testing batches [00100/00125] 	loss=1.6642
Train     loss=0.2001     Test     loss=1.6516
Epoch: [005/010]
training batches [00100/00625] 	loss=0.1418
training batches [00200/00625] 	loss=0.1478
training batches [00300/00625] 	loss=0.1571
training batches [00400/00625] 	loss=0.1493
training batches [00500/00625] 	loss=0.1562
training batches [00600/00625] 	loss=0.1827
testing batches [00100/00125] 	loss=1.6275
Train     loss=0.1817     Test     loss=1.6241
Epoch: [006/010]
training batches [00100/00625] 	loss=0.1259
training batches [00200/00625] 	loss=0.1453
training batches [00300/00625] 	loss=0.1402
training batches [00400/00625] 	loss=0.1421
training batches [00500/00625] 	loss=0.1480
training batches [00600/00625] 	loss=0.1507
testing batches [00100/00125] 	loss=1.6815
Train     loss=0.1496     Test     loss=1.6759
Epoch: [007/010]
training batches [00100/00625] 	loss=0.1502
training batches [00200/00625] 	loss=0.1629
training batches [00300/00625] 	loss=0.1629
training batches [00400/00625] 	loss=0.1503
training batches [00500/00625] 	loss=0.1487
training batches [00600/00625] 	loss=0.1472
testing batches [00100/00125] 	loss=1.6893
Train     loss=0.1465     Test     loss=1.6839
Epoch: [008/010]
training batches [00100/00625] 	loss=0.1398
training batches [00200/00625] 	loss=0.1192
training batches [00300/00625] 	loss=0.1338
training batches [00400/00625] 	loss=0.1354
training batches [00500/00625] 	loss=0.1361
training batches [00600/00625] 	loss=0.1351
testing batches [00100/00125] 	loss=1.6247
Train     loss=0.1373     Test     loss=1.6101
Epoch: [009/010]
training batches [00100/00625] 	loss=0.1043
training batches [00200/00625] 	loss=0.1097
training batches [00300/00625] 	loss=0.1204
training batches [00400/00625] 	loss=0.1182
training batches [00500/00625] 	loss=0.1214
training batches [00600/00625] 	loss=0.1241
testing batches [00100/00125] 	loss=1.6235
Train     loss=0.1261     Test     loss=1.6347

Training progress#

Let’s plot the evolution of loss for both train and test sets.

fig = plt.figure(figsize=(7, 4))
fig.suptitle('CLIP image-text pairing')
ax = fig.add_subplot(1, 1, 1)
ax.plot(np.array(logs['train']), '-x', label="Train")
ax.plot(np.array(logs['val']), '-x', label="Test")
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")
ax.legend()
<matplotlib.legend.Legend at 0x7a3da2b18d90>
../_images/06ffc5c2c525a635450af1394c2f74dc3fbd0b2678315c3598d6899774bcf477.png

Prediction#

Now we have a model that matches any pairs of image-texts. We can use it for zero-short evaluation in several applications. In this tutorial, we use it for query matching.

Query matching#

The get_image_embeddings function:

  • Loads the model we have trained

  • Computed the image_embeddings for all images in the validation set and stored it in a list.

  • We use the list of image_embeddings later on for query matching.

def get_image_embeddings(valid_df, model_path):
    valid_loader = build_loaders(args, valid_df, mode="valid")

    model = CLIPNet(args.model_vision, args.model_text, args.projection_dim).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    valid_image_embeddings = []
    with torch.no_grad():
        for batch in valid_loader:
            image_features = model.image_encoder(batch["image"].to(device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)
_, valid_df = make_train_valid_dfs(args.data_dir)
model, image_embeddings = get_image_embeddings(valid_df, "%s/checkpoint.pth.tar" % args.out_dir)
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

The find_matches function searches in all image_embeddings and find those that best match the passed query argument.

def find_matches(model, image_embeddings, query, image_filenames, n=3):
    encoded_query = text_tokeniser([query])
    batch = {
        key: torch.tensor(values).to(device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = nn.functional.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = nn.functional.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    transforms = get_transforms(args.img_size, for_network=False)
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    for match, ax in zip(matches, axes.flatten()):
        image = pil_image.open("%s/imgs/%s" % (args.data_dir, match))
        image = transforms(image)
        ax.imshow(image)
        ax.axis("off")
    fig.suptitle('Query: %s' % query)

Let’s try our network with three similar phrases to evaluate how well it can distinguish sutle differences:

  • “A man on the mountains.”

  • “A man next to another human.”

  • “A man on the road.”

find_matches(
    model, image_embeddings, image_filenames=valid_df['image'].values,
    query="A man on the mountains."
)
../_images/1dd32ef1c2364fd492c31135e8200519597bdc2b7744170ea321ee9dcda8414b.png
find_matches(
    model, image_embeddings, image_filenames=valid_df['image'].values,
    query="A man next to another human."
)
../_images/f74e4a106b272d09fef8a83d570560e0f6db0ae6cd748f402a77bab455bbba49.png
find_matches(
    model, image_embeddings, image_filenames=valid_df['image'].values,
    query="A man on the road."
)
../_images/6b04cbc2c8dd2f19e6e206332086299fcfba59d21201b8c54365a83af7e9c1a9.png

Excercises#

Below is a list of exercises to practice what we have learnt in this notebook:

  1. Change the vision encoder from ResNet to another network, e.g. ViT.

  2. Plot the query matching results before any training and after each epoch. How fast do the results become qualitatively acceptable?

  3. Train the network without using the pretrained weights.

  4. Add data augmentation to the training pipeline (both for images and captions).

References#

The following sources inspire the materials in this notebook:

  1. OpenAI CLIP

  2. Simple CLIP in PyTorch