Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Metrics #38

Merged
merged 56 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
5a2e0be
adds timeit functionality to chat and query functions
jeisenman23 Nov 14, 2024
86e2a51
reformatting time metrics
jeisenman23 Nov 14, 2024
93a7d97
trying to fix lint
jeisenman23 Nov 15, 2024
4e4dc4b
fixing lint
jeisenman23 Nov 15, 2024
fab34d9
reducing complexity
jeisenman23 Nov 15, 2024
785ec4d
removing white spaces
jeisenman23 Nov 15, 2024
b0429a4
fixing lint
jeisenman23 Nov 15, 2024
a41b030
fixing return statement
jeisenman23 Nov 15, 2024
dba79f8
fixing docs
jeisenman23 Nov 15, 2024
dc4579a
removing timeit as optional argument
jeisenman23 Nov 21, 2024
bfdd13f
changing performance metrics according to stream mode:
jeisenman23 Nov 21, 2024
96be266
update test
jeisenman23 Nov 21, 2024
6d37186
change test to fit performance metrics
jeisenman23 Dec 10, 2024
4aa6fb2
change test to fit performance metrics
jeisenman23 Dec 10, 2024
0ac1ca9
adding back query into chat
jeisenman23 Dec 10, 2024
f8c3d1a
removing await
jeisenman23 Dec 10, 2024
4018c59
fixing test
jeisenman23 Dec 20, 2024
4ca2970
fixing test
jeisenman23 Dec 20, 2024
34c58c0
removing whitespace
jeisenman23 Dec 20, 2024
abd7000
fixing space
jeisenman23 Dec 20, 2024
d94403c
finicky flake8 error fix
jeisenman23 Dec 20, 2024
63070ef
fixing elm tests
jeisenman23 Dec 20, 2024
20fb602
ensuring test cases
jeisenman23 Dec 20, 2024
15c417a
reversing - statement
jeisenman23 Dec 20, 2024
0186fdf
removing whitespace
jeisenman23 Dec 20, 2024
d4a9bf0
removing whitespace
jeisenman23 Dec 20, 2024
fc88d7a
fixing line issue
jeisenman23 Dec 27, 2024
878e4f0
fixing osti bug
jeisenman23 Jan 2, 2025
b496352
adding spaces for engineer query
jeisenman23 Jan 2, 2025
991d3e3
adding spaces for chat function
jeisenman23 Jan 2, 2025
af86612
remove trailing whitespaces
jeisenman23 Jan 2, 2025
ae37420
Merge branch 'main' into time
jeisenman23 Jan 2, 2025
3232c02
fixing OSTI bug
jeisenman23 Jan 6, 2025
5334535
removing comments for flake
jeisenman23 Jan 6, 2025
8f4cdcc
making line shorter
jeisenman23 Jan 6, 2025
8d15faa
line too long
jeisenman23 Jan 6, 2025
8348af3
adding blank line
jeisenman23 Jan 6, 2025
667fe05
rerun of actions
jeisenman23 Jan 6, 2025
80be899
changing first
jeisenman23 Jan 6, 2025
564f5d4
attempting to fix osti
jeisenman23 Jan 6, 2025
bc196f1
attempt to fix OSTI in multiple envs
jeisenman23 Jan 6, 2025
022f734
removing test and fixing test
jeisenman23 Jan 6, 2025
371dcc2
inputting local change that works
jeisenman23 Jan 6, 2025
6549134
fixing lint
jeisenman23 Jan 6, 2025
ab7449c
debug statement
jeisenman23 Jan 6, 2025
96c3b87
attempt to fix escape sequence
jeisenman23 Jan 6, 2025
24b41c9
attempting to fix str
jeisenman23 Jan 6, 2025
7f55760
fixing escape
jeisenman23 Jan 6, 2025
4eac5dc
getting get pages to work
jeisenman23 Jan 7, 2025
2ed61b9
clean code
jeisenman23 Jan 7, 2025
7819a13
fixing linter
jeisenman23 Jan 7, 2025
80c92b6
fixing linter
jeisenman23 Jan 7, 2025
0e6b9f9
fixing linter
jeisenman23 Jan 7, 2025
18a4465
fixing over indent
jeisenman23 Jan 7, 2025
ac3e3c5
cleaned up docstrings and removed unnecessary debug kwarg from wizard…
grantbuster Jan 9, 2025
28680a0
Merge pull request #44 from NREL/gb/nodebugkw
jeisenman23 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ examples/research_hub/pdfs/
examples/research_hub/embed/
examples/research_hub/txt/
examples/research_hub/meta.csv

