diff --git a/Cargo.lock b/Cargo.lock index 63c841c801ed..0a897e04c2c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1931,9 +1931,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", ] @@ -2504,14 +2504,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87" dependencies = [ "aho-corasick 1.0.1", "memchr", - "regex-automata 0.3.8", - "regex-syntax 0.7.5", + "regex-automata 0.4.1", + "regex-syntax 0.8.1", ] [[package]] @@ -2525,13 +2525,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b" dependencies = [ "aho-corasick 1.0.1", "memchr", - "regex-syntax 0.7.5", + "regex-syntax 0.8.1", ] [[package]] @@ -2546,6 +2546,12 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +[[package]] +name = "regex-syntax" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33" + [[package]] name = "reqwest" version = "0.11.18" @@ -3107,6 +3113,7 @@ dependencies = [ "nvml-wrapper", "opentelemetry", "opentelemetry-otlp", + "regex", "rust-embed 8.0.0", "serde", "serde_json", @@ -3119,6 +3126,7 @@ dependencies = [ "tabby-inference", "tabby-scheduler", "tantivy", + "textdistance", "tokio", "tower", "tower-http 0.4.0", @@ -3383,6 +3391,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "textdistance" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d321c8576c2b47e43953e9cce236550d4cd6af0a6ce518fe084340082ca6037b" + [[package]] name = "thiserror" version = "1.0.45" diff --git a/Cargo.toml b/Cargo.toml index dc1c021f10dd..22aeb19a50f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,3 +36,4 @@ derive_builder = "0.12.0" tokenizers = "0.13.4-rc3" futures = "0.3.28" async-stream = "0.3.5" +regex = "1.10.0" \ No newline at end of file diff --git a/crates/tabby-common/src/events.rs b/crates/tabby-common/src/events.rs index 996cf313134b..88e683b93312 100644 --- a/crates/tabby-common/src/events.rs +++ b/crates/tabby-common/src/events.rs @@ -56,10 +56,16 @@ pub enum Event<'a> { completion_id: &'a str, language: &'a str, prompt: &'a str, + segments: &'a Segments<'a>, choices: Vec>, user: Option<&'a str>, }, } +#[derive(Serialize)] +pub struct Segments<'a> { + pub prefix: &'a str, + pub suffix: Option<&'a str>, +} #[derive(Serialize)] struct Log<'a> { diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index 0e9f4a157844..97464be96165 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -11,5 +11,5 @@ async-trait = { workspace = true } dashmap = "5.5.3" derive_builder = "0.12.0" futures = { workspace = true } -regex = "1.9.5" +regex.workspace = true tokenizers.workspace = true diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 04bd73d81247..a828e2e48e98 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -40,6 +40,8 @@ futures = { workspace = true } async-stream = { workspace = true } axum-streams = { version = "0.9.1", features = ["json"] } minijinja = { version = "1.0.8", features = ["loader"] } +textdistance = "1.0.2" +regex.workspace = true [target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies] llama-cpp-bindings = { path = "../llama-cpp-bindings" } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 5c55fde0a8aa..bc1b83b57d70 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -23,6 +23,7 @@ use super::search::IndexServer; } }))] pub struct CompletionRequest { + #[deprecated] #[schema(example = "def fib(n):")] prompt: Option, @@ -37,6 +38,14 @@ pub struct CompletionRequest { // A unique identifier representing your end-user, which can help Tabby to monitor & generating // reports. user: Option, + + debug: Option, +} + +#[derive(Serialize, ToSchema, Deserialize, Clone, Debug)] +pub struct DebugRequest { + // When true, returns debug_data in completion response. + enabled: bool, } #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -55,9 +64,31 @@ pub struct Choice { } #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct Snippet { + filepath: String, + body: String, + score: f32, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +#[schema(example=json!({ + "id": "string", + "choices": [ { "index": 0, "text": "string" } ] +}))] pub struct CompletionResponse { id: String, choices: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + debug_data: Option, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct DebugData { + #[serde(skip_serializing_if = "Vec::is_empty")] + snippets: Vec, + + prompt: String, } #[utoipa::path( @@ -97,7 +128,10 @@ pub async fn completions( }; debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); - let prompt = state.prompt_builder.build(&language, segments); + let snippets = state.prompt_builder.collect(&language, &segments); + let prompt = state + .prompt_builder + .build(&language, segments.clone(), &snippets); debug!("PROMPT: {}", prompt); let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let text = state.engine.generate(&prompt, options).await; @@ -106,6 +140,10 @@ pub async fn completions( completion_id: &completion_id, language: &language, prompt: &prompt, + segments: &tabby_common::events::Segments { + prefix: &segments.prefix, + suffix: segments.suffix.as_deref(), + }, choices: vec![events::Choice { index: 0, text: &text, @@ -114,9 +152,16 @@ pub async fn completions( } .log(); + let debug_data = DebugData { snippets, prompt }; + Ok(Json(CompletionResponse { id: completion_id, choices: vec![Choice { index: 0, text }], + debug_data: if request.debug.is_some_and(|x| x.enabled) { + Some(debug_data) + } else { + None + }, })) } diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index d97d188e3c26..1825f4b2b60d 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -1,14 +1,17 @@ use std::sync::Arc; +use lazy_static::lazy_static; +use regex::Regex; use strfmt::strfmt; +use textdistance::Algorithm; use tracing::warn; -use super::Segments; +use super::{Segments, Snippet}; use crate::serve::{completions::languages::get_language, search::IndexServer}; static MAX_SNIPPETS_TO_FETCH: usize = 20; -static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512; -static SNIPPET_SCORE_THRESHOLD: f32 = 5.0; +static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; +static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; pub struct PromptBuilder { prompt_template: Option, @@ -23,47 +26,36 @@ impl PromptBuilder { } } - fn build_prompt(&self, prefix: String, suffix: String) -> String { - if let Some(prompt_template) = &self.prompt_template { - strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap() - } else { - prefix - } - } + fn build_prompt(&self, prefix: String, suffix: Option) -> String { + let Some(suffix) = suffix else { + return prefix; + }; + + let Some(prompt_template) = &self.prompt_template else { + return prefix; + }; - pub fn build(&self, language: &str, segments: Segments) -> String { - let segments = self.rewrite(language, segments); - self.build_prompt(segments.prefix, get_default_suffix(segments.suffix)) + strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap() } - fn rewrite(&self, language: &str, segments: Segments) -> Segments { + pub fn collect(&self, language: &str, segments: &Segments) -> Vec { if let Some(index_server) = &self.index_server { - rewrite_with_index(index_server, language, segments) + collect_snippets(index_server, language, &segments.prefix) } else { - segments + vec![] } } -} -fn get_default_suffix(suffix: Option) -> String { - if suffix.is_none() { - return "\n".to_owned(); - } - - let suffix = suffix.unwrap(); - if suffix.is_empty() { - "\n".to_owned() - } else { - suffix + pub fn build(&self, language: &str, segments: Segments, snippets: &[Snippet]) -> String { + let segments = rewrite_with_snippets(language, segments, snippets); + self.build_prompt( + segments.prefix, + segments.suffix.filter(|x| !x.trim_end().is_empty()), + ) } } -fn rewrite_with_index( - index_server: &Arc, - language: &str, - segments: Segments, -) -> Segments { - let snippets = collect_snippets(index_server, language, &segments.prefix); +fn rewrite_with_snippets(language: &str, segments: Segments, snippets: &[Snippet]) -> Segments { if snippets.is_empty() { segments } else { @@ -72,35 +64,23 @@ fn rewrite_with_index( } } -fn build_prefix(language: &str, prefix: &str, snippets: Vec) -> String { +fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { if snippets.is_empty() { return prefix.to_owned(); } let comment_char = get_language(language).line_comment; - let mut lines: Vec = vec![ - format!( - "Below are some relevant {} snippets found in the repository:", - language - ), - "".to_owned(), - ]; + let mut lines: Vec = vec![]; - let mut count_characters = 0; for (i, snippet) in snippets.iter().enumerate() { - if count_characters + snippet.len() > MAX_SNIPPET_CHARS_IN_PROMPT { - break; - } - - lines.push(format!("== Snippet {} ==", i + 1)); - for line in snippet.lines() { + lines.push(format!("Path: {}", snippet.filepath)); + for line in snippet.body.lines() { lines.push(line.to_owned()); } if i < snippets.len() - 1 { lines.push("".to_owned()); } - count_characters += snippet.len(); } let commented_lines: Vec = lines @@ -117,9 +97,10 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec) -> String { format!("{}\n{}", comments, prefix) } -fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec { +fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec { let mut ret = Vec::new(); - let sanitized_text = sanitize_text(text); + let tokens = tokenize_text(text); + let sanitized_text = tokens.join(" "); if sanitized_text.is_empty() { return ret; } @@ -134,35 +115,49 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V } }; + let mut count_characters = 0; for hit in serp.hits { - if hit.score < SNIPPET_SCORE_THRESHOLD { + let body = hit.doc.body; + let body_tokens = tokenize_text(&body); + + if count_characters + body.len() > MAX_SNIPPET_CHARS_IN_PROMPT { break; } - let body = hit.doc.body; + let similarity = if body_tokens.len() > tokens.len() { + 0.0 + } else { + let distance = textdistance::LCSSeq::default() + .for_iter(tokens.iter(), body_tokens.iter()) + .val() as f32; + distance / body_tokens.len() as f32 + }; - if text.contains(&body) { - // Exclude snippets already in the context window. + if similarity > MAX_SIMILARITY_THRESHOLD { + // Exclude snippets presents in context window. continue; } - ret.push(body.to_owned()); + count_characters += body.len(); + ret.push(Snippet { + filepath: hit.doc.filepath, + body, + score: hit.score, + }); } ret } -fn sanitize_text(text: &str) -> String { - // Only keep [a-zA-Z0-9-_] - let x = text.replace( - |c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-', - " ", - ); - let tokens: Vec<&str> = x - .split(' ') - .filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5) - .collect(); - tokens.join(" ") +lazy_static! { + static ref TOKENIZER: Regex = Regex::new(r"[^\w]").unwrap(); +} + +fn tokenize_text(text: &str) -> Vec<&str> { + TOKENIZER + .split(text) + .filter(|s| *s != "AND" && *s != "OR" && *s != "NOT") + .collect() } #[cfg(test)] @@ -187,6 +182,7 @@ mod tests { // Rewrite disabled, so the language doesn't matter. let language = "python"; + let snippets = &vec![]; // Test w/ prefix, w/ suffix. { @@ -195,7 +191,7 @@ mod tests { suffix: Some("this is some suffix".into()), }; assert_eq!( - pb.build(language, segments), + pb.build(language, segments, snippets), "
 this is some prefix this is some suffix "
             );
         }
