-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchatbot.py
58 lines (47 loc) · 1.93 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate
from langchain.chains import ConversationalRetrievalChain
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
# Get the API key from the environment variable
openai_api_key = ""
if not openai_api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
# Initialize OpenAI Chat model
llm = ChatOpenAI(
model_name="gpt-4o-mini",
temperature=0.7,
max_tokens=500,
api_key=openai_api_key # Explicitly pass the API key here
)
# Initialize memory for conversation
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Create embedding model
embeddings = OpenAIEmbeddings(api_key=openai_api_key) # Also pass the API key here
# Initialize Chroma DB (vector database)
vector_db = Chroma(embedding_function=embeddings, collection_name="my_collection",persist_directory="./my_chroma_db")
# Create a ContextualCompressionRetriever
retriever = ContextualCompressionRetriever(
base_compressor=LLMChainExtractor.from_llm(llm),
base_retriever=vector_db.as_retriever()
)
# Create a template for combining the memory, database, and LLM
prompt_template = ChatPromptTemplate.from_template("""
Context: {context}
Chat History: {chat_history}
Human: {question}
AI: Please provide a relevant answer based on the context and chat history.
""")
# Create the ConversationalRetrievalChain
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
combine_docs_chain_kwargs={"prompt": prompt_template}
)
def chatbot_response(user_input):
# Get the response based on user input
return conversation_chain({"question": user_input})["answer"]