diff --git a/.gitignore b/.gitignore index e49c2b381..0d1cd6f1e 100644 --- a/.gitignore +++ b/.gitignore @@ -153,4 +153,9 @@ app/lib/env/prod_env.g.dart /backend/scripts/research/*.md /backend/scripts/research/*.json /backend/scripts/research/*.csv -/backend/scripts/research/users \ No newline at end of file +/backend/scripts/research/users +/backend/scripts/stt/_temp +/backend/scripts/stt/_temp2 +/backend/scripts/stt/pretrained_models +/backend/scripts/stt/results +/backend/scripts/stt/diarization.json diff --git a/backend/main.py b/backend/main.py index ecbc421ff..f5f2c57f0 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,8 +6,7 @@ from modal import Image, App, asgi_app, Secret, Cron from routers import workflow, chat, firmware, plugins, memories, transcribe, notifications, speech_profile, \ - agents, facts, users, postprocessing, processing_memories, trends,sdcard - + agents, facts, users, postprocessing, processing_memories, trends, sdcard from utils.other.notifications import start_cron_job if os.environ.get('SERVICE_ACCOUNT_JSON'): @@ -54,7 +53,7 @@ memory=(512, 1024), cpu=2, allow_concurrent_inputs=10, - timeout=60 * 5, + timeout=60 * 19, ) @asgi_app() def api(): @@ -70,3 +69,30 @@ def api(): @modal_app.function(image=image, schedule=Cron('* * * * *')) async def notifications_cronjob(): await start_cron_job() + + +@app.post('/webhook') +async def webhook(data: dict): + diarization = data['output']['diarization'] + joined = [] + for speaker in diarization: + if not joined: + joined.append(speaker) + else: + if speaker['speaker'] == joined[-1]['speaker']: + joined[-1]['end'] = speaker['end'] + else: + joined.append(speaker) + + print(data['jobId'], json.dumps(joined)) + # openn scripts/stt/diarization.json, get jobId=memoryId, delete but get memoryId, and save memoryId=joined + with open('scripts/stt/diarization.json', 'r') as f: + diarization_data = json.loads(f.read()) + + memory_id = diarization_data.get(data['jobId']) + if memory_id: + diarization_data[memory_id] = joined + del diarization_data[data['jobId']] + with open('scripts/stt/diarization.json', 'w') as f: + json.dump(diarization_data, f, indent=2) + return 'ok' diff --git a/backend/scripts/stt/k_compare_transcripts_performance.py b/backend/scripts/stt/k_compare_transcripts_performance.py index b65d0f074..3f193cc97 100644 --- a/backend/scripts/stt/k_compare_transcripts_performance.py +++ b/backend/scripts/stt/k_compare_transcripts_performance.py @@ -17,6 +17,7 @@ from typing import Dict, List import firebase_admin +import requests from dotenv import load_dotenv from pydub import AudioSegment from tabulate import tabulate @@ -63,7 +64,7 @@ def execute_groq(file_path: str): transcription = client.audio.transcriptions.create( file=(file_path, file.read()), model="whisper-large-v3", - response_format="verbose_json", + response_format="text", language="en", temperature=0.0 ) @@ -370,13 +371,9 @@ def compute_wer(): def regex_fix(text: str): - # Define the regular expression + """Fix some of the stored JSON in results/$id.json from the Groq API.""" pattern = r'(?<=transcription\(text=["\'])(.*?)(?=["\'],\s*task=)' - - # Search for the pattern in the data match = re.search(pattern, text) - - # If a match is found, extract and print the text if match: extracted_text = match.group(0) return extracted_text @@ -385,20 +382,200 @@ def regex_fix(text: str): return text +def pyannote_diarize(file_path: str): + memory_id = file_path.split('/')[-1].split('.')[0] + with open('diarization.json', 'r') as f: + results = json.loads(f.read()) + + if memory_id in results: + print('Already diarized', memory_id) + return + + url = "https://api.pyannote.ai/v1/diarize" + headers = {"Authorization": f"Bearer {os.getenv('PYANNOTE_API_KEY')}"} + webhook = 'https://camel-lucky-reliably.ngrok-free.app/webhook' + signed_url = upload_postprocessing_audio(file_path) + data = {'webhook': webhook, 'url': signed_url} + response = requests.post(url, headers=headers, json=data) + print(memory_id, response.json()['jobId']) + # update diarization.json, and set jobId=memoryId + with open('diarization.json', 'r') as f: + diarization = json.loads(f.read()) + + diarization[response.json()['jobId']] = memory_id + with open('diarization.json', 'w') as f: + json.dump(diarization, f, indent=2) + + +def generate_diarizations(): + uids = os.listdir('_temp2') + for uid in uids: + memories = os.listdir(f'_temp2/{uid}') + memories = [f'_temp2/{uid}/{memory}' for memory in memories] + for memory in memories: + memory_id = memory.split('/')[-1].split('.')[0] + if os.path.exists(f'results/{memory_id}.json'): + pyannote_diarize(memory) + else: + print('Skipping', memory_id) + + +from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.core import Annotation, Segment + +der_metric = DiarizationErrorRate() + + +def compute_der(): + """ + Computes the Diarization Error Rate (DER) for each model across all JSON files in the 'results/' directory. + Outputs a summary table and rankings to 'der_report.txt'. + """ + dir_path = 'results/' # Directory containing result JSON files and 'diarization.json' + output_file = os.path.join(dir_path, 'der_report.txt') # Output report file + excluded_model = 'whisper-large-v3' # Model to exclude from analysis + + # Initialize DER metric + der_metric = DiarizationErrorRate() + + # Check if the directory exists + if not os.path.isdir(dir_path): + print(f"Directory '{dir_path}' does not exist.") + return + + # Path to 'diarization.json' + diarization_path = 'diarization.json' + + # Load reference diarization data + with open(diarization_path, 'r', encoding='utf-8') as f: + try: + diarization = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in 'diarization.json'.") + return + + # Prepare to collect DER results + der_results = [] # List to store [Memory ID, Model, DER] + model_der_accumulator = defaultdict(list) # To calculate average DER per model + + # Iterate through all JSON files in 'results/' directory + for file in os.listdir(dir_path): + if not file.endswith('.json') or file == 'diarization.json': + continue # Skip non-JSON files and 'diarization.json' itself + + memory_id = file.split('.')[0] # Extract memory ID from filename + + # Check if memory_id exists in 'diarization.json' + if memory_id not in diarization: + print(f"Memory ID '{memory_id}' not found in 'diarization.json'. Skipping file: {file}") + continue + + # Load reference segments for the current memory_id + ref_segments = diarization[memory_id] + ref_annotation = Annotation() + for seg in ref_segments: + speaker, start, end = seg['speaker'], seg['start'], seg['end'] + ref_annotation[Segment(start, end)] = speaker + + # Load hypothesis segments from the result JSON file + file_path = os.path.join(dir_path, file) + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + except json.JSONDecodeError: + print(f"Error decoding JSON in file: {file}. Skipping.") + continue + + # Iterate through each model's segments in the result + for model, segments in data.items(): + if model == excluded_model: + continue # Skip the excluded model + + hyp_annotation = Annotation() + for seg in segments: + speaker, start, end = seg['speaker'], seg['start'], seg['end'] + # Optional: Normalize speaker labels if necessary + if speaker == 'SPEAKER_0': + speaker = 'SPEAKER_00' + elif speaker == 'SPEAKER_1': + speaker = 'SPEAKER_01' + elif speaker == 'SPEAKER_2': + speaker = 'SPEAKER_02' + elif speaker == 'SPEAKER_3': + speaker = 'SPEAKER_03' + hyp_annotation[Segment(start, end)] = speaker + + # Compute DER between reference and hypothesis + der = der_metric(ref_annotation, hyp_annotation) + + # Store the result + der_results.append([memory_id, model, f"{der:.2%}"]) + model_der_accumulator[model].append(der) + + # Generate the detailed DER table + der_table = tabulate(der_results, headers=["Memory ID", "Model", "DER"], tablefmt="grid", stralign="left") + + # Calculate average DER per model + average_der = [] + for model, ders in model_der_accumulator.items(): + avg = sum(ders) / len(ders) + average_der.append([model, f"{avg:.2%}"]) + + # Sort models by average DER ascending (lower is better) + average_der_sorted = sorted(average_der, key=lambda x: float(x[1].strip('%'))) + + # Determine the winner (model with the lowest average DER) + winner = average_der_sorted[0][0] if average_der_sorted else "N/A" + + # Prepare rankings (1st, 2nd, etc.) + rankings = [] + rank = 1 + previous_der = None + for model, avg in average_der_sorted: + current_der = float(avg.strip('%')) + if previous_der is None or current_der < previous_der: + current_rank = rank + else: + current_rank = rank - 1 # Same rank as previous if DER is equal + rankings.append([current_rank, model, avg]) + previous_der = current_der + rank += 1 + + # Generate the rankings table + ranking_table = tabulate(rankings, headers=["Rank", "Model", "Average DER"], tablefmt="grid", stralign="left") + + # Write all results to the output file + with open(output_file, 'w', encoding='utf-8') as out_f: + out_f.write("Diarization Error Rate (DER) Analysis Report\n") + out_f.write("=" * 50 + "\n\n") + out_f.write("Detailed DER Results:\n") + out_f.write(der_table + "\n\n") + out_f.write("Average DER per Model:\n") + out_f.write( + tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left") + "\n\n") + out_f.write("Model Rankings Based on Average DER:\n") + out_f.write(ranking_table + "\n\n") + out_f.write(f"Winner: {winner}\n") + + # Print a confirmation message + print(f"Diarization Error Rate (DER) analysis completed. Report saved to '{output_file}'.") + + # Optionally, print the tables to the console as well + if der_results: + print("\nDetailed DER Results:") + print(der_table) + if average_der_sorted: + print("\nAverage DER per Model:") + print(tabulate(average_der_sorted, headers=["Model", "Average DER"], tablefmt="grid", stralign="left")) + if rankings: + print("\nModel Rankings Based on Average DER:") + print(ranking_table) + print(f"\nWinner: {winner}") + + if __name__ == '__main__': # asyncio.run(process_memories_audio_files()) - compute_wer() - # client = Groq(api_key=os.getenv('GROQ_API_KEY')) - # file_path = '_temp2/DX8n89KAmUaG9O7Qvj8xTi81Zu12/0bce5547-675b-4dea-b9fe-cfb69740100b.wav' - - # with open(file_path, "rb") as file: - # transcription = client.audio.transcriptions.create( - # file=(file_path, file.read()), - # model="whisper-large-v3", - # response_format="verbose_json", - # language="en", - # temperature=0.0 - # ) - # print(transcription) - # for segment in transcription.segments: - # print(segment['start'], segment['end'], segment['text']) + # generate_diarizations() + + # compute_wer() + compute_der()