From 8497fb1372beda685150bbcbeed55f1c76fc5b0b Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 6 Oct 2023 11:54:12 -0700 Subject: [PATCH] feat: implement /v1beta/search interface (#516) * feat: implement /v1beta/search interface * update * update * improve debugger --- Cargo.lock | 1 + crates/tabby-common/Cargo.toml | 1 + crates/tabby-common/src/index.rs | 20 ++++ crates/tabby-common/src/lib.rs | 1 + crates/tabby-scheduler/src/index.rs | 15 +-- crates/tabby/Cargo.toml | 4 +- crates/tabby/src/main.rs | 2 - crates/tabby/src/serve/mod.rs | 17 +++- crates/tabby/src/serve/search.rs | 144 +++++++++++++++++++++++++++ experimental/scheduler/completion.py | 38 +++++++ experimental/scheduler/search.py | 20 +--- 11 files changed, 232 insertions(+), 31 deletions(-) create mode 100644 crates/tabby-common/src/index.rs create mode 100644 crates/tabby/src/serve/search.rs create mode 100644 experimental/scheduler/completion.py diff --git a/Cargo.lock b/Cargo.lock index b9bcb223c06a..3e3434b7ce02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3142,6 +3142,7 @@ dependencies = [ "serde", "serde-jsonlines", "serdeconv", + "tantivy", "tokio", "uuid 1.4.1", ] diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index 794e3d958c87..0fb6329f342e 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -13,6 +13,7 @@ serde-jsonlines = { workspace = true } reqwest = { workspace = true, features = [ "json" ] } tokio = { workspace = true, features = ["rt", "macros"] } uuid = { version = "1.4.1", features = ["v4"] } +tantivy.workspace = true [features] testutils = [] diff --git a/crates/tabby-common/src/index.rs b/crates/tabby-common/src/index.rs new file mode 100644 index 000000000000..1d726b30ded1 --- /dev/null +++ b/crates/tabby-common/src/index.rs @@ -0,0 +1,20 @@ +use tantivy::{ + tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer}, + Index, +}; + +pub trait IndexExt { + fn register_tokenizer(&self); +} + +pub static CODE_TOKENIZER: &str = "code"; + +impl IndexExt for Index { + fn register_tokenizer(&self) { + let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap()) + .filter(RemoveLongFilter::limit(128)) + .build(); + + self.tokenizers().register(CODE_TOKENIZER, code_tokenizer); + } +} diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index b338145fe9c2..ae52e5a171b1 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod events; +pub mod index; pub mod path; pub mod usage; diff --git a/crates/tabby-scheduler/src/index.rs b/crates/tabby-scheduler/src/index.rs index 4719fe57114e..feed15751713 100644 --- a/crates/tabby-scheduler/src/index.rs +++ b/crates/tabby-scheduler/src/index.rs @@ -1,12 +1,16 @@ use std::fs; use anyhow::Result; -use tabby_common::{config::Config, path::index_dir, SourceFile}; +use tabby_common::{ + config::Config, + index::{IndexExt, CODE_TOKENIZER}, + path::index_dir, + SourceFile, +}; use tantivy::{ directory::MmapDirectory, doc, schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING}, - tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer}, Index, }; @@ -18,7 +22,7 @@ pub fn index_repositories(_config: &Config) -> Result<()> { let mut builder = Schema::builder(); let code_indexing_options = TextFieldIndexing::default() - .set_tokenizer("code") + .set_tokenizer(CODE_TOKENIZER) .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); let code_options = TextOptions::default() .set_indexing_options(code_indexing_options) @@ -36,11 +40,8 @@ pub fn index_repositories(_config: &Config) -> Result<()> { fs::create_dir_all(index_dir())?; let directory = MmapDirectory::open(index_dir())?; let index = Index::open_or_create(directory, schema)?; - let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w*)").unwrap()) - .filter(RemoveLongFilter::limit(40)) - .build(); + index.register_tokenizer(); - index.tokenizers().register("code", code_tokenizer); let mut writer = index.writer(10_000_000)?; writer.delete_all_documents()?; diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index ed2391d4b54b..b13bdabd245e 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] ctranslate2-bindings = { path = "../ctranslate2-bindings" } tabby-common = { path = "../tabby-common" } -tabby-scheduler = { path = "../tabby-scheduler", optional = true } +tabby-scheduler = { path = "../tabby-scheduler" } tabby-download = { path = "../tabby-download" } tabby-inference = { path = "../tabby-inference" } axum = "0.6" @@ -53,9 +53,7 @@ features = [ ] [features] -default = ["scheduler"] link_shared = ["ctranslate2-bindings/link_shared"] -scheduler = ["tabby-scheduler"] [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 6fd37a2aad91..d5be128bda1b 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -32,7 +32,6 @@ pub enum Commands { Download(download::DownloadArgs), /// Run scheduler progress for cron jobs integrating external code repositories. - #[cfg(feature = "scheduler")] Scheduler(SchedulerArgs), } @@ -53,7 +52,6 @@ async fn main() { match &cli.command { Commands::Serve(args) => serve::main(&config, args).await, Commands::Download(args) => download::main(args).await, - #[cfg(feature = "scheduler")] Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) .await .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 9e5696a1bef1..896bab7bbd96 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -4,6 +4,7 @@ mod engine; mod events; mod health; mod playground; +mod search; use std::{ net::{Ipv4Addr, SocketAddr}, @@ -28,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi; use self::{ engine::{create_engine, EngineInfo}, health::HealthState, + search::IndexServer, }; use crate::fatal; @@ -48,7 +50,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi servers( (url = "/", description = "Server"), ), - paths(events::log_event, completions::completions, chat::completions, health::health), + paths(events::log_event, completions::completions, chat::completions, health::health, search::search), components(schemas( events::LogEventRequest, completions::CompletionRequest, @@ -90,6 +92,10 @@ pub struct ServeArgs { #[clap(long)] chat_model: Option, + /// When set to `true`, the search API route will be enabled. + #[clap(long, default_value_t = false)] + enable_search: bool, + #[clap(long, default_value_t = 8080)] port: u16, @@ -195,6 +201,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { routing::post(completions::completions).with_state(completion_state), ); + let router = if args.enable_search { + router.route( + "/v1beta/search", + routing::get(search::search).with_state(Arc::new(IndexServer::new())), + ) + } else { + router + }; + let router = if let Some(chat_state) = chat_state { router.route( "/v1beta/chat/completions", diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs new file mode 100644 index 000000000000..4aed1a74f2ca --- /dev/null +++ b/crates/tabby/src/serve/search.rs @@ -0,0 +1,144 @@ +use std::sync::Arc; + +use anyhow::Result; +use axum::{ + extract::{Query, State}, + Json, +}; +use hyper::StatusCode; +use serde::{Deserialize, Serialize}; +use tabby_common::{index::IndexExt, path}; +use tantivy::{ + collector::{Count, TopDocs}, + query::QueryParser, + schema::{Field, FieldType, NamedFieldDocument, Schema}, + DocAddress, Document, Index, IndexReader, Score, +}; +use tracing::instrument; +use utoipa::IntoParams; + +#[derive(Deserialize, IntoParams)] +pub struct SearchQuery { + #[param(default = "get")] + q: String, + + #[param(default = 20)] + limit: Option, + + #[param(default = 0)] + offset: Option, +} + +#[derive(Serialize)] +pub struct SearchResponse { + q: String, + num_hits: usize, + hits: Vec, +} + +#[derive(Serialize)] +pub struct Hit { + score: Score, + doc: NamedFieldDocument, + id: u32, +} + +#[utoipa::path( + get, + params(SearchQuery), + path = "/v1beta/search", + operation_id = "search", + tag = "v1beta", + responses( + (status = 200, description = "Success" , content_type = "application/json"), + (status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"), + ) +)] +#[instrument(skip(state, query))] +pub async fn search( + State(state): State>, + query: Query, +) -> Result, StatusCode> { + let Ok(serp) = state.search( + &query.q, + query.limit.unwrap_or(20), + query.offset.unwrap_or(0), + ) else { + return Err(StatusCode::INTERNAL_SERVER_ERROR); + }; + + Ok(Json(serp)) +} + +pub struct IndexServer { + reader: IndexReader, + query_parser: QueryParser, + schema: Schema, +} + +impl IndexServer { + pub fn new() -> Self { + Self::load().expect("Failed to load code state") + } + + fn load() -> Result { + let index = Index::open_in_dir(path::index_dir())?; + index.register_tokenizer(); + + let schema = index.schema(); + let default_fields: Vec = schema + .fields() + .filter(|&(_, field_entry)| match field_entry.field_type() { + FieldType::Str(ref text_field_options) => { + text_field_options.get_indexing_options().is_some() + } + _ => false, + }) + .map(|(field, _)| field) + .collect(); + let query_parser = + QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone()); + let reader = index.reader()?; + Ok(Self { + reader, + query_parser, + schema, + }) + } + + fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result { + let query = self + .query_parser + .parse_query(q) + .expect("Parsing the query failed"); + let searcher = self.reader.searcher(); + let (top_docs, num_hits) = { + searcher.search( + &query, + &(TopDocs::with_limit(limit).and_offset(offset), Count), + )? + }; + let hits: Vec = { + top_docs + .iter() + .map(|(score, doc_address)| { + let doc = searcher.doc(*doc_address).unwrap(); + self.create_hit(*score, doc, *doc_address) + }) + .collect() + }; + Ok(SearchResponse { + q: q.to_owned(), + num_hits, + hits, + }) + } + + fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit { + Hit { + score, + doc: self.schema.to_named_doc(&doc), + id: doc_address.doc_id, + } + } +} diff --git a/experimental/scheduler/completion.py b/experimental/scheduler/completion.py new file mode 100644 index 000000000000..1c00e0c03132 --- /dev/null +++ b/experimental/scheduler/completion.py @@ -0,0 +1,38 @@ +import re +import requests +import streamlit as st +from typing import NamedTuple + +class Doc(NamedTuple): + name: str + body: str + score: float + filepath: str + + @staticmethod + def from_json(json: dict): + doc = json["doc"] + return Doc( + name=doc["name"][0], + body=doc["body"][0], + score=json["score"], + filepath=doc["filepath"][0], + ) + +# force wide mode +st.set_page_config(layout="wide") + +language = st.text_input("Language", "rust") +query = st.text_area("Query", "get") +tokens = re.findall(r"\w+", query) +tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"] + +query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language + +if query: + r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query)) + hits = r.json()["hits"] + for x in hits: + doc = Doc.from_json(x) + st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score)) + st.code(doc.body) diff --git a/experimental/scheduler/search.py b/experimental/scheduler/search.py index dbc0ca323afd..be89563c2b62 100644 --- a/experimental/scheduler/search.py +++ b/experimental/scheduler/search.py @@ -2,29 +2,13 @@ import streamlit as st from typing import NamedTuple -class Doc(NamedTuple): - name: str - body: str - score: float - - @staticmethod - def from_json(json: dict): - doc = json["doc"] - return Doc( - name=doc["name"][0], - body=doc["body"][0], - score=json["score"] - ) - # force wide mode st.set_page_config(layout="wide") query = st.text_input("Query") if query: - r = requests.get("http://localhost:3000/api", params=dict(q=query)) + r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query)) hits = r.json()["hits"] for x in hits: - doc = Doc.from_json(x) - st.write(doc.name + " : " + str(doc.score)) - st.code(doc.body) \ No newline at end of file + st.write(x)