diff --git a/httpx/_client.py b/httpx/_client.py index d95877e8be..2710ad4162 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -164,6 +164,8 @@ class BaseClient: def __init__( self, *, + request_class: type[Request] = Request, + response_class: type[Response] = Response, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -178,6 +180,9 @@ def __init__( ) -> None: event_hooks = {} if event_hooks is None else event_hooks + self._request_class = request_class + self._response_class = response_class + self._base_url = self._enforce_trailing_slash(URL(base_url)) self._auth = self._build_auth(auth) @@ -195,6 +200,14 @@ def __init__( self._default_encoding = default_encoding self._state = ClientState.UNOPENED + @property + def request_class(self) -> type[Request]: + return self._request_class + + @property + def response_class(self) -> type[Response]: + return self._response_class + @property def is_closed(self) -> bool: """ @@ -356,7 +369,7 @@ def build_request( else Timeout(timeout) ) extensions = dict(**extensions, timeout=timeout.as_dict()) - return Request( + return self.request_class( method, url, content=content, @@ -463,7 +476,7 @@ def _build_redirect_request(self, request: Request, response: Response) -> Reque headers = self._redirect_headers(request, url, method) stream = self._redirect_stream(request, method) cookies = Cookies(self.cookies) - return Request( + return self.request_class( method=method, url=url, headers=headers, @@ -629,6 +642,8 @@ class Client(BaseClient): def __init__( self, *, + request_class: type[Request] = Request, + response_class: type[Response] = Response, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -652,6 +667,8 @@ def __init__( default_encoding: str | typing.Callable[[bytes], str] = "utf-8", ) -> None: super().__init__( + request_class=request_class, + response_class=response_class, auth=auth, params=params, headers=headers, @@ -748,6 +765,7 @@ def _init_transport( http2=http2, limits=limits, trust_env=trust_env, + response_class=self.response_class, ) def _init_proxy_transport( @@ -768,6 +786,7 @@ def _init_proxy_transport( limits=limits, trust_env=trust_env, proxy=proxy, + response_class=self.response_class, ) def _transport_for_url(self, url: URL) -> BaseTransport: @@ -1376,6 +1395,8 @@ class AsyncClient(BaseClient): def __init__( self, *, + request_class: type[Request] = Request, + response_class: type[Response] = Response, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -1399,6 +1420,8 @@ def __init__( default_encoding: str | typing.Callable[[bytes], str] = "utf-8", ) -> None: super().__init__( + request_class=request_class, + response_class=response_class, auth=auth, params=params, headers=headers, @@ -1495,6 +1518,7 @@ def _init_transport( http2=http2, limits=limits, trust_env=trust_env, + response_class=self.response_class, ) def _init_proxy_transport( @@ -1515,6 +1539,7 @@ def _init_proxy_transport( limits=limits, trust_env=trust_env, proxy=proxy, + response_class=self.response_class, ) def _transport_for_url(self, url: URL) -> AsyncBaseTransport: diff --git a/httpx/_models.py b/httpx/_models.py index 01d9583bc5..f5eb057e4f 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -446,6 +446,9 @@ def __setstate__(self, state: dict[str, typing.Any]) -> None: self.stream = UnattachedStream() +_ResponseT = typing.TypeVar("_ResponseT", bound="Response") + + class Response: def __init__( self, @@ -725,7 +728,7 @@ def has_redirect_location(self) -> bool: and "Location" in self.headers ) - def raise_for_status(self) -> Response: + def raise_for_status(self: _ResponseT) -> _ResponseT: """ Raise the `HTTPStatusError` if one occurred. """ diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index 33db416dd1..9aa304fd8c 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -135,6 +135,7 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + response_class: type[Response] = Response, ) -> None: ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy @@ -201,6 +202,8 @@ def __init__( f" but got {proxy.url.scheme!r}." ) + self._response_class = response_class + def __enter__(self: T) -> T: # Use generics for subclass support. self._pool.__enter__() return self @@ -237,7 +240,7 @@ def handle_request( assert isinstance(resp.stream, typing.Iterable) - return Response( + return self._response_class( status_code=resp.status, headers=resp.headers, stream=ResponseStream(resp.stream), @@ -276,6 +279,7 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + response_class: type[Response] = Response, ) -> None: ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy @@ -342,6 +346,8 @@ def __init__( " but got {proxy.url.scheme!r}." ) + self._response_class = response_class + async def __aenter__(self: A) -> A: # Use generics for subclass support. await self._pool.__aenter__() return self @@ -378,7 +384,7 @@ async def handle_async_request( assert isinstance(resp.stream, typing.AsyncIterable) - return Response( + return self._response_class( status_code=resp.status, headers=resp.headers, stream=AsyncResponseStream(resp.stream), diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 657839018a..cecee778ad 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -460,3 +460,63 @@ def cp1252_but_no_content_type(request): assert response.reason_phrase == "OK" assert response.encoding == "ISO-8859-1" assert response.text == text + + +def test_client_request_class(): + class Request(httpx.Request): + def __init__(self, *args, **kwargs): + kwargs["content"] = "foobar" + super().__init__(*args, **kwargs) + + class Client(httpx.Client): + request_class = Request + + class AsyncClient(httpx.AsyncClient): + request_class = Request + + request = Client().build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" + + request = AsyncClient().build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" + + with httpx.Client(request_class=Request) as client: + request = client.build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" + + +@pytest.mark.anyio +async def test_client_response_class(server): + class Response(httpx.Response): + def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]: + yield b"foobar" + + class Client(httpx.Client): + response_class = Response + + class AsyncResponse(httpx.Response): + async def aiter_bytes( + self, chunk_size: int | None = None + ) -> typing.AsyncIterator[bytes]: + yield b"foobar" + + class AsyncClient(httpx.AsyncClient): + response_class = AsyncResponse + + with Client() as client: + response = client.get(server.url) + assert isinstance(response, Response) + assert response.read() == b"foobar" + + async with AsyncClient() as async_client: + response = await async_client.get(server.url) + assert isinstance(response, AsyncResponse) + assert await response.aread() == b"foobar" + + with httpx.Client(response_class=Response) as httpx_client: + response = httpx_client.get(server.url) + assert isinstance(response, Response) + assert response.read() == b"foobar"