@@ -207,8 +203,8 @@ mod tests {
                 suffix: None,
             };
             assert_eq!(
-                pb.build(language, segments),
-                "
 this is some prefix \n "
+                pb.build(language, segments, snippets),
+                "this is some prefix"
             );
         }
 
@@ -219,8 +215,8 @@ mod tests {
                 suffix: Some("".into()),
             };
             assert_eq!(
-                pb.build(language, segments),
-                "
 this is some prefix \n "
+                pb.build(language, segments, snippets),
+                "this is some prefix"
             );
         }
 
@@ -231,7 +227,7 @@ mod tests {
                 suffix: Some("this is some suffix".into()),
             };
             assert_eq!(
-                pb.build(language, segments),
+                pb.build(language, segments, snippets),
                 "
  this is some suffix "
             );
         }
@@ -242,7 +238,7 @@ mod tests {
                 prefix: "".into(),
                 suffix: None,
             };
-            assert_eq!(pb.build(language, segments), "
  \n ");
+            assert_eq!(pb.build(language, segments, snippets), "");
         }
 
         // Test w/ emtpy prefix, w/ empty suffix.
@@ -251,7 +247,7 @@ mod tests {
                 prefix: "".into(),
                 suffix: Some("".into()),
             };
-            assert_eq!(pb.build(language, segments), "
  \n ");
+            assert_eq!(pb.build(language, segments, snippets), "");
         }
     }
 
