diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index c6b37e3fa..fc0d08d84 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -3,9 +3,8 @@ from __future__ import annotations import logging -from typing import Union +from typing import TYPE_CHECKING, Union -import arviz as az import numpy as np import pymc import pytensor.tensor as pt @@ -16,6 +15,9 @@ from ..result import McmcPtResult from .sampler import Sampler, SamplerImportError +if TYPE_CHECKING: + import arviz as az + logger = logging.getLogger(__name__) # implementation based on: