Skip to content

Commit

Permalink
Merge pull request collabora#223 from peldszus/single-model-mode
Browse files Browse the repository at this point in the history
Single model mode
  • Loading branch information
makaveli10 authored Jun 7, 2024
2 parents ee13251 + 1407731 commit 5b9bc2b
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 20 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ python3 run_server.py --port 9090 \
--omp_num_threads 4
```

#### Single model mode
By default, when running the server without specifying a model, the server will instantiate a new whisper model for every client connection. This has the advantage, that the server can use different model sizes, based on the client's requested model size. On the other hand, it also means you have to wait for the model to be loaded upon client connection and you will have increased (V)RAM usage.

When serving a custom TensorRT model using the `-trt` or a custom faster_whisper model using the `-fw` option, the server will instead only instantiate the custom model once and then reuse it for all client connections.

If you don't want this, set `--no_single_model`.


### Running the Client
- Initializing the client with below parameters:
- `lang`: Language of the input audio, applicable only if using a multilingual model.
Expand Down
6 changes: 5 additions & 1 deletion run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
type=int,
default=1,
help="Number of threads to use for OpenMP")
parser.add_argument('--no_single_model', '-nsm',
action='store_true',
help='Set this if every connection should instantiate its own model. Only relevant for custom model, passed using -trt or -fw.')
args = parser.parse_args()

if args.backend == "tensorrt":
Expand All @@ -42,5 +45,6 @@
backend=args.backend,
faster_whisper_custom_model_path=args.faster_whisper_custom_model_path,
whisper_tensorrt_path=args.trt_model_path,
trt_multilingual=args.trt_multilingual
trt_multilingual=args.trt_multilingual,
single_model=not args.no_single_model,
)
100 changes: 81 additions & 19 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self):
self.client_manager = ClientManager()
self.no_voice_activity_chunks = 0
self.use_vad = True
self.single_model = False

def initialize_client(
self, websocket, options, faster_whisper_custom_model_path,
Expand All @@ -141,7 +142,8 @@ def initialize_client(
language=options["language"],
task=options["task"],
client_uid=options["uid"],
model=whisper_tensorrt_path
model=whisper_tensorrt_path,
single_model=self.single_model,
)
logging.info("Running TensorRT backend.")
except Exception as e:
Expand All @@ -168,6 +170,7 @@ def initialize_client(
initial_prompt=options.get("initial_prompt"),
vad_parameters=options.get("vad_parameters"),
use_vad=self.use_vad,
single_model=self.single_model,
)
logging.info("Running faster_whisper backend.")

Expand Down Expand Up @@ -288,14 +291,26 @@ def run(self,
backend="tensorrt",
faster_whisper_custom_model_path=None,
whisper_tensorrt_path=None,
trt_multilingual=False):
trt_multilingual=False,
single_model=False):
"""
Run the transcription server.
Args:
host (str): The host address to bind the server.
port (int): The port number to bind the server.
"""
if faster_whisper_custom_model_path is not None and not os.path.exists(faster_whisper_custom_model_path):
raise ValueError(f"Custom faster_whisper model '{faster_whisper_custom_model_path}' is not a valid path.")
if whisper_tensorrt_path is not None and not os.path.exists(whisper_tensorrt_path):
raise ValueError(f"TensorRT model '{whisper_tensorrt_path}' is not a valid path.")
if single_model:
if faster_whisper_custom_model_path or whisper_tensorrt_path:
logging.info("Custom model option was provided. Switching to single model mode.")
self.single_model = True
# TODO: load model initially
else:
logging.info("Single model mode currently only works with custom models.")
with serve(
functools.partial(
self.recv_audio,
Expand Down Expand Up @@ -532,7 +547,11 @@ def cleanup(self):


class ServeClientTensorRT(ServeClientBase):
def __init__(self, websocket, task="transcribe", multilingual=False, language=None, client_uid=None, model=None):

SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()

def __init__(self, websocket, task="transcribe", multilingual=False, language=None, client_uid=None, model=None, single_model=False):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
Expand All @@ -546,21 +565,22 @@ def __init__(self, websocket, task="transcribe", multilingual=False, language=No
multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False.
language (str, optional): The language for transcription. Defaults to None.
client_uid (str, optional): A unique identifier for the client. Defaults to None.
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
"""
super().__init__(client_uid, websocket)
self.language = language if multilingual else "en"
self.task = task
self.eos = False
self.transcriber = WhisperTRTLLM(
model,
assets_dir="assets",
device="cuda",
is_multilingual=multilingual,
language=self.language,
task=self.task
)
self.warmup()

if single_model:
if ServeClientTensorRT.SINGLE_MODEL is None:
self.create_model(model, multilingual)
ServeClientTensorRT.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientTensorRT.SINGLE_MODEL
else:
self.create_model(model, multilingual)

# threading
self.trans_thread = threading.Thread(target=self.speech_to_text)
Expand All @@ -572,6 +592,21 @@ def __init__(self, websocket, task="transcribe", multilingual=False, language=No
"backend": "tensorrt"
}))

def create_model(self, model, multilingual, warmup=True):
"""
Instantiates a new model, sets it as the transcriber and does warmup if desired.
"""
self.transcriber = WhisperTRTLLM(
model,
assets_dir="assets",
device="cuda",
is_multilingual=multilingual,
language=self.language,
task=self.task
)
if warmup:
self.warmup()

def warmup(self, warmup_steps=10):
"""
Warmup TensorRT since first few inferences are slow.
Expand Down Expand Up @@ -616,12 +651,16 @@ def transcribe_audio(self, input_bytes):
Args:
input_bytes (np.array): The audio chunk to transcribe.
"""
if ServeClientTensorRT.SINGLE_MODEL:
ServeClientTensorRT.SINGLE_MODEL_LOCK.acquire()
logging.info(f"[WhisperTensorRT:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
mel, duration = self.transcriber.log_mel_spectrogram(input_bytes)
last_segment = self.transcriber.transcribe(
mel,
text_prefix=f"<|startoftranscript|><|{self.language}|><|{self.task}|><|notimestamps|>"
)
if ServeClientTensorRT.SINGLE_MODEL:
ServeClientTensorRT.SINGLE_MODEL_LOCK.release()
if last_segment:
self.handle_transcription_output(last_segment, duration)

Expand Down Expand Up @@ -681,8 +720,12 @@ def speech_to_text(self):


class ServeClientFasterWhisper(ServeClientBase):

SINGLE_MODEL = None
SINGLE_MODEL_LOCK = threading.Lock()

def __init__(self, websocket, task="transcribe", device=None, language=None, client_uid=None, model="small.en",
initial_prompt=None, vad_parameters=None, use_vad=True):
initial_prompt=None, vad_parameters=None, use_vad=True, single_model=False):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
Expand All @@ -697,6 +740,7 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli
client_uid (str, optional): A unique identifier for the client. Defaults to None.
model (str, optional): The whisper model size. Defaults to 'small.en'
initial_prompt (str, optional): Prompt for whisper inference. Defaults to None.
single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
"""
super().__init__(client_uid, websocket)
self.model_sizes = [
Expand All @@ -718,12 +762,15 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli
if self.model_size_or_path is None:
return

self.transcriber = WhisperModel(
self.model_size_or_path,
device=device,
compute_type="int8" if device == "cpu" else "float16",
local_files_only=False,
)
if single_model:
if ServeClientFasterWhisper.SINGLE_MODEL is None:
self.create_model(device)
ServeClientFasterWhisper.SINGLE_MODEL = self.transcriber
else:
self.transcriber = ServeClientFasterWhisper.SINGLE_MODEL
else:
self.create_model(device)

self.use_vad = use_vad

# threading
Expand All @@ -739,6 +786,17 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli
)
)

def create_model(self, device):
"""
Instantiates a new model, sets it as the transcriber.
"""
self.transcriber = WhisperModel(
self.model_size_or_path,
device=device,
compute_type="int8" if device == "cpu" else "float16",
local_files_only=False,
)

def check_valid_model(self, model_size):
"""
Check if it's a valid whisper model size.
Expand Down Expand Up @@ -794,13 +852,17 @@ def transcribe_audio(self, input_sample):
depends on the implementation of the `transcriber.transcribe` method but typically
includes the transcribed text.
"""
if ServeClientFasterWhisper.SINGLE_MODEL:
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.acquire()
result, info = self.transcriber.transcribe(
input_sample,
initial_prompt=self.initial_prompt,
language=self.language,
task=self.task,
vad_filter=self.use_vad,
vad_parameters=self.vad_parameters if self.use_vad else None)
if ServeClientFasterWhisper.SINGLE_MODEL:
ServeClientFasterWhisper.SINGLE_MODEL_LOCK.release()

if self.language is None and info is not None:
self.set_language(info)
Expand Down

0 comments on commit 5b9bc2b

Please sign in to comment.