Skip to content

Commit

Permalink
Feature: configuration based rest proxy support (#1073)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartin-tech committed Jan 16, 2025
2 parents c2ef0c0 + 9ac2b71 commit 7ed485c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/source/garak.generators.rest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Uses the following options from ``_config.plugins.generators["rest.RestGenerator
* ``req_template_json_object`` - (optional) the request template as a Python object, to be serialised as a JSON string before replacements
* ``method`` - a string describing the HTTP method, to be passed to the requests module; default "post".
* ``headers`` - dict describing HTTP headers to be sent with the request
* ``proxies`` - dict passed to ``requests`` method call. See `required format<https://requests.readthedocs.io/en/latest/user/advanced/#proxies">`_.
* ``response_json`` - Is the response in JSON format? (bool)
* ``response_json_field`` - (optional) Which field of the response JSON should be used as the output string? Default ``text``. Can also be a JSONPath value, and ``response_json_field`` is used as such if it starts with ``$``.
* ``request_timeout`` - How many seconds should we wait before timing out? Default 20
Expand Down
13 changes: 12 additions & 1 deletion garak/generators/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jsonpath_ng.exceptions import JsonPathParserError

from garak import _config
from garak.exception import APIKeyMissingError, RateLimitHit
from garak.exception import APIKeyMissingError, BadGeneratorException, RateLimitHit
from garak.generators.base import Generator


Expand All @@ -35,6 +35,7 @@ class RestGenerator(Generator):
"response_json_field": None,
"req_template": "$INPUT",
"request_timeout": 20,
"proxies": None,
}

ENV_VAR = "REST_API_KEY"
Expand All @@ -59,6 +60,7 @@ class RestGenerator(Generator):
"skip_codes",
"temperature",
"top_k",
"proxies",
)

def __init__(self, uri=None, config_root=_config):
Expand Down Expand Up @@ -118,6 +120,14 @@ def __init__(self, uri=None, config_root=_config):
self.method = "post"
self.http_function = getattr(requests, self.method)

# validate proxies formatting
# sanity check only leave actual parsing of values to the `requests` library on call.
if hasattr(self, "proxies") and self.proxies is not None:
if not isinstance(self.proxies, dict):
raise BadGeneratorException(
"`proxies` value provided is not in the required format. See documentation from the `requests` package for details on expected format. https://requests.readthedocs.io/en/latest/user/advanced/#proxies"
)

# validate jsonpath
if self.response_json and self.response_json_field:
try:
Expand Down Expand Up @@ -193,6 +203,7 @@ def _call_model(
data_kw: request_data,
"headers": request_headers,
"timeout": self.request_timeout,
"proxies": self.proxies,
}
resp = self.http_function(self.uri, **req_kArgs)

Expand Down
50 changes: 48 additions & 2 deletions tests/generators/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import pytest
import requests_mock
from sympy import is_increasing

from garak import _config, _plugins

Expand Down Expand Up @@ -122,3 +120,51 @@ def test_rest_skip_code(requests_mock):
)
output = generator._call_model("Who is Enabran Tain's son?")
assert output == [None]


@pytest.mark.usefixtures("set_rest_config")
def test_rest_valid_proxy(mocker, requests_mock):
test_proxies = {
"http": "http://localhost:8080",
"https": "https://localhost:8443",
}
_config.plugins.generators["rest"]["RestGenerator"]["proxies"] = test_proxies
generator = _plugins.load_plugin(
"generators.rest.RestGenerator", config_root=_config
)
requests_mock.post(
DEFAULT_URI,
text=json.dumps(
{
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": DEFAULT_TEXT_RESPONSE,
},
}
]
}
),
)
mock_http_function = mocker.patch.object(
generator, "http_function", wraps=generator.http_function
)
generator._call_model("Who is Enabran Tain's son?")
mock_http_function.assert_called_once()
assert mock_http_function.call_args_list[0].kwargs["proxies"] == test_proxies


@pytest.mark.usefixtures("set_rest_config")
def test_rest_invalid_proxy(requests_mock):
from garak.exception import GarakException

test_proxies = [
"http://localhost:8080",
"https://localhost:8443",
]
_config.plugins.generators["rest"]["RestGenerator"]["proxies"] = test_proxies
with pytest.raises(GarakException) as exc_info:
_plugins.load_plugin("generators.rest.RestGenerator", config_root=_config)
assert "not in the required format" in str(exc_info.value)

0 comments on commit 7ed485c

Please sign in to comment.