Skip to content

Commit

Permalink
Add a pure python wrapper to pybindings.portable_lib (#3137) (#3218)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3137

When installed as a pip wheel, we must import `torch` before trying to import the pybindings shared library extension. This will load libtorch.so and related libs, ensuring that the pybindings lib can resolve those runtime dependencies.

So, add a pure python wrapper that lets us do this when users say `import executorch.extension.pybindings.portable_lib`

We only need this for OSS, so don't bother doing this for other pybindings targets.

Reviewed By: orionr, mikekgfb

Differential Revision: D56317150

fbshipit-source-id: 920382636732aa276c25a76163afb7d28b1846d0
(cherry picked from commit 969aa96)

Co-authored-by: Dave Bort <dbort@meta.com>
  • Loading branch information
pytorchbot and dbort authored Apr 22, 2024
1 parent 773da4d commit 67d0dd7
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 10 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,11 @@ if(EXECUTORCH_BUILD_PYBIND)

# pybind portable_lib
pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp)
# The actual output file needs a leading underscore so it can coexist with
# portable_lib.py in the same python package.
set_target_properties(portable_lib PROPERTIES OUTPUT_NAME "_portable_lib")
target_compile_definitions(portable_lib
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib)
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=_portable_lib)
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
target_link_libraries(
Expand Down
16 changes: 12 additions & 4 deletions extension/pybindings/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ runtime.genrule(
srcs = [":pybinding_types"],
outs = {
"aten_lib.pyi": ["aten_lib.pyi"],
"portable_lib.pyi": ["portable_lib.pyi"],
"_portable_lib.pyi": ["_portable_lib.pyi"],
},
cmd = "cp $(location :pybinding_types)/* $OUT/portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
cmd = "cp $(location :pybinding_types)/* $OUT/_portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
visibility = ["//executorch/extension/pybindings/..."],
)

Expand All @@ -46,8 +46,9 @@ executorch_pybindings(
executorch_pybindings(
compiler_flags = ["-std=c++17"],
cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB,
python_module_name = "portable_lib",
types = ["//executorch/extension/pybindings:pybindings_types_gen[portable_lib.pyi]"],
# Give this an underscore prefix because it has a pure python wrapper.
python_module_name = "_portable_lib",
types = ["//executorch/extension/pybindings:pybindings_types_gen[_portable_lib.pyi]"],
visibility = ["PUBLIC"],
)

Expand All @@ -58,3 +59,10 @@ executorch_pybindings(
types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_lib.pyi]"],
visibility = ["PUBLIC"],
)

runtime.python_library(
name = "portable_lib",
srcs = ["portable_lib.py"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [":_portable_lib"],
)
34 changes: 34 additions & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

# When installed as a pip wheel, we must import `torch` before trying to import
# the pybindings shared library extension. This will load libtorch.so and
# related libs, ensuring that the pybindings lib can resolve those runtime
# dependencies.
import torch as _torch

# Let users import everything from the C++ _portable_lib extension as if this
# python file defined them. Although we could import these dynamically, it
# wouldn't preserve the static type annotations.
from executorch.extension.pybindings._portable_lib import ( # noqa: F401
# Disable "imported but unused" (F401) checks.
_create_profile_block, # noqa: F401
_dump_profile_results, # noqa: F401
_get_operator_names, # noqa: F401
_load_bundled_program_from_buffer, # noqa: F401
_load_for_executorch, # noqa: F401
_load_for_executorch_from_buffer, # noqa: F401
_load_for_executorch_from_bundled_program, # noqa: F401
_reset_profile_results, # noqa: F401
BundledModule, # noqa: F401
ExecuTorchModule, # noqa: F401
)

# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`
# (apart from some __dunder__ names).
del _torch
12 changes: 8 additions & 4 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
# pyre-strict
from typing import Any, Dict, List, Sequence, Tuple

class ExecutorchModule:
class ExecuTorchModule:
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
def __call__(self, inputs: Any) -> List[Any]: ...
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
# Bundled program methods.
def load_bundled_input(
Expand All @@ -30,16 +33,17 @@ class BundledModule: ...

def _load_for_executorch(
path: str, enable_etdump: bool = False
) -> ExecutorchModule: ...
) -> ExecuTorchModule: ...
def _load_for_executorch_from_buffer(
buffer: bytes, enable_etdump: bool = False
) -> ExecutorchModule: ...
) -> ExecuTorchModule: ...
def _load_for_executorch_from_bundled_program(
module: BundledModule, enable_etdump: bool = False
) -> ExecutorchModule: ...
) -> ExecuTorchModule: ...
def _load_bundled_program_from_buffer(
buffer: bytes, non_const_pool_size: int = ...
) -> BundledModule: ...
def _get_operator_names() -> List[str]: ...
def _create_profile_block(name: str) -> None: ...
def _dump_profile_results() -> bytes: ...
def _reset_profile_results() -> None: ...
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def get_ext_modules() -> list[Extension]:
# portable kernels, and a selection of backends. This lets users
# load and execute .pte files from python.
BuiltExtension(
"portable_lib.*", "executorch.extension.pybindings.portable_lib"
"_portable_lib.*", "executorch.extension.pybindings._portable_lib"
)
)

Expand Down

0 comments on commit 67d0dd7

Please sign in to comment.