-
Notifications
You must be signed in to change notification settings - Fork 0
/
web_interface.py
102 lines (85 loc) · 3.62 KB
/
web_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import os
import gradio as gr
import shutil
import uuid
from db_image_ingestion import add_to_chroma, clear_database, process_images
from document_preprocessor import load_documents
from text_chunker import split_documents
from query_processor import query_rag
from config import DATA_PATH, CHROMA_PATH
from langchain.schema import Document
class WebInterface:
def __init__(self):
self.data_path = DATA_PATH
self.chroma_path = CHROMA_PATH
def reset_database(self):
print("✨ Clearing Database")
clear_database()
def train_model(self, file_paths):
if not os.path.exists(self.data_path):
os.makedirs(self.data_path)
try:
# Process documents
documents = load_documents(self.data_path)
# Generate unique IDs for each document before splitting
for doc in documents:
if "id" not in doc.metadata:
doc.metadata["id"] = str(uuid.uuid4())
# Split documents into chunks while preserving metadata
chunks = split_documents(documents)
# Add to database
if add_to_chroma(chunks):
return "Training completed successfully."
else:
return "Error occurred during database addition."
except Exception as e:
return f"An error occurred during training: {e}"
def test_model(self, query):
try:
result = query_rag(query)
response_text = result.get("Response", "No response generated.")
sources = result.get("sources", [])
formatted_response = f"Response: {response_text}\n\nSources: {sources}"
return formatted_response
except Exception as e:
return f"An error occurred during testing: {e}"
def create_interface(self):
# Define the train interface
with gr.Blocks() as train_interface:
gr.Markdown("# Train Model")
with gr.Row():
with gr.Column():
file_upload = gr.File(
label="Upload Training Data",
file_count="multiple",
file_types=[".pdf", ".png", ".jpg", ".jpeg"],
type="filepath"
)
train_button = gr.Button("Train")
with gr.Column():
train_response = gr.Textbox(label="Training Response", lines=5)
train_button.click(self.train_model, inputs=file_upload, outputs=train_response)
# Define the test interface
with gr.Blocks() as test_interface:
gr.Markdown("# Test Model")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(label="Enter Query", lines=2)
test_button = gr.Button("Test")
with gr.Column():
test_response = gr.Textbox(label="Test Response", lines=5)
test_button.click(self.test_model, inputs=query_input, outputs=test_response)
# Combine the interfaces into a single app
return gr.TabbedInterface([train_interface, test_interface], ["Train", "Test"])
def main():
interface = WebInterface()
parser = argparse.ArgumentParser(description="RAG System Interface")
parser.add_argument("--reset", action="store_true", help="Reset the database.")
args = parser.parse_args()
if args.reset:
interface.reset_database()
app = interface.create_interface()
app.launch()
if __name__ == "__main__":
main()