@@ -261,6 +257,7 @@ mod tests {
 
         // Rewrite disabled, so the language doesn't matter.
         let language = "python";
+        let snippets = &vec![];
 
         // Test w/ prefix, w/ suffix.
         {
@@ -268,7 +265,10 @@ mod tests {
                 prefix: "this is some prefix".into(),
                 suffix: Some("this is some suffix".into()),
             };
-            assert_eq!(pb.build(language, segments), "this is some prefix");
+            assert_eq!(
+                pb.build(language, segments, snippets),
+                "this is some prefix"
+            );
         }
 
         // Test w/ prefix, w/o suffix.
@@ -277,7 +277,10 @@ mod tests {
                 prefix: "this is some prefix".into(),
                 suffix: None,
             };
-            assert_eq!(pb.build(language, segments), "this is some prefix");
+            assert_eq!(
+                pb.build(language, segments, snippets),
+                "this is some prefix"
+            );
         }
 
         // Test w/ prefix, w/ empty suffix.
@@ -286,7 +289,10 @@ mod tests {
                 prefix: "this is some prefix".into(),
                 suffix: Some("".into()),
             };
-            assert_eq!(pb.build(language, segments), "this is some prefix");
+            assert_eq!(
+                pb.build(language, segments, snippets),
+                "this is some prefix"
+            );
         }
 
         // Test w/ empty prefix, w/ suffix.
@@ -295,7 +301,7 @@ mod tests {
                 prefix: "".into(),
                 suffix: Some("this is some suffix".into()),
             };
-            assert_eq!(pb.build(language, segments), "");
+            assert_eq!(pb.build(language, segments, snippets), "");
         }
 
         // Test w/ empty prefix, w/o suffix.
@@ -304,7 +310,7 @@ mod tests {
                 prefix: "".into(),
                 suffix: None,
             };
