Skip to content

Commit

Permalink
LIT:Fix type hints for model initializers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678805235
  • Loading branch information
llcourage authored and LIT team committed Sep 25, 2024
1 parent 3dc61a2 commit 0e354cb
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lit_nlp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
from collections.abc import Callable, Iterable, Mapping, Sequence
import functools
import inspect
import math
import os
import random
Expand Down Expand Up @@ -520,7 +521,10 @@ def _create_model(
raise_for_unsupported=True,
)

return_type = get_type_hints(model_initializer)['return']
return_type = dict[str, Any]

if inspect.isfunction(model_initializer):
return_type = get_type_hints(model_initializer)['return']

if Mapping in return_type.__mro__:
model_initializer = cast(MultipleModelLoader, model_initializer)
Expand Down

0 comments on commit 0e354cb

Please sign in to comment.