Skip to content

Commit

Permalink
RAG langchain_rag_chroma_preload
Browse files Browse the repository at this point in the history
  • Loading branch information
weitsung50110 committed Dec 3, 2024
1 parent ebfddc9 commit 0305161
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 0 deletions.
100 changes: 100 additions & 0 deletions RAG/langchain_rag_chroma_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import shutil
from langchain_chroma import Chroma
from langchain.schema import Document
from langchain_community.embeddings import OllamaEmbeddings

# 向量資料庫存儲目錄
persist_directory = "chroma_vectorstore_nba"

# 初始化嵌入模型
embeddings = OllamaEmbeddings(model="kenneth85/llama-3-taiwan:8b-instruct")

# 初始化向量數據庫
if os.path.exists(persist_directory):
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
else:
os.makedirs(persist_directory)
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)

# 查看所有數據
def view_all():
documents = vectorstore._collection.get(include=["metadatas", "documents"])
if not documents["documents"]:
print("資料庫為空!")
else:
for doc, metadata in zip(documents["documents"], documents["metadatas"]):
print(f"內容: {doc}, 元數據: {metadata}")

# 新增數據
def add_data():
print("輸入新增的數據 (格式: 玩家名, 場均得分, 助攻, 球隊)")
data = input("輸入數據: ").split(",")
if len(data) != 4:
print("輸入格式錯誤!請重新嘗試。")
return
player, points_per_game, assists_per_game, team = data
document = Document(
page_content=player.strip(),
metadata={
"player": player.strip(),
"points_per_game": float(points_per_game.strip()),
"assists_per_game": float(assists_per_game.strip()),
"team": team.strip(),
},
)
vectorstore.add_documents([document])
print(f"新增數據成功!玩家: {player.strip()}")

# 刪除所有數據
def delete_all():
confirm = input("確定要刪除所有數據嗎?(yes/no): ")
if confirm.lower() == "yes":
shutil.rmtree(persist_directory)
os.makedirs(persist_directory)
global vectorstore
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
print("所有數據已刪除!")
else:
print("操作已取消。")

# 查詢相關數據
def query_data():
query = input("輸入查詢內容: ")
retriever = vectorstore.as_retriever()
results = retriever.get_relevant_documents(query)
if results:
print("---檢索到的相關數據:")
for res in results:
print(f"球員: {res.metadata['player']}, 場均得分: {res.metadata['points_per_game']}, 助攻: {res.metadata['assists_per_game']}, 球隊: {res.metadata['team']}")
else:
print("未找到相關數據!")

# 主程序
def main():
print("Chroma 向量數據庫管理系統")
print("功能列表:")
print("1. 查看所有數據 (view_all)")
print("2. 新增數據 (add_data)")
print("3. 刪除所有數據 (delete_all)")
print("4. 查詢相關數據 (query_data)")
print("輸入 'exit' 退出系統")

while True:
command = input("輸入功能指令: ").strip()
if command == "exit":
print("退出系統。")
break
elif command == "view_all":
view_all()
elif command == "add_data":
add_data()
elif command == "delete_all":
delete_all()
elif command == "query_data":
query_data()
else:
print("未知指令!請輸入正確的功能指令。")

if __name__ == "__main__":
main()
114 changes: 114 additions & 0 deletions RAG/langchain_rag_chroma_preload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import hashlib
import os
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import Document

# 初始化 Callback Manager
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

# 初始化 Ollama LLM
llm = Ollama(
model="kenneth85/llama-3-taiwan:8b-instruct",
callback_manager=callback_manager
)

# 初始化 Ollama Embeddings
embeddings = OllamaEmbeddings(model="kenneth85/llama-3-taiwan:8b-instruct")

# 向量資料庫存儲目錄
persist_directory = "chroma_vectorstore_nba"

# 範例 NBA 數據
nba_data = [
{"player": "好崴寶Weibert", "points_per_game": 30.1, "assists_per_game": 5.7, "team": "Weibert"},
{"player": "孟孟", "points_per_game": 29.7, "assists_per_game": 8.7, "team": "Mengbert"},
{"player": "崴崴Weiberson", "points_per_game": 25.1, "assists_per_game": 10.5, "team": "Weiberson"}
]

# 計算數據哈希
def calculate_data_hash(data):
hasher = hashlib.md5()
hasher.update(str(data).encode('utf-8'))
return hasher.hexdigest()

# 哈希文件存儲路徑
hash_file_path = os.path.join(persist_directory, "data_hash.txt")

# 檢查是否需要更新數據庫
def needs_update(data, hash_file_path):
new_hash = calculate_data_hash(data)
if not os.path.exists(hash_file_path):
return True, new_hash
with open(hash_file_path, "r") as f:
existing_hash = f.read().strip()
return new_hash != existing_hash, new_hash

# 檢查是否需要更新向量數據庫
update_required, new_hash = needs_update(nba_data, hash_file_path)

if os.path.exists(persist_directory) and not update_required:
print("---正在加載現有的向量數據庫...")
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
else:
print("---數據已更新,正在重新生成數據庫...")

# 清理舊數據庫
if os.path.exists(persist_directory):
print("---正在刪除舊的向量數據庫...")
import shutil
shutil.rmtree(persist_directory)

# 步驟 1:生成嵌入並存儲數據
print("---正在生成向量數據庫...")

# 將數據轉換為 Document 對象
documents = [
Document(page_content=data["player"], metadata=data) for data in nba_data
]

# 創建 Chroma 向量數據庫
vectorstore = Chroma.from_documents(
documents=documents,
embedding=embeddings,
persist_directory=persist_directory # 啟用持久化
)

# 保存新的數據哈希值
with open(hash_file_path, "w") as f:
f.write(new_hash)
print("---向量數據庫已保存!")

# 步驟 2:建立檢索問答鏈
retriever = vectorstore.as_retriever()
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

# 提供查詢功能
def query_system():
while True:
query = input("請輸入查詢內容(輸入 'bye' 退出):")
if query.lower() == "bye":
print("已退出查詢系統。")
break

# 檢索相關數據
results = retriever.get_relevant_documents(query)
if results:
print("---檢索到的相關數據:")
# 整理上下文
context = "\n".join(
[f"球員: {res.metadata['player']}, 場均得分: {res.metadata['points_per_game']}, 助攻: {res.metadata['assists_per_game']}, 球隊: {res.metadata['team']}" for res in results]
)
print(context) # 檢查上下文
# 傳遞給 LLM
prompt = f"以下是相關數據:\n{context}\n\n根據以上數據,回答以下問題:{query}"
response = qa_chain.invoke(prompt)
print(f"回答:{response}")
else:
print("未找到相關數據,請嘗試其他查詢。")

query_system()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 0305161

Please sign in to comment.