Skip to content

Commit

Permalink
update for new model
Browse files Browse the repository at this point in the history
  • Loading branch information
zyddnys committed Jan 24, 2022
1 parent b3c6d2e commit ab826fc
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 123 deletions.
12 changes: 6 additions & 6 deletions text_rendering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from . import text_render
from textblockdetector.textblock import TextBlock

async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], textlines: List[Quadrilateral], text_regions: List[Quadrilateral], force_horizontal: bool) -> np.ndarray :
async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], textlines: List[Quadrilateral], text_regions: List[Quadrilateral], text_direction_overwrite: str) -> np.ndarray :
for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) :
if not trans_text :
continue
if force_horizontal :
region.majority_dir = 'h'
if text_direction_overwrite and text_direction_overwrite in ['h', 'v'] :
region.majority_dir = text_direction_overwrite
print(region.text)
print(trans_text)
#print(region.majority_dir, region.pts)
Expand Down Expand Up @@ -113,13 +113,13 @@ async def dispatch(img_canvas: np.ndarray, text_mag_ratio: np.integer, translate
return img_canvas


async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], force_horizontal: bool) -> np.ndarray :
async def dispatch_ctd_render(img_canvas: np.ndarray, text_mag_ratio: np.integer, translated_sentences: List[str], text_regions: List[TextBlock], text_direction_overwrite: str) -> np.ndarray :
for ridx, (trans_text, region) in enumerate(zip(translated_sentences, text_regions)) :
print(f'text: {region.get_text()} \n trans: {trans_text}')
if not trans_text :
continue
if force_horizontal :
majority_dir = 'h'
if text_direction_overwrite and text_direction_overwrite in ['h', 'v'] :
majority_dir = text_direction_overwrite
else:
majority_dir = 'v' if region.vertical else 'h'

Expand Down
78 changes: 41 additions & 37 deletions translate_demo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@

import asyncio
import torch
import einops
import argparse
import cv2
import numpy as np
import requests
import os
from oscrypto import util as crypto_utils
import asyncio

from detection import dispatch as dispatch_detection, load_model as load_detection_model
from ocr import dispatch as dispatch_ocr, load_model as load_ocr_model
from inpainting import dispatch as dispatch_inpainting, load_model as load_inpainting_model
from text_mask import dispatch as dispatch_mask_refinement
from textline_merge import dispatch as dispatch_textline_merge
from text_rendering import dispatch as dispatch_rendering
from textblockdetector import dispatch as dispatch_ctd_detection
from textblockdetector.textblock import visualize_textblocks

parser = argparse.ArgumentParser(description='Generate text bboxes given a image file')
parser.add_argument('--mode', default='demo', type=str, help='Run demo in either single image demo mode (demo), web service mode (web) or batch translation mode (batch)')
Expand Down Expand Up @@ -47,32 +56,25 @@ def update_state(task_id, nonce, state) :
def get_task(nonce) :
try :
rjson = requests.get(f'http://127.0.0.1:5003/task-internal?nonce={nonce}').json()
if 'task_id' in rjson :
return rjson['task_id']
if 'task_id' in rjson and 'data' in rjson :
return rjson['task_id'], rjson['data']
else :
return None
return None, None
except :
return None

from detection import dispatch as dispatch_detection, load_model as load_detection_model
from ocr import dispatch as dispatch_ocr, load_model as load_ocr_model
from inpainting import dispatch as dispatch_inpainting, load_model as load_inpainting_model
from text_mask import dispatch as dispatch_mask_refinement
from textline_merge import dispatch as dispatch_textline_merge
from text_rendering import dispatch as dispatch_rendering
from textblockdetector import dispatch as dispatch_ctd_detection
from textblockdetector.textblock import visualize_textblocks
return None, None