-            assert_eq!(pb.build(language, segments), "");
+            assert_eq!(pb.build(language, segments, snippets), "");
         }
 
         // Test w/ empty prefix, w/ empty suffix.
@@ -313,18 +319,28 @@ mod tests {
                 prefix: "".into(),
                 suffix: Some("".into()),
             };
-            assert_eq!(pb.build(language, segments), "");
+            assert_eq!(pb.build(language, segments, snippets), "");
         }
     }
 
     #[test]
     fn test_build_prefix_readable() {
         let snippets = vec![
-            "res_1 = invoke_function_1(n)".to_string(),
-            "res_2 = invoke_function_2(n)".to_string(),
-            "res_3 = invoke_function_3(n)".to_string(),
-            "res_4 = invoke_function_4(n)".to_string(),
-            "res_5 = invoke_function_5(n)".to_string(),
+            Snippet {
+                filepath: "a1.py".to_owned(),
+                body: "res_1 = invoke_function_1(n)".to_owned(),
+                score: 1.0,
+            },
+            Snippet {
+                filepath: "a2.py".to_owned(),
+                body: "res_2 = invoke_function_2(n)".to_owned(),
+                score: 1.0,
+            },
+            Snippet {
+                filepath: "a3.py".to_owned(),
+                body: "res_3 = invoke_function_3(n)".to_owned(),
+                score: 1.0,
+            },
         ];
 
         let prefix = "\
@@ -334,53 +350,22 @@ Use some invoke_function to do some job.
 def this_is_prefix():\n";
 
         let expected_built_prefix = "\
-# Below are some relevant python snippets found in the repository:
-#
-# == Snippet 1 ==
+# Path: a1.py
 # res_1 = invoke_function_1(n)
 #
-# == Snippet 2 ==
+# Path: a2.py
 # res_2 = invoke_function_2(n)
 #
-# == Snippet 3 ==
+# Path: a3.py
 # res_3 = invoke_function_3(n)
-#
-# == Snippet 4 ==
-# res_4 = invoke_function_4(n)
-#
-# == Snippet 5 ==
-# res_5 = invoke_function_5(n)
 '''
 Use some invoke_function to do some job.
 '''
 def this_is_prefix():\n";
 
         assert_eq!(
-            build_prefix("python", prefix, snippets),
+            build_prefix("python", prefix, &snippets),
             expected_built_prefix
         );
     }
-
-    #[test]
-    fn test_build_prefix_count_chars() {
-        let snippets_expected = 4;
-        let snippet_payload = "a".repeat(MAX_SNIPPET_CHARS_IN_PROMPT / snippets_expected);
-        let mut snippets = vec![];
-        for _ in 0..snippets_expected + 1 {
-            snippets.push(snippet_payload.clone());
-        }
-
-        let prefix = "def this_is_prefix():\n";
-
-        let generated_prompt = build_prefix("python", prefix, snippets);
-
-        for i in 0..snippets_expected + 1 {
-            let st = format!("# == Snippet {} ==", i + 1);
-            if i < snippets_expected {
-                assert!(generated_prompt.contains(&st));
-            } else {
-                assert!(!generated_prompt.contains(&st));
-            }
-        }
-    }
 }
diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs
index 382564037e70..655d395094df 100644
--- a/crates/tabby/src/serve/mod.rs
+++ b/crates/tabby/src/serve/mod.rs
@@ -57,6 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
         completions::CompletionResponse,
         completions::Segments,
         completions::Choice,
+        completions::DebugRequest,
+        completions::DebugData,
+        completions::Snippet,
         chat::ChatCompletionRequest,
         chat::Message,
         chat::ChatCompletionChunk,
diff --git a/experimental/scheduler/completion.py b/experimental/scheduler/completion.py
index c3743c73757a..62fcc6d5449e 100644
--- a/experimental/scheduler/completion.py
+++ b/experimental/scheduler/completion.py
@@ -8,13 +8,21 @@
 
 language = st.text_input("Language", "rust")
 
-query = st.text_area("Query", "get")
-tokens = re.findall(r"\w+", query)
-tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
-query = "(" + " ".join(tokens) + ")" + " "  + "AND language:" + language
+query = st.text_area("Query", "to_owned")
 
 if query:
-    r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
-    hits = r.json()["hits"]
-    for x in hits:
-        st.write(x)
\ No newline at end of file
+    r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug=dict(enabled=True)))
+    json = r.json()
+    debug = json["debug_data"]
+    snippets = debug.get("snippets", [])
+
+    st.write("Prompt")
+    st.code(debug["prompt"])
+
+    st.write("Completion")
+    st.code(json["choices"][0]["text"])
+
+    for x in snippets:
+        st.write(f"**{x['filepath']}**: {x['score']}")
+        st.write(f"Length: {len(x['body'])}")
+        st.code(x['body'])
\ No newline at end of file