Skip to content

Commit

Permalink
Add classify-by-topic to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh committed Aug 14, 2024
1 parent 6298c42 commit 18f5c90
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ontogpt.engines.reasoner_engine import ReasonerEngine
from ontogpt.engines.spires_engine import SPIRESEngine
from ontogpt.engines.synonym_engine import SynonymEngine
from ontogpt.engines.topic_classifier_engine import TopicClassifierEngine
from ontogpt.evaluation.resolver import create_evaluator
from ontogpt.io.csv_wrapper import parse_yaml_predictions, write_graph, write_obj_as_csv
from ontogpt.io.html_exporter import HTMLExporter
Expand Down Expand Up @@ -1112,6 +1113,86 @@ def synonyms(
output.write(f"{line}\n")


@main.command()
@inputfile_option
@use_pdf_options
@model_option
@temperature_option
@api_base_option
@api_version_option
@model_provider_option
@system_message_option
@click.argument("topic")
def classify_by_topic(
inputfile,
model,
temperature,
api_base,
api_version,
model_provider,
system_message,
topic,
use_pdf,
):
"""Classify input text by topic.
Returns True if the input text is about the topic, False otherwise,
along with the name of the input file.
A path to a file containing input text may be passed as inputfile,
as may a directory of input files.
Example:
ontogpt classify-by-topic -i temp/30091466.txt
"clinical observations of human patients, including the diagnostic
and therapeutic procedures used during their clinical care"
"""

if not model:
model = DEFAULT_MODEL

inputdict = {}

if not inputfile or inputfile == "-":
text = sys.stdin.read()
inputdict["Input"] = text
elif inputfile and Path(inputfile).is_dir():
logging.info(f"Input file directory: {inputfile}")
inputfiles = Path(inputfile).glob("*.txt")
inputdict = {f: (open(f, "r").read()) for f in inputfiles if f.is_file()}
logging.info(f"Found {len(inputdict)} input files here.")
elif inputfile and Path(inputfile).exists():
logging.info(f"Input file: {inputfile}")
if use_pdf:
import pymupdf

doc = pymupdf.open(inputfile)
text = ""
for page in doc:
text = text + (page.get_text())
else:
text = open(inputfile, "rb").read().decode(encoding="utf-8", errors="ignore")
logging.info(f"Input text: {text}")
inputdict[inputfile] = text
elif inputfile and not Path(inputfile).exists():
raise FileNotFoundError(f"Cannot find input file {inputfile}")

ke = TopicClassifierEngine(
model=model,
temperature=temperature,
api_base=api_base,
api_version=api_version,
model_provider=model_provider,
system_message=system_message,
)

for input_entry in inputdict:
response = ke.binary_classify(topic=topic, text=inputdict[input_entry])
print(f"{input_entry}\t{response}")


@main.command()
@model_option
@api_base_option
Expand Down

0 comments on commit 18f5c90

Please sign in to comment.