-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinf.py
63 lines (51 loc) · 1.65 KB
/
inf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
from llm_tw_word.const import TRANSLATOR_TYPE
from llm_tw_word.const import DEFAULT_LLAMA_MODEL
from llm_tw_word.const import DEFAULT_OPENAI_MODEL
from llm_tw_word.translate import LlamaTranslator
from llm_tw_word.translate import OpenAITranslator
def parse_args():
parser = argparse.ArgumentParser(
description="Script for model inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"text",
type=str,
help="Text to be translated",
)
parser.add_argument(
"translator",
type=str,
choices=(
TRANSLATOR_TYPE.LLAMA,
TRANSLATOR_TYPE.OPENAI,
),
help="Translator type",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Specified model name for the translator. If not provided, there"
" will be a default model",
)
args = parser.parse_args()
return args
def main(args):
text_trad = args.text
translator_name = args.translator
model_name = args.model
if translator_name == TRANSLATOR_TYPE.LLAMA:
model_name = model_name if model_name else DEFAULT_LLAMA_MODEL
translator = LlamaTranslator(model_name=model_name)
else:
model_name = model_name if model_name else DEFAULT_OPENAI_MODEL
translator = OpenAITranslator(model_name=model_name)
pred = translator.translate([text_trad])[0]
print(f"Translator: {translator_name}")
print(f"Model: {model_name}")
print(f"Input Text: {text_trad}")
print(f"Output Text: {pred}")
if __name__ == "__main__":
main(parse_args())