Skip to content

Commit

Permalink
hugging_dialog
Browse files Browse the repository at this point in the history
  • Loading branch information
weitsung50110 committed Nov 7, 2024
1 parent eb38e93 commit b17f579
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
31 changes: 31 additions & 0 deletions hugging_chinese_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# from dialogbot import GPTBot
# model = GPTBot("~/trans_project/hugging_convert/gpt2-dialogbot-base-chinese")
# r = model.answer("今天你的病好点了吗?")
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "hugging_convert/Llama3-8B-Chinese-Chat"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype="auto", device_map="auto"
)

messages = [
{"role": "user", "content": "7年前,妈妈年龄是儿子的6倍,儿子今年12岁,妈妈今年多少岁?"},
]

input_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

outputs = model.generate(
input_ids,
max_new_tokens=8192,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

# https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat
27 changes: 27 additions & 0 deletions hugging_taiwan_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "Llama-3-Taiwan-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype="auto", device_map="auto"
)

messages = [
{"role": "system", "content": "你是一位來自可愛國的公主, 請用公主的語氣回答以下問題:"},
{"role": "user", "content": "請問你是誰?"},
]

input_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

outputs = model.generate(
input_ids,
max_new_tokens=8192,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

0 comments on commit b17f579

Please sign in to comment.