Skip to content

Commit

Permalink
Update transcribe.py
Browse files Browse the repository at this point in the history
Add some cacheing stuff.
  • Loading branch information
BBC-Esq authored Oct 4, 2024
1 parent 1c138df commit 3184507
Showing 1 changed file with 64 additions and 43 deletions.
107 changes: 64 additions & 43 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,11 @@ def __init__(
download_root: Optional[str] = None,
local_files_only: bool = False,
files: dict = None,
max_encoder_cache_items: int = 100,
**model_kwargs,
):
"""Initializes the Whisper model.
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
Expand Down Expand Up @@ -623,9 +624,10 @@ def __init__(
files: Load model files from the memory. This argument is a dictionary mapping file names
to file contents as file-like or bytes objects. If this is set, model_path acts as an
identifier for this model.
max_encoder_cache_items: Maximum number of encoder outputs to cache.
"""
self.logger = get_logger()

tokenizer_bytes, preprocessor_bytes = None, None
if files:
model_path = model_size_or_path
Expand All @@ -652,7 +654,7 @@ def __init__(
files=files,
**model_kwargs,
)

tokenizer_file = os.path.join(model_path, "tokenizer.json")
if tokenizer_bytes:
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
Expand All @@ -678,6 +680,8 @@ def __init__(
)
self.time_precision = 0.02
self.max_length = 448
self.encoder_cache = {}
self.max_encoder_cache_items = max_encoder_cache_items

@property
def supported_languages(self) -> List[str]:
Expand Down Expand Up @@ -1115,7 +1119,7 @@ def generate_segments(
) -> Iterable[Segment]:
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
content_duration = float(content_frames * self.feature_extractor.time_per_frame)

if isinstance(options.clip_timestamps, str):
options = options._replace(
clip_timestamps=[
Expand All @@ -1137,24 +1141,29 @@ def generate_segments(
seek_clips: List[Tuple[int, int]] = list(
zip(seek_points[::2], seek_points[1::2])
)

punctuation = "\"'¿([{-\"'.。,,!!??::)]}、"

punctuation = "\"'"¿([{-\"'.。,,!!??::")]}、"
idx = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
all_tokens = []
prompt_reset_since = 0

if options.initial_prompt is not None:
if isinstance(options.initial_prompt, str):
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
all_tokens.extend(initial_prompt_tokens)
else:
all_tokens.extend(options.initial_prompt)

last_speech_timestamp = 0.0

# Cache the initial encoder output if not provided
if encoder_output is None:
encoder_output = self.encode(features)

# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
Expand Down Expand Up @@ -1183,17 +1192,18 @@ def generate_segments(
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"Processing segment at %s", format_timestamp(time_offset)
)

previous_tokens = all_tokens[prompt_reset_since:]

if encoder_output is None:

# Use the cached encoder output or encode the segment if necessary
if seek > 0:
encoder_output = self.encode(segment)

# Perform language detection at every segment to update task based on output language,
# if the language is english, task is transcribe,
# else the task is translate to english (default)
Expand All @@ -1206,7 +1216,7 @@ def generate_segments(
task = "translate"
else:
task = "transcribe"

# Update tokenizer based on task and language
tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>")
tokenizer.language = tokenizer.tokenizer.token_to_id(language_token)
Expand All @@ -1219,35 +1229,32 @@ def generate_segments(
prefix=options.prefix if seek == 0 else None,
hotwords=options.hotwords,
)

if seek > 0 or encoder_output is None:
encoder_output = self.encode(segment)


(
result,
avg_logprob,
temperature,
compression_ratio,
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)

if options.no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > options.no_speech_threshold

if (
options.log_prob_threshold is not None
and avg_logprob > options.log_prob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False

if should_skip:
self.logger.debug(
"No speech threshold is met (%f > %f)",
result.no_speech_prob,
options.no_speech_threshold,
)

# Skip if the logprob is very low (below the threshold value),
# despite no_speech_prob being low (ex: Too ambiguous outputs)
if options.log_prob_low_threshold:
Expand All @@ -1258,16 +1265,16 @@ def generate_segments(
avg_logprob,
options.log_prob_low_threshold,
)

if should_skip:
# fast-forward to the next segment boundary
seek += segment_size
continue

tokens = result.sequences_ids[0]

previous_seek = seek

# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
Expand All @@ -1280,18 +1287,18 @@ def word_anomaly_score(word: dict) -> float:
if duration > 2.0:
score += duration - 2.0
return score

def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)

def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)

(
current_segments,
seek,
Expand All @@ -1304,7 +1311,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
segment_duration=segment_duration,
seek=seek,
)

if options.word_timestamps:
self.add_word_timestamps(
[current_segments],
Expand All @@ -1319,19 +1326,19 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * self.frames_per_second)

# skip silence before possible hallucinations
if options.hallucination_silence_threshold is not None:
threshold = options.hallucination_silence_threshold

# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * self.frames_per_second)
continue

# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
Expand Down Expand Up @@ -1367,20 +1374,20 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
current_segments[si:] = []
break
hal_last_end = segment["end"]

last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
for segment in current_segments:
tokens = segment["tokens"]
text = tokenizer.decode(tokens)

if segment["start"] == segment["end"] or not text.strip():
continue

all_tokens.extend(tokens)
idx += 1

yield Segment(
id=idx,
seek=seek,
Expand All @@ -1398,7 +1405,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
else None
),
)

if (
not options.condition_on_previous_text
or temperature > options.prompt_reset_on_temperature
Expand All @@ -1409,19 +1416,33 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
temperature,
options.prompt_reset_on_temperature,
)

prompt_reset_since = len(all_tokens)

def encode(self, features: torch.Tensor) -> ctranslate2.StorageView:
# Generate a unique key for the features
feature_key = hash(features.cpu().numpy().tobytes())

if feature_key in self.encoder_cache:
return self.encoder_cache[feature_key]

# When the model is running on multiple GPUs, the encoder output should be moved
# to the CPU since we don't know which GPU will handle the next job.
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1

if features.ndim == 2:
features = features.unsqueeze(0)
features = get_ctranslate2_storage(features)

return self.model.encode(features, to_cpu=to_cpu)

encoder_output = self.model.encode(features, to_cpu=to_cpu)

# Cache the result
if len(self.encoder_cache) >= self.max_encoder_cache_items:
# Remove the oldest item if cache is full
self.encoder_cache.pop(next(iter(self.encoder_cache)))
self.encoder_cache[feature_key] = encoder_output

return encoder_output

def generate_with_fallback(
self,
Expand Down

0 comments on commit 3184507

Please sign in to comment.