Skip to content

Commit

Permalink
Try to reuse existing chunks. Close #3793
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzhichang committed Dec 12, 2024
1 parent 835fd7a commit b94bd8d
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 83 deletions.
28 changes: 18 additions & 10 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@
from flask import request
from flask_login import login_required, current_user

from api.db.db_models import Task, File
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search

from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.db_models import File, Task
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.task_service import queue_tasks
from api.db.services.user_service import UserTenantService
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.task_service import TaskService
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.utils.api_utils import (
server_error_response,
get_data_error_result,
validate_request,
)
from api.utils import get_uuid
from api import settings
from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL
Expand Down Expand Up @@ -316,6 +322,7 @@ def rm():

b, n = File2DocumentService.get_storage_address(doc_id=doc_id)

TaskService.filter_delete([Task.doc_id == doc_id])
if not DocumentService.remove_document(doc, tenant_id):
return get_data_error_result(
message="Database error (Document removal)!")
Expand Down Expand Up @@ -361,11 +368,12 @@ def run():
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if req.get("delete", False):
TaskService.filter_delete([Task.doc_id == id])
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)

if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
Expand Down
14 changes: 14 additions & 0 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ class Task(DataBaseModel):
help_text="process message",
default="")
retry_count = IntegerField(default=0)
digest = TextField(null=True, help_text="task digest", default="")
chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")


class Dialog(DataBaseModel):
Expand Down Expand Up @@ -1090,4 +1092,16 @@ def migrate_db():
)
except Exception:
pass
try:
migrate(
migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default=""))
)
except Exception:
pass

try:
migrate(
migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
)
except Exception:
pass
25 changes: 25 additions & 0 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,31 @@ def get_embd_id(cls, doc_id):
return
return docs[0]["embd_id"]

@classmethod
@DB.connection_context()
def get_chunking_config(cls, doc_id):
configs = (
cls.model.select(
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
Knowledgebase.language,
Knowledgebase.embd_id,
Tenant.id.alias("tenant_id"),
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
return None
return configs[0]

@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
Expand Down
100 changes: 86 additions & 14 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#
import os
import random
import xxhash
import bisect

from api.db.db_utils import bulk_insert_into_db
from deepdoc.parser import PdfParser
Expand All @@ -29,7 +31,21 @@
from rag.settings import SVR_QUEUE_NAME
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
from api import settings
from rag.nlp import search

def trim_header_by_lines(text: str, max_length) -> str:
if len(text) <= max_length:
return text
lines = text.split("\n")
total = 0
idx = len(lines) - 1
for i in range(len(lines)-1, -1, -1):
if total + len(lines[i]) > max_length:
break
idx = i
text2 = "\n".join(lines[idx:])
return text2

class TaskService(CommonService):
model = Task
Expand Down Expand Up @@ -87,6 +103,30 @@ def get_task(cls, task_id):

return docs[0]

@classmethod
@DB.connection_context()
def get_tasks(cls, doc_id: str):
fields = [
cls.model.id,
cls.model.from_page,
cls.model.progress,
cls.model.digest,
cls.model.chunk_ids,
]
tasks = (
cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc())
.where(cls.model.doc_id == doc_id)
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks

@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()

@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
Expand Down Expand Up @@ -133,22 +173,18 @@ def get_ongoing_doc_name(cls):
@classmethod
@DB.connection_context()
def do_cancel(cls, id):
try:
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0
except Exception:
pass
return False
task = cls.model.get_by_id(id)
_, doc = DocumentService.get_by_id(task.doc_id)
return doc.run == TaskStatus.CANCEL.value or doc.progress < 0

@classmethod
@DB.connection_context()
def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
Expand All @@ -157,9 +193,9 @@ def update_progress(cls, id, info):

with DB.lock("update_progress", -1):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
Expand All @@ -168,7 +204,7 @@ def update_progress(cls, id, info):

def queue_tasks(doc: dict, bucket: str, name: str):
def new_task():
return {"id": get_uuid(), "doc_id": doc["id"]}
return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0}

tsks = []

Expand Down Expand Up @@ -203,10 +239,46 @@ def new_task():
else:
tsks.append(new_task())

chunking_config = DocumentService.get_chunking_config(doc["id"])
for task in tsks:
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
hasher.update(str(chunking_config[field]).encode("utf-8"))
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
task_digest = hasher.hexdigest()
task["digest"] = task_digest
task["progress"] = 0.0

prev_tasks = TaskService.get_tasks(doc["id"])
if prev_tasks:
for task in tsks:
reuse_prev_task_chunks(task, prev_tasks, chunking_config)
TaskService.filter_delete([Task.doc_id == doc["id"]])
chunk_ids = []
for task in prev_tasks:
if task["chunk_ids"]:
chunk_ids.extend(task["chunk_ids"].split())
if chunk_ids:
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])

bulk_insert_into_db(Task, tsks, True)
DocumentService.begin2parse(doc["id"])

tsks = [task for task in tsks if task["progress"] < 1.0]
for t in tsks:
assert REDIS_CONN.queue_product(
SVR_QUEUE_NAME, message=t
), "Can't access Redis. Please check the Redis' status."

def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict):
idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"])
if idx >= len(prev_tasks):
return
prev_task = prev_tasks[idx]
if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]:
return
task["chunk_ids"] = prev_task["chunk_ids"]
task["progress"] = 1.0
task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks"
prev_task["chunk_ids"] = ""
Loading

0 comments on commit b94bd8d

Please sign in to comment.