diff --git a/lit_nlp/examples/simple_pytorch_demo.py b/lit_nlp/examples/simple_pytorch_demo.py index 2a925895..cde642db 100644 --- a/lit_nlp/examples/simple_pytorch_demo.py +++ b/lit_nlp/examples/simple_pytorch_demo.py @@ -79,7 +79,7 @@ def _from_pretrained(cls, *args, **kw): return cls.from_pretrained(*args, from_tf=True, **kw) -class SimpleSentimentModel(lit_model.Model): +class SimpleSentimentModel(lit_model.BatchedModel): """Simple sentiment analysis model.""" LABELS = ["0", "1"] # negative, positive @@ -103,7 +103,7 @@ def __init__(self, model_name_or_path): ## # LIT API implementation def max_minibatch_size(self): - # This tells lit_model.Model.predict() how to batch inputs to + # This tells lit_model.BatchedModel.predict() how to batch inputs to # predict_minibatch(). # Alternately, you can just override predict() and handle batching yourself. return 32 diff --git a/lit_nlp/examples/sst_pytorch_demo.py b/lit_nlp/examples/sst_pytorch_demo.py index fd341b37..dede8a61 100644 --- a/lit_nlp/examples/sst_pytorch_demo.py +++ b/lit_nlp/examples/sst_pytorch_demo.py @@ -70,7 +70,7 @@ def _from_pretrained(cls, *args, **kw): return cls.from_pretrained(*args, from_tf=True, **kw) -class SimpleSentimentModel(lit_model.Model): +class SimpleSentimentModel(lit_model.BatchedModel): """Simple sentiment analysis model.""" LABELS = ["0", "1"] # negative, positive @@ -95,7 +95,7 @@ def __init__(self, model_name_or_path): ## # LIT API implementation def max_minibatch_size(self): - # This tells lit_model.Model.predict() how to batch inputs to + # This tells lit_model.BatchedModel.predict() how to batch inputs to # predict_minibatch(). # Alternately, you can just override predict() and handle batching yourself. return 32