Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

correct directory handling for tasks that are imported as local (non-package) modules #715

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- Requirements: require semver>=3.0.0
- Added `delimiter` option to `csv_dataset()` (defaults to ",")
- Open log files in binary mode when reading headers (fixes ijson deprecation warning).
- Correct directory handling for tasks that are imported as local (non-package) modules.
- Call tools sequentially when they have opted out of parallel calling.
- Bugfix: strip protocol prefix when resolving eval event content
- Bugfix: switch to run directory when running multiple tasks with the same run directory.
Expand Down
12 changes: 12 additions & 0 deletions src/inspect_ai/_eval/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import inspect
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, TypeVar, cast, overload

from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.package import get_installed_package_name
from inspect_ai._util.registry import (
RegistryInfo,
registry_add,
Expand All @@ -16,6 +18,7 @@
from inspect_ai.model import ModelName

from .task import Task
from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR

MODEL_PARAM = "model"

Expand Down Expand Up @@ -139,6 +142,15 @@ def wrapper(*w_args: Any, **w_kwargs: Any) -> Task:
**w_kwargs,
)

# if its not from an installed package then it is a "local"
# module import, so set its task file and run dir
if get_installed_package_name(task_type) is None:
module = inspect.getmodule(task_type)
if module and module.__file__:
file = Path(module.__file__)
setattr(task_instance, TASK_FILE_ATTR, file.as_posix())
setattr(task_instance, TASK_RUN_DIR_ATTR, file.parent.as_posix())

# Return the task instance
return task_instance

Expand Down
71 changes: 71 additions & 0 deletions src/inspect_ai/_util/package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import importlib.util
import inspect
import json
import site
import sys
from functools import lru_cache
from importlib.metadata import Distribution, PackageNotFoundError
from typing import Any


def get_installed_package_name(obj: Any) -> str | None:
# get the module of the object
module = inspect.getmodule(obj)
if module is None:
return None

# find the origin (install path) for the module
module_name = module.__name__
try:
spec = importlib.util.find_spec(module_name)
except (ImportError, AttributeError):
return None
if spec is None or spec.origin is None:
return None

# check if this is a package (either in library or installed editable)
package_name = module_name.split(".")[0]
if package_path_is_in_site_packages(spec.origin):
return package_name
if package_is_installed_editable(package_name):
return package_name
else:
return None


@lru_cache(maxsize=None)
def package_path_is_in_site_packages(path: str) -> bool:
path = path.lower()
return (
any(path.startswith(p.lower()) for p in site.getsitepackages())
or path.startswith(site.getusersitepackages().lower())
or any(
"site-packages" in p.lower() and path.startswith(p.lower())
for p in sys.path
)
)


@lru_cache(maxsize=None)
def package_is_installed_editable(package: str) -> bool:
# get the distribution
try:
distribution = Distribution.from_name(package)
except (ValueError, PackageNotFoundError):
return False

# read the direct_url json
direct_url_json = distribution.read_text("direct_url.json")
if not direct_url_json:
return False

# parse the json
try:
direct_url = json.loads(direct_url_json)
if not isinstance(direct_url, dict):
return False
except json.JSONDecodeError:
return False

# read the editable property
return direct_url.get("dir_info", {}).get("editable", False) is not False
22 changes: 4 additions & 18 deletions src/inspect_ai/_util/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import inspect
from importlib import import_module
from inspect import get_annotations, getmodule, isclass
from inspect import get_annotations, isclass
from typing import Any, Callable, Literal, TypedDict, TypeGuard, cast

from pydantic import BaseModel, Field
from pydantic_core import to_jsonable_python

from inspect_ai._util.package import get_installed_package_name

from .constants import PKG_NAME
from .entrypoints import ensure_entry_points

Expand Down Expand Up @@ -116,7 +117,7 @@ def registry_name(o: object, name: str) -> str:
This function checks whether the passed object is in a package,
and if it is, prepends the package name as a namespace
"""
package = get_package_name(o)
package = get_installed_package_name(o)
return f"{package}/{name}" if package else name


Expand Down Expand Up @@ -366,21 +367,6 @@ def registry_key(type: RegistryType, name: str) -> str:
_registry: dict[str, object] = {}


def get_package_name(o: object) -> str | None:
module = getmodule(o)
package = str(getattr(module, "__package__", ""))
if package:
package = package.split(".")[0]
if package != "None":
package_module = import_module(package)
if package_module:
package_path = getattr(package_module, "__path__", None)
if package_path:
return package

return None


class RegistryDict(TypedDict):
type: RegistryType
name: str
Expand Down
6 changes: 6 additions & 0 deletions tests/test_helpers/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from inspect_ai import Task, task


@task
def empty_task() -> Task:
return Task()
9 changes: 9 additions & 0 deletions tests/test_task_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from test_helpers.tasks import empty_task

from inspect_ai._eval.task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR


def test_local_module_attr():
task = empty_task()
assert getattr(task, TASK_FILE_ATTR, None)
assert getattr(task, TASK_RUN_DIR_ATTR, None)
56 changes: 56 additions & 0 deletions tests/util/test_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os

import httpx
import numpy as np
import pytest
from test_helpers.tools import addition, list_files

import inspect_ai
from inspect_ai._util.package import get_installed_package_name


def test_numpy_package():
assert get_installed_package_name(np.array) == "numpy"
assert get_installed_package_name(np.random.rand) == "numpy"


def test_httpx_package():
assert get_installed_package_name(httpx.get) == "httpx"
assert get_installed_package_name(httpx.Client) == "httpx"


def test_builtin_module():
assert get_installed_package_name(os.path.join) is None
assert get_installed_package_name(list.append) is None


def test_inspect_ai_package():
assert get_installed_package_name(inspect_ai.eval) == "inspect_ai"


def test_local_module():
assert get_installed_package_name(addition) is None
assert get_installed_package_name(list_files) is None


def test_local_function():
def local_func():
pass

assert get_installed_package_name(local_func) is None


def test_local_class():
class LocalClass:
pass

assert get_installed_package_name(LocalClass) is None


def test_none_input():
assert get_installed_package_name(None) is None


@pytest.mark.parametrize("value", [42, "string", [1, 2, 3]])
def test_builtin_types(value):
assert get_installed_package_name(value) is None
Loading