Fine-tuning: Adapting embeddings to your specific domain

Introduction

In this chapter, we are going to explore fine-tuning of embedding models and briefly mention embedding adapters. Fine-tuning adapts embedding models to domain-specific data, optimizing both query and document representations for better alignment. Embedding adapters, on the other hand, adjust query embeddings dynamically, offering a cost-efficient alternative without requiring re-indexing of the document corpus.

Fine-Tuning Embedding Models

Fine-tuning offers flexible options for adapting an embedding model to domain-specific tasks or retrieval objectives. You can retrain the entire model to deeply integrate domain nuances, or focus on specific layers—often the final ones or newly added layers—similar to techniques used in Convolutional neural nets. This allows for tailored optimization of embeddings, ensuring that queries and documents align more closely with the desired relevance criteria while balancing computational complexity.

How It Works

  1. Data Preparation
  2. A labeled dataset of query-document pairs with relevance scores is required. This dataset can be created manually or generated synthetically using models like LLMs.

  3. Training Process
    1. The model learns to adjust the embedding space such that relevant query-document pairs are closer, and irrelevant pairs are farther apart. This involves:

    2. Forward propagation to compute similarity scores for query-document pairs.
    3. Loss computation, such as contrastive loss or triplet loss, to optimize distances in the embedding space.
    4. Backpropagation to update the model's weights across all layers.
  4. Inference
  5. Once fine-tuned, both query and document embeddings are re-encoded, ensuring that future retrievals reflect the improved alignment. This step requires reprocessing your entire corpus to generate updated document embeddings.

Key Considerations

  • Advantages:
    • Provides deep integration with domain-specific semantics.
    • Ideal for retrieval tasks requiring high precision in specialized fields like legal, medical, or technical documentation.
  • Drawbacks:
    • Requires significant labeled data and computational resources.
    • Necessitates re-embedding the entire corpus post-training, which can be costly for large datasets.

Code Example

Code adapted from our friends at llamaindex 💙: Finetuning an Adapter on Top of any Black-Box Embedding Model - LlamaIndexFinetuning an Adapter on Top of any Black-Box Embedding Model - LlamaIndex

Download data

We start by downloading Uber and Lyft annual reports from 2021.

#!pip install llama-index-embeddings-huggingface llama_index.finetuning llama-index-readers-file
#!mkdir -p 'data/10k/'
#!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'
#!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
TRAIN_FILES = ["./data/10k/lyft_2021.pdf"]
VAL_FILES = ["./data/10k/uber_2021.pdf"]

TRAIN_CORPUS_FPATH = "./data/train_corpus.json"
VAL_CORPUS_FPATH = "./data/val_corpus.json"

Next we define function for loading the corpus:

import json
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import MetadataMode

def load_corpus(files, verbose=False):
    if verbose:
        print(f"Loading files {files}")

    reader = SimpleDirectoryReader(input_files=files)
    docs = reader.load_data()
    if verbose:
        print(f"Loaded {len(docs)} docs")

    parser = SentenceSplitter()
    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)

    if verbose:
        print(f"Parsed {len(nodes)} nodes")

    return nodes

We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset.

train_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)

Generate synthetic queries

Now, we use an LLM (gpt-4o) to generate questions using each text chunk in the corpus as context.

Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation).

Set up an LLM:

from llama_index.llms.openai import OpenAI
import getpass
import os

# Prompt the user to enter the OpenAI API key securely
api_key = getpass.getpass(prompt='Enter your OpenAI API key: ')

# Set the OpenAI API key as an environment variable
os.environ["OPENAI_API_KEY"] = api_key

# Initialize the LLM with GPT-4o
llm = OpenAI(model="gpt-4o", temperature=0)

Create queries:

from llama_index.finetuning import generate_qa_embedding_pairs
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

train_dataset = generate_qa_embedding_pairs(train_nodes, llm)
val_dataset = generate_qa_embedding_pairs(val_nodes, llm)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")
# [Optional] Load
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")

Fine-tuning

This section demonstrates a basic setup for fine-tuning an embedding model using the EmbeddingAdapterFinetuneEngine. It begins by resolving a base embedding model (BGE-small), which serves as the starting point for training. The finetuning engine is initialized with the training dataset, specifying parameters like the number of epochs (12), an optimizer (Adam), and a learning rate (0.001).

The finetune() method executes the fine-tuning process, and the resulting fine-tuned model is retrieved for use. While this setup is functional, it is simplistic and serves as a starting point. For optimal results, you would perform hyperparameter optimization—experimenting with learning rates, batch sizes, and other training configurations to maximize performance.

from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.core.embeddings import resolve_embed_model
import torch

# Resolve base embedding model (BGE-small)
base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")

