Skip to content

Commit

Permalink
gitignore update + Diarization error rate compute
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Sep 22, 2024
1 parent bff3387 commit 1e65fd9
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 25 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,9 @@ app/lib/env/prod_env.g.dart
/backend/scripts/research/*.md
/backend/scripts/research/*.json
/backend/scripts/research/*.csv
/backend/scripts/research/users
/backend/scripts/research/users
/backend/scripts/stt/_temp
/backend/scripts/stt/_temp2
/backend/scripts/stt/pretrained_models
/backend/scripts/stt/results
/backend/scripts/stt/diarization.json
32 changes: 29 additions & 3 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from modal import Image, App, asgi_app, Secret, Cron
from routers import workflow, chat, firmware, plugins, memories, transcribe, notifications, speech_profile, \
agents, facts, users, postprocessing, processing_memories, trends,sdcard

agents, facts, users, postprocessing, processing_memories, trends, sdcard
from utils.other.notifications import start_cron_job

if os.environ.get('SERVICE_ACCOUNT_JSON'):
Expand Down Expand Up @@ -54,7 +53,7 @@
memory=(512, 1024),
cpu=2,
allow_concurrent_inputs=10,
timeout=60 * 5,
timeout=60 * 19,
)
@asgi_app()
def api():
Expand All @@ -70,3 +69,30 @@ def api():
@modal_app.function(image=image, schedule=Cron('* * * * *'))
async def notifications_cronjob():
await start_cron_job()


@app.post('/webhook')
async def webhook(data: dict):
diarization = data['output']['diarization']
joined = []
for speaker in diarization:
if not joined:
joined.append(speaker)
else:
if speaker['speaker'] == joined[-1]['speaker']:
joined[-1]['end'] = speaker['end']
else:
joined.append(speaker)

print(data['jobId'], json.dumps(joined))
# openn scripts/stt/diarization.json, get jobId=memoryId, delete but get memoryId, and save memoryId=joined
with open('scripts/stt/diarization.json', 'r') as f:
diarization_data = json.loads(f.read())

memory_id = diarization_data.get(data['jobId'])
if memory_id:
diarization_data[memory_id] = joined
del diarization_data[data['jobId']]
with open('scripts/stt/diarization.json', 'w') as f:
json.dump(diarization_data, f, indent=2)
return 'ok'
219 changes: 198 additions & 21 deletions backend/scripts/stt/k_compare_transcripts_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict, List

import firebase_admin
import requests
from dotenv import load_dotenv
from pydub import AudioSegment
from tabulate import tabulate
Expand Down Expand Up @@ -63,7 +64,7 @@ def execute_groq(file_path: str):
transcription = client.audio.transcriptions.create(
file=(file_path, file.read()),
model="whisper-large-v3",
response_format="verbose_json",
response_format="text",
language="en",
temperature=0.0
)
Expand Down Expand Up @@ -370,13 +371,9 @@ def compute_wer():


def regex_fix(text: str):
# Define the regular expression
"""Fix some of the stored JSON in results/$id.json from the Groq API."""
pattern = r'(?<=transcription\(text=["\'])(.*?)(?=["\'],\s*task=)'

# Search for the pattern in the data
match = re.search(pattern, text)

# If a match is found, extract and print the text
if match:
extracted_text = match.group(0)
return extracted_text
Expand All @@ -385,20 +382,200 @@ def regex_fix(text: str):
return text


def pyannote_diarize(file_path: str):
memory_id = file_path.split('/')[-1].split('.')[0]
with open('diarization.json', 'r') as f:
results = json.loads(f.read())

if memory_id in results:
print('Already diarized', memory_id)
return

url = "https://api.pyannote.ai/v1/diarize"
headers = {"Authorization": f"Bearer {os.getenv('PYANNOTE_API_KEY')}"}
webhook = 'https://camel-lucky-reliably.ngrok-free.app/webhook'
signed_url = upload_postprocessing_audio(file_path)
data = {'webhook': webhook, 'url': signed_url}
response = requests.post(url, headers=headers, json=data)
print(memory_id, response.json()['jobId'])
# update diarization.json, and set jobId=memoryId
with open('diarization.json', 'r') as f:
diarization = json.loads(f.read())

diarization[response.json()['jobId']] = memory_id
with open('diarization.json', 'w') as f:
json.dump(diarization, f, indent=2)


def generate_diarizations():
uids = os.listdir('_temp2')
for uid in uids:
memories = os.listdir(f'_temp2/{uid}')
memories = [f'_temp2/{uid}/{memory}' for memory in memories]
for memory in memories:
memory_id = memory.split('/')[-1].split('.')[0]
if os.path.exists(f'results/{memory_id}.json'):
pyannote_diarize(memory)
else:
print('Skipping', memory_id)


from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.core import Annotation, Segment

der_metric = DiarizationErrorRate()


def compute_der():
"""
Computes the Diarization Error Rate (DER) for each model across all JSON files in the 'results/' directory.
Outputs a summary table and rankings to 'der_report.txt'.
"""
dir_path = 'results/' # Directory containing result JSON files and 'diarization.json'
output_file = os.path.join(dir_path, 'der_report.txt') # Output report file
excluded_model = 'whisper-large-v3' # Model to exclude from analysis

# Initialize DER metric
der_metric = DiarizationErrorRate()

# Check if the directory exists
if not os.path.isdir(dir_path):
print(f"Directory '{dir_path}' does not exist.")
return

# Path to 'diarization.json'
diarization_path = 'diarization.json'

# Load reference diarization data
with open(diarization_path, 'r', encoding='utf-8') as f:
try:
diarization = json.load(f)
except json.JSONDecodeError:
print(f"Error decoding JSON in 'diarization.json'.")
return

# Prepare to collect DER results
der_results = [] # List to store [Memory ID, Model, DER]
model_der_accumulator = defaultdict(list) # To calculate average DER per model

# Iterate through all JSON files in 'results/' directory
for file in os.listdir(dir_path):
if not file.endswith('.json') or file == 'diarization.json':
continue # Skip non-JSON files and 'diarization.json' itself

memory_id = file.split('.')[0] # Extract memory ID from filename

# Check if memory_id exists in 'diarization.json'
if memory_id not in diarization:
print(f"Memory ID '{memory_id}' not found in 'diarization.json'. Skipping file: {file}")
continue

# Load reference segments for the current memory_id
ref_segments = diarization[memory_id]
ref_annotation = Annotation()
for seg in ref_segments:
speaker, start, end = seg['speaker'], seg['start'], seg['end']
ref_annotation[Segment(start, end)] = speaker

# Load hypothesis segments from the result JSON file
file_path = os.path.join(dir_path, file)
with open(file_path, 'r', encoding='utf-8') as f:
try:
data = json.load(f)
except json.JSONDecodeError:
print(f"Error decoding JSON in file: {file}. Skipping.")
continue

# Iterate through each model's segments in the result
for model, segments in data.items():
if model == excluded_model:
continue # Skip the excluded model

hyp_annotation = Annotation()
for seg in segments:
speaker, start, end = seg['speaker'], seg['start'], seg['end']
# Optional: Normalize speaker labels if necessary
if speaker == 'SPEAKER_0':
speaker = 'SPEAKER_00'
elif speaker == 'SPEAKER_1':
speaker = 'SPEAKER_01'
elif speaker == 'SPEAKER_2':
speaker = 'SPEAKER_02'
elif speaker == 'SPEAKER_3':
speaker = 'SPEAKER_03'
hyp_annotation[Segment(start, end)] = speaker

# Compute DER between reference and hypothesis
der = der_metric(ref_annotation, hyp_annotation)

# Store the result
der_results.append([memory_id, model, f"{der:.2%}"])
model_der_accumulator[model].append(der)

# Generate the detailed DER table
der_table = tabulate(der_results, headers=["Memory ID", "Model", "DER"], tablefmt="grid", stralign="left")

# Calculate average DER per model
average_der = []
for model, ders in model_der_accumulator.items():
avg = sum(ders) / len(ders)
average_der.append([model, f"{avg:.2%}"])

# Sort models by average DER ascending (lower is better)
average_der_sorted = sorted(average_der, key=lambda x: float(x[1].strip('%')))

# Determine the winner (model with the lowest average DER)
winner = average_der_sorted[0][0] if average_der_sorted else "N/A"

# Prepare rankings (1st, 2nd, etc.)
rankings = []
rank = 1
previous_der = None
for model, avg in average_der_sorted:
current_der = float(avg.strip('%'))
if previous_der is None or current_der < previous_der:
current_rank = rank
else:
current_rank = rank - 1 # Same rank as previous if DER is equal
rankings.append([current_rank, model, avg])
previous_der = current_der
rank += 1

# Generate the rankings table
ranking_table = tabulate(rankings, headers=["Rank", "Model", "Average DER"], tablefmt="grid", stralign="left")

# Write all results to the output file
with open(output_file, 'w', encoding='utf-8') as out_f:
out_f.write("Diarization Error Rate (DER) Analysis Report\n")
out_f.write("=" * 50 + "\n\n")
out_f.write("Detailed DER Results:\n")
out_f.write(der_table + "\n\n")
out_f.write("Average DER per Model:\n")
out_f.write(
tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left") + "\n\n")
out_f.write("Model Rankings Based on Average DER:\n")
out_f.write(ranking_table + "\n\n")
out_f.write(f"Winner: {winner}\n")

# Print a confirmation message
print(f"Diarization Error Rate (DER) analysis completed. Report saved to '{output_file}'.")

# Optionally, print the tables to the console as well
if der_results:
print("\nDetailed DER Results:")
print(der_table)
if average_der_sorted:
print("\nAverage DER per Model:")
print(tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left"))
if rankings:
print("\nModel Rankings Based on Average DER:")
print(ranking_table)
print(f"\nWinner: {winner}")


if __name__ == '__main__':
# asyncio.run(process_memories_audio_files())
compute_wer()
# client = Groq(api_key=os.getenv('GROQ_API_KEY'))
# file_path = '_temp2/DX8n89KAmUaG9O7Qvj8xTi81Zu12/0bce5547-675b-4dea-b9fe-cfb69740100b.wav'

# with open(file_path, "rb") as file:
# transcription = client.audio.transcriptions.create(
# file=(file_path, file.read()),
# model="whisper-large-v3",
# response_format="verbose_json",
# language="en",
# temperature=0.0
# )
# print(transcription)
# for segment in transcription.segments:
# print(segment['start'], segment['end'], segment['text'])
# generate_diarizations()

# compute_wer()
compute_der()

0 comments on commit 1e65fd9

Please sign in to comment.