*ignore*.py

# pixi environments
Expand All @@ -130,4 +131,4 @@ examples/research_hub/meta.csv
pixi*

# Scratch
*scratch*/
*scratch*/
75 changes: 39 additions & 36 deletions elm/web/osti.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
Utilities for retrieving data from OSTI.
"""
import re
import copy
import requests
import json
Expand All @@ -28,7 +29,9 @@ def __init__(self, record):

@staticmethod
def strip_nested_brackets(text):
"""Remove text between brackets/parentheses for cleaning OSTI text"""
"""
Remove text between brackets/parentheses for cleaning OSTI text
"""
ret = ''
skip1c = 0
skip2c = 0
Expand Down Expand Up @@ -183,54 +186,54 @@ def __init__(self, url, n_pages=1):
super().__init__(records)

def _get_first(self):
"""Get the first page of OSTI records

Returns
-------
list
"""
"""Get the first page of OSTI records"""
self._response = self._session.get(self.url)

if not self._response.ok:
msg = ('OSTI API Request got error {}: "{}"'
.format(self._response.status_code,
self._response.reason))
msg = f'''OSTI API Request got error {self._response.status_code}:
"{self._response.reason}"'''
raise RuntimeError(msg)
first_page = self._response.json()

try:
raw_text = self._response.text
if raw_text.endswith('}\r\n]'):
raw_text = raw_text[:-1]
first_page = json.loads(raw_text)
except (json.JSONDecodeError, UnicodeError) as e:
logger.error(f"JSON decode error: {str(e)}\nRaw text: {raw_text[:500]}...")
raise

self._n_pages = 1
if 'last' in self._response.links:
url = self._response.links['last']['url']
self._n_pages = int(url.split('page=')[-1])

logger.debug('Found approximately {} records.'
.format(self._n_pages * len(first_page)))

logger.debug(f'Found approximately {self._n_pages * len(first_page)} records.')
return first_page

def _get_pages(self, n_pages):
"""Get response pages up to n_pages from OSTI.

Parameters
----------
n_pages : int
Number of pages to retrieve

Returns
-------
next_pages : list
This function will return a generator of next pages, each of which
is a list of OSTI records
"""
if n_pages > 1:
for page in range(2, self._n_pages + 1):
if page <= n_pages:
next_page = self._session.get(self.url,
params={'page': page})
next_page = next_page.json()
yield next_page
else:
break
"""Get response pages up to n_pages from OSTI.

Parameters
----------
n_pages : int
Number of pages to retrieve

Returns
-------
next_pages : list
This function will return a generator of next pages, each of which
is a list of OSTI records
"""
if n_pages > 1:
for page in range(2, self._n_pages + 1):
if page <= n_pages:
next_page = self._session.get(self.url,
params={'page': page})
next_page = next_page.json()
yield next_page
else:
break

def _get_all(self, n_pages):
"""Get all pages of records up to n_pages.
Expand Down
55 changes: 38 additions & 17 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
ELM energy wizard
"""
from abc import ABC, abstractmethod
from time import perf_counter
import copy
import os
import json
Expand Down Expand Up @@ -61,8 +62,12 @@ def query_vector_db(self, query, limit=100):
ranked strings/scores outputs.
"""

