-
Notifications
You must be signed in to change notification settings - Fork 22
/
tool.py
66 lines (56 loc) · 2.95 KB
/
tool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import sys
import argparse
from core.seq2seq.chatglm import ChatGLMSeq2Seq
from core.seq2seq.llama import LLAMASeq2Seq
from core.seq2seq.bloom import BLoomSeq2Seq
from core.seq2seq.qwen import QwenSeq2Seq
from core.seq2seq.baichuan import BaichuanSeq2Seq
from core.seq2seq.mixtral import MixtralSeq2Seq
from core.seq2seq.phi import PhiSeq2Seq
from core.seq2seq.gemma import GemmaSeq2Seq
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Inference for all.')
subparsers = parser.add_subparsers(dest='tool', help='Tools or Modules')
parser_combine = subparsers.add_parser('combine', help='Combine Base model weight and Adapter weight')
parser_combine.add_argument('--model_type', default="llama",
choices=['llama', 'llama2', 'llama3', 'chatglm', 'chatglm2', 'bloom', 'qwen', "baichuan",
"mixtral", "phi", "phi3", "gemma"])
parser_combine.add_argument('--model_path', default=None, type=str)
parser_combine.add_argument('--adapter_weights', default=None, type=str, help="The DIR of adapter weights")
parser_combine.add_argument('--output_dir', default="combined_model/", type=str, help="The DIR to save the model")
parser_combine.add_argument('--max_shard_size', default="5GB", type=str, help="Max size of each of the combined model weight, like 1GB,5GB,etc.")
args, _ = parser.parse_known_args()
if args.tool == 'combine':
if not args.model_path or not args.adapter_weights:
print("[Error] Miss param: model_path or adapter_weights")
sys.exit(-1)
if args.model_type == "chatglm" or args.model_type == "chatglm2":
llm = ChatGLMSeq2Seq()
elif args.model_type == "llama" or args.model_type == "llama2" or args.model_type == "llama3":
llm = LLAMASeq2Seq()
elif args.model_type == "bloom":
llm = BLoomSeq2Seq()
elif args.model_type == "qwen":
llm = QwenSeq2Seq()
elif args.model_type == "baichuan":
llm = BaichuanSeq2Seq()
elif args.model_type == "mixtral":
llm = MixtralSeq2Seq()
elif args.model_type == "phi" or args.model_type == "phi3":
llm = PhiSeq2Seq()
elif args.model_type == "gemma":
llm = GemmaSeq2Seq()
else:
print("model_type should be llama/llama2/llama3/bloom/chatglm/chatglm2/qwen/baichuan/mixtral/phi/phi3/gemma")
sys.exit(-1)
llm.base_model = args.model_path
llm.model_type = args.model_type
llm.adapter_weights = args.adapter_weights
llm.output_dir = args.output_dir
llm.max_shard_size = args.max_shard_size
if not os.path.exists(llm.output_dir):
os.makedirs(llm.output_dir)
print("Warning: Directory {} Not Found, create automatically".format(llm.output_dir))
llm.combine()
print("[Success] Combine Adapter weight, new model is in {}".format(llm.output_dir))