Skip to content

Commit

Permalink
Merge pull request #305 from Carreau/clean-trans
Browse files Browse the repository at this point in the history
Use slightly more recent IPython API.
  • Loading branch information
joerick authored Jul 31, 2024
2 parents 2b48948 + 41d749d commit 9fa0054
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions pyinstrument/magic/magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import threading
import urllib.parse
from ast import parse
from textwrap import dedent

import IPython
from IPython import get_ipython # type: ignore
from IPython.core.magic import Magics, line_cell_magic, magics_class, no_var_expand
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
from IPython.display import IFrame, display

from .. import Profiler
from ._utils import PrePostAstTransformer

_active_profiler = None

Expand All @@ -34,12 +35,31 @@ def _get_active_profiler():
class PyinstrumentMagic(Magics):
def __init__(self, shell):
super().__init__(shell)
# This will leak _get_active_profiler into the users space until we can magle it
self.pre = parse(
"\nfrom pyinstrument.magic.magic import _get_active_profiler; _get_active_profiler().start()\n"
)
self.post = parse("\n_get_active_profiler().stop()")
self._transformer = PrePostAstTransformer(self.pre, self.post)
if IPython.version_info < (8, 15): # type: ignore
from ._utils import PrePostAstTransformer

# This will leak _get_active_profiler into the users space until we can magle it
pre = parse(
"\nfrom pyinstrument.magic.magic import _get_active_profiler; _get_active_profiler().start()\n"
)
post = parse("\n_get_active_profiler().stop()")
self._transformer = PrePostAstTransformer(pre, post)
else:
from IPython.core.magics.ast_mod import ReplaceCodeTransformer # type: ignore

self._transformer = ReplaceCodeTransformer.from_string(
dedent(
"""
from pyinstrument.magic.magic import _get_active_profiler as ___get_prof
___get_prof().start()
try:
__code__
finally:
___get_prof().stop()
__ret__
"""
)
)

@magic_arguments()
@argument(
Expand Down Expand Up @@ -109,6 +129,9 @@ def pyinstrument(self, line, cell=None):
cell_result = ip.run_cell(code)
else:
cell_result = self.run_cell_async(ip, code)
mangled_keys = [k for k in ip.user_ns.keys() if "-" in k]
for k in mangled_keys:
del ip.user_ns[k]
ip.ast_transformers.remove(self._transformer)

if (
Expand Down

0 comments on commit 9fa0054

Please sign in to comment.