def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
convo=False):
def engineer_query(self,
query,
token_budget=None,
new_info_threshold=0.7,
convo=False
):
"""Engineer a query for GPT using the corpus of information

Parameters
Expand All @@ -79,6 +84,7 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
Flag to perform semantic search with full conversation history
(True) or just the single query (False). Call EnergyWizard.clear()
to reset the chat history.

Returns
-------
message : str
Expand All @@ -87,6 +93,11 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
references : list
The list of references (strs) used in the engineered prompt is
returned here
vector_query_time : float
grantbuster marked this conversation as resolved.
Show resolved Hide resolved
measures vector database query time
used_index : list
Shows the indices of the documents used in making a query to the
vector database
"""

self.messages.append({"role": "user", "content": query})
Expand All @@ -99,9 +110,10 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
query = '\n\n'.join(query)

token_budget = token_budget or self.token_budget

start_time = perf_counter()
strings, _, idx = self.query_vector_db(query)

end_time = perf_counter()
vector_query_time = end_time - start_time
message = copy.deepcopy(self.MODEL_INSTRUCTION)
question = f"\n\nQuestion: {query}"
used_index = []
Expand All @@ -125,8 +137,7 @@ def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
message = message + question
used_index = np.array(used_index)
references = self.make_ref_list(used_index)

return message, references, used_index
return message, references, used_index, vector_query_time

@abstractmethod
def make_ref_list(self, idx):
Expand All @@ -144,15 +155,17 @@ def make_ref_list(self, idx):
["{ref_title} ({ref_url})"]
"""

def chat(self, query,
def chat(self,
query,
debug=True,
stream=True,
temperature=0,
convo=False,
token_budget=None,
new_info_threshold=0.7,
print_references=False,
return_chat_obj=False):
return_chat_obj=False
):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.

Expand Down Expand Up @@ -195,12 +208,15 @@ def chat(self, query,
references : list
If debug is True, the list of references (strs) used in the
engineered prompt is returned here
performance : dict
dictionary with keys of total_chat_time,
chat_completion_time and vectordb_query_time.
"""

start_chat_time = perf_counter()
out = self.engineer_query(query, token_budget=token_budget,
new_info_threshold=new_info_threshold,
convo=convo)
query, references, _ = out
query, references, _, vector_query_time = out

messages = [{"role": "system", "content": self.MODEL_ROLE},
{"role": "user", "content": query}]
Expand All @@ -209,20 +225,20 @@ def chat(self, query,
messages=messages,
temperature=temperature,
stream=stream)
start_completion_time = perf_counter()

response = self._client.chat.completions.create(**kwargs)

if return_chat_obj:
return response, query, references

if stream:
for chunk in response:
chunk_msg = chunk.choices[0].delta.content or ""
response_message += chunk_msg
print(chunk_msg, end='')

else:
response_message = response.choices[0].message.content
finish_completion_time = perf_counter()
chat_completion_time = finish_completion_time - start_completion_time

self.messages.append({'role': 'assistant',
'content': response_message})
Expand All @@ -234,11 +250,16 @@ def chat(self, query,
response_message += ref_msg
if stream:
print(ref_msg)

end_time = perf_counter()
total_chat_time = end_time - start_chat_time
performance = {
"total_chat_time": total_chat_time,
"chat_completion_time": chat_completion_time,
grantbuster marked this conversation as resolved.
Show resolved Hide resolved
"vectordb_query_time": vector_query_time
}
if debug:
return response_message, query, references
else:
return response_message
return response_message, query, references, performance
return response_message, query, performance


class EnergyWizard(EnergyWizardBase):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_chunk_and_embed(mocker):

Note that embedding api is mocked here and not actually tested.
"""

corpus = make_corpus(mocker)
wizard = EnergyWizard(pd.DataFrame(corpus), token_budget=1000,
ref_col='ref')
Expand All @@ -81,12 +80,12 @@ def test_chunk_and_embed(mocker):
question = 'What time is it?'
out = wizard.chat(question, debug=True, stream=False,
print_references=True)
msg, query, ref = out

assert msg.startswith('hello!')
response_message, query, references, performance = out
assert response_message.startswith('hello!')
assert query.startswith(EnergyWizard.MODEL_INSTRUCTION)
assert query.endswith(question)
assert 'source0' in ref
assert 'source0' in references
assert isinstance(performance, dict)


def test_convo_query(mocker):
Expand Down
Loading