diff --git a/sleap/gui/app.py b/sleap/gui/app.py index de6ce9fbf..065563e66 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -49,6 +49,8 @@ import platform import random import re +import traceback +from logging import getLogger from pathlib import Path from typing import Callable, List, Optional, Tuple @@ -85,6 +87,9 @@ from sleap.util import parse_uri_path +logger = getLogger(__name__) + + class MainWindow(QMainWindow): """The SLEAP GUI application. @@ -101,6 +106,7 @@ class MainWindow(QMainWindow): def __init__( self, labels_path: Optional[str] = None, + labels: Optional[Labels] = None, reset: bool = False, no_usage_data: bool = False, *args, @@ -118,7 +124,7 @@ def __init__( self.setAcceptDrops(True) self.state = GuiState() - self.labels = Labels() + self.labels = labels or Labels() self.commands = CommandContext( state=self.state, app=self, update_callback=self.on_data_update @@ -175,8 +181,10 @@ def __init__( print("Restoring GUI state...") self.restoreState(prefs["window state"]) - if labels_path: + if labels_path is not None: self.commands.loadProjectFile(filename=labels_path) + elif labels is not None: + self.commands.loadLabelsObject(labels=labels) else: self.state["project_loaded"] = False @@ -1594,8 +1602,7 @@ def _show_keyboard_shortcuts_window(self): ShortcutDialog().exec_() -def main(args: Optional[list] = None): - """Starts new instance of app.""" +def create_parser(): import argparse @@ -1635,6 +1642,13 @@ def main(args: Optional[list] = None): default=False, ) + return parser + + +def main(args: Optional[list] = None, labels: Optional[Labels] = None): + """Starts new instance of app.""" + + parser = create_parser() args = parser.parse_args(args) if args.nonnative: @@ -1651,12 +1665,23 @@ def main(args: Optional[list] = None): app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png"))) window = MainWindow( - labels_path=args.labels_path, reset=args.reset, no_usage_data=args.no_usage_data + labels_path=args.labels_path, + labels=labels, + reset=args.reset, + no_usage_data=args.no_usage_data, ) window.showMaximized() # Disable GPU in GUI process. This does not affect subprocesses. - sleap.use_cpu_only() + try: + sleap.use_cpu_only() + except RuntimeError: # Visible devices cannot be modified after being initialized + logger.warning( + "Running processes on the GPU. Restarting your GUI should allow switching " + "back to CPU-only mode.\n" + "Received the following error when trying to switch back to CPU-only mode:" + ) + traceback.print_exc() # Print versions. print() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 78a8c2a31..8ac4d87fb 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from enum import Enum from glob import glob from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union import attr import cv2 @@ -260,16 +260,15 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None): """ self.execute(LoadLabelsObject, labels=labels, filename=filename) - def loadProjectFile(self, filename: str): + def loadProjectFile(self, filename: Union[str, Labels]): """Loads given labels file into GUI. Args: - filename: The path to the saved labels dataset. If None, - then don't do anything. + filename: The path to the saved labels dataset or the `Labels` object. + If None, then don't do anything. Returns: None - """ self.execute(LoadProjectFile, filename=filename) @@ -647,9 +646,8 @@ def do_action(context: "CommandContext", params: dict): Returns: None. - """ - filename = params["filename"] + filename = params.get("filename", None) # If called with just a Labels object labels: Labels = params["labels"] context.state["labels"] = labels @@ -669,7 +667,9 @@ def do_action(context: "CommandContext", params: dict): context.state["video"] = labels.videos[0] context.state["project_loaded"] = True - context.state["has_changes"] = params.get("changed_on_load", False) + context.state["has_changes"] = params.get("changed_on_load", False) or ( + filename is None + ) # This is not listed as an edit command since we want a clean changestack context.app.on_data_update([UpdateTopic.project, UpdateTopic.all]) @@ -683,17 +683,16 @@ def ask(context: "CommandContext", params: dict): if len(filename) == 0: return - gui_video_callback = Labels.make_gui_video_callback( - search_paths=[os.path.dirname(filename)], context=params - ) - has_loaded = False labels = None - if type(filename) == Labels: + if isinstance(filename, Labels): labels = filename filename = None has_loaded = True else: + gui_video_callback = Labels.make_gui_video_callback( + search_paths=[os.path.dirname(filename)], context=params + ) try: labels = Labels.load_file(filename, video_search=gui_video_callback) has_loaded = True