Skip to content

Commit

Permalink
refactor attempt to utilized property annotations (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartin-tech committed Dec 2, 2024
2 parents ff3d58c + ca1a665 commit 3ee53a8
Showing 1 changed file with 78 additions and 84 deletions.
162 changes: 78 additions & 84 deletions garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,95 +105,89 @@ def as_dict(self) -> dict:
"messages": self.messages,
}

def __getattribute__(self, name: str) -> Any:
"""override prompt and outputs access to take from history"""
if name == "prompt":
if len(self.messages) == 0: # nothing set
return None
if isinstance(self.messages[0], dict): # only initial prompt set
return self.messages[0]["content"]
if isinstance(
self.messages, list
): # there's initial prompt plus some history
return self.messages[0][0]["content"]
else:
raise ValueError(
"Message history of attempt uuid %s in unexpected state, sorry: "
% str(self.uuid)
+ repr(self.messages)
)
@property
def prompt(self):
if len(self.messages) == 0: # nothing set
return None
if isinstance(self.messages[0], dict): # only initial prompt set
return self.messages[0]["content"]
if isinstance(self.messages, list): # there's initial prompt plus some history
return self.messages[0][0]["content"]
else:
raise ValueError(
"Message history of attempt uuid %s in unexpected state, sorry: "
% str(self.uuid)
+ repr(self.messages)
)

elif name == "outputs":
if len(self.messages) and isinstance(self.messages[0], list):
# work out last_output_turn that was assistant
assistant_turns = [
@property
def outputs(self):
if len(self.messages) and isinstance(self.messages[0], list):
# work out last_output_turn that was assistant
assistant_turns = [
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "assistant"
]
if assistant_turns == []:
return []
last_output_turn = max(assistant_turns)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return []

@property
def latest_prompts(self):
if len(self.messages[0]) > 1:
# work out last_output_turn that was user
last_output_turn = max(
[
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "assistant"
if val["role"] == "user"
]
if assistant_turns == []:
return []
last_output_turn = max(assistant_turns)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return []

elif name == "latest_prompts":
if len(self.messages[0]) > 1:
# work out last_output_turn that was user
last_output_turn = max(
[
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "user"
]
)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return (
self.prompt
) # returning a string instead of a list tips us off that generation count is not yet known

elif name == "all_outputs":
all_outputs = []
if len(self.messages) and not isinstance(self.messages[0], dict):
for thread in self.messages:
for turn in thread:
if turn["role"] == "assistant":
all_outputs.append(turn["content"])
return all_outputs

else:
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""override prompt and outputs access to take from history NB. output elements need to be able to be None"""

if name == "prompt":
if value is None:
raise TypeError("'None' prompts are not valid")
self._add_first_turn("user", value)

elif name == "outputs":
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
elif len(self.messages) == 1 and isinstance(self.messages[0], dict):
self._expand_prompt_to_histories(len(value))
# append each list item to each history, with role:assistant
self._add_turn("assistant", value)

elif name == "latest_prompts":
assert isinstance(value, list)
self._add_turn("user", value)

)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return super().__setattr__(name, value)
return (
self.prompt
) # returning a string instead of a list tips us off that generation count is not yet known

@property
def all_outputs(self):
all_outputs = []
if len(self.messages) and not isinstance(self.messages[0], dict):
for thread in self.messages:
for turn in thread:
if turn["role"] == "assistant":
all_outputs.append(turn["content"])
return all_outputs

@prompt.setter
def prompt(self, value):
if value is None:
raise TypeError("'None' prompts are not valid")
self._add_first_turn("user", value)

@outputs.setter
def outputs(self, value):
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
elif len(self.messages) == 1 and isinstance(self.messages[0], dict):
self._expand_prompt_to_histories(len(value))
# append each list item to each history, with role:assistant
self._add_turn("assistant", value)

@latest_prompts.setter
def latest_prompts(self, value):
assert isinstance(value, list)
self._add_turn("user", value)

def _expand_prompt_to_histories(self, breadth):
"""expand a prompt-only message history to many threads"""
Expand Down

0 comments on commit 3ee53a8

Please sign in to comment.