Skip to content

Commit

Permalink
iOS downloader integration (#183)
Browse files Browse the repository at this point in the history
Co-authored-by: Yaxing Cai <[email protected]>
  • Loading branch information
tqchen and cyx-6 authored May 19, 2023
1 parent 058cbbf commit 5351693
Show file tree
Hide file tree
Showing 15 changed files with 440 additions and 230 deletions.
36 changes: 25 additions & 11 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,6 @@ class LLMChat {
}

void Reload(tvm::runtime::Module executable, String model_path) {
// Step 0. Clear the previously allocated memory.
const PackedFunc* fclear_memory_manager =
tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear");
ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear";
(*fclear_memory_manager)();

// Step 1. Set tokenizer.
this->tokenizer_ = TokenizerFromPath(model_path);

Expand Down Expand Up @@ -744,7 +738,6 @@ class LLMChat {

std::vector<int32_t> prompt_tokens = this->GetPromptTokens();
int64_t token_len = static_cast<int64_t>(prompt_tokens.size());

tvm::runtime::NDArray input_data = this->GetInputTokenNDArray(prompt_tokens);

total_seq_len_ += token_len;
Expand Down Expand Up @@ -820,7 +813,7 @@ class LLMChat {
return encounter_stop_str_ || total_seq_len_ >= max_window_size_;
}

size_t FindEffectiveUTF8Pos(const std::string& s, size_t start_pos) {
size_t FindEffectiveUTF8Pos(const std::string& s) {
int pos = s.size() - 1;
for (; pos >= 0; pos--) {
if ((s[pos] & 0x80) == 0x00) {
Expand All @@ -840,8 +833,16 @@ class LLMChat {

std::string GetMessage() {
// remove non-utf8 characters
std::string cropped_message =
output_message_.substr(0, FindEffectiveUTF8Pos(output_message_, 0));
size_t effective_end = FindEffectiveUTF8Pos(output_message_);
while (effective_end > 0 && output_message_[effective_end - 1] == '\n') {
--effective_end;
}
size_t effective_begin = 0;
while (effective_begin < effective_end && output_message_[effective_begin] == ' ') {
++effective_begin;
}
std::string cropped_message = output_message_.substr(
effective_begin, effective_end - effective_begin);
return cropped_message;
}

Expand Down Expand Up @@ -1109,17 +1110,30 @@ class LLMChat {

class LLMChatModule : public ModuleNode {
public:
// clear global memory manager
static void ClearGlobalMemoryManager() {
// Step 0. Clear the previously allocated memory.
const PackedFunc* fclear_memory_manager =
tvm::runtime::Registry::Get("vm.builtin.memory_manager.clear");
ICHECK(fclear_memory_manager) << "Cannot find env function vm.builtin.memory_manager.clear";
(*fclear_memory_manager)();
}

// overrides
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "reload") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 2);
chat_ = nullptr;
ClearGlobalMemoryManager();
chat_ = std::make_unique<LLMChat>(LLMChat(device_));
chat_->Reload(args[0], args[1]);
});
} else if (name == "unload") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { chat_ = nullptr; });
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
chat_ = nullptr;
ClearGlobalMemoryManager();
});
} else if (name == "evaluate") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->Evaluate(); });
Expand Down
28 changes: 14 additions & 14 deletions ios/MLCChat.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
objects = {

/* Begin PBXBuildFile section */
1453A4C92A1353F1001B909F /* ModelConfig in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1453A4C72A1353D7001B909F /* ModelConfig */; };
1453A4CF2A1354B9001B909F /* StartView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CA2A1354B9001B909F /* StartView.swift */; };
1453A4D02A1354B9001B909F /* ModelView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CB2A1354B9001B909F /* ModelView.swift */; };
1453A4D12A1354B9001B909F /* StartState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* StartState.swift */; };
Expand All @@ -16,6 +15,7 @@
C06A74E429F99E5500BC4BE6 /* LLMChat.mm in Sources */ = {isa = PBXBuildFile; fileRef = C06A74E329F99E5500BC4BE6 /* LLMChat.mm */; };
C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; };
C06A74F429F9BE7A00BC4BE6 /* ThreadWorker.swift in Sources */ = {isa = PBXBuildFile; fileRef = C06A74F329F9BE7A00BC4BE6 /* ThreadWorker.swift */; };
C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */ = {isa = PBXBuildFile; fileRef = C09834182A16F4CB00A05B51 /* app-config.json */; };
C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; };
C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; };
C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; };
Expand All @@ -31,7 +31,7 @@
dstPath = "";
dstSubfolderSpec = 7;
files = (
1453A4C92A1353F1001B909F /* ModelConfig in CopyFiles */,
C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */,
C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand All @@ -49,17 +49,17 @@
/* End PBXCopyFilesBuildPhase section */

/* Begin PBXFileReference section */
1453A4C72A1353D7001B909F /* ModelConfig */ = {isa = PBXFileReference; lastKnownFileType = folder; name = ModelConfig; path = MLCChat/ModelConfig; sourceTree = "<group>"; };
1453A4CA2A1354B9001B909F /* StartView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = StartView.swift; path = MLCChat/StartView.swift; sourceTree = "<group>"; };
1453A4CB2A1354B9001B909F /* ModelView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = ModelView.swift; path = MLCChat/ModelView.swift; sourceTree = "<group>"; };
1453A4CC2A1354B9001B909F /* StartState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = StartState.swift; path = MLCChat/StartState.swift; sourceTree = "<group>"; };
1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = ModelConfig.swift; path = MLCChat/ModelConfig.swift; sourceTree = "<group>"; };
1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = ModelState.swift; path = MLCChat/ModelState.swift; sourceTree = "<group>"; };
1453A4CA2A1354B9001B909F /* StartView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StartView.swift; sourceTree = "<group>"; };
1453A4CB2A1354B9001B909F /* ModelView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelView.swift; sourceTree = "<group>"; };
1453A4CC2A1354B9001B909F /* StartState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StartState.swift; sourceTree = "<group>"; };
1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = "<group>"; };
1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelState.swift; sourceTree = "<group>"; };
C06A74E029F99C9F00BC4BE6 /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; path = dist; sourceTree = "<group>"; };
C06A74E229F99E5500BC4BE6 /* MLCChat-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "MLCChat-Bridging-Header.h"; sourceTree = "<group>"; };
C06A74E329F99E5500BC4BE6 /* LLMChat.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = LLMChat.mm; sourceTree = "<group>"; };
C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = "<group>"; };
C06A74F329F9BE7A00BC4BE6 /* ThreadWorker.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ThreadWorker.swift; sourceTree = "<group>"; };
C09834182A16F4CB00A05B51 /* app-config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = "app-config.json"; sourceTree = "<group>"; };
C0D643AF29F99A7F004DDAA4 /* MLCChat.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat.app; sourceTree = BUILT_PRODUCTS_DIR; };
C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCChatApp.swift; sourceTree = "<group>"; };
C0D643B629F99A80004DDAA4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
Expand All @@ -83,12 +83,6 @@
C0D643A629F99A7F004DDAA4 = {
isa = PBXGroup;
children = (
1453A4CD2A1354B9001B909F /* ModelConfig.swift */,
1453A4CE2A1354B9001B909F /* ModelState.swift */,
1453A4CB2A1354B9001B909F /* ModelView.swift */,
1453A4CC2A1354B9001B909F /* StartState.swift */,
1453A4CA2A1354B9001B909F /* StartView.swift */,
1453A4C72A1353D7001B909F /* ModelConfig */,
C06A74E029F99C9F00BC4BE6 /* dist */,
C0D643B129F99A7F004DDAA4 /* MLCChat */,
C0D643B029F99A7F004DDAA4 /* Products */,
Expand All @@ -107,6 +101,12 @@
C0D643B129F99A7F004DDAA4 /* MLCChat */ = {
isa = PBXGroup;
children = (
C09834182A16F4CB00A05B51 /* app-config.json */,
1453A4CD2A1354B9001B909F /* ModelConfig.swift */,
1453A4CE2A1354B9001B909F /* ModelState.swift */,
1453A4CB2A1354B9001B909F /* ModelView.swift */,
1453A4CC2A1354B9001B909F /* StartState.swift */,
1453A4CA2A1354B9001B909F /* StartView.swift */,
C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */,
C0D643C729F99B34004DDAA4 /* MessageView.swift */,
C06A74E329F99E5500BC4BE6 /* LLMChat.mm */,
Expand Down
54 changes: 37 additions & 17 deletions ios/MLCChat/ChatState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,13 @@ class ChatState : ObservableObject {
private var stopLock = NSLock();
private var requestedReset = false;
private var stopRequested = false;
private var gpuVRAMDetectionPass = false;

private var reloadReady = false;
private var modelLib = "";
private var modelPath = "";

init() {
threadWorker.qualityOfService = QualityOfService.userInteractive;
threadWorker.start()
// TODO(change to state dependent)
let modelLib = "RedPajama-INCITE-Chat-3B-v1-q4f16_0";
let modelPath = Bundle.main.bundlePath + "/dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0";
self.mainReload(modelName: "RedPajama-3B", modelLib: modelLib, modelPath: modelPath)
}

// reset all chat state
Expand All @@ -50,14 +47,27 @@ class ChatState : ObservableObject {
self.requestedReset = false;
}

func mainReload(modelName: String, modelLib: String, modelPath: String, estimatedMemReq : Int64 = 4000000000) {
if (self.inProgress) {
func mainReload(modelName: String, modelLib: String, modelPath: String, estimatedMemReq : Int64) {
if (self.reloadReady &&
self.modelLib == modelLib &&
self.modelPath == modelPath &&
self.modelName == modelName) {
return;
}
self.mainResetChat()
self.gpuVRAMDetectionPass = false;
// request stop regardless of the state
// to previous action can finish soon
if (self.inProgress) {
self.stopLock.lock()
self.stopRequested = true;
self.stopLock.unlock()
}
self.mainResetChat();
// we are not reload ready
self.reloadReady = false;
self.inProgress = true;
self.modelName = modelName;
self.modelLib = modelLib;
self.modelPath = modelPath;

threadWorker.push {[self] in
self.updateReply(role: MessageRole.bot, message: "[System] Initalize...")
Expand All @@ -71,13 +81,15 @@ class ChatState : ObservableObject {
"Sorry, the system do not have" + reqMem + " memory as requested, " +
"so we cannot initialize this model on this device."
)
self.messages.append(MessageData(role: MessageRole.bot, message: errMsg))
self.gpuVRAMDetectionPass = false
self.inProgress = true
DispatchQueue.main.sync {
self.messages.append(MessageData(role: MessageRole.bot, message: errMsg))
self.reloadReady = false
self.inProgress = true
}
return
}
self.gpuVRAMDetectionPass = true
backend.reload(modelLib, modelPath: modelPath)
self.reloadReady = true
self.updateReply(role: MessageRole.bot, message: "[System] Ready to chat")
self.commitReply()
self.markFinish()
Expand Down Expand Up @@ -124,6 +136,13 @@ class ChatState : ObservableObject {
let needStop = self.stopRequested;
self.stopLock.unlock()
if (needStop) {
let forceStop = !self.reloadReady;
// if we are not reload ready
// this means we are forced stoped during reload
// do not do anything to refresh UX
if (forceStop) {
return
}
break;
}
}
Expand All @@ -133,12 +152,13 @@ class ChatState : ObservableObject {
DispatchQueue.main.sync { [runtimeText] in
self.infoText = runtimeText;
}

self.markFinish()
};
}

func generate(prompt: String) {
if (!self.gpuVRAMDetectionPass) {
if (!self.reloadReady) {
return
}
self.inProgress = true
Expand All @@ -147,7 +167,7 @@ class ChatState : ObservableObject {
}

func requestStop() {
if (!self.gpuVRAMDetectionPass) {
if (!self.reloadReady) {
return
}
if (self.inProgress) {
Expand All @@ -158,7 +178,7 @@ class ChatState : ObservableObject {
}

func resetChat() {
if (!self.gpuVRAMDetectionPass) {
if (!self.reloadReady) {
return
}
if (self.inProgress) {
Expand Down
98 changes: 47 additions & 51 deletions ios/MLCChat/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,58 @@ struct ChatView: View {
@EnvironmentObject var state: ChatState
@Namespace var bottomID;
@Namespace var infoID;

init() {
// UIScrollView.appearance().bounces = false
}

init() {}

var body: some View {
NavigationView {
VStack {
Text(state.infoText)
.multilineTextAlignment(.center)
.opacity(0.5)
.listRowSeparator(.hidden)
ScrollViewReader { proxy in
ScrollView {
// Hack:rotate and reverse the inner view
// then rotate inverse the scroll
// so the result auto-scrolls
//
// This works more smoothly than scrollTo
// when there are a lot of text
if (state.unfinishedRespondMessage != "") {
MessageView(
role: state.unfinishedRespondRole,
message: state.unfinishedRespondMessage
).rotationEffect(.radians(.pi))
VStack {
Text(state.infoText)
.multilineTextAlignment(.center)
.opacity(0.5)
.listRowSeparator(.hidden)
ScrollViewReader { proxy in
ScrollView {
// Hack:rotate and reverse the inner view
// then rotate inverse the scroll
// so the result auto-scrolls
//
// This works more smoothly than scrollTo
// when there are a lot of text
if (state.unfinishedRespondMessage != "") {
MessageView(
role: state.unfinishedRespondRole,
message: state.unfinishedRespondMessage
).rotationEffect(.radians(.pi))
.scaleEffect(x: -1, y: 1, anchor: .center)
}
ForEach(state.messages.reversed()) { msg in
MessageView(role: msg.role, message: msg.message)
.rotationEffect(.radians(.pi))
.scaleEffect(x: -1, y: 1, anchor: .center)
}
ForEach(state.messages.reversed()) { msg in
MessageView(role: msg.role, message: msg.message)
.rotationEffect(.radians(.pi))
.scaleEffect(x: -1, y: 1, anchor: .center)
}
}.rotationEffect(.radians(.pi))
.scaleEffect(x: -1, y: 1, anchor: .center)
}

HStack {
TextField("Inputs...", text: $inputMessage, axis: .vertical)
.textFieldStyle(RoundedBorderTextFieldStyle())
.frame(minHeight: CGFloat(30))
.focused($inputIsFocused)
Button("Send") {
self.inputIsFocused = false
generateMessage()
}.bold().opacity(state.inProgress ? 0.5 : 1)
}.frame(minHeight: CGFloat(70)).padding()
}
.navigationBarTitle("MLC Chat: " + state.modelName, displayMode: .inline)
.toolbar{
ToolbarItem(placement: .navigationBarLeading) {
Button("Reset") {
resetChat()
}
.opacity(0.9)
.padding()
}.rotationEffect(.radians(.pi))
.scaleEffect(x: -1, y: 1, anchor: .center)
}

HStack {
TextField("Inputs...", text: $inputMessage, axis: .vertical)
.textFieldStyle(RoundedBorderTextFieldStyle())
.frame(minHeight: CGFloat(30))
.focused($inputIsFocused)
Button("Send") {
self.inputIsFocused = false
generateMessage()
}.bold().opacity(state.inProgress ? 0.5 : 1)
}.frame(minHeight: CGFloat(70)).padding()
}
.navigationBarTitle(state.modelName, displayMode: .inline)
.toolbar{
ToolbarItem(placement: .navigationBarTrailing) {
Button("Reset") {
resetChat()
}
.opacity(0.9)
.padding()
}
}
}
Expand Down
Loading

0 comments on commit 5351693

Please sign in to comment.