async def infer(
img,
mode,
nonce,
options = None,
task_id = '',
dst_image_name = ''
) :
options = options or {}
img_detect_size = args.size
if task_id and len(task_id) != 32 :
size_ind = task_id[-1]
if 'size' in options :
size_ind = options['size']
if size_ind == 'S' :
img_detect_size = 1024
elif size_ind == 'M' :
Expand All @@ -81,20 +83,29 @@ async def infer(
img_detect_size = 2048
elif size_ind == 'X' :
img_detect_size = 2560
print(f' -- Detection size {size_ind}, resolution {img_detect_size}')
print(f' -- Detection resolution {img_detect_size}')
detector = 'ctd' if args.use_ctd else 'default'
if 'detector' in options :
detector = options['detector']
print(f' -- Detector using {detector}')
render_text_direction_overwrite = 'h' if args.force_horizontal else ''
if 'direction' in options :
if options['direction'] == 'horizontal' :
render_text_direction_overwrite = 'h'
print(f' -- Render text direction is {render_text_direction_overwrite or "auto"}')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

if mode == 'web' and task_id :
update_state(task_id, nonce, 'detection')

if args.use_ctd:
if detector == 'ctd' :
mask, final_mask, textlines = await dispatch_ctd_detection(img, args.use_cuda)
text_regions = textlines
else:
textlines, mask = await dispatch_detection(img, img_detect_size, args.use_cuda, args, verbose = args.verbose)

if args.verbose :
if args.use_ctd:
if detector == 'ctd' :
bboxes = visualize_textblocks(cv2.cvtColor(img,cv2.COLOR_BGR2RGB), textlines)
cv2.imwrite(f'result/{task_id}/bboxes.png', bboxes)
cv2.imwrite(f'result/{task_id}/mask_raw.png', mask)
Expand All @@ -110,7 +121,7 @@ async def infer(
update_state(task_id, nonce, 'ocr')
textlines = await dispatch_ocr(img, textlines, args.use_cuda, args)

if not args.use_ctd:
if detector == 'default' :
text_regions, textlines = await dispatch_textline_merge(textlines, img.shape[1], img.shape[0], verbose = args.verbose)
if args.verbose :
img_bbox = np.copy(img)
Expand Down Expand Up @@ -155,7 +166,7 @@ async def infer(
print(' -- Translating')
# try:
from translators import dispatch as run_translation
if args.use_ctd:
if detector == 'ctd' :
translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.get_text() for r in text_regions])
else:
translated_sentences = await run_translation(args.translator, 'auto', args.target_lang, [r.text for r in text_regions])
Expand All @@ -177,11 +188,11 @@ async def infer(
if mode == 'web' and task_id :
update_state(task_id, nonce, 'render')
# render translated texts
if args.use_ctd:
if detector == 'ctd' :
from text_rendering import dispatch_ctd_render
output = await dispatch_ctd_render(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, text_regions, args.force_horizontal)
output = await dispatch_ctd_render(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, text_regions, render_text_direction_overwrite)
else:
output = await dispatch_rendering(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, textlines, text_regions, args.force_horizontal)
output = await dispatch_rendering(np.copy(img_inpainted), args.text_mag_ratio, translated_sentences, textlines, text_regions, render_text_direction_overwrite)

print(' -- Saving results')
if dst_image_name :
Expand All @@ -192,28 +203,21 @@ async def infer(
if mode == 'web' and task_id :
update_state(task_id, nonce, 'finished')

from PIL import Image
import time
import asyncio

def replace_prefix(s: str, old: str, new: str) :
if s.startswith(old) :
s = new + s[len(old):]
return s

async def main(mode = 'demo') :
print(' -- Loading models')
import os
os.makedirs('result', exist_ok = True)
text_render.prepare_renderer()
with open('alphabet-all-v5.txt', 'r', encoding = 'utf-8') as fp :
dictionary = [s[:-1] for s in fp.readlines()]
load_ocr_model(dictionary, args.use_cuda)
if args.use_ctd:
from textblockdetector import load_model as load_ctd_model
load_ctd_model(args.use_cuda)
else:
load_detection_model(args.use_cuda)
from textblockdetector import load_model as load_ctd_model
load_ctd_model(args.use_cuda)
load_detection_model(args.use_cuda)
load_inpainting_model(args.use_cuda)

if mode == 'demo' :
Expand All @@ -232,12 +236,12 @@ async def main(mode = 'demo') :
import sys
subprocess.Popen([sys.executable, 'web_main.py', nonce, '5003'])
while True :
task_id = get_task(nonce)
task_id, options = get_task(nonce)
if task_id :
print(f' -- Processing task {task_id}')
img = cv2.imread(f'result/{task_id}/input.png')
try :
infer_task = asyncio.create_task(infer(img, mode, nonce, task_id))
infer_task = asyncio.create_task(infer(img, mode, nonce, options, task_id))
asyncio.gather(infer_task)
except :
import traceback
Expand Down
30 changes: 28 additions & 2 deletions ui.html
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@
event.preventDefault();
var detect_res = document.querySelector('input[name="detect-size"]:checked').value;
var translator = document.querySelector('input[name="translator-sel"]:checked').value;
var dir = document.querySelector('input[name="dir-sel"]:checked').value;
var detector = document.querySelector('input[name="detector-sel"]:checked').value;
var tgt_lang = document.getElementById("target-language").value;
console.log(document.getElementById("target-language"));
var files = document.getElementById("image-file").files;
Expand All @@ -135,6 +137,8 @@
formData.append('size', detect_res);
formData.append('translator', translator);
formData.append('tgt_lang', tgt_lang);
formData.append('dir', dir);
formData.append('detector', detector);
XHR.open('POST', BASE_URI + "submit", true);
XHR.onload = async function () {
if (XHR.status == 200) {
Expand Down Expand Up @@ -185,7 +189,7 @@ <h3>Image/Manga translator</h3>
<div>
<form id="image-form" action="#" onsubmit="upload_image(event);" method="post" enctype="multipart/form-data">
<div style="padding-bottom: 0.5rem;">
<span>Detection resolution: </span>
<span>Detection resolution </span>
<label>
<input type="radio" id="detect-res-S" name="detect-size" value="S">
<span>1024px</span>
Expand All @@ -204,7 +208,18 @@ <h3>Image/Manga translator</h3>
</label>
</div>
<div style="padding-bottom: 0.5rem;">
<span>Translator: </span>
<span>Text detector </span>
<label>
<input type="radio" id="detector-default" name="detector-sel" value="auto" checked>
<span>Default</span>
</label>
<label>
<input type="radio" id="detector-ctd" name="detector-sel" value="ctd">
<span>CTD</span>
</label>
</div>
<div style="padding-bottom: 0.5rem;">
<span>Translator </span>
<label>
<input type="radio" id="translator-youdao" name="translator-sel" value="youdao" checked>
<span>Youdao</span>
Expand All @@ -222,6 +237,17 @@ <h3>Image/Manga translator</h3>
<span>DeepL</span>
</label>
</div>
<div style="padding-bottom: 0.5rem;">
<span>Render text direction </span>
<label>
<input type="radio" id="dir-auto" name="dir-sel" value="auto" checked>
<span>Auto</span>
</label>
<label>
<input type="radio" id="dir-horizontal" name="dir-sel" value="horizontal">
<span>Horizontal</span>
</label>
</div>
<div style="padding-bottom: 0.5rem;">
<span>Target language:</span>
<select name="tgt_lang" id="target-language">
Expand Down
Loading

0 comments on commit ab826fc

Please sign in to comment.