diff --git a/Cargo.lock b/Cargo.lock index 3a0e5c32f28e..83d01df668d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.20" @@ -57,6 +68,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -342,6 +359,24 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "cached" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cead8ece0da6b744b2ad8ef9c58a4cdc7ef2921e60a6ddfb9eaaa86839b5fc5" +dependencies = [ + "ahash 0.8.3", + "async-trait", + "cached_proc_macro", + "cached_proc_macro_types", + "futures", + "hashbrown 0.14.0", + "instant", + "once_cell", + "thiserror", + "tokio", +] + [[package]] name = "cached-path" version = "0.6.1" @@ -364,6 +399,24 @@ dependencies = [ "zip", ] +[[package]] +name = "cached_proc_macro" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8245dd5f576a41c3b76247b54c15b0e43139ceeb4f732033e15be7c005176" +dependencies = [ + "darling 0.14.4", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "cached_proc_macro_types" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663" + [[package]] name = "cc" version = "1.0.79" @@ -1177,7 +1230,7 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" dependencies = [ - "ahash", + "ahash 0.7.6", ] [[package]] @@ -1185,6 +1238,10 @@ name = "hashbrown" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +dependencies = [ + "ahash 0.8.3", + "allocator-api2", +] [[package]] name = "heck" @@ -2702,9 +2759,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "itoa", "ryu", @@ -3011,14 +3068,18 @@ name = "tabby-download" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "cached", "futures-util", "indicatif 0.17.3", "reqwest", "serde", + "serde_json", "serdeconv", "tabby-common", "tokio-retry", "tracing", + "urlencoding", ] [[package]] @@ -3875,6 +3936,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8-ranges" version = "1.0.5" diff --git a/Cargo.toml b/Cargo.toml index 63fca28d1d98..8b4df43ad522 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ homepage = "https://github.com/TabbyML/tabby" [workspace.dependencies] lazy_static = "1.4.0" serde = { version = "1.0", features = ["derive"] } +serde_json = "1" serdeconv = "0.4.1" tokio = "1.28" tokio-util = "0.7" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 492e0461974f..92ab7250d2a9 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" async-trait.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } -serde_json = "1.0.105" +serde_json = { workspace = true } tabby-inference = { version = "0.1.0", path = "../tabby-inference" } [dev-dependencies] diff --git a/crates/tabby-download/Cargo.toml b/crates/tabby-download/Cargo.toml index 94a64126c728..eed343639702 100644 --- a/crates/tabby-download/Cargo.toml +++ b/crates/tabby-download/Cargo.toml @@ -13,3 +13,7 @@ serde = { workspace = true } serdeconv = { workspace = true } tracing = { workspace = true } tokio-retry = "0.3.0" +urlencoding = "2.1.3" +serde_json = { workspace = true } +cached = { version = "0.46.0", features = ["async", "proc_macro"] } +async-trait = { workspace = true } diff --git a/crates/tabby-download/src/cache_info.rs b/crates/tabby-download/src/cache_info.rs index 8b308a0a3c55..c82ffef6fa57 100644 --- a/crates/tabby-download/src/cache_info.rs +++ b/crates/tabby-download/src/cache_info.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, fs, path::Path}; -use anyhow::{anyhow, Result}; +use anyhow::Result; use serde::{Deserialize, Serialize}; use tabby_common::path::ModelDir; @@ -33,15 +33,6 @@ impl CacheInfo { self.etags.get(path).map(|x| x.as_str()) } - pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> { - let key = res - .headers() - .get("etag") - .ok_or(anyhow!("etag key missing"))? - .to_str()?; - Ok(key) - } - pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { self.etags.insert(path.to_string(), cache_key.to_string()); } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index c4608c725069..f77aac26b2e8 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -1,4 +1,5 @@ mod cache_info; +mod registry; use std::{cmp, fs, io::Write, path::Path}; @@ -6,111 +7,114 @@ use anyhow::{anyhow, Result}; use cache_info::CacheInfo; use futures_util::StreamExt; use indicatif::{ProgressBar, ProgressStyle}; +use registry::{create_registry, Registry}; use tabby_common::path::ModelDir; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, }; -use tracing::info; - -impl CacheInfo { - async fn download( - &mut self, - model_id: &str, - path: &str, - prefer_local_file: bool, - is_optional: bool, - ) -> Result<()> { - // Create url. - let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); - - // Get cache key. - let local_cache_key = self.local_cache_key(path); - - // Create destination path. - let filepath = ModelDir::new(model_id).path_string(path); - - // Cache hit. - let local_file_ready = if prefer_local_file { - if let Some(local_cache_key) = local_cache_key { - if local_cache_key == "404" { - true - } else { - fs::metadata(&filepath).is_ok() - } - } else { - false - } - } else { - false - }; - if !local_file_ready { - let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3); - let etag = Retry::spawn(strategy, || { - download_file(&url, &filepath, local_cache_key, is_optional) - }) - .await?; - self.set_local_cache_key(path, &etag).await; - } - Ok(()) - } +pub struct Downloader { + model_id: String, + prefer_local_file: bool, + registry: Box, } -pub async fn download_model( - model_id: &str, - download_ctranslate2_files: bool, - download_ggml_files: bool, - prefer_local_file: bool, -) -> Result<()> { - if fs::metadata(model_id).is_ok() { - // Local path, no need for downloading. - return Ok(()); +impl Downloader { + pub fn new(model_id: &str, prefer_local_file: bool) -> Self { + Self { + model_id: model_id.to_owned(), + prefer_local_file, + registry: create_registry(), + } } - info!("Start downloading model `{}`", model_id); - - let mut cache_info = CacheInfo::from(model_id).await; - - let mut optional_files = vec![]; - if download_ctranslate2_files { - optional_files.push("ctranslate2/vocabulary.txt"); - optional_files.push("ctranslate2/shared_vocabulary.txt"); - optional_files.push("ctranslate2/vocabulary.json"); - optional_files.push("ctranslate2/shared_vocabulary.json"); + pub async fn download_ctranslate2_files(&self) -> Result<()> { + let files = vec![ + ("tabby.json", true), + ("tokenizer.json", true), + ("ctranslate2/vocabulary.txt", false), + ("ctranslate2/shared_vocabulary.txt", false), + ("ctranslate2/vocabulary.json", false), + ("ctranslate2/shared_vocabulary.json", false), + ("ctranslate2/config.json", true), + ("ctranslate2/model.bin", true), + ]; + + self.download_files(&files).await } - if download_ggml_files { - optional_files.push("ggml/q8_0.gguf"); + pub async fn download_ggml_files(&self) -> Result<()> { + let files = vec![ + ("tabby.json", true), + ("tokenizer.json", true), + ("ggml/q8_0.gguf", true), + ]; + self.download_files(&files).await } - for path in optional_files { - cache_info - .download( - model_id, + async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> { + // Local path, no need for downloading. + if fs::metadata(&self.model_id).is_ok() { + return Ok(()); + } + + let mut cache_info = CacheInfo::from(&self.model_id).await; + for (path, required) in files { + download_model_file( + self.registry.as_ref(), + &mut cache_info, + &self.model_id, path, - prefer_local_file, - /* is_optional */ true, + self.prefer_local_file, + *required, ) .await?; + } + Ok(()) } +} - let mut required_files = vec!["tabby.json", "tokenizer.json"]; +async fn download_model_file( + registry: &dyn Registry, + cache_info: &mut CacheInfo, + model_id: &str, + path: &str, + prefer_local_file: bool, + required: bool, +) -> Result<()> { + // Create url. + let url = registry.build_url(model_id, path); - if download_ctranslate2_files { - required_files.push("ctranslate2/config.json"); - required_files.push("ctranslate2/model.bin"); - } + // Get cache key. + let local_cache_key = cache_info.local_cache_key(path); - for path in required_files { - cache_info - .download( - model_id, - path, - prefer_local_file, - /* required= */ false, - ) - .await?; + // Create destination path. + let filepath = ModelDir::new(model_id).path_string(path); + + // Cache hit. + let local_file_ready = if prefer_local_file { + if let Some(local_cache_key) = local_cache_key { + if local_cache_key == "404" { + true + } else { + fs::metadata(&filepath).is_ok() + } + } else { + false + } + } else { + false + }; + + if !local_file_ready { + let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); + let etag = Retry::spawn(strategy, || { + download_file(registry, &url, &filepath, local_cache_key, !required) + }) + .await?; + + cache_info.set_local_cache_key(path, &etag).await; } cache_info.save(model_id)?; @@ -118,6 +122,7 @@ pub async fn download_model( } async fn download_file( + registry: &dyn Registry, url: &str, path: &str, local_cache_key: Option<&str>, @@ -137,7 +142,7 @@ async fn download_file( return Err(anyhow!(format!("Invalid url: {}", url))); } - let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string(); + let remote_cache_key = registry.build_cache_key(url).await?; if let Some(local_cache_key) = local_cache_key { if local_cache_key == remote_cache_key { return Ok(remote_cache_key); diff --git a/crates/tabby-download/src/registry/huggingface.rs b/crates/tabby-download/src/registry/huggingface.rs new file mode 100644 index 000000000000..4fef766c8234 --- /dev/null +++ b/crates/tabby-download/src/registry/huggingface.rs @@ -0,0 +1,24 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; + +use crate::Registry; + +#[derive(Default)] +pub struct HuggingFaceRegistry {} + +#[async_trait] +impl Registry for HuggingFaceRegistry { + fn build_url(&self, model_id: &str, path: &str) -> String { + format!("https://huggingface.co/{}/resolve/main/{}", model_id, path) + } + + async fn build_cache_key(&self, url: &str) -> Result { + let res = reqwest::get(url).await?; + let cache_key = res + .headers() + .get("etag") + .ok_or(anyhow!("etag key missing"))? + .to_str()?; + Ok(cache_key.to_owned()) + } +} diff --git a/crates/tabby-download/src/registry/mod.rs b/crates/tabby-download/src/registry/mod.rs new file mode 100644 index 000000000000..45181fc88dff --- /dev/null +++ b/crates/tabby-download/src/registry/mod.rs @@ -0,0 +1,25 @@ +mod huggingface; +mod modelscope; + +use anyhow::Result; +use async_trait::async_trait; +use huggingface::HuggingFaceRegistry; + +use self::modelscope::ModelScopeRegistry; + +#[async_trait] +pub trait Registry { + fn build_url(&self, model_id: &str, path: &str) -> String; + async fn build_cache_key(&self, url: &str) -> Result; +} + +pub fn create_registry() -> Box { + let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned()); + if registry == "huggingface" { + Box::::default() + } else if registry == "modelscope" { + Box::::default() + } else { + panic!("Unsupported registry {}", registry); + } +} diff --git a/crates/tabby-download/src/registry/modelscope.rs b/crates/tabby-download/src/registry/modelscope.rs new file mode 100644 index 000000000000..fd39ae5eb1b2 --- /dev/null +++ b/crates/tabby-download/src/registry/modelscope.rs @@ -0,0 +1,79 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use cached::proc_macro::cached; +use reqwest::Url; +use serde::Deserialize; + +use crate::Registry; + +#[derive(Default)] +pub struct ModelScopeRegistry {} + +#[async_trait] +impl Registry for ModelScopeRegistry { + fn build_url(&self, model_id: &str, path: &str) -> String { + format!( + "https://modelscope.cn/api/v1/models/{}/repo?FilePath={}", + model_id, + urlencoding::encode(path) + ) + } + + async fn build_cache_key(&self, url: &str) -> Result { + let url = Url::parse(url)?; + let model_id = url + .path() + .strip_prefix("/api/v1/models/") + .ok_or(anyhow!("Invalid url"))? + .strip_suffix("/repo") + .ok_or(anyhow!("Invalid url"))?; + + let query: HashMap<_, _> = url.query_pairs().into_owned().collect(); + let path = query + .get("FilePath") + .ok_or(anyhow!("Failed to extract FilePath"))?; + + let revision_map = fetch_revision_map(model_id.to_owned()).await?; + for x in revision_map.data.files { + if x.path == *path { + return Ok(x.sha256); + } + } + + Err(anyhow!("Failed to find {} in revisions", path)) + } +} + +#[cached(size = 1, result = true)] +async fn fetch_revision_map(model_id: String) -> Result { + let url = format!( + "https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true", + model_id + ); + let resp = reqwest::get(url) + .await? + .json::() + .await?; + Ok(resp) +} + +#[derive(Deserialize, Clone)] +#[serde(rename_all = "PascalCase")] +struct ModelScopeRevision { + data: ModelScopeRevisionData, +} + +#[derive(Deserialize, Clone)] +#[serde(rename_all = "PascalCase")] +struct ModelScopeRevisionData { + files: Vec, +} + +#[derive(Deserialize, Clone)] +#[serde(rename_all = "PascalCase")] +struct ModelScopeRevisionFile { + path: String, + sha256: String, +} diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index d9e34c087391..59762066e267 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -17,7 +17,7 @@ utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] } utoipa-swagger-ui = { version = "3.1", features = ["axum"] } serde = { workspace = true } serdeconv = { workspace = true } -serde_json = "1.0" +serde_json = { workspace = true } tower-http = { version = "0.4.0", features = ["cors"] } clap = { version = "4.3.0", features = ["derive"] } lazy_static = { workspace = true } diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index 722ea499238a..20e872148576 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,4 +1,5 @@ use clap::Args; +use tabby_download::Downloader; use tracing::info; use crate::fatal; @@ -15,19 +16,18 @@ pub struct DownloadArgs { } pub async fn main(args: &DownloadArgs) { - tabby_download::download_model( - &args.model, - /* download_ctranslate2_files= */ true, - /* download_ggml_files= */ true, - args.prefer_local_file, - ) - .await - .unwrap_or_else(|err| { - fatal!( - "Failed to fetch model due to '{}', is '{}' a valid model id?", - err, - args.model - ) - }); + let downloader = Downloader::new(&args.model, args.prefer_local_file); + + let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,); + + downloader + .download_ctranslate2_files() + .await + .unwrap_or_else(handler); + downloader + .download_ggml_files() + .await + .unwrap_or_else(handler); + info!("model '{}' is ready", args.model); } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 943ca7a12b94..9691952ed1ea 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -12,6 +12,7 @@ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; use tabby_common::{config::Config, usage}; +use tabby_download::Downloader; use tokio::time::sleep; use tower_http::cors::CorsLayer; use tracing::{info, warn}; @@ -136,25 +137,16 @@ fn should_download_ggml_files(device: &Device) -> bool { pub async fn main(config: &Config, args: &ServeArgs) { valid_args(args); + let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true); if args.device != Device::ExperimentalHttp { - let download_ctranslate2_files = !should_download_ggml_files(&args.device); - let download_ggml_files = should_download_ggml_files(&args.device); - - // Ensure model exists. - tabby_download::download_model( - &args.model, - download_ctranslate2_files, - download_ggml_files, - /* prefer_local_file= */ true, - ) - .await - .unwrap_or_else(|err| { - fatal!( - "Failed to fetch model due to '{}', is '{}' a valid model id?", - err, - args.model - ) - }); + let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,); + let download_result = if should_download_ggml_files(&args.device) { + downloader.download_ggml_files().await + } else { + downloader.download_ctranslate2_files().await + }; + + download_result.unwrap_or_else(handler); } else { warn!("HTTP device is unstable and does not comply with semver expectations.") } diff --git a/experimental/copy-to-modelscope/.gitignore b/experimental/copy-to-modelscope/.gitignore new file mode 100644 index 000000000000..05f85898c239 --- /dev/null +++ b/experimental/copy-to-modelscope/.gitignore @@ -0,0 +1,2 @@ +hf_model +ms_model diff --git a/experimental/copy-to-modelscope/README.md b/experimental/copy-to-modelscope/README.md new file mode 100644 index 000000000000..c7dc94b38e79 --- /dev/null +++ b/experimental/copy-to-modelscope/README.md @@ -0,0 +1,3 @@ +# copy-to-modelscope + +Scripts to copy huggingface model to modelscope diff --git a/experimental/copy-to-modelscope/main.sh b/experimental/copy-to-modelscope/main.sh new file mode 100755 index 000000000000..a1437368703b --- /dev/null +++ b/experimental/copy-to-modelscope/main.sh @@ -0,0 +1,54 @@ +#!/bin/bash +set -e + +MODEL_ID=$1 +ACCESS_TOKEN=$2 + +usage() { + echo "Usage: $0 " + exit 1 +} + +if [ -z "${MODEL_ID}" ]; then + usage +fi + +git clone https://huggingface.co/$MODEL_ID hf_model +git clone https://oauth2:${ACCESS_TOKEN}@www.modelscope.cn/$MODEL_ID.git ms_model + +echo "Sync directory" +rsync -a --exclude '.git' hf_model/ ms_model/ + +echo "Create README.md" +cat <ms_model/README.md +--- +license: other +tasks: +- text-generation +--- + +# ${MODEL_ID} + +This is an mirror of [${MODEL_ID}](https://huggingface.co/${MODEL_ID}). +EOF + +echo "Create configuration.json" +cat <ms_model/configuration.json +{ + "framework": "pytorch", + "task": "text-generation", + "pipeline": { + "type": "text-generation-pipeline" + } +} +EOF + +set -x +cd ms_model +git add . +git commit -m "sync with upstream" +git push origin + +echo "Success!" +rm -rf hf_model +rm -rf ms_model