From f81b4b45bf43f86ed286410a90b705f4417b30d9 Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Wed, 11 Oct 2023 14:12:40 -0400 Subject: [PATCH] python: support Path in GPT4All.__init__ (#1462) --- gpt4all-bindings/python/gpt4all/gpt4all.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index c6d5c9baa13f..6821bcf48e61 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -1,6 +1,8 @@ """ Python only API for running all GPT4All models. """ +from __future__ import annotations + import os import sys import time @@ -60,7 +62,7 @@ class GPT4All: def __init__( self, model_name: str, - model_path: Optional[str] = None, + model_path: Optional[Union[str, os.PathLike[str]]] = None, model_type: Optional[str] = None, allow_download: bool = True, n_threads: Optional[int] = None, @@ -115,7 +117,7 @@ def list_models() -> List[ConfigType]: @staticmethod def retrieve_model( model_name: str, - model_path: Optional[str] = None, + model_path: Optional[Union[str, os.PathLike[str]]] = None, allow_download: bool = True, verbose: bool = True, ) -> ConfigType: @@ -160,7 +162,7 @@ def retrieve_model( ) model_path = DEFAULT_MODEL_DIRECTORY else: - model_path = model_path.replace("\\", "\\\\") + model_path = str(model_path).replace("\\", "\\\\") if not os.path.exists(model_path): raise ValueError(f"Invalid model directory: {model_path}") @@ -185,7 +187,7 @@ def retrieve_model( @staticmethod def download_model( model_filename: str, - model_path: str, + model_path: Union[str, os.PathLike[str]], verbose: bool = True, url: Optional[str] = None, ) -> str: