Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to reuse existing chunks #3983

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@
from flask import request
from flask_login import login_required, current_user

from api.db.db_models import Task, File
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search

from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.db_models import File, Task
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.task_service import queue_tasks
from api.db.services.user_service import UserTenantService
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.task_service import TaskService
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.utils.api_utils import (
server_error_response,
get_data_error_result,
validate_request,
)
from api.utils import get_uuid
from api import settings
from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL
Expand Down Expand Up @@ -316,6 +322,7 @@ def rm():

b, n = File2DocumentService.get_storage_address(doc_id=doc_id)

TaskService.filter_delete([Task.doc_id == doc_id])
if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
Expand Down Expand Up @@ -361,11 +368,12 @@ def run():
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)

if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
Expand Down
14 changes: 14 additions & 0 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ class Task(DataBaseModel):
help_text="process message",
default="")
retry_count = IntegerField(default=0)
digest = TextField(null=True, help_text="task digest", default="")
chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")


class Dialog(DataBaseModel):
Expand Down Expand Up @@ -1090,4 +1092,16 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default=""))
)
except Exception:
pass

try:
migrate(
migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
)
except Exception:
pass
25 changes: 25 additions & 0 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,31 @@ def get_embd_id(cls, doc_id):
return
return docs[0]["embd_id"]

@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
configs = (
cls.model.select(
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
Knowledgebase.language,
Knowledgebase.embd_id,
Tenant.id.alias("tenant_id"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
return None
return configs[0]

@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
Expand Down
100 changes: 86 additions & 14 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#
import os
import random
import xxhash
import bisect

from api.db.db_utils import bulk_insert_into_db
from deepdoc.parser import PdfParser
Expand All @@ -29,7 +31,21 @@
from rag.settings import SVR_QUEUE_NAME
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search

def trim_header_by_lines(text: str, max_length) -> str:
if len(text) <= max_length:
return text
lines = text.split("\n")
total = 0
idx = len(lines) - 1
for i in range(len(lines)-1, -1, -1):
if total + len(lines[i]) > max_length:
break
idx = i
text2 = "\n".join(lines[idx:])
return text2

class TaskService(CommonService):
model = Task
Expand Down Expand Up @@ -87,6 +103,30 @@ def get_task(cls, task_id):

return docs[0]

@classmethod
@DB.connection_context()
def get_tasks(cls, doc_id: str):
fields = [
cls.model.id,
cls.model.from_page,
cls.model.progress,
cls.model.digest,
cls.model.chunk_ids,
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks

@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()

@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
Expand Down Expand Up @@ -133,22 +173,18 @@ def get_ongoing_doc_name(cls):
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
try:
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except Exception:
pass
return False
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0

@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
Expand All @@ -157,9 +193,9 @@ def update_progress(cls, id, info):

with DB.lock("update_progress", -1):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
Expand All @@ -168,7 +204,7 @@ def update_progress(cls, id, info):

def queue_tasks(doc: dict, bucket: str, name: str):
def new_task():
return {"id": get_uuid(), "doc_id": doc["id"]}
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0}

tsks = []

Expand Down Expand Up @@ -203,10 +239,46 @@ def new_task():
else:
tsks.append(new_task())

chunking_config = DocumentService.get_chunking_config(doc["id"])
for task in tsks:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
task_digest = hasher.hexdigest()
task["digest"] = task_digest
task["progress"] = 0.0

prev_tasks = TaskService.get_tasks(doc["id"])
if prev_tasks:
for task in tsks:
reuse_prev_task_chunks(task, prev_tasks, chunking_config)
TaskService.filter_delete([Task.doc_id == doc["id"]])
chunk_ids = []
for task in prev_tasks:
if task["chunk_ids"]:
chunk_ids.extend(task["chunk_ids"].split())
if chunk_ids:
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])

bulk_insert_into_db(Task, tsks, True)
DocumentService.begin2parse(doc["id"])

tsks = [task for task in tsks if task["progress"] < 1.0]
for t in tsks:
assert REDIS_CONN.queue_product(
SVR_QUEUE_NAME, message=t
), "Can't access Redis. Please check the Redis' status."

def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"])
if idx >= len(prev_tasks):
return
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
return
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks"
prev_task["chunk_ids"] = ""
Loading