Skip to content

Commit

Permalink
[Server] Proactively use the user code context manager
Browse files Browse the repository at this point in the history
This is important to ensure that we handle imports properly in user
code (e.g., inside user-defined functions).
  • Loading branch information
geoffxy committed Jan 9, 2020
1 parent 2f5ccc0 commit 7f75467
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 35 deletions.
58 changes: 42 additions & 16 deletions cli/skyline/analysis/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import skyline.protocol_gen.innpv_pb2 as pm
from skyline.exceptions import AnalysisError, exceptions_as_analysis_errors
from skyline.exceptions import AnalysisError
from skyline.profiler.iteration import IterationProfiler
from skyline.tracking.memory import track_memory_usage
from skyline.tracking.report import MiscSizeType
Expand All @@ -24,12 +24,14 @@ class AnalysisSession:
def __init__(
self,
project_root,
path_to_entry_point_dir,
model_provider,
input_provider,
iteration_provider,
batch_size
):
self._project_root = project_root
self._path_to_entry_point_dir = path_to_entry_point_dir
self._model_provider = model_provider
self._input_provider = input_provider
self._iteration_provider = iteration_provider
Expand All @@ -38,9 +40,17 @@ def __init__(

@classmethod
def new_from(cls, project_root, entry_point):
path_to_entry_point = os.path.join(project_root, entry_point)
# Note: This is not necessarily the same as project_root because the
# entry_point could be in a subdirectory.
path_to_entry_point_dir = os.path.dirname(path_to_entry_point)

# 1. Run the entry point file to "load" the model
try:
scope = _run_entry_point(project_root, entry_point)
scope = _run_entry_point(
path_to_entry_point,
path_to_entry_point_dir,
)
except SyntaxError as ex:
raise AnalysisError(
"Syntax error on line {} column {} in {}.".format(
Expand Down Expand Up @@ -77,10 +87,15 @@ def new_from(cls, project_root, entry_point):
iteration_provider = scope[ITERATION_PROVIDER_NAME]

batch_size = _validate_providers(
model_provider, input_provider, iteration_provider)
model_provider,
input_provider,
iteration_provider,
path_to_entry_point_dir,
)

return cls(
project_root,
path_to_entry_point_dir,
model_provider,
input_provider,
iteration_provider,
Expand All @@ -92,6 +107,7 @@ def measure_memory_usage(self, nvml):
self._model_provider,
self._input_provider,
self._iteration_provider,
self._path_to_entry_point_dir,
)

memory_usage = pm.MemoryUsageResponse()
Expand Down Expand Up @@ -125,6 +141,7 @@ def measure_throughput(self):
self._model_provider,
self._input_provider,
self._iteration_provider,
self._path_to_entry_point_dir,
)
num_samples = 3
samples = profiler.sample_run_time_ms_by_batch_size(
Expand Down Expand Up @@ -171,15 +188,11 @@ def measure_throughput(self):
return throughput


def _run_entry_point(project_root, entry_point):
file_name = os.path.join(project_root, entry_point)
# Note: This is not necessarily the same as project_root because the
# entry_point could be in a subdirectory.
path_to_entry_point = os.path.dirname(file_name)
with open(file_name) as file:
def _run_entry_point(path_to_entry_point, path_to_entry_point_dir):
with open(path_to_entry_point) as file:
code_str = file.read()
code = compile(code_str, file_name, mode="exec")
with user_code_environment(path_to_entry_point):
code = compile(code_str, path_to_entry_point, mode="exec")
with user_code_environment(path_to_entry_point_dir):
scope = {}
exec(code, scope, scope)
return scope
Expand All @@ -195,7 +208,12 @@ def _set_file_context(message, project_root, entry):
relative_file_path.split(os.sep))


def _validate_providers(model_provider, input_provider, iteration_provider):
def _validate_providers(
model_provider,
input_provider,
iteration_provider,
path_to_entry_point_dir,
):
model_sig = inspect.signature(model_provider)
if len(model_sig.parameters) != 0:
raise AnalysisError(
Expand All @@ -221,18 +239,26 @@ def _validate_providers(model_provider, input_provider, iteration_provider):
)

err = _validate_provider_return_values(
model_provider, input_provider, iteration_provider)
model_provider,
input_provider,
iteration_provider,
path_to_entry_point_dir,
)
if err is not None:
raise err

return batch_size


