Skip to content

Commit

Permalink
[iOS] support for multimodal (#524)
Browse files Browse the repository at this point in the history
This PR introduces multimodality for iOS. Specifically, below is a demo
of running MiniGPT on iOS.



Changes:
* standalone `image_embed.cc` and `image_module.py` for
image-module-related functionalities
* support uploading images or photo taking in iOS
* `prefillImage` function in `LLMChat.mm` handling conversion from
UIImage* to void* to tvm::runtime::NDArray
* add a image pre-processing module in `relax_model`

Update:
* did not add minigpt model to the `app-config.json` file cuz it would
affect users. let's add it in a followup pr after we upload the tuned
minigpt model to HF
  • Loading branch information
Kathryn-cat authored Jul 17, 2023
1 parent 0358e5a commit 6cf8d4f
Show file tree
Hide file tree
Showing 13 changed files with 792 additions and 56 deletions.
212 changes: 212 additions & 0 deletions cpp/image_embed.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*!
* Copyright (c) 2023 by Contributors
* \file image_embed.cc
* \brief Implementation of image embedding module in support of multimodality in LLM.
*/
#define PICOJSON_USE_INT64
#define __STDC_FORMAT_MACROS

#include "image_embed.h"

#include <picojson.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/relax_vm/memory_manager.h>

#include <cctype>
#include <chrono>
#include <filesystem>
#include <fstream>
#include <iomanip>
#include <list>
#include <memory>
#include <optional>
#include <random>
#include <string>
#include <unordered_set>

namespace mlc {
namespace llm {

using tvm::Device;
using namespace tvm::runtime;

//------------------------------
// Image embedding module
//------------------------------
class LLMImageModule;

/*!
* \brief Implements the image embedding module wrapper
*/
class LLMImage {
friend class LLMImageModule;

public:
explicit LLMImage(DLDevice device) : device_(device) {}

/*!
* \brief Reload the image model from the specified model path.
* \param executable The module to reload.
* \param model_path The path to search for models.
*/
void Reload(tvm::runtime::Module executable, String model_path) {
// Step 1. Initialize vm, we use the packed function mechanism
// so there is no explicit abi dependency on these extra
// classes other than basic tvm runtime.
auto fload_exec = executable->GetFunction("vm_load_executable");
ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable";
vm_ = fload_exec();
vm_->GetFunction("vm_initialization")(static_cast<int>(device_.device_type), device_.device_id,
static_cast<int>(relax_vm::AllocatorType::kPooled),
static_cast<int>(kDLCPU), 0,
static_cast<int>(relax_vm::AllocatorType::kPooled));

embed_func_ = vm_->GetFunction("embed");

// Step 2. Load params in nd-array cache.
const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load");
ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load";
(*fload_cache)(model_path, static_cast<int32_t>(device_.device_type), device_.device_id);

const PackedFunc* fload_params =
tvm::runtime::Registry::Get("vm.builtin.param_array_from_cache");
ICHECK(fload_params) << "Cannot find env function vm.builtin.param_array_from_cache";
params_ = (*fload_params)("param", -1);

// after we get params, it is safe to simply clear the cached version
// as these params are referenced by params_
const PackedFunc* fclear_ndarray_cache =
tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.clear");
ICHECK(fclear_ndarray_cache) << "Cannot find env function vm.builtin.ndarray_cache.clear";
(*fclear_ndarray_cache)();

this->Reset();
}

void Reset() { this->ResetRuntimeStats(); }

/*! \brief reset the runtime stats. */
void ResetRuntimeStats() { this->embed_total_time = 0; }

/*!
* \brief Given the input image, generate the embedding of the image.
* \param image The input image in type DLTensor*.
* \return The embedding of the input image.
*/
NDArray EmbedStep(NDArray image) {
CHECK(embed_func_.defined());
auto tstart = std::chrono::high_resolution_clock::now();

NDArray embedding = embed_func_(image, params_);

auto tend = std::chrono::high_resolution_clock::now();
this->embed_total_time += static_cast<double>((tend - tstart).count()) / 1e9;

return embedding;
}

/*!
* \return Text describing runtime stats.
*/
std::string RuntimeStatsText() {
std::ostringstream os;
os << "image embed: " << std::setprecision(1) << std::fixed << this->embed_total_time << " s";
return os.str();
}

//----------------------------
// Statistics
//----------------------------
double embed_total_time = 0;
//----------------------------
// TVM related states
//----------------------------
// runtime device
Device device_;
// The vm module
Module vm_;
// embedding function
PackedFunc embed_func_;
// local params
Array<NDArray> params_;
};

/*!
* \brief An image module implementation that exposes
* the functions as tvm::runtime::Module.
*
* We do it so that the module is accessible to any image module in LLM
* that tvm runtime can access.
*/
class LLMImageModule : public ModuleNode {
public:
// overrides
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "reload") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
image_mod_ = nullptr;
// we do not call ClearGlobalMemoryManager() here, please make sure to call reload image
// model after reload LLM, since ClearGlobalMemoryManager() will be called there
image_mod_ = std::make_unique<LLMImage>(LLMImage(device_));
ICHECK_EQ(args.size(), 2);
image_mod_->Reload(args[0], args[1]);
});
} else if (name == "unload") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
// we do not call ClearGlobalMemoryManager() here, please make sure to call unload image
// model before unload LLM, since ClearGlobalMemoryManager() will be called there
image_mod_ = nullptr;
});
} else if (name == "embed") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 1);
*rv = GetImageModule()->EmbedStep(args[0]);
});
} else if (name == "reset") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
GetImageModule()->Reset();
});
} else if (name == "runtime_stats_text") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
*rv = GetImageModule()->RuntimeStatsText();
});
} else if (name == "reset_runtime_stats") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
GetImageModule()->ResetRuntimeStats();
});
} else {
return PackedFunc(nullptr);
}
}

