This repository contains training code used for my master thesis titled
Joint Multi-Modal Query-Document Representation Learning
which you can read here.
Model training is configured by config.yaml
file with training parameters.
Model training is done on Google Cloud Platform using Vertex AI Training with custom image. Configs, datasets and models are stored on Google Cloud Storage, gcsfuse is required.
The training flow is the following:
- Create local training config -
local_config.yaml
- Create training docker image using dedicated script. Point the training
script to
local_config.yaml
, as well as your GCP project. The script will output your${image_name}
. - Publish the docker image to gcr:
docker push ${image_name}
- Upload the training config to GCS .
- Run the gcp training script with correct
CONTAINER_IMAGE_URI
andconfig.yaml
gcs path.
Models checkpoint is saved after each epoch, model from last epoch is saved separately.
Evaluation metrics - recall@k
and mrr@k
are saved and can be visualized on tensorboard.
Embedding visualization can be optionally turned on if you want to play with it on TB projector.
There are 3 required datasets for training and evaluation (details are in the thesis).
- Training dataset - pairs of query, relevant document
- Evaluation queries dataset (
recall_validation_queries_dataset
) - pairs of query, relevant document id - Evaluation documents dataset (
recall_validation_items_dataset
) - candidate pool for evaluation
Supported training parameters
- task_id
- run_id
- num_epochs
- dataset_dir
- batch_size
- learning_rate
- reuse_epoch
- dataloader_workers
- dataset - structure of training features
- loss - can be
batch_softmax
ortriplet
- text_vectorizer - path to the token dictionary and tokenization config (word_unigram, word_bigram, char_trigram + oov)
- model - can be
SimpleTextEncoder
,TwoTower
, orMultiModalTwoTower
- recall_validation - for what 'k' validation should be run and whether to generate dataset with typos
Research papers can be found in the thesis. For the code part special thanks goes to:
- https://github.com/adambielski/siamese-triplet
- https://gist.github.com/danmelton/183313
- https://stackoverflow.com/a/58144658/7073537
Best engineering practices do not apply to master thesis, sorry
qdrl stands for Query Document Representation Learning
No.