In the previous post, you learned how to build a simple retrieval-augmented generation (RAG) system. RAG is a powerful approach for enhancing large language models with external knowledge and there are many variations in how to make it work better. In the following, you will see some advanced features and techniques to improve the performance of your RAG system. In particular, you will learn:
- How to improve the prompt used in RAG
- How to use hybrid retrieval to improve the quality of the retrieved documents
- How to use multi-stage retrieval with re-ranking to improve the quality of the generated responses
Let’s get started.
Advanced Techniques to Build Your RAG System
Photo by Limonovich. Some rights reserved.
Overview
This post is divided into three parts; they are:
- Query Expansion and Reformulation
- Hybrid Retrieval: Dense and Sparse Methods
- Multi-Stage Retrieval with Re-ranking
Query Expansion and Reformulation
One of the challenges in RAG systems is that the user’s query might not match the terminology used in the knowledge base. This is not a problem if a good model is used for generating the embeddings because the context of the query matters. However, you will never know if this is the case for a particular query.
Query expansion and reformulation can help bridge this gap by generating multiple versions of the query. It is under the assumption that with several variations of the same query, at least one of them can help retrieve the most relevant documents for RAG.
To do query expansion, you need a model that can generate variations of the input. BART is an example. Let’s see how you can use it for query expansion:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from transformers import BartForConditionalGeneration, BartTokenizer
# Load BART model and tokenizer tokenizer = BartTokenizer.from_pretrained(“facebook/bart-large”) model = BartForConditionalGeneration.from_pretrained(“facebook/bart-large”)
def reformulate_query(query, n=2): inputs = tokenizer(query, return_tensors=“pt”) outputs = model.generate( **inputs, max_length=64, num_beams=10, num_return_sequences=n, temperature=1.5, # High temperature for diversity top_k=50, do_sample=True ) # Decode the outputs one by one reformulations = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] all_queries = [query] + reformulations return all_queries
# Generate reformulations from an example query query = “How do transformer-based systems process natural language?” reformulated_queries = reformulate_query(query) print(f“Original Query: {query}”) print(“Reformulated Queries:”) for i, q in enumerate(reformulated_queries[1:], 1): print(f“{i}. {q}”) |
In this code, you load a pre-trained BART model and tokenizer. It is created as a BartForConditionalGeneration
object, which is a sequence-to-sequence model for text generation. Same as how you use a model in the Hugging Face transformers library, you tokenize the input and pass on to the model in the function reformulate_query()
. You asked the model to generate n
outputs for only one input.
To create more variations, you set the temperature slightly above 1, and you may try a value even higher. The generation using BART virtually asks the model to read your input and remember what it means in a “hidden state” and then decode the hidden state back into text, with possible variations. The multiple variations are created using beam search, and you may add more generation parameters if you prefer.
The multiple outputs are decoded one by one into text using the tokenizer. Then you can print them out at the end of the code. If you run this, you may see:
Original Query: How do transformer-based systems process natural language? Reformulated Queries: 1. How do transformer-based systems process natural language? 2. How do transformer-based systems work in natural language? |
The more ambiguity in your original query, the larger the variations you will get.
Hybrid Retrieval: Dense and Sparse Methods
The idea of RAG is to supplement the context of the query with the most relevant documents from the knowledge base. This additional information can help the model generate a better response. You can use different methods to find the relevant documents.
Dense vector retrieval means to represent the documents in your knowledge base as a high-dimensional vector. All dimensions in this vector are important, and no concrete reason to identify what each dimension represents. Usually, a dense vector is a vector of floating-point numbers that looks random.
A sparse vector, however, has many zeros. It is usually in a much higher dimension and a vector of integers. One example is the one-hot vector, in which each position represents a word in the vocabulary, and the value is 1 only if that word is present in the document.
Neither dense nor sparse vector is better. If you generate the dense vector with an embedding model, it is good at capturing semantic similarity. Sparse vector, on the other hand, is good at capturing keywords, usually. Operating on a sparse vector might consume a lot of memory, but you can reduce the workload by using techniques like Okapi BM25.
In the code below, you will need to install the library to compute the BM25 score. You can do this using pip:
Let’s see how you can combine both sparse and dense vectors to build a retrieval system:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from rank_bm25 import BM25Okapi from transformers import AutoTokenizer, AutoModel import faiss import numpy as np import torch
dense_tokenizer = AutoTokenizer.from_pretrained(“sentence-transformers/all-MiniLM-L6-v2”) dense_model = AutoModel.from_pretrained(“sentence-transformers/all-MiniLM-L6-v2”)
def generate_embedding(text): “”“Generate dense vector using mean pooling”“” inputs = dense_tokenizer(text, padding=True, truncation=True, return_tensors=“pt”, max_length=512) with torch.no_grad(): outputs = dense_model(**inputs)
attention_mask = inputs[‘attention_mask’] embeddings = outputs.last_hidden_state
expanded_mask = attention_mask.unsqueeze(–1).expand(embeddings.shape).float() sum_embeddings = torch.sum(embeddings * expanded_mask, axis=1) sum_mask = torch.clamp(expanded_mask.sum(axis=1), min=1e–9) mean_embeddings = sum_embeddings / sum_mask return mean_embeddings.cpu().numpy()
# Sample document collection documents = [ “Transformers use self-attention mechanisms to process input sequences in “ “parallel, making them efficient for long sequences.”, “The attention mechanism in transformers allows the model to focus on different “ “parts of the input sequence when generating each output element.”, “Transformer models have a fixed context length determined by the positional “ “encoding and self-attention mechanisms.”, “To handle sequences longer than the context length, transformers can use “ “techniques like sliding windows or hierarchical processing.”, “Recurrent Neural Networks (RNNs) process sequences sequentially, which can be “ “inefficient for long sequences.”, “Long Short-Term Memory (LSTM) networks are a type of RNN designed to handle “ “long-term dependencies in sequences.”, “The Transformer architecture was introduced in the paper ‘Attention Is All “ “You Need’ by Vaswani et al.”, “BERT (Bidirectional Encoder Representations from Transformers) is a “ “transformer-based model designed for understanding the context of words.”, “GPT (Generative Pre-trained Transformer) is a transformer-based model designed “ “for natural language generation.”, “Transformer-XL extends the context length of transformers by using a “ “segment-level recurrence mechanism.” ]
# Prepare for sparse retrieval (BM25) tokenized_corpus = [doc.lower().split() for doc in documents] bm25 = BM25Okapi(tokenized_corpus)
# Prepare for dense retrieval (FAISS) document_embeddings = generate_embedding(documents) dimension = document_embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(document_embeddings)
def hybrid_retrieval(query, k=3, alpha=0.5): “”“Hybrid retrieval: Use both the BM25 and L2 index on FAISS”“” # Sparse score of each document with BM25 tokenized_query = query.lower().split() bm25_scores = bm25.get_scores(tokenized_query)
# Normalize BM25 scores to [0,1] unless all elements are zero if max(bm25_scores) > 0: bm25_scores = bm25_scores / max(bm25_scores)
# Sort all documents according to L2 distance to query query_embedding = generate_embedding(query) distances, indices = index.search(query_embedding, len(documents))
# Dense score: 1/distance as similarity metric, then normalize to [0,1] eps = 1e–5 # a small value to prevent division by zero dense_scores = 1 / (eps + np.array(distances[0])) dense_scores = dense_scores / max(dense_scores)
# Combine scores = affine combination of sparse and dense scores combined_scores = alpha * dense_scores + (1 – alpha) * bm25_scores
# Get top-k documents top_indices = np.argsort(combined_scores)[::–1][:k] results = [(documents[idx], combined_scores[idx]) for idx in top_indices] return results
# Retrieve documents using hybrid retrieval query = “How do transformers handle long sequences?” results = hybrid_retrieval(query) print(f“Query: {query}”) for i, (doc, score) in enumerate(results): print(f“Document {i+1} (Score: {score:.4f}):”) print(doc) print() |
When you run this, you will see:
Query: How do transformers handle long sequences?
Document 1 (Score: 0.7924): Transformers use self-attention mechanisms to process input sequences in parallel, making them efficient for long sequences.
Document 2 (Score: 0.7458): Long Short-Term Memory (LSTM) networks are a type of RNN designed to handle long-term dependencies in sequences.
Document 3 (Score: 0.7131): To handle sequences longer than the context length, transformers can use techniques like sliding windows or hierarchical processing. |
At the beginning, you create an Okapi BM25 index for all documents in your collection. Okapi BM25 is a TF-IDF-based scoring method, which means it compares two texts by checking the intersection of the exact words. In this sense, capitalization is not important. So you convert the documents to lowercase in using BM25.
Then, you generate the dense vector for your document collection using a pre-trained sentence transformer model. You stored these dense vectors in a FAISS index for efficient similarity search using L2 distance.
The key part of this code is in the function hybrid_retrieval()
. With the Okapi BM25 and FAISS index prepared, you look for the document that best fits your query string. The BM25 score obtained is a TF-IDF score corresponding to each document. You also computed the L2 distance metric from FAISS for each document. Then this distance is converted to scores to match that of BM25: a higher score should mean a better match. To make sure the two methods are comparable, you normalize the scores to the range [0, 1].
Depending on your choice, you can put more emphasis on the dense retrieval or the sparse retrieval by changing the parameter alpha
. The combined score is then used to find the top-k documents to return. As you can see from the output above.
This hybrid approach often outperforms either method alone, especially for complex queries where both semantic understanding and specific terminology are important.
Multi-Stage Retrieval with Re-ranking
If you have a perfect model to score the relevance of the documents to your query, a simple retrieval system is enough. However, no model is perfect. Indeed, usually the model of higher quality is also more computationally intensive. This is where multi-stage retrieval comes in.
Hybrid retrieval is good at picking documents quickly. Especially if you use a fast model, you can easily compute the score for a lot of documents. The pick, however, is not always good. But you can use a slower but more accurate model to recompute the score. This time, not all documents are considered, but only those picked by the hybrid retrieval. As long as the model used in the first stage is roughly correct, the more computationally intensive model in the second stage will give you an accurate selection.
This is what the multi-stage retrieval technique is about. Let’s see how you can implement this:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
...
# Load pre-trained model and tokenizer for re-ranking reranker_tokenizer = AutoTokenizer.from_pretrained(“cross-encoder/ms-marco-MiniLM-L-6-v2”) reranker_model = AutoModelForSequenceClassification.from_pretrained(“cross-encoder/ms-marco-MiniLM-L-6-v2”)
def rerank(query, documents, top_k=3): “”“Sort documents by the reranker model and select top-k”“” # Prepare inputs for the re-ranker pairs = [[query, doc] for doc in documents] features = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors=“pt”) # Get re-ranking scores with torch.no_grad(): scores = reranker_model(**features).logits.squeeze(–1).cpu().numpy() # Sort documents by score, then pick top-k ranked_indices = np.argsort(scores)[::–1][:top_k] reranked_docs = [(documents[idx], float(scores[idx])) for idx in ranked_indices] return reranked_docs
def multi_stage_retrieval(query, documents, initial_k=5, final_k=3): “”“Multi-stage retrieval: Hybrid retrievel to shortlist documents, then pick with a reranker”“” # Stage 1: Initial retrieval using hybrid method initial_results = hybrid_retrieval(query, k=initial_k) initial_docs = [doc for doc, _ in initial_results] # Stage 2: Re-ranking reranked_results = rerank(query, initial_docs, top_k=final_k) return reranked_results
# Example query query = “How do transformers handle long sequences?” results = multi_stage_retrieval(query, documents) print(f“Query: {query}”) print(“Re-ranked Results:”) for i, (doc, score) in enumerate(results): print(f“Document {i+1} (Score: {score:.4f}):”) print(doc) print() |
This code is built on top of the previous section. It uses the same hybrid_retrieval()
function as before.
In the function multi_stage_retrieval()
, you first use the hybrid retrieval to get a list of documents. Then you use the re-ranking model to re-rank these documents.
The re-ranking model is a cross-encoder model, a type of transformer model that can be used for ranking tasks. In simple terms, it takes two sequences as input, which are concatenated in the format of [CLS] query [SEP] document [SEP]
. The model’s output is a score, showing the document’s relevance to the query. This is a slow model but more accurate than L2 distance or BM25.
In the function rerank()
, you run the re-ranking model on the query and the documents that are shortlisted by the hybrid retrieval. Then you pick the top-k documents based on the score as provided by the re-ranking model. The parameters initial_k
and final_k
in the function multi_stage_retrieval()
let you control the trade-off between recall (retrieving all relevant documents) and precision (ensuring retrieved documents are relevant). A larger initial_k
increases recall but requires more re-ranking computation, while a smaller final_k
focuses on the most relevant documents.
Further Reading
Below are some further readings that you may find useful:
Summary
In this tutorial, you’ve explored several advanced techniques for enhancing RAG systems. For a given generator model, the success of a RAG system largely depends on whether you can provide useful context as well as accurately describe your expected output in the prompt. You learned how to improve the retriever and create a better prompt. In particular, you learned:
- Use query expansion to try out different ways to instruct the model
- Use hybrid retrieval to combine dense and sparse retrieval so that you can retrieve more relevant documents
- Use multi-stage retrieval with re-ranking to improve the quality of the retrieved documents
These advanced features can significantly improve the performance and capabilities of RAG systems, making them more effective for a wide range of applications.