Multimodal retrieval-augmented generation (RAG) is transforming how AI applications handle complex information by merging retrieval and generation capabilities across diverse data types, such as text, images, and video.
Unlike traditional RAG, which typically focuses on text-based retrieval and generation, multimodal RAG systems can pull in relevant content from both text and visual sources to generate more contextually rich, comprehensive responses.
This article, the seventh installment in our Building Multimodal RAG Applications series, dives into building multimodal RAG systems with LangChain.
We will wrap all the modules created in the previous articles in LangChain chains using RunnableParallel, RunnablePassthrough, and RunnableLambda methods from LangChain.
This article is the seventh in the ongoing series of Building Multimodal RAG Application:
Multimodal RAG with Multimodal LangChain (You are here!)
Putting it All Together! Building Multimodal RAG Application (Coming soon!)
You can find the codes and datasets used in this series in this GitHub Repo
Table of Contents:
Setting Up Working Environment
Invoke the Multimodal RAG System with a Query
Multimodal RAG System Showing Retrieved Image/Frame
1. Setting Up Working Environment
The first step as usual is to set up our working environment. We will start with installing predictionguard through which we will access our Large Vision Language Model (LVLM)
!pip install predictionguard
Next, we will define the get_prediction_guard_api_key function to get and the API key for predictionguard API.
import os
from dotenv import load_dotenv
def get_prediction_guard_api_key():
load_env()
PREDICTION_GUARD_API_KEY = os.getenv("PREDICTION_GUARD_API_KEY", None)
if PREDICTION_GUARD_API_KEY is None:
PREDICTION_GUARD_API_KEY = input("Please enter your Prediction Guard API Key: ")
return PREDICTION_GUARD_API_KEY
Then we are going to import and use the following form Langchain:
The RunnableParallel primitive is essentially a dict whose values are runnables (or things that can be coerced to runnables, like functions). It runs all of its values in parallel, and each value is called with the overall input of the RunnableParallel. The final return value is a dict with the results of each value under its appropriate key.
The RunnablePassthrough allows you to pass inputs unchanged. It is typically used in conjunction with RunnableParallel to pass data through to a new key in the map.
The RunnableLambda converts a Python function into a Runnable. Wrapping a function in a RunnableLambda makes the function usable within either a sync or async context.
import lancedb
from PIL import Image
from langchain_core.runnables import (
RunnableParallel,
RunnablePassthrough,
RunnableLambda
)
Finally, we will combine all the previous modules from previous articles into a chain to create a Multimodal RAG system.
# combine all the modules into a chain
# to create Multimodal RAG system
mm_rag_chain = (
RunnableParallel({
"retrieved_results": retriever_module ,
"user_query": RunnablePassthrough()
})
| prompt_processing_module
| lvlm_inference_module
)
Now we are ready to invoke our multimodal RAG systems with queries and get answers.