diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 0beae6630eea..fa879cde15e7 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -28,6 +28,27 @@ EmbeddingsType = TypeVar('EmbeddingsType', bound='list[Any]') +# Find CUDA libraries from the official packages +cuda_found = False +if platform.system() in ('Linux', 'Windows'): + try: + from nvidia import cuda_runtime, cublas + except ImportError: + pass # CUDA is optional + else: + if platform.system() == 'Linux': + cudalib = 'lib/libcudart.so.12' + cublaslib = 'lib/libcublas.so.12' + else: # Windows + cudalib = r'bin\cudart64_12.dll' + cublaslib = r'bin\cublas64_12.dll' + + # preload the CUDA libs so the backend can find them + ctypes.CDLL(os.path.join(cuda_runtime.__path__[0], cudalib), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(os.path.join(cublas.__path__[0], cublaslib), mode=ctypes.RTLD_GLOBAL) + cuda_found = True + + # TODO: provide a config file to make this more robust MODEL_LIB_PATH = importlib_resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build" @@ -218,7 +239,16 @@ def __init__(self, model_path: str, n_ctx: int, ngl: int, backend: str): model = llmodel.llmodel_model_create2(self.model_path, backend.encode(), ctypes.byref(err)) if model is None: s = err.value - raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}") + errmsg = 'null' if s is None else s.decode() + + if ( + backend == 'cuda' + and not cuda_found + and errmsg.startswith('Could not find any implementations for backend') + ): + print('WARNING: CUDA runtime libraries not found. Try `pip install "gpt4all[cuda]"`\n', file=sys.stderr) + + raise RuntimeError(f"Unable to instantiate model: {errmsg}") self.model: ctypes.c_void_p | None = model def __del__(self, llmodel=llmodel): diff --git a/gpt4all-bindings/python/setup.py b/gpt4all-bindings/python/setup.py index 19e9d4bf9e9f..3aab696d152c 100644 --- a/gpt4all-bindings/python/setup.py +++ b/gpt4all-bindings/python/setup.py @@ -93,7 +93,15 @@ def get_long_description(): 'typing-extensions>=4.3.0; python_version >= "3.9" and python_version < "3.11"', ], extras_require={ + 'cuda': [ + 'nvidia-cuda-runtime-cu12', + 'nvidia-cublas-cu12', + ], + 'all': [ + 'gpt4all[cuda]; platform_system == "Windows" or platform_system == "Linux"', + ], 'dev': [ + 'gpt4all[all]', 'pytest', 'twine', 'wheel',