Skip to content

Commit

Permalink
Merge pull request collabora#163 from makaveli10/upgrade_faster_whisper
Browse files Browse the repository at this point in the history
Upgrade faster whisper==1.0.1
  • Loading branch information
makaveli10 authored Mar 4, 2024
2 parents acd4902 + 8a06ba8 commit a17f404
Showing 3 changed files with 196 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04
ARG DEBIAN_FRONTEND=noninteractive

# Remove any third-party apt sources to avoid issues with expiring keys.
2 changes: 1 addition & 1 deletion requirements/server.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
faster-whisper==0.10.0
faster-whisper==1.0.1
torch
websockets
onnxruntime==1.16.0
223 changes: 194 additions & 29 deletions whisper_live/transcriber.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py

import itertools
import json
import logging
import os
import zlib
import json
from inspect import signature

from inspect import signature
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union

import ctranslate2
import numpy as np
import tokenizers

from faster_whisper.audio import decode_audio
from faster_whisper.audio import decode_audio, pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_logger
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
from faster_whisper.vad import (
SpeechTimestampsMap,
VadOptions,
@@ -68,6 +68,9 @@ class TranscriptionOptions(NamedTuple):
word_timestamps: bool
prepend_punctuations: str
append_punctuations: str
max_new_tokens: Optional[int]
clip_timestamps: Union[str, List[float]]
hallucination_silence_threshold: Optional[float]


class TranscriptionInfo(NamedTuple):
@@ -96,8 +99,8 @@ def __init__(
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a converted
model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub.
device: Device to use for computation ("cpu", "cuda", "auto").
@@ -215,6 +218,10 @@ def transcribe( # noqa:
append_punctuations: str = "\"'.。,,!!??::”)]}、",
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
@@ -266,6 +273,16 @@ def transcribe( # noqa:
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
process. The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected
Returns:
A tuple with:
@@ -318,7 +335,7 @@ def transcribe( # noqa:
if audio.shape[0] == 0:
return None, None

features = self.feature_extractor(audio)
features = self.feature_extractor(audio, chunk_length=chunk_length)

encoder_output = None
all_language_probs = None
@@ -384,6 +401,9 @@ def transcribe( # noqa:
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
max_new_tokens=max_new_tokens,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
@@ -403,16 +423,41 @@ def transcribe( # noqa:

return segments, info

def generate_segments( # noqa: C901
def generate_segments(
self,
features: np.ndarray,
tokenizer: Tokenizer,
options: TranscriptionOptions,
encoder_output: Optional[ctranslate2.StorageView] = None,
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
content_duration = float(content_frames * self.feature_extractor.time_per_frame)

if isinstance(options.clip_timestamps, str):
TranscriptionOptions.clip_timestamps = [
float(ts)
for ts in (
options.clip_timestamps.split(",")
if options.clip_timestamps
else []
)
]
seek_points: List[int] = [
round(ts * self.frames_per_second) for ts in options.clip_timestamps
]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2])
)

punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"

idx = 0
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
all_tokens = []
prompt_reset_since = 0

@@ -426,13 +471,34 @@ def generate_segments( # noqa:

last_speech_timestamp = 0.0
all_segments = []
while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek_clip_end > content_frames:
seek_clip_end = content_frames
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = seek * self.feature_extractor.time_per_frame
segment = features[:, seek:seek + self.feature_extractor.nb_max_frames]
window_end_time = float(
(seek + self.feature_extractor.nb_max_frames)
* self.feature_extractor.time_per_frame
)
segment_size = min(
self.feature_extractor.nb_max_frames, content_frames - seek
self.feature_extractor.nb_max_frames,
content_frames - seek,
seek_clip_end - seek,
)
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
@@ -484,10 +550,33 @@ def generate_segments( # noqa:
previous_seek = seek
current_segments = []

# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score

def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)

def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)

single_timestamp_ending = (
len(tokens) >= 2
and tokens[-2] < tokenizer.timestamp_begin
and tokens[-1] >= tokenizer.timestamp_begin
and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
)

consecutive_timestamps = [
@@ -570,18 +659,62 @@ def generate_segments( # noqa:
last_speech_timestamp=last_speech_timestamp,
)

word_end_timestamps = [
w["end"] for s in current_segments for w in s["words"]
]
if len(word_end_timestamps) > 0:
last_speech_timestamp = word_end_timestamps[-1]
if not single_timestamp_ending and len(word_end_timestamps) > 0:
seek_shift = round(
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
)

if seek_shift > 0:
seek = previous_seek + seek_shift
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * self.frames_per_second)

# skip silence before possible hallucinations
if options.hallucination_silence_threshold is not None:
threshold = options.hallucination_silence_threshold

# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * self.frames_per_second)
continue

# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* self.frames_per_second
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]

last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end

for segment in current_segments:
tokens = segment["tokens"]
@@ -608,7 +741,7 @@ def generate_segments( # noqa:
[Word(**word) for word in segment["words"]]
if options.word_timestamps
else None
),
),
))

if (
@@ -649,6 +782,21 @@ def generate_with_fallback(
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
if options.max_new_tokens is not None:
max_length = len(prompt) + options.max_new_tokens
else:
max_length = self.max_length

if max_length > self.max_length:
raise ValueError(
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
f"and `max_new_tokens` is: {max_length}. This exceeds the "
f"`max_length` of the Whisper model: {self.max_length}. "
"You should either reduce the length of your prompt, or "
"reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {self.max_length}."
)

for temperature in options.temperatures:
if temperature > 0:
@@ -670,7 +818,7 @@ def generate_with_fallback(
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
max_length=self.max_length,
max_length=max_length,
return_scores=True,
return_no_speech_prob=True,
suppress_blank=options.suppress_blank,
@@ -728,6 +876,8 @@ def generate_with_fallback(
if (
options.no_speech_threshold is not None
and result.no_speech_prob > options.no_speech_threshold
and options.log_prob_threshold is not None
and avg_logprob < options.log_prob_threshold
):
needs_fallback = False # silence

@@ -738,6 +888,13 @@ def generate_with_fallback(
decode_result = max(
below_cr_threshold_results or all_results, key=lambda x: x[1]
)
# to pass final temperature for prompt_reset_on_temperature
decode_result = (
decode_result[0],
decode_result[1],
temperature,
decode_result[3],
)

return decode_result

@@ -752,7 +909,7 @@ def get_prompt(

if previous_tokens:
prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1):])
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])

prompt.extend(tokenizer.sot_sequence)

@@ -794,6 +951,7 @@ def add_word_timestamps( # no
word_durations = np.array([word["end"] - word["start"] for word in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2

# hack: truncate long words at sentence boundaries.
@@ -915,6 +1073,13 @@ def find_alignment(
words, word_tokens = tokenizer.split_to_word_tokens(
text_tokens + [tokenizer.eot]
)
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
# array([0.])
# This results in crashes when we lookup jump_times with float, like
# IndexError: arrays used as indices must be of integer (or boolean) type
return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
if len(word_boundaries) <= 1:
return []

0 comments on commit a17f404

Please sign in to comment.