Skip to content

Commit

Permalink
Combined ChatWaifuCN.py and ChatWaifuJP.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kevl0317 committed Dec 21, 2022
1 parent d515995 commit 19eb30e
Showing 1 changed file with 217 additions and 0 deletions.
217 changes: 217 additions & 0 deletions ChatWaifu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from scipy.io.wavfile import write
from mel_processing import spectrogram_torch
from text import text_to_sequence, _clean_text
from models import SynthesizerTrn
import utils
import commons
import sys
import re
from torch import no_grad, LongTensor
import logging
from winsound import PlaySound

chinese_model_path = ".\model\CN\model.pth"
chinese_config_path = ".\model\CN\config.json"
japanese_model_path = ".\model\H_excluded.pth"
japanese_config_path = ".\model\config.json"

####################################
#CHATGPT INITIALIZE
from pyChatGPT import ChatGPT
import json

modelmessage = """ID Model
0 Chinese
1 Japanese
"""

idmessage = """ID Speaker
0 綾地寧々
1 在原七海
2 小茸
3 唐乐吟
"""
speakerID = 0

def get_input():
# prompt for input
print("You:")
user_input = input()
return user_input

def get_input_jp():
# prompt for input
print("You:")
user_input = input() +" 使用日本语"
return user_input

def get_token():
token = input("Copy your token from ChatGPT and press Enter \n")
return token


################################################


logging.getLogger('numba').setLevel(logging.WARNING)


def ex_print(text, escape=False):
if escape:
print(text.encode('unicode_escape').decode())
else:
print(text)


def get_text(text, hps, cleaned=False):
if cleaned:
text_norm = text_to_sequence(text, hps.symbols, [])
else:
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm


def ask_if_continue():
while True:
answer = input('Continue? (y/n): ')
if answer == 'y':
break
elif answer == 'n':
sys.exit(0)


def print_speakers(speakers, escape=False):
if len(speakers) > 100:
return
print('ID\tSpeaker')
for id, name in enumerate(speakers):
ex_print(str(id) + '\t' + name, escape)


def get_speaker_id(message):
speaker_id = input(message)
try:
speaker_id = int(speaker_id)
except:
print(str(speaker_id) + ' is not a valid ID!')
sys.exit(1)
return speaker_id


def get_label_value(text, label, default, warning_name='value'):
value = re.search(rf'\[{label}=(.+?)\]', text)
if value:
try:
text = re.sub(rf'\[{label}=(.+?)\]', '', text, 1)
value = float(value.group(1))
except:
print(f'Invalid {warning_name}!')
sys.exit(1)
else:
value = default
return value, text


def get_label(text, label):
if f'[{label}]' in text:
return True, text.replace(f'[{label}]', '')
else:
return False, text



def generateSound(inputString, id, model_id):
if '--escape' in sys.argv:
escape = True
else:
escape = False

#model = input('0: Chinese')
#config = input('Path of a config file: ')
if model_id == 0:
model = chinese_model_path
config = chinese_config_path
elif model_id == 1:
model = japanese_model_path
config = japanese_config_path


hps_ms = utils.get_hparams_from_file(config)
n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0
n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0
emotion_embedding = hps_ms.data.emotion_embedding if 'emotion_embedding' in hps_ms.data.keys() else False

net_g_ms = SynthesizerTrn(
n_symbols,
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=n_speakers,
emotion_embedding=emotion_embedding,
**hps_ms.model)
_ = net_g_ms.eval()
utils.load_checkpoint(model, net_g_ms)

if n_symbols != 0:
if not emotion_embedding:
#while True:
if(1 == 1):
choice = 't'
if choice == 't':
text = inputString
if text == '[ADVANCED]':
text = "我不会说"

length_scale, text = get_label_value(
text, 'LENGTH', 1, 'length scale')
noise_scale, text = get_label_value(
text, 'NOISE', 0.667, 'noise scale')
noise_scale_w, text = get_label_value(
text, 'NOISEW', 0.8, 'deviation of noise')
cleaned, text = get_label(text, 'CLEANED')

stn_tst = get_text(text, hps_ms, cleaned=cleaned)

speaker_id = id
out_path = "output.wav"

with no_grad():
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = LongTensor([stn_tst.size(0)])
sid = LongTensor([speaker_id])
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale,
noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()

write(out_path, hps_ms.data.sampling_rate, audio)
print('Successfully saved!')

if __name__ == "__main__":
session_token = get_token()
api = ChatGPT(session_token)
print(modelmessage)
model_id = int(input('选择回复语言: '))
print("\n" + idmessage)
id = get_speaker_id('选择角色: ')
print()
while True:

if model_id == 0:
resp = api.send_message(get_input())
if(resp == "quit()"):
break
answer = resp["message"].replace('\n','')
print("ChatGPT:")
print(answer)
generateSound("[ZH]"+answer+"[ZH]", id, model_id)
PlaySound(r'.\output.wav', flags=1)
elif model_id == 1:
resp = api.send_message(get_input_jp())
if(resp == "quit()"):
break
answer = resp["message"].replace('\n','')
print("ChatGPT:")
print(answer)
generateSound(answer, id, model_id)
PlaySound(r'.\output.wav', flags=1)

0 comments on commit 19eb30e

Please sign in to comment.