forked from cjyaddone/ChatWaifu
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Combined ChatWaifuCN.py and ChatWaifuJP.py
- Loading branch information
Showing
1 changed file
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
from scipy.io.wavfile import write | ||
from mel_processing import spectrogram_torch | ||
from text import text_to_sequence, _clean_text | ||
from models import SynthesizerTrn | ||
import utils | ||
import commons | ||
import sys | ||
import re | ||
from torch import no_grad, LongTensor | ||
import logging | ||
from winsound import PlaySound | ||
|
||
chinese_model_path = ".\model\CN\model.pth" | ||
chinese_config_path = ".\model\CN\config.json" | ||
japanese_model_path = ".\model\H_excluded.pth" | ||
japanese_config_path = ".\model\config.json" | ||
|
||
#################################### | ||
#CHATGPT INITIALIZE | ||
from pyChatGPT import ChatGPT | ||
import json | ||
|
||
modelmessage = """ID Model | ||
0 Chinese | ||
1 Japanese | ||
""" | ||
|
||
idmessage = """ID Speaker | ||
0 綾地寧々 | ||
1 在原七海 | ||
2 小茸 | ||
3 唐乐吟 | ||
""" | ||
speakerID = 0 | ||
|
||
def get_input(): | ||
# prompt for input | ||
print("You:") | ||
user_input = input() | ||
return user_input | ||
|
||
def get_input_jp(): | ||
# prompt for input | ||
print("You:") | ||
user_input = input() +" 使用日本语" | ||
return user_input | ||
|
||
def get_token(): | ||
token = input("Copy your token from ChatGPT and press Enter \n") | ||
return token | ||
|
||
|
||
################################################ | ||
|
||
|
||
logging.getLogger('numba').setLevel(logging.WARNING) | ||
|
||
|
||
def ex_print(text, escape=False): | ||
if escape: | ||
print(text.encode('unicode_escape').decode()) | ||
else: | ||
print(text) | ||
|
||
|
||
def get_text(text, hps, cleaned=False): | ||
if cleaned: | ||
text_norm = text_to_sequence(text, hps.symbols, []) | ||
else: | ||
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) | ||
if hps.data.add_blank: | ||
text_norm = commons.intersperse(text_norm, 0) | ||
text_norm = LongTensor(text_norm) | ||
return text_norm | ||
|
||
|
||
def ask_if_continue(): | ||
while True: | ||
answer = input('Continue? (y/n): ') | ||
if answer == 'y': | ||
break | ||
elif answer == 'n': | ||
sys.exit(0) | ||
|
||
|
||
def print_speakers(speakers, escape=False): | ||
if len(speakers) > 100: | ||
return | ||
print('ID\tSpeaker') | ||
for id, name in enumerate(speakers): | ||
ex_print(str(id) + '\t' + name, escape) | ||
|
||
|
||
def get_speaker_id(message): | ||
speaker_id = input(message) | ||
try: | ||
speaker_id = int(speaker_id) | ||
except: | ||
print(str(speaker_id) + ' is not a valid ID!') | ||
sys.exit(1) | ||
return speaker_id | ||
|
||
|
||
def get_label_value(text, label, default, warning_name='value'): | ||
value = re.search(rf'\[{label}=(.+?)\]', text) | ||
if value: | ||
try: | ||
text = re.sub(rf'\[{label}=(.+?)\]', '', text, 1) | ||
value = float(value.group(1)) | ||
except: | ||
print(f'Invalid {warning_name}!') | ||
sys.exit(1) | ||
else: | ||
value = default | ||
return value, text | ||
|
||
|
||
def get_label(text, label): | ||
if f'[{label}]' in text: | ||
return True, text.replace(f'[{label}]', '') | ||
else: | ||
return False, text | ||
|
||
|
||
|
||
def generateSound(inputString, id, model_id): | ||
if '--escape' in sys.argv: | ||
escape = True | ||
else: | ||
escape = False | ||
|
||
#model = input('0: Chinese') | ||
#config = input('Path of a config file: ') | ||
if model_id == 0: | ||
model = chinese_model_path | ||
config = chinese_config_path | ||
elif model_id == 1: | ||
model = japanese_model_path | ||
config = japanese_config_path | ||
|
||
|
||
hps_ms = utils.get_hparams_from_file(config) | ||
n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0 | ||
n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0 | ||
emotion_embedding = hps_ms.data.emotion_embedding if 'emotion_embedding' in hps_ms.data.keys() else False | ||
|
||
net_g_ms = SynthesizerTrn( | ||
n_symbols, | ||
hps_ms.data.filter_length // 2 + 1, | ||
hps_ms.train.segment_size // hps_ms.data.hop_length, | ||
n_speakers=n_speakers, | ||
emotion_embedding=emotion_embedding, | ||
**hps_ms.model) | ||
_ = net_g_ms.eval() | ||
utils.load_checkpoint(model, net_g_ms) | ||
|
||
if n_symbols != 0: | ||
if not emotion_embedding: | ||
#while True: | ||
if(1 == 1): | ||
choice = 't' | ||
if choice == 't': | ||
text = inputString | ||
if text == '[ADVANCED]': | ||
text = "我不会说" | ||
|
||
length_scale, text = get_label_value( | ||
text, 'LENGTH', 1, 'length scale') | ||
noise_scale, text = get_label_value( | ||
text, 'NOISE', 0.667, 'noise scale') | ||
noise_scale_w, text = get_label_value( | ||
text, 'NOISEW', 0.8, 'deviation of noise') | ||
cleaned, text = get_label(text, 'CLEANED') | ||
|
||
stn_tst = get_text(text, hps_ms, cleaned=cleaned) | ||
|
||
speaker_id = id | ||
out_path = "output.wav" | ||
|
||
with no_grad(): | ||
x_tst = stn_tst.unsqueeze(0) | ||
x_tst_lengths = LongTensor([stn_tst.size(0)]) | ||
sid = LongTensor([speaker_id]) | ||
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, | ||
noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0, 0].data.cpu().float().numpy() | ||
|
||
write(out_path, hps_ms.data.sampling_rate, audio) | ||
print('Successfully saved!') | ||
|
||
if __name__ == "__main__": | ||
session_token = get_token() | ||
api = ChatGPT(session_token) | ||
print(modelmessage) | ||
model_id = int(input('选择回复语言: ')) | ||
print("\n" + idmessage) | ||
id = get_speaker_id('选择角色: ') | ||
print() | ||
while True: | ||
|
||
if model_id == 0: | ||
resp = api.send_message(get_input()) | ||
if(resp == "quit()"): | ||
break | ||
answer = resp["message"].replace('\n','') | ||
print("ChatGPT:") | ||
print(answer) | ||
generateSound("[ZH]"+answer+"[ZH]", id, model_id) | ||
PlaySound(r'.\output.wav', flags=1) | ||
elif model_id == 1: | ||
resp = api.send_message(get_input_jp()) | ||
if(resp == "quit()"): | ||
break | ||
answer = resp["message"].replace('\n','') | ||
print("ChatGPT:") | ||
print(answer) | ||
generateSound(answer, id, model_id) | ||
PlaySound(r'.\output.wav', flags=1) |