Skip to content

Commit

Permalink
Add model weight variant in iOS (#187)
Browse files Browse the repository at this point in the history
Co-authored-by: Yaxing Cai <[email protected]>
  • Loading branch information
tqchen and cyx-6 authored May 19, 2023
1 parent 4ebfba9 commit da63f2c
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 64 deletions.
8 changes: 4 additions & 4 deletions ios/MLCChat/ChatState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ChatState : ObservableObject {
self.requestedReset = false;
}

func mainReload(modelName: String, modelLib: String, modelPath: String, estimatedMemReq : Int64) {
func mainReload(modelName: String, modelLib: String, modelPath: String, estimatedVRAMReq : Int64) {
if (self.reloadReady &&
self.modelLib == modelLib &&
self.modelPath == modelPath &&
Expand All @@ -73,12 +73,12 @@ class ChatState : ObservableObject {
self.updateReply(role: MessageRole.bot, message: "[System] Initalize...")
backend.unload();
let vram = os_proc_available_memory()
if (vram < estimatedMemReq) {
if (vram < estimatedVRAMReq) {
let reqMem = String (
format: "%.1fGB", Double(estimatedMemReq) / Double(1 << 20)
format: "%.1fGB", Double(estimatedVRAMReq) / Double(1 << 20)
)
let errMsg = (
"Sorry, the system do not have" + reqMem + " memory as requested, " +
"Sorry, the system cannot provide " + reqMem + " VRAM as requested to the app, " +
"so we cannot initialize this model on this device."
)
DispatchQueue.main.sync {
Expand Down
12 changes: 11 additions & 1 deletion ios/MLCChat/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ struct ChatView: View {
@EnvironmentObject var state: ChatState
@Namespace var bottomID;
@Namespace var infoID;
@Environment(\.dismiss) private var dismiss

init() {}

Expand Down Expand Up @@ -54,8 +55,17 @@ struct ChatView: View {
}.bold().opacity(state.inProgress ? 0.5 : 1)
}.frame(minHeight: CGFloat(70)).padding()
}
.navigationBarTitle(state.modelName, displayMode: .inline)
.navigationBarTitle("MLC Chat: " + state.modelName, displayMode: .inline)
.navigationBarBackButtonHidden()
.toolbar{
ToolbarItem(placement: .navigationBarLeading) {
Button() {
dismiss()
} label: {
Image(systemName: "chevron.backward")
}
.buttonStyle(.borderless)
}
ToolbarItem(placement: .navigationBarTrailing) {
Button("Reset") {
resetChat()
Expand Down
17 changes: 0 additions & 17 deletions ios/MLCChat/MLCChatApp.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@

import SwiftUI

//@main
//struct MLCChatApp: App {
// @StateObject private var state = ChatState()
//
// init() {
// UITableView.appearance().separatorStyle = .none
// UITableView.appearance().tableFooterView = UIView()
// }
//
// var body: some Scene {
// WindowGroup {
// ChatView()
// .environmentObject(state)
// }
// }
//}

@main
struct MLCChatApp: App {
@StateObject private var state = StartState()
Expand Down
9 changes: 8 additions & 1 deletion ios/MLCChat/ModelConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct ModelConfig: Codable, Hashable {
let local_id: String
let tokenizer_files: [String]
let display_name: String!
let estimated_memory_req: Int64!
let estimated_vram_req: Int64!
}

struct ParamsRecord: Codable, Hashable {
Expand All @@ -29,4 +29,11 @@ struct ModelRecord: Codable, Hashable {
struct AppConfig: Codable, Hashable {
let model_libs: [String]
var model_list: [ModelRecord]
let add_model_samples: [ModelRecord]
}

struct ExampleModelUrl: Hashable, Identifiable {
let id = UUID()
let model_url: String
let local_id: String
}
12 changes: 6 additions & 6 deletions ios/MLCChat/ModelState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
// ModelState.swift
// MLCChat
//
// Created by Yaxing Cai on 5/15/23.
//

import Foundation

enum ModelInitState{
enum ModelInitState {
case Initializing
case Indexing
case Paused
Expand All @@ -24,7 +22,7 @@ struct DownloadTask: Hashable {
}


class ModelState : ObservableObject, Identifiable{
class ModelState : ObservableObject, Identifiable {
@Published var modelConfig: ModelConfig!
@Published var modelInitState: ModelInitState = .Initializing
@Published var progress: Int = 0
Expand All @@ -46,13 +44,13 @@ class ModelState : ObservableObject, Identifiable{

func reloadChatStateWithThisModel() {
// TODO(tvm-team) consider log optional model name
let estimatedMemReq = modelConfig.estimated_memory_req ?? 4000000000;
let estimatedVRAMReq = modelConfig.estimated_vram_req ?? 4000000000;
let modelName = modelConfig.display_name ?? modelConfig.local_id.components(separatedBy: "-")[0];
self.chatState.mainReload(
modelName: modelName,
modelLib: modelConfig.model_lib,
modelPath: modelDirUrl.path(),
estimatedMemReq: estimatedMemReq)
estimatedVRAMReq: estimatedVRAMReq)
}

func switchToInitializing(modelConfig: ModelConfig, modelUrl: URL?, modelDirUrl: URL) {
Expand Down Expand Up @@ -91,6 +89,7 @@ class ModelState : ObservableObject, Identifiable{
urlOrNil, responseOrNil, errorOrNil in
guard let fileUrl = urlOrNil else { return }
do {
try? self.fileManager.removeItem(at: paramsConfigUrl)
try self.fileManager.moveItem(at: fileUrl, to: paramsConfigUrl)
DispatchQueue.main.async {
self.loadParamsConfig()
Expand Down Expand Up @@ -176,6 +175,7 @@ class ModelState : ObservableObject, Identifiable{

do {
try self.fileManager.createDirectory(at: downloadTask.localUrl.deletingLastPathComponent(), withIntermediateDirectories: true)
try? self.fileManager.removeItem(at: downloadTask.localUrl)
try self.fileManager.moveItem(at: fileUrl, to: downloadTask.localUrl)
} catch {
print(error.localizedDescription)
Expand Down
92 changes: 60 additions & 32 deletions ios/MLCChat/StartState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ import Foundation

class StartState : ObservableObject {
@Published var models = [ModelState]()
@Published var exampleModelUrls = [ExampleModelUrl]()
@Published var alertMessage = ""
@Published var alertDisplayed = false
private var appConfig: AppConfig!
private let chatState = ChatState()
private var cacheDirUrl: URL!
private let fileManager: FileManager = FileManager.default
private let decoder = JSONDecoder()
private let encoder = JSONEncoder()
private var prebuiltLocals = Set<String>();

init() {
let bundleUrl = Bundle.main.bundleURL
// models in dist
Expand All @@ -40,44 +43,59 @@ class StartState : ObservableObject {
} catch {
print(error.localizedDescription)
}

// models in cache to download
do {
cacheDirUrl = fileManager.urls(for: .cachesDirectory, in: .userDomainMask)[0]
let appConfigUrl = bundleUrl.appending(path: "app-config.json")
var appConfigUrl = cacheDirUrl.appending(path: "app-config.json")
if !fileManager.fileExists(atPath: appConfigUrl.path()) {
appConfigUrl = bundleUrl.appending(path: "app-config.json")
}
assert(fileManager.fileExists(atPath: appConfigUrl.path()))
let fileHandle = try FileHandle(forReadingFrom: appConfigUrl)
let data = fileHandle.readDataToEndOfFile()
self.appConfig = try decoder.decode(AppConfig.self, from: data)

for model in self.appConfig.model_list {
if self.prebuiltLocals.contains(model.local_id) {
continue
}
let configUrl = cacheDirUrl
.appending(path: model.local_id)
.appending(path: "mlc-chat-config.json")

if fileManager.fileExists(atPath: configUrl.path()) {
loadConfig(modelRecord: model)
} else {
downloadConfig(modelUrl: URL(string: model.model_url)!, newRecord: false)
}
}
for rec in self.appConfig.add_model_samples {
self.exampleModelUrls.append(ExampleModelUrl(model_url: rec.model_url, local_id: rec.local_id))
}
} catch {
print(error.localizedDescription)
self.showAlert(message: String(
format: "Init error %s", error.localizedDescription))
}
}



func addModel(modelRemoteBaseUrl: String) {
for model in self.appConfig.model_list {
if model.model_url == modelRemoteBaseUrl {
self.showAlert(message: "Model URL already added")
return
}
}
downloadConfig(modelUrl: URL(string: modelRemoteBaseUrl)!, newRecord: true)
}


func showAlert(message: String) {
self.alertMessage = message
self.alertDisplayed = true
}


func downloadConfig(modelUrl: URL, newRecord: Bool) {
let downloadTask = URLSession.shared.downloadTask(with: modelUrl.appending(path: "mlc-chat-config.json")) {
urlOrNil, responseOrNil, errorOrNil in
Expand All @@ -87,45 +105,58 @@ class StartState : ObservableObject {
let data = fileHandle.readDataToEndOfFile()
let modelConfig = try self.decoder.decode(ModelConfig.self, from: data)
let modelBaseUrl = self.cacheDirUrl.appending(path: modelConfig.local_id)

let record = ModelRecord(model_url: modelUrl.absoluteString, local_id: modelConfig.local_id)

if (newRecord) {
// TODO(tvm-team) add error message about duolicate
if self.prebuiltLocals.contains(modelConfig.local_id) {
return;
var localIdExist = self.prebuiltLocals.contains(modelConfig.local_id)

if (!localIdExist) {
for model in self.appConfig.model_list {
if (model.local_id == modelConfig.local_id) {
localIdExist = true
break
}
}
}
for model in self.appConfig.model_list {
if (model.local_id == modelConfig.local_id) {
return
if (localIdExist) {
DispatchQueue.main.async {
self.showAlert(message: String(
format: "local_id %s already exists",
record.local_id
))
}
return
}
self.appConfig.model_list.append(record)
self.commitUpdate()
}

try self.fileManager.createDirectory(at: modelBaseUrl, withIntermediateDirectories: true)
try self.fileManager.moveItem(at: fileUrl, to: modelBaseUrl.appending(path: "mlc-chat-config.json"))
let dst = modelBaseUrl.appending(path: "mlc-chat-config.json")
try? self.fileManager.removeItem(at: dst)
try self.fileManager.moveItem(at: fileUrl, to: dst)
DispatchQueue.main.async {
let record = ModelRecord(model_url: modelUrl.absoluteString, local_id: modelConfig.local_id)
self.loadConfig(modelRecord: record)
if (!newRecord) {
self.appConfig.model_list.append(record)
self.commitUpdate()
}

}
} catch {
print(error.localizedDescription)
DispatchQueue.main.sync {
self.showAlert(message: "Cannot download model config from the given url")
}
}
}
downloadTask.resume()
}

func loadConfig(modelRecord: ModelRecord) {
// local-id dir should exist
let modelBaseUrl = cacheDirUrl.appending(path: modelRecord.local_id)
assert(fileManager.fileExists(atPath: modelBaseUrl.path()))

// mlc-chat-config.json should exist
let modelConfigUrl = modelBaseUrl.appending(path: "mlc-chat-config.json")
assert(fileManager.fileExists(atPath: modelConfigUrl.path()))

do {
let fileHandle = try FileHandle(forReadingFrom: modelConfigUrl)
let data = fileHandle.readDataToEndOfFile()
Expand All @@ -140,15 +171,12 @@ class StartState : ObservableObject {
print(error.localizedDescription)
}
}

func commitUpdate(){
// TODO(tvm-team): atomic switch over
let appConfigUrl = Bundle.main.bundleURL.appending(path: "app-config.json")
assert(fileManager.fileExists(atPath: appConfigUrl.path()))
let appConfigUrl = cacheDirUrl.appending(path: "app-config.json")
do {
let fileHandle = try FileHandle(forWritingTo: appConfigUrl)
let data = try encoder.encode(appConfig)
try fileHandle.write(contentsOf: data)
try data.write(to: appConfigUrl, options: Data.WritingOptions.atomic)
} catch {
print(error.localizedDescription)
}
Expand Down
Loading

0 comments on commit da63f2c

Please sign in to comment.