# Initialize and configure the finetuning engine
finetune_engine = EmbeddingAdapterFinetuneEngine(
    dataset=train_dataset,
    embed_model=base_embed_model,
    model_output_path="model_output_test",
    epochs=12,  # Adjust as needed
    verbose=True,
    optimizer_class=torch.optim.Adam,  # Optional customization
    optimizer_params={"lr": 0.001}     # Optional customization
)

# Run fine-tuning
finetune_engine.finetune()

# Retrieve the fine-tuned embedding model
embed_model = finetune_engine.get_finetuned_model()

Helper functions

Below are helper functions provided by LlamaIndex that are used for analyzing the results of the fine-tuned embedding model. The evaluate function tests the retrieval performance of the model by comparing its output to expected results and calculating metrics like hit rate and Mean Reciprocal Rank (MRR). The display_results function aggregates and visualizes these metrics for multiple retrievers, providing a clear comparison of their performance. These functions will help you assess the effectiveness of your fine-tuned model.

Helper functions

Getting the results

This section evaluates the performance of both the base embedding model (bge) and the fine-tuned model (ft) on the validation dataset. The evaluate function is used to compute retrieval metrics like hit rate and MRR for each model. Finally, the display_results function presents a side-by-side comparison of the two models, allowing you to see the impact of fine-tuning on retrieval performance. This is where you validate whether the fine-tuned model has improved over the base model.

ft_val_results = evaluate(val_dataset, embed_model)
bge = "local:BAAI/bge-small-en"
bge_val_results = evaluate(val_dataset, bge)
display_results(
    ["bge", "ft"], [bge_val_results, ft_val_results]
)

Results

Retriever
Hit Rate
MRR
bge
79.6%
0.624989
ft
96.1%
0.818652

The results compare the performance of the base embedding model (bge) with the fine-tuned model (ft) on the validation dataset:

  1. Hit Rate: Represents how often the correct document is included in the top-k results.
    • bge: 79.6%
    • ft: 96.1%
    • Interpretation: The fine-tuned model exhibits a significant improvement, indicating much more consistent retrieval of relevant documents.
  2. MRR (Mean Reciprocal Rank): Indicates how highly the correct document is ranked in the results.
    • bge: 0.624989
    • ft: 0.818652
    • Interpretation: The fine-tuned model shows a substantial improvement, ranking the correct documents considerably higher in the results.

Observations:

  • The fine-tuned model outperforms the base model in both hit rate and MRR by a large margin, underscoring the value of fine-tuning for retrieval tasks.
  • Even with a basic training setup and limited examples, the fine-tuned model demonstrates significant gains, suggesting strong potential for further improvements with additional optimizations.

Potential Next Steps:

  • Validate the results through statistical testing (e.g., paired t-tests or bootstrapping) to confirm the significance of the observed differences.
  • Experiment with larger datasets, better sampling strategies, and hyperparameter tuning to further enhance the fine-tuned model's performance.

Further resources:

Embedding Adapters

Embedding adapters are a promising alternative to fine-tuning, offering a lightweight method to adjust embeddings dynamically during retrieval. Unlike the fine-tuning example discussed earlier, where adapters were additional trainable layers requiring the re-indexing of the entire vector database, these adapters operate exclusively on the query side. This eliminates the need to re-embed and re-index the document corpus, resulting in significant cost and computational savings. Instead of modifying the embedding model or document embeddings, query embeddings are transformed through a learned matrix, adapting them to specific relevance criteria on the fly.

The core idea behind embedding adapters is to apply a transformation matrix with the same dimensions as the embeddings themselves. This matrix amplifies relevant dimensions of the embedding space while suppressing less relevant ones, dynamically tailoring the query embedding to the task at hand. While this approach has seen limited exploration in the broader AI community, it holds considerable promise for scenarios requiring adaptive retrieval without the overhead of full re-indexing. By refining query embeddings in real-time, embedding adapters offer a flexible and cost-efficient way to enhance retrieval accuracy in domain-specific RAG pipelines.

So what?

I highly recommend fine-tuning for your specific use case. In his amazing Stanford lecture, Douwe Kiela described the traditional RAG setup as a "Frankenstein monster" – a collection of components stitched together without much focus on optimization. By fine-tuning embedding models or incorporating embedding adapters, you take a significant step toward building RAG systems that are carefully tailored and optimized for your unique requirements. It is definitely worth exploring!

Conclusion

We demonstrated how fine-tuning embedding models, even with a basic setup, can improve retrieval accuracy by better aligning embeddings to the training data. Additionally, we introduced embedding adapters as a lightweight option for refining query embeddings, which can enhance retrieval without reprocessing the corpus. These techniques are practical tools for increasing the likelihood of retrieving relevant information in domain-specific contexts.

In the next chapter, we will cover Multimodal RAG! This is a new and exciting field, modern databases (cough cough Deeplake) allow you to embed images the same way as you embed text and then you can query both at the same time! 🤯

Jupyter: Google ColabGoogle Colab