Skip to content

Commit

Permalink
fix: make sure timeout is properly applied for everything except run/…
Browse files Browse the repository at this point in the history
…stream (#202)
  • Loading branch information
isidentical authored May 2, 2024
1 parent 61161e5 commit 9a04111
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def get(self) -> AnyJSON:
@dataclass(frozen=True)
class AsyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 60.0
default_timeout: float = 120.0

@cached_property
def _client(self) -> httpx.AsyncClient:
Expand All @@ -181,7 +181,8 @@ def _client(self) -> httpx.AsyncClient:
headers={
"Authorization": f"Key {key}",
"User-Agent": USER_AGENT,
}
},
timeout=self.default_timeout,
)

async def run(
Expand All @@ -190,6 +191,7 @@ async def run(
arguments: AnyJSON,
*,
path: str = "",
timeout: float | None = None,
) -> AnyJSON:
"""Run an application with the given arguments (which will be JSON serialized). The path parameter can be used to
specify a subpath when applicable. This method will return the result of the inference call directly.
Expand All @@ -202,7 +204,7 @@ async def run(
response = await self._client.post(
url,
json=arguments,
timeout=self.default_timeout,
timeout=timeout,
)
response.raise_for_status()
return response.json()
Expand All @@ -225,7 +227,6 @@ async def submit(
response = await self._client.post(
url,
json=arguments,
timeout=self.default_timeout,
)
response.raise_for_status()

Expand All @@ -244,6 +245,7 @@ async def stream(
arguments: AnyJSON,
*,
path: str = "/stream",
timeout: float | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""Stream the output of an application with the given arguments (which will be JSON serialized). This is only supported
at a few select applications at the moment, so be sure to first consult with the documentation of individual applications
Expand All @@ -256,7 +258,13 @@ async def stream(
if path:
url += "/" + path.lstrip("/")

async with aconnect_sse(self._client, "POST", url, json=arguments) as events:
async with aconnect_sse(
self._client,
"POST",
url,
json=arguments,
timeout=timeout,
) as events:
async for event in events.aiter_sse():
yield event.json()

Expand Down Expand Up @@ -294,7 +302,7 @@ async def upload_image(self, image: Image.Image, format: str = "jpeg") -> str:
@dataclass(frozen=True)
class SyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 60.0
default_timeout: float = 120.0

@cached_property
def _client(self) -> httpx.Client:
Expand Down Expand Up @@ -371,6 +379,7 @@ def stream(
arguments: AnyJSON,
*,
path: str = "/stream",
timeout: float | None = None,
) -> Iterator[dict[str, Any]]:
"""Stream the output of an application with the given arguments (which will be JSON serialized). This is only supported
at a few select applications at the moment, so be sure to first consult with the documentation of individual applications
Expand All @@ -383,7 +392,9 @@ def stream(
if path:
url += "/" + path.lstrip("/")

with connect_sse(self._client, "POST", url, json=arguments) as events:
with connect_sse(
self._client, "POST", url, json=arguments, timeout=timeout
) as events:
for event in events.iter_sse():
yield event.json()

Expand Down

0 comments on commit 9a04111

Please sign in to comment.