-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
79 lines (63 loc) · 2.24 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, BackgroundTasks, Query
from fastapi.responses import ORJSONResponse, FileResponse, Response, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from model import PopMusicTransformer
import utils
import tensorflow as tf
from datetime import datetime
import os
import s3fs
from pathlib import Path
from typing import List
from starlette.background import BackgroundTask
CHECKPOINT = 'tmp/model'
if not os.path.exists('tmp/responses'):
os.makedirs('tmp/responses', exist_ok=True)
if not os.path.exists(CHECKPOINT):
os.makedirs(CHECKPOINT, exist_ok=True)
print('Loading S3FileSystem...')
fs = s3fs.S3FileSystem()
bucket = 's3://popgen-model/model/'
files = fs.ls(bucket)
for f in files:
print('Downloading ' + f)
name = f.split('/')[-1]
fs.download(f, f"{CHECKPOINT}/{name}")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
tf.compat.v1.disable_eager_execution()
tf.compat.v1.reset_default_graph()
model = PopMusicTransformer(
checkpoint=CHECKPOINT,
is_training=False)
@app.get('/generate')
async def generate(n_target_bar: int = 5, temperature: float = 1.2, topk: int = 5, key: str = None, with_chords: bool = True):
words = model.generate(n_target_bar, temperature, topk, key_prompt=key)
events = utils.word_to_event(words, model.word2event)
if with_chords:
events = utils.add_played_chords(events)
words = [model.event2word['{}_{}'.format(e.name, e.value)] for e in events]
path = f'tmp/responses/{str(datetime.now())}.mid'
utils.write_midi(words, model.word2event, path)
with open(path, 'rb') as f:
res = f.read()
buffer = []
for c in res:
buffer.append(int(c))
eventStrings = [f'{e.name}_{e.value}' for e in events]
return JSONResponse({
'buffer': buffer,
'events': eventStrings
}, background=BackgroundTask(lambda: os.remove(path)))
@app.get('/download')
async def download(events: List[str] = Query(...)):
path = f'tmp/responses/{str(datetime.now())}.mid'
events = [utils.Event(name, None, value, None) for name, value in [e.split('_') for e in events]]
utils.write_midi(None, None, path, events)
return FileResponse(path, background=BackgroundTask(lambda: os.remove(path)))