Skip to content

Commit

Permalink
Initialize chat module from config JSON string (#166)
Browse files Browse the repository at this point in the history
This PR supports initialization the chat module from JSON-structured
string.
* The CLI is updated with the latest json initialization.
* The iOS/Android app sides are remained using the previous way of
initialization. Now the previous function is renamed to
`init_chat_legacy`, which we will remove after updating the app and web
sides accordingly.
* This PR reverts part of the previous metadata PR as a wrong
assumption was made on the conversation template.
* This PR removes `stream_interval` from chat module as it is unused.
Now `stream_interval` is only kept on the CLI side.
  • Loading branch information
MasterJH5574 authored May 17, 2023
1 parent 0078c52 commit 47a7a11
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 29 deletions.
4 changes: 2 additions & 2 deletions android/MLCChat/app/src/main/java/ai/mlc/mlcchat/LLMChat.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ public void Init() {
assert stopped_func_ != null;
assert runtime_stats_text_func_ != null;

String conv_template = "vicuna_v1.1";
double temperature = 0.7;
double top_p = 0.95;
int stream_interval = 1;
int mean_gen_len = 128;
double shift_fill_factor = 0.2;
llm_chat_.getFunction("init_chat").pushArg(temperature).pushArg(top_p).pushArg(stream_interval).pushArg(mean_gen_len).pushArg(shift_fill_factor).invoke();
llm_chat_.getFunction("init_chat_legacy").pushArg(conv_template).pushArg(temperature).pushArg(top_p).pushArg(mean_gen_len).pushArg(shift_fill_factor).invoke();

systemlib_func.release();
lib.release();
Expand Down
3 changes: 1 addition & 2 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,12 @@ def dump_default_mlc_llm_config(args):
config["conv_template"] = args.conv_template
config["temperature"] = 0.7
config["top_p"] = 0.95
config["stream_interval"] = 2
config["mean_gen_len"] = 128
config["shift_fill_factor"] = 0.3
dump_path = os.path.join(args.artifact_path, "params", "mlc-chat-config.json")
with open(dump_path, "w") as outfile:
json.dump(config, outfile, indent=4)
print(f"Finish exporting mlc_llm_config to {dump_path}")
print(f"Finish exporting chat config to {dump_path}")


def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
Expand Down
15 changes: 9 additions & 6 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <bitset>
#include <cstdio>
#include <filesystem>
#include <fstream>
#include <optional>
#include <string>
#include <vector>
Expand Down Expand Up @@ -147,12 +148,9 @@ void PrintSpecialCommands() {
* \param temperature The temperature to use for sampling.
* \param top_p The top_p to use for sampling.
*/
void Chat(tvm::runtime::Module chat_mod, double temperature = 0.7, double top_p = 0.95,
int64_t stream_interval = 2, int max_window_size = 768, int mean_gen_len = 128,
double shift_fill_factor = 0.3) {
void Chat(tvm::runtime::Module chat_mod, std::string config_str, int stream_interval = 2) {
// initialize chat context
chat_mod.GetFunction("init_chat")(temperature, top_p, stream_interval, mean_gen_len,
shift_fill_factor);
chat_mod.GetFunction("init_chat")(tvm::String(config_str));
auto f_stop = chat_mod.GetFunction("stopped");
auto f_encode = chat_mod.GetFunction("encode");
auto f_decode = chat_mod.GetFunction("decode");
Expand Down Expand Up @@ -286,6 +284,11 @@ int main(int argc, char* argv[]) {
return 1;
}
std::cout << "Use config " << config_path_opt.value().string() << std::endl;
std::ifstream config_istream(config_path_opt.value().c_str());
std::ostringstream config_ostream;
assert(config_istream);
config_ostream << config_istream.rdbuf();
std::string config_str = config_ostream.str();
std::filesystem::path model_path = config_path_opt.value().parent_path();

// Locate the library.
Expand Down Expand Up @@ -329,7 +332,7 @@ int main(int argc, char* argv[]) {
if (args.get<bool>("--evaluate")) {
chat_mod.GetFunction("evaluate")();
} else {
Chat(chat_mod);
Chat(chat_mod, config_str);
}
} catch (const std::runtime_error& err) {
// catch exception so error message
Expand Down
64 changes: 55 additions & 9 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,21 +409,69 @@ class LLMChatModule : public ModuleNode {
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { this->DecodeStep(); });
} else if (name == "init_chat") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
// Process metadata
std::string metadata_str = this->GetMetadata();
picojson::value metadata_info;
picojson::parse(metadata_info, metadata_str);
auto metadata = metadata_info.get<picojson::object>();
ICHECK(metadata["model_name"].is<std::string>());
ICHECK(metadata["conv_template"].is<std::string>());
ICHECK(metadata["max_window_size"].is<int64_t>());
ICHECK(metadata["add_prefix_space"].is<bool>());
ICHECK(metadata["stop_tokens"].is<picojson::array>());

this->model_name_ = metadata["model_name"].get<std::string>();
std::string conv_template = metadata["conv_template"].get<std::string>();
this->max_window_size_ = metadata["max_window_size"].get<int64_t>();
this->add_prefix_space_ = metadata["add_prefix_space"].get<bool>();
auto stop_tokens = metadata["stop_tokens"].get<picojson::array>();
this->stop_tokens_.reserve(stop_tokens.size());
for (const picojson::value& stop_token : stop_tokens) {
ICHECK(stop_token.is<int64_t>());
this->stop_tokens_.push_back(static_cast<int32_t>(stop_token.get<int64_t>()));
}

// Process config json string
ICHECK_EQ(args.size(), 1);
ObjectRef config_obj = args[0];
std::string config_str = std::string(Downcast<String>(config_obj));
picojson::value config_info;
picojson::parse(config_info, config_str);
auto config = config_info.get<picojson::object>();
ICHECK(config["conv_template"].is<std::string>());
ICHECK(config["temperature"].is<double>());
ICHECK(config["top_p"].is<double>());
ICHECK(config["mean_gen_len"].is<int64_t>());
ICHECK(config["shift_fill_factor"].is<double>());
std::string conv_template = config["conv_template"].get<std::string>();
this->temperature_ = config["temperature"].get<double>();
this->top_p_ = config["top_p"].get<double>();
this->mean_gen_len_ = config["mean_gen_len"].get<int64_t>();
this->shift_fill_factor_ = config["shift_fill_factor"].get<double>();

this->conversation_ = Conversation::Create(conv_template);
this->ClearKVCache();
this->total_seq_len_ = 0;
this->start_pos_ = 0;
this->cur_pos_ = 0;
this->add_bos_ = true;
this->stop_str_ =
this->conversation_.separator_style == Conversation::SeparatorStyle::kSingle
? this->conversation_.sep
: this->conversation_.sep2;
});
} else if (name == "init_chat_legacy") {
// TODO: remove the legacy initialization func after updating app and web sides.
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
// Process metadata
std::string metadata_str = this->GetMetadata();
picojson::value metadata_info;
picojson::parse(metadata_info, metadata_str);
auto metadata = metadata_info.get<picojson::object>();
ICHECK(metadata["model_name"].is<std::string>());
ICHECK(metadata["max_window_size"].is<int64_t>());
ICHECK(metadata["add_prefix_space"].is<bool>());
ICHECK(metadata["stop_tokens"].is<picojson::array>());
this->model_name_ = metadata["model_name"].get<std::string>();
this->max_window_size_ = metadata["max_window_size"].get<int64_t>();
this->add_prefix_space_ = metadata["add_prefix_space"].get<bool>();
auto stop_tokens = metadata["stop_tokens"].get<picojson::array>();
this->stop_tokens_.reserve(stop_tokens.size());
for (const picojson::value& stop_token : stop_tokens) {
Expand All @@ -432,12 +480,12 @@ class LLMChatModule : public ModuleNode {
}

ICHECK_EQ(args.size(), 5);
this->conversation_ = Conversation::Create(conv_template);
this->temperature_ = args[0];
this->top_p_ = args[1];
this->stream_interval_ = args[2];
this->conversation_ = Conversation::Create(args[0]);
this->temperature_ = args[1];
this->top_p_ = args[2];
this->mean_gen_len_ = args[3];
this->shift_fill_factor_ = args[4];

this->ClearKVCache();
this->total_seq_len_ = 0;
this->start_pos_ = 0;
Expand Down Expand Up @@ -931,8 +979,6 @@ class LLMChatModule : public ModuleNode {
double temperature_{0.8};
// top_p
double top_p_{0.95};
// stream interval
int64_t stream_interval_{1};
// next_token
int32_t next_token_{0};
// output ids till now (refresh after encoding step)
Expand Down
6 changes: 3 additions & 3 deletions ios/MLCChat/LLMChat.mm
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@
ICHECK(stopped_func_ != nullptr);
ICHECK(runtime_stats_text_func_ != nullptr);

std::string conv_template = "vicuna_v1.1";
double temperature = 0.7;
double top_p = 0.95;
int stream_interval = 1;
int mean_gen_len = 128;
double shift_fill_factor = 0.2;
llm_chat_->GetFunction("init_chat")(temperature, top_p, stream_interval, mean_gen_len,
shift_fill_factor);
llm_chat_->GetFunction("init_chat_legacy")(conv_template, temperature, top_p, mean_gen_len,
shift_fill_factor);
}

void Evaluate() {
Expand Down
2 changes: 0 additions & 2 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
def create_metadata_func(
bb: relax.BlockBuilder,
model_name: str,
conv_template: str,
max_window_size: int,
stop_tokens: List[int],
add_prefix_space: bool,
):
metadata = json.dumps(
{
"model_name": model_name,
"conv_template": conv_template,
"max_window_size": max_window_size,
"stop_tokens": stop_tokens,
"add_prefix_space": add_prefix_space,
Expand Down
3 changes: 0 additions & 3 deletions mlc_llm/relax_model/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,8 @@ def get_model(model_name: str, model_path: str, dtype: str, hf_config):
from transformers import AutoModelForCausalLM # type: ignore[import]

if model_name.startswith("dolly"):
conv_template = "dolly"
stop_tokens = [2]
elif model_name.startswith("stablelm"):
conv_template = "stablelm"
stop_tokens = [50278, 50279, 50277, 1, 0]
else:
raise ValueError(f"Unsupported model {model_name}")
Expand Down Expand Up @@ -646,7 +644,6 @@ def get_model(model_name: str, model_path: str, dtype: str, hf_config):
create_metadata_func(
bb,
model_name=model_name,
conv_template=conv_template,
max_window_size=config.max_sequence_length,
stop_tokens=stop_tokens,
add_prefix_space=False,
Expand Down
1 change: 0 additions & 1 deletion mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ def get_model(args, hf_config):
create_metadata_func(
bb,
model_name=model_name,
conv_template="vicuna_v1.1" if model_name.startswith("vicuna") else "",
max_window_size=config.max_sequence_length,
stop_tokens=[2],
add_prefix_space=False,
Expand Down
1 change: 0 additions & 1 deletion mlc_llm/relax_model/moss.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ def get_model(args, hf_config):
create_metadata_func(
bb,
model_name=model_name,
conv_template="moss",
max_window_size=config.max_sequence_length,
stop_tokens=[106068],
add_prefix_space=True,
Expand Down

0 comments on commit 47a7a11

Please sign in to comment.