Skip to content

Commit

Permalink
Add kwargs to create_gradio_component
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Nov 8, 2024
1 parent 79db0e5 commit 3336731
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/hype/gui/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import gradio as gr


def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
def create_gradio_component(
name: str, field_info: FieldInfo, **kwargs: Any
) -> "gr.Component":
import gradio as gr

label = field_info.alias or name
Expand Down Expand Up @@ -50,6 +52,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
info=field_info.description,
placeholder=f"Enter valid {field_type.__name__}",
**kwargs,
)

# Handle lists/sequences
Expand All @@ -64,12 +67,14 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
value=field_info.default
if field_info.default_factory is None
else None,
**kwargs,
)

# Handle dictionaries
if get_origin(field_type) is dict:
return gr.JSON(
label=label,
**kwargs,
)

# Handle datetime types
Expand All @@ -78,6 +83,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
value=default,
info=field_info.description,
**kwargs,
)

# Handle file paths and URLs
Expand All @@ -88,20 +94,23 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
):
return gr.File(
label=label,
**kwargs,
)

# Handle HTML content
if field_type is str and json_schema_extra.get("format") == "html":
return gr.HTML(
value=default,
label=label,
**kwargs,
)

# Handle markdown content
if field_type is str and json_schema_extra.get("format") == "markdown":
return gr.Markdown(
value=default,
label=label,
**kwargs,
)

# Handle enums - use Dropdown for long enums, Radio for short ones
Expand All @@ -113,12 +122,14 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
value=field_info.default.value if field_info.default else None,
info=field_info.description,
**kwargs,
)
return gr.Radio(
choices=choices,
label=label,
value=field_info.default.value if field_info.default else None,
info=field_info.description,
**kwargs,
)

# Handle number types with constraints
Expand Down Expand Up @@ -163,6 +174,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
if field_info.default_factory is None
else None,
info=field_info.description,
**kwargs,
)

return gr.Number(
Expand All @@ -179,6 +191,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
),
None,
),
**kwargs,
)

# Handle ByteSize
Expand All @@ -187,6 +200,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
info=field_info.description,
placeholder="e.g., 1GB, 500MB, 1024B",
**kwargs,
)

# Handle Decimal with precision
Expand All @@ -195,6 +209,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
precision=getattr(field_info, "decimal_places", None),
info=field_info.description,
**kwargs,
)

# Handle boolean types
Expand All @@ -203,6 +218,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
value=default,
info=field_info.description,
**kwargs,
)

# Handle color inputs
Expand All @@ -211,6 +227,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
value=default,
info=field_info.description,
**kwargs,
)

# Handle date/time inputs
Expand All @@ -221,6 +238,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
value=default,
info=field_info.description,
**kwargs,
)

# Handle Path types
Expand All @@ -237,6 +255,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
return gr.File(
label=label,
file_count="directory",
**kwargs,
)

# Handle specific file types from json schema
Expand Down Expand Up @@ -267,12 +286,14 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
label=label,
file_types=file_types if file_types else None,
file_count=file_count,
**kwargs,
)

# Handle file paths
if field_type is str and json_schema_extra.get("format") == "file-path":
return gr.File(
label=label,
**kwargs,
)

# Fallback to textbox
Expand Down Expand Up @@ -303,6 +324,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
value=default,
max_lines=10,
info=field_info.description,
**kwargs,
)
else:
return gr.Textbox(
Expand All @@ -311,6 +333,7 @@ def create_gradio_component(name: str, field_info: FieldInfo) -> "gr.Component":
value=default,
info=field_info.description,
max_lines=1,
**kwargs,
)


Expand Down

0 comments on commit 3336731

Please sign in to comment.