void Init(DLDevice device) { device_ = device; }

LLMImage* GetImageModule() {
ICHECK(image_mod_ != nullptr) << "Image embedding module is not initialized via reload";
return image_mod_.get();
}

const char* type_key() const final { return "mlc.image_embed"; }

private:
std::unique_ptr<LLMImage> image_mod_ = nullptr;
DLDevice device_;
};

tvm::runtime::Module CreateImageModule(DLDevice device) {
ObjectPtr<LLMImageModule> n = make_object<LLMImageModule>();
n->Init(device);
return Module(n);
}

// register as a system function that can be queried
TVM_REGISTER_GLOBAL("mlc.llm_image_module_create")
.set_body_typed([](int device_type, int device_id) {
return CreateImageModule(DLDevice{static_cast<DLDeviceType>(device_type), device_id});
});

} // namespace llm
} // namespace mlc
28 changes: 28 additions & 0 deletions cpp/image_embed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*!
* Copyright (c) 2023 by Contributors
* \file image_embed.h
* \brief Implementation of image embedding pipeline.
*/
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/module.h>

#ifndef MLC_LLM_DLL
#ifdef _WIN32
#ifdef MLC_LLM_EXPORTS
#define MLC_LLM_DLL __declspec(dllexport)
#else
#define MLC_LLM_DLL __declspec(dllimport)
#endif
#else
#define MLC_LLM_DLL __attribute__((visibility("default")))
#endif
#endif

namespace mlc {
namespace llm {

// explicit export via TVM_DLL
MLC_LLM_DLL tvm::runtime::Module CreateImageModule(DLDevice device);

} // namespace llm
} // namespace mlc
6 changes: 6 additions & 0 deletions ios/MLCChat.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
1453A4D12A1354B9001B909F /* StartState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* StartState.swift */; };
1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CD2A1354B9001B909F /* ModelConfig.swift */; };
1453A4D32A1354B9001B909F /* ModelState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CE2A1354B9001B909F /* ModelState.swift */; };
A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */ = {isa = PBXBuildFile; fileRef = A773CC642A5DC98200467BFE /* ImageProcessing.swift */; };
C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; };
C09834192A16F4E000A05B51 /* app-config.json in CopyFiles */ = {isa = PBXBuildFile; fileRef = C09834182A16F4CB00A05B51 /* app-config.json */; };
C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; };
Expand Down Expand Up @@ -53,6 +54,7 @@
1453A4CC2A1354B9001B909F /* StartState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StartState.swift; sourceTree = "<group>"; };
1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = "<group>"; };
1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelState.swift; sourceTree = "<group>"; };
A773CC642A5DC98200467BFE /* ImageProcessing.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageProcessing.swift; sourceTree = "<group>"; };
C06A74E029F99C9F00BC4BE6 /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; path = dist; sourceTree = "<group>"; };
C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = "<group>"; };
C09834182A16F4CB00A05B51 /* app-config.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = "app-config.json"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -111,6 +113,7 @@
C0D643C029F99B07004DDAA4 /* ChatState.swift */,
C0D643C229F99B07004DDAA4 /* ChatView.swift */,
C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */,
A773CC642A5DC98200467BFE /* ImageProcessing.swift */,
C0D643B629F99A80004DDAA4 /* Assets.xcassets */,
C0D643B829F99A80004DDAA4 /* Preview Content */,
);
Expand Down Expand Up @@ -216,6 +219,7 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */,
1453A4D12A1354B9001B909F /* StartState.swift in Sources */,
C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */,
C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */,
Expand Down Expand Up @@ -363,6 +367,7 @@
"HEADER_SEARCH_PATHS[arch=*]" = "";
INFOPLIST_FILE = MLCChat/Info.plist;
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.education";
INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly.";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
Expand Down Expand Up @@ -416,6 +421,7 @@
"HEADER_SEARCH_PATHS[arch=*]" = "";
INFOPLIST_FILE = MLCChat/Info.plist;
INFOPLIST_KEY_LSApplicationCategoryType = "public.app-category.education";
INFOPLIST_KEY_NSCameraUsageDescription = "This app requires usage of camera to function properly.";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
Expand Down
Loading

0 comments on commit 6cf8d4f

Please sign in to comment.