Skip to content

Commit

Permalink
test: add tests for output mapping functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Pouyanpi committed Feb 3, 2025
1 parent aa04453 commit 4bcd6c2
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 1 deletion.
106 changes: 106 additions & 0 deletions tests/test_actions_output_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from nemoguardrails.actions import action
from nemoguardrails.actions.output_mapping import (
default_output_mapping,
should_block_output,
)

# Tests for default_output_mapping


def test_default_output_mapping_boolean_true():
# For booleans, the mapping returns the negation (block if result is False).
# If result is True, not True == False, so output is not blocked.
assert default_output_mapping(True) is False


def test_default_output_mapping_boolean_false():
# If result is False, then not False == True, so it is blocked.
assert default_output_mapping(False) is True


def test_default_output_mapping_numeric_below_threshold():
# For numeric values, block if the value is less than 0.5.
assert default_output_mapping(0.4) is True


def test_default_output_mapping_numeric_above_threshold():
# For numeric values greater than or equal to 0.5, do not block.
assert default_output_mapping(0.5) is False
assert default_output_mapping(0.6) is False


def test_default_output_mapping_non_numeric_non_boolean():
# For other types (e.g., strings), default mapping returns False (allowed).
assert default_output_mapping("anything") is False


# Tests for should_block_output


# Create a dummy action function with an attached mapping in its metadata.
def dummy_action_output_mapping(val):
# For testing, block if the value equals "block", otherwise do not block.
return val == "block"


@action(output_mapping=dummy_action_output_mapping)
def dummy_action(result):
return result


def test_should_block_output_with_tuple_result_and_mapping():
# Test should_block_output when the result is a tuple and the dummy mapping is used.
# When the first element equals "block", we expect True.
result = ("block",)
assert should_block_output(result, dummy_action) is True

# When the result is not "block", we expect False.
result = ("allow",)
assert should_block_output(result, dummy_action) is False


def test_should_block_output_with_non_tuple_result_and_mapping():
# Test should_block_output when the result is not a tuple.
# The function should wrap it into a tuple.
result = "block"
assert should_block_output(result, dummy_action) is True

result = "allow"
assert should_block_output(result, dummy_action) is False


def test_should_block_output_without_action_meta():
# Test should_block_output when the action function does not have an "action_meta" attribute.
# In this case, default_output_mapping should be used.
def action_without_meta(res):
return res

# Ensure there is no action_meta attribute.
if hasattr(action_without_meta, "action_meta"):
del action_without_meta.action_meta

# Test with a boolean: default_output_mapping for True is False and for False is True.
assert should_block_output(True, action_without_meta) is False
assert should_block_output(False, action_without_meta) is True

# Test with a numeric value: block if < 0.5.
assert should_block_output(0.4, action_without_meta) is True
assert should_block_output(0.6, action_without_meta) is False
78 changes: 77 additions & 1 deletion tests/test_llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Any, Dict, List, Optional, Union

import pytest

from nemoguardrails import LLMRails, RailsConfig
from nemoguardrails.rails.llm.llmrails import _get_action_details_from_flow_id
from tests.utils import FakeLLM, clean_events, event_sequence_conforms


Expand Down Expand Up @@ -621,3 +622,78 @@ async def compute(what: Optional[str] = "2 + 3"):
"role": "assistant",
"content": "The answer is 5\nAre you happy with the result?",
}


# get_action_details_from_flow_id used in llmrails.py


@pytest.fixture
def dummy_flows() -> List[Union[Dict, Any]]:
return [
{
"id": "test_flow",
"elements": [
{
"_type": "run_action",
"_source_mapping": {
"filename": "flows.v1.co",
"line_text": "execute something",
},
"action_name": "test_action",
"action_params": {"param1": "value1"},
}
],
},
# Additional flow that should match on a prefix
{
"id": "other_flow is prefix",
"elements": [
{
"_type": "run_action",
"_source_mapping": {
"filename": "flows.v1.co",
"line_text": "execute something else",
},
"action_name": "other_action",
"action_params": {"param2": "value2"},
}
],
},
]


def test_get_action_details_exact_match(dummy_flows):
action_name, action_params = _get_action_details_from_flow_id(
"test_flow", dummy_flows
)
assert action_name == "test_action"
assert action_params == {"param1": "value1"}


def test_get_action_details_prefix_match(dummy_flows):
# For a flow_id that starts with the prefix "other_flow",
# we expect to retrieve the action details from the flow whose id starts with that prefix.
# we expect a result since we are passing the prefixes argument.
action_name, action_params = _get_action_details_from_flow_id(
"other_flow", dummy_flows, prefixes=["other_flow"]
)
assert action_name == "other_action"
assert action_params == {"param2": "value2"}


def test_get_action_details_prefix_match_unsupported_prefix(dummy_flows):
# For a flow_id that starts with the prefix "other_flow",
# we expect to retrieve the action details from the flow whose id starts with that prefix.
# but as the prefix is not supported, we expect a ValueError.

with pytest.raises(ValueError) as exc_info:
_get_action_details_from_flow_id("other_flow", dummy_flows)

assert "No action found for flow_id" in str(exc_info.value)


def test_get_action_details_no_match(dummy_flows):
# Tests that a non matching flow_id raises a ValueError
with pytest.raises(ValueError) as exc_info:
_get_action_details_from_flow_id("non_existing_flow", dummy_flows)
assert "No action found for flow_id" in str(exc_info.value)

0 comments on commit 4bcd6c2

Please sign in to comment.