Skip to content

Commit

Permalink
remove END
Browse files Browse the repository at this point in the history
  • Loading branch information
qooba committed May 30, 2023
1 parent 82bac87 commit 3d1d6db
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ impl Args {
}
}

const END: &str = "<<END>>";
static TX_INFER: OnceCell<Arc<Mutex<SyncSender<String>>>> = OnceCell::new();

static RX_CALLBACK: OnceCell<Arc<Mutex<Receiver<String>>>> = OnceCell::new();
static RX_CALLBACK: OnceCell<Arc<Mutex<Receiver<llm::InferenceResponse>>>> = OnceCell::new();

#[derive(Deserialize, Debug, Clone)]
pub struct ChatRequest {
Expand All @@ -63,13 +61,18 @@ async fn chat(chat_request: web::Query<ChatRequest>) -> Result<impl Responder, B
let stream_tasks = async_stream::stream! {
let mut bytes = BytesMut::new();
while let Some(msg) = &rx_callback.recv().await {
if msg.to_string() == END {
break

match msg {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
let b = t.to_string().as_bytes();
bytes.extend_from_slice(t.as_bytes());
let byte = bytes.split().freeze();
yield Ok::<Bytes, Box<dyn Error>>(byte);

}
_ => break

}
let b = msg.to_string().as_bytes();
bytes.extend_from_slice(msg.as_bytes());
let byte = bytes.split().freeze();
yield Ok::<Bytes, Box<dyn Error>>(byte);
}
};

Expand All @@ -81,7 +84,7 @@ async fn chat(chat_request: web::Query<ChatRequest>) -> Result<impl Responder, B
fn infer(
args: &Args,
rx_infer: std::sync::mpsc::Receiver<String>,
tx_callback: tokio::sync::mpsc::Sender<String>,
tx_callback: tokio::sync::mpsc::Sender<llm::InferenceResponse>,
) -> Result<()> {
let vocabulary_source = args.to_vocabulary_source();
let model_architecture = args.model_architecture;
Expand Down Expand Up @@ -118,17 +121,13 @@ fn infer(
maximum_token_count: None,
},
&mut Default::default(),
|r| match r {
llm::InferenceResponse::PromptToken(t)
| llm::InferenceResponse::InferredToken(t) => {
tx_callback.blocking_send(t.to_string());
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
|r| {
tx_callback.blocking_send(r);
Ok(llm::InferenceFeedback::Continue)
},
);

tx_callback.blocking_send(END.to_string());
tx_callback.blocking_send(llm::InferenceResponse::EotToken);
println!("INFER END");
}

Expand All @@ -138,22 +137,20 @@ fn infer(
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let args = Args::parse();
println!("{args:#?}");

let (tx_infer, rx_infer) = sync_channel::<String>(3);
let (tx_callback, rx_callback) = channel::<String>(3);
let (tx_callback, rx_callback) = channel::<llm::InferenceResponse>(3);

TX_INFER.set(Arc::new(Mutex::new(tx_infer))).unwrap();
RX_CALLBACK.set(Arc::new(Mutex::new(rx_callback))).unwrap();

//"/home/jovyan/rust-src/llm-ui/models/ggml-model-q4_0.binA
let c_args = args.clone();
let host = args.host.to_string();
let port: u16 = args.port.clone();

thread::spawn(move || {
infer(&args, rx_infer, tx_callback);
});

let host = c_args.host.to_string();
let port: u16 = c_args.port;

HttpServer::new(|| {
App::new()
Expand Down

0 comments on commit 3d1d6db

Please sign in to comment.