forked from zyddnys/manga-image-translator
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from manga_translator.moeflow import async_detection | ||
from .cache_async import cache_async | ||
from pydantic import BaseModel | ||
|
||
class TranslateTaskDef(BaseModel): | ||
image_file: str | ||
|
||
@cache_async | ||
async def start_translate_task(task_def: TranslateTaskDef): | ||
detection_result = await async_detection(task_def.image_file, detector_key="craft") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from functools import _make_key, lru_cache, _lru_cache_wrapper | ||
import streamlit as st | ||
from typing import TypeVar, overload, Callable, Any, Coroutine | ||
import asyncio | ||
import functools | ||
import inspect | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
R = TypeVar("R") | ||
|
||
|
||
def _get_session_cache() -> dict[str, Callable]: | ||
if "#cached_async_functions" not in st.session_state: | ||
st.session_state["#cached_async_functions"] = {} | ||
return st.session_state["#cached_async_functions"] | ||
|
||
|
||
def _hash_function(f: Callable) -> str: | ||
logger.debug("module: %s", f.__module__) | ||
logger.debug("file: %s", inspect.getabsfile(f)) | ||
logger.debug("srclines: %s", inspect.getsourcelines(f)) | ||
lines, lineno = inspect.getsourcelines(f) | ||
h = hash( | ||
( | ||
f.__module__, | ||
inspect.getabsfile(f), | ||
lineno, | ||
tuple(lines), | ||
# inspect.getsource(f), | ||
# inspect.getsourcefile(f) | ||
) | ||
) | ||
return f"{f.__module__}:{inspect.getabsfile(f)}:{lineno} ${h}" | ||
|
||
|
||
def cache_async( | ||
f: Callable[..., Coroutine[Any, Any, R]], cache_size: int = 128 | ||
) -> Callable[..., asyncio.Task[R]]: | ||
"""_summary_ | ||
Args: | ||
f (Callable[..., Coroutine[Any, Any, R]]): an async functoin | ||
Returns: | ||
Callable[..., asyncio.Task[R]]: a per (session, callee) | ||
callee is identified by , therefore mu | ||
""" | ||
f_id = _hash_function(f) | ||
|
||
def wrapped(*args, **kwargs): | ||
"""the outmost function | ||
Returns: | ||
asyncio.Task[R]: maybe-cached asyncio.Task, created by caching calls to f() | ||
""" | ||
cache = _get_session_cache() | ||
if f_id not in cache: | ||
logger.debug("creating lru cache for %s", f_id) | ||
|
||
def run_for_task(*args, **kwargs): | ||
return asyncio.create_task(f(*args, **kwargs)) | ||
|
||
cache[f_id] = functools.lru_cache(maxsize=cache_size, typed=True)( | ||
run_for_task | ||
) | ||
return cache[f_id](*args, **kwargs) | ||
# nested levels: wrapped >> (lru-ed run_for_task) >> run_for_task >> f | ||
|
||
return wrapped |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import streamlit as st | ||
from streamlit.runtime.uploaded_file_manager import UploadedFile | ||
import logging | ||
import tempfile | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
uploaded: list[UploadedFile] = st.file_uploader("pick a file", type=['png', 'jpg'], accept_multiple_files=True) | ||
|
||
logger.debug('uploaded: %s', uploaded) | ||
|
||
def start_save(): | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
for file in uploaded: | ||
with open(f'{tmpdirname}/{file.name}', 'wb') as f: | ||
f.write(file.getvalue()) | ||
logger.info("saved %/%s", tmpdirname, file.name) | ||
|
||
st.button("save", on_click=start_save) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import streamlit as st | ||
import streamlit_pydantic as sp | ||
|
||
from manga_translator.streamlit import start_translate_task, TranslateTaskDef | ||
|
||
task_def = sp.pydantic_form(key="single_input_file", model=TranslateTaskDef) | ||
|
||
if task_def: | ||
sp.pydantic_output(task_def) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import streamlit as st | ||
|
||
st.markdown("TODO") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import logging | ||
import streamlit as st | ||
import src as _src | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# logger.setLevel(logging.DEBUG) | ||
logger.debug("page1") | ||
|
||
st.markdown("page1") | ||
|
||
a = st.number_input("a", value=1) | ||
b = st.number_input("b", value=2) | ||
|
||
def on_slide_change(): | ||
count = st.session_state.get("slide_change_count", 0) | ||
count += 1 | ||
st.session_state["slide_change_count"] = count | ||
st.write(f"slide changed {count}") | ||
|
||
c=st.slider("c",1,10,3, on_change=on_slide_change) | ||
|
||
st.write(_src.cached_compute(a, c)) | ||
|
||
st.session_state |