-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
63 lines (49 loc) · 1.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Serve the model as a fastapi app with gradio client"""
import json
import logging
from logging import getLogger
from pathlib import Path
import coloredlogs
import gradio as gr
import librosa
import torch
import torchvision.transforms as transforms
from fastapi import FastAPI
from moviepy.video.io.bindings import mplfig_to_npimage
from PIL import Image
from utils.preprocessing import plot_mel
coloredlogs.install()
logger = getLogger(__name__)
app = FastAPI()
CUSTOM_PATH = "/"
# @app.get("/")
# def read_main():
# """Return a friendly HTTP greeting."""
# return {"message": "This is your main app"}
# @app.get("/gradio")
def predict(file_path):
"""Predict the emotion of the audio file"""
# parse the json data
logging.info(file_path)
audio, rate = librosa.load(file_path)
fig = plot_mel(audio, rate)
numpy_image = mplfig_to_npimage(fig)
data = Image.fromarray(numpy_image)
# data.save("gfg_dummy_pic.jpeg")
resize = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
input = resize(data).unsqueeze(0)
f = open(Path(Path().resolve(), "data/responses.json"), encoding="utf-8")
labels = json.load(f)
f.close()
outputs = MODEL.forward(input)
_, y_hat = outputs.max(1)
prediction = labels[str(y_hat.item())]
return prediction
io = gr.Interface(
fn=predict,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
)
MODEL = torch.load("model.pt", map_location="cpu")
gradio_app = gr.routes.App.create_app(io)
app.mount(CUSTOM_PATH, gradio_app)