Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added baseline UI using textual #5

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ dependencies = [
"sentence-transformers>=3.3.1",
"transformers>=4.48.0",
"accelerate>=1.2.1",
"gnews>=0.3.9"
"gnews>=0.3.9",
"textual>=1.0.0"
]

[project.optional-dependencies]
Expand Down
21 changes: 3 additions & 18 deletions src/news_ask_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,10 @@
from news_ask_ai.ask.news_ask_engine import NewsAskEngine

from news_ask_ai.ask.news_ask_ui import NewsAskUI
from news_ask_ai.utils.logger import setup_logger

logger = setup_logger()


def main() -> None:
logger.info("Initializing RAG engine...")
search_engine = NewsAskEngine(collection_name="news-collection")

print("Indicate the topic of the news you want to ask about")
topic_input = input("\nYour topic: ")
search_engine.ingest_data(topic_input)

while True:
print("What do you want to ask about?")
question_input = input("\nYour question: ")
if question_input.lower() == "exit":
break

completion = search_engine.get_completions(question_input)
print(completion)

print("Thank you for using the NewsAskAI system. Goodbye!")
app = NewsAskUI()
app.run()
120 changes: 120 additions & 0 deletions src/news_ask_ai/ask/news_ask_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from textual.app import App, ComposeResult
from textual.widgets import Footer, Header, Input, Button, Static
from textual.containers import Horizontal, Container

from textual.widget import Widget

from news_ask_ai.ask.news_ask_engine import NewsAskEngine
from news_ask_ai.utils.logger import setup_logger

logger = setup_logger()


class ConversationalContainer(Container, can_focus=True):
"""Conversational container widget."""


class MessageBox(Widget):
def __init__(self, text: str, role: str) -> None:
self.text = text
self.role = role
super().__init__()

def compose(self) -> ComposeResult:
yield Static(self.text, classes=f"message {self.role}")


class NewsAskUI(App): # type: ignore
TITLE = "NewsAskAI"
SUB_TITLE = "Ask questions about the news directly in your terminal"
CSS_PATH = "../static/css/styles.css"

def __init__(self) -> None:
super().__init__()
logger.info("Initializing RAG engine...")
self.search_engine = NewsAskEngine(collection_name="news-collection")

def compose(self) -> ComposeResult:
yield Header()

# -- Container for topic input and ingest button.
with Container(id="news_topic_container"):
yield MessageBox(
"Welcome to NewsAskAI!\n"
"Get the latest news insights with just a few clicks.\n"
"Enter a topic of interest below and press 'Ingest news' to start.\n"
"Need help? Use the commands at the bottom of your screen.",
role="info",
)

with Horizontal(id="input_box"):
yield Input(placeholder="Enter news topic...", id="news_topic_input")
yield Button(label="Ingest news", variant="success", id="news_ingest_button")

# -- Container for conversation UI, hidden by default.
with Container(id="conversation_container"):
with ConversationalContainer(id="conversation_box"):
yield MessageBox(
"You're all set to explore the news!\n"
"Type a question about the ingested news and press 'Ask' to get your answer.\n"
"Wait for the response...\n"
"Need assistance? Commands are listed at the bottom.",
role="info",
)

with Horizontal(id="input_box"):
yield Input(placeholder="Enter your question", id="conversation_input")
yield Button(label="Ask", variant="success", id="conversation_button")

yield Footer()

def on_mount(self) -> None:
"""Called when app is first mounted."""
self.query_one("#news_topic_input", Input).focus()

conversation_container = self.query_one("#conversation_container")
conversation_container.display = False

async def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button pressed events."""
button = event.button

if button.id == "news_ingest_button":
news_topic_input = self.query_one("#news_topic_input", Input)
if not news_topic_input.value.strip():
return

conversation_container = self.query_one("#news_topic_container")
conversation_container.display = False

conversation_container = self.query_one("#conversation_container")
conversation_container.display = True

self.search_engine.ingest_data(news_topic_input.value)

elif button.id == "conversation_button":
await self.conversation()

async def conversation(self) -> None:
message_input = self.query_one("#conversation_input", Input)
if message_input.value == "":
return

conversation_box = self.query_one("#conversation_box")

message_box = MessageBox(message_input.value, "question")
conversation_box.mount(message_box)
conversation_box.scroll_end(animate=True)

# Clean up the input without triggering events
with message_input.prevent(Input.Changed):
message_input.value = ""

completion = self.search_engine.get_completions(message_box.text)

conversation_box.mount(
MessageBox(
text=completion,
role="answer",
)
)
5 changes: 5 additions & 0 deletions src/news_ask_ai/services/llm_completion_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import cast

import torch
Expand All @@ -11,6 +12,9 @@

logger = setup_logger()

logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)

# Input Format:
# The phi-4 model is best suited for prompts formatted as a chat, using special tokens.
# Each message should follow this structure:
Expand All @@ -36,6 +40,7 @@ def __init__(self, model_name: str = "microsoft/Phi-3.5-mini-instruct") -> None:

torch.random.manual_seed(0)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info(f"Using device: {self.device}")

model_kwargs = {
Expand Down
70 changes: 70 additions & 0 deletions src/news_ask_ai/static/css/styles.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
Screen {
background: #212529;
}

MessageBox {
layout: horizontal;
height: auto;
align-horizontal: center;
}

.message {
width: 50%;
min-width: 25%;
max-width: 50%;
border: tall black;
padding: 1 3;
margin: 1 0;
background: #343a40;
}

.info {
width: auto;
text-align: center;
}

.question {
margin: 1 0 1 25;
}

.answer {
margin: 1 25 1 0;
}

#conversation_box {
overflow-y: auto;
height: 100%;
}

#input_box {
dock: bottom;
height: auto;
width: 100%;
margin: 0 0 2 0;
align_horizontal: center;
}

/* News style */

#news_topic_input {
width: 50%;
background: #343a40;
}

#news_ingest_button {
width: auto;
}

/* Conversation style */

#conversation_input {
width: 50%;
background: #343a40;
}


#conversation_button {
width: auto;
}


17 changes: 8 additions & 9 deletions src/news_ask_ai/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
from logging.handlers import RotatingFileHandler


def setup_logger(level: int = logging.INFO) -> logging.Logger:


def setup_logger(level: int = logging.INFO, log_file: str = "app.log") -> logging.Logger:
"""
Set up a logger with the specified name, log file, and level.

Expand All @@ -13,22 +16,18 @@ def setup_logger(level: int = logging.INFO) -> logging.Logger:
Returns:
logging.Logger: Configured logger.
"""
# Create a logger
logger = logging.getLogger()
if logger.hasHandlers():
logger.handlers.clear()

logger.setLevel(level)

# Create a console handler for outputting logs to the console
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
file_handler = RotatingFileHandler(log_file, maxBytes=1024 * 1024, backupCount=5)
file_handler.setLevel(level)

# Define a log format
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)

# Add handlers to the logger
logger.addHandler(console_handler)
logger.addHandler(file_handler)

return logger
Loading
Loading