def _validate_provider_return_values(
model_provider, input_provider, iteration_provider):
with exceptions_as_analysis_errors():
model_provider,
input_provider,
iteration_provider,
path_to_entry_point_dir,
):
with user_code_environment(path_to_entry_point_dir):
# We return exceptions instead of raising them here to prevent
# them from being caught by the exception context manager.
# them from being caught by the code environment context manager.
model = model_provider()
if not callable(model):
return AnalysisError(
Expand Down
25 changes: 17 additions & 8 deletions cli/skyline/profiler/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch

from skyline.exceptions import AnalysisError, exceptions_as_analysis_errors
from skyline.exceptions import AnalysisError
from skyline.user_code_utils import user_code_environment

logger = logging.getLogger(__name__)

Expand All @@ -12,17 +13,25 @@


class IterationProfiler:
def __init__(self, iteration, input_provider):
def __init__(self, iteration, input_provider, path_to_entry_point_dir):
self._iteration = iteration
self._input_provider = input_provider
self._path_to_entry_point_dir = path_to_entry_point_dir
self._start_event = torch.cuda.Event(enable_timing=True)
self._end_event = torch.cuda.Event(enable_timing=True)

@classmethod
def new_from(cls, model_provider, input_provider, iteration_provider):
model = model_provider()
iteration = iteration_provider(model)
return cls(iteration, input_provider)
def new_from(
cls,
model_provider,
input_provider,
iteration_provider,
path_to_entry_point_dir,
):
with user_code_environment(path_to_entry_point_dir):
model = model_provider()
iteration = iteration_provider(model)
return cls(iteration, input_provider, path_to_entry_point_dir)

def measure_run_time_ms(self, batch_size, initial_repetitions=None):
"""
Expand All @@ -31,15 +40,15 @@ def measure_run_time_ms(self, batch_size, initial_repetitions=None):
NOTE: This method will raise a RuntimeError if there is not enough GPU
memory to run the iteration.
"""
with exceptions_as_analysis_errors():
with user_code_environment(self._path_to_entry_point_dir):
inputs = self._input_provider(batch_size=batch_size)
# Warm up
self._iteration(*inputs)

torch.cuda.synchronize()

def measure(iterations):
with exceptions_as_analysis_errors():
with user_code_environment(self._path_to_entry_point_dir):
self._start_event.record()
for _ in range(iterations):
self._iteration(*inputs)
Expand Down
13 changes: 8 additions & 5 deletions cli/skyline/tracking/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import torch

from skyline.exceptions import exceptions_as_analysis_errors
from skyline.tracking.base import TrackerBase
from skyline.tracking.call_stack import CallStack
from skyline.tracking.hook_manager import HookManager
from skyline.user_code_utils import user_code_environment

OperationContext = collections.namedtuple(
'OperationContext',
Expand All @@ -24,11 +24,12 @@ class ActivationsTracker:
def __init__(self):
self._activations = []

def track_memory_usage(self, model, input_provider):
def track_memory_usage(self, model, input_provider, user_code_path):
# 1. Run the forward pass of the model with the given inputs. We keep
# track of all the operations that contribute to the autograd graph.
model_output, grad_function_contexts = \
self._get_grad_function_contexts(model, input_provider)
self._get_grad_function_contexts(
model, input_provider, user_code_path)

# 2. Traverse the autograd graph and get a topological ordering. Filter
# the function contexts by the gradient functions in our topological
Expand Down Expand Up @@ -70,9 +71,11 @@ def populate_report(self, report_builder):
stack_context=entry.stack,
)

def _get_grad_function_contexts(self, model, input_provider):
def _get_grad_function_contexts(
self, model, input_provider, user_code_path):
grad_function_tracker = GradFunctionTracker()
with grad_function_tracker.track(), exceptions_as_analysis_errors():
with grad_function_tracker.track(), \
user_code_environment(user_code_path):
out = model(*input_provider())
return out, grad_function_tracker.grad_function_contexts

Expand Down
18 changes: 12 additions & 6 deletions cli/skyline/tracking/memory.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import torch

from skyline.exceptions import exceptions_as_analysis_errors
from skyline.tracking.activations import ActivationsTracker
from skyline.tracking.report import TrackerReportBuilder, MiscSizeType
from skyline.tracking.weights import WeightsTracker
from skyline.user_code_utils import user_code_environment


def track_memory_usage(
model_provider, input_provider, iteration_provider, report_file=None):
model_provider,
input_provider,
iteration_provider,
user_code_path,
report_file=None,
):
_ensure_cuda_initialization()

# Track and record memory usage associated with model creation
weight_tracker = WeightsTracker()
with weight_tracker.track(), exceptions_as_analysis_errors():
with weight_tracker.track(), user_code_environment(user_code_path):
model = model_provider()

with exceptions_as_analysis_errors():
with user_code_environment(user_code_path):
iteration = iteration_provider(model)
# Run one iteration to initialize the gradients
iteration(*input_provider())

# Track and record memory usage associated with stored activations
activations_tracker = ActivationsTracker()
activations_tracker.track_memory_usage(model, input_provider)
activations_tracker.track_memory_usage(
model, input_provider, user_code_path)

# Record peak memory usage
torch.cuda.reset_max_memory_allocated()
with exceptions_as_analysis_errors():
with user_code_environment(user_code_path):
iteration(*input_provider())
peak_usage_bytes = torch.cuda.max_memory_allocated()

Expand Down

0 comments on commit 7f75467

Please sign in to comment.