diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index a4a2176f14..3f3c8d545e 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -43,4 +43,4 @@ MACRO_KEY, load_bundle_config, ) -from .workflows import BundleWorkflow, ConfigWorkflow +from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index df69b021e1..b55c62174b 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str """ return self._resolve_one_item(id=id, **kwargs) + def remove_resolved_content(self, id: str) -> Any | None: + """ + Remove the resolved ``ConfigItem`` by id. + + Args: + id: id name of the expected item. + + """ + return self.resolved_content.pop(id) if id in self.resolved_content else None + @classmethod def normalize_id(cls, id: str | int) -> str: """ diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 3ecd5dfbc5..75cf7b0b09 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -44,12 +44,18 @@ class BundleWorkflow(ABC): workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. - default to `train` for train workflow. + default to `None` for only using meta properties. workflow: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. - properties_path: the path to the JSON file of properties. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "meta". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. @@ -97,29 +103,50 @@ def __init__( meta_file = None workflow_type = workflow if workflow is not None else workflow_type - if workflow_type is None and properties_path is None: - self.properties = copy(MetaProperties) - self.workflow_type = None - self.meta_file = meta_file - return + if workflow_type is not None: + if workflow_type.lower() in self.supported_train_type: + workflow_type = "train" + elif workflow_type.lower() in self.supported_infer_type: + workflow_type = "infer" + else: + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + if properties_path is not None: properties_path = Path(properties_path) if not properties_path.is_file(): raise ValueError(f"Property file {properties_path} does not exist.") with open(properties_path) as json_file: - self.properties = json.load(json_file) - self.workflow_type = None - self.meta_file = meta_file - return - if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr] - self.properties = {**TrainProperties, **MetaProperties} - self.workflow_type = "train" - elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr] - self.properties = {**InferProperties, **MetaProperties} - self.workflow_type = "infer" + try: + properties = json.load(json_file) + self.properties: dict = {} + if workflow_type is not None and workflow_type in properties: + self.properties = properties[workflow_type] + if "meta" in properties: + self.properties.update(properties["meta"]) + elif workflow_type is None: + if "meta" in properties: + self.properties = properties["meta"] + logger.info( + "No workflow type specified, default to load meta properties from property file." + ) + else: + logger.warning("No 'meta' key found in properties while workflow_type is None.") + except KeyError as e: + raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e + except json.JSONDecodeError as e: + raise ValueError(f"Error decoding JSON from property file {properties_path}") from e else: - raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + if workflow_type == "train": + self.properties = {**TrainProperties, **MetaProperties} + elif workflow_type == "infer": + self.properties = {**InferProperties, **MetaProperties} + elif workflow_type is None: + self.properties = copy(MetaProperties) + logger.info("No workflow type and property file specified, default to 'meta' properties.") + else: + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + self.workflow_type = workflow_type self.meta_file = meta_file @abstractmethod @@ -226,6 +253,124 @@ def check_properties(self) -> list[str] | None: return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)] +class PythonicWorkflow(BundleWorkflow): + """ + Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow. + It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc. + This also provides the interface to get / set public properties to interact with a bundle workflow through + defined `get_` accessor methods or directly defining members of the object. + For how to set the properties, users can define the `_set_` methods or directly set the members of the object. + The `initialize` method is called to set up the workflow before running. This method sets up internal state + and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized` + is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is + properly set up with the new property values. + + Args: + workflow_type: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for only using meta properties. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "meta". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. + config_file: path to the config file, typically used to store hyperparameters. + meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. + + """ + + supported_train_type: tuple = ("train", "training") + supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation") + + def __init__( + self, + workflow_type: str | None = None, + properties_path: PathLike | None = None, + config_file: str | Sequence[str] | None = None, + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, + **override: Any, + ): + meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file + super().__init__( + workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file + ) + self._props_vals: dict = {} + self._set_props_vals: dict = {} + self.parser = ConfigParser() + if config_file is not None: + self.parser.read_config(f=config_file) + if self.meta_file is not None: + self.parser.read_meta(f=self.meta_file) + + # the rest key-values in the _args are to override config content + self.parser.update(pairs=override) + self._is_initialized: bool = False + + def initialize(self, *args: Any, **kwargs: Any) -> Any: + """ + Initialize the bundle workflow before running. + """ + self._props_vals = {} + self._is_initialized = True + + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the expected property value. + If the property is already generated, return from the bucket directly. + If user explicitly set the property, return it directly. + Otherwise, generate the expected property as a class private property with prefix "_". + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + """ + if not self._is_initialized: + raise RuntimeError("Please execute 'initialize' before getting any properties.") + value = None + if name in self._set_props_vals: + value = self._set_props_vals[name] + elif name in self._props_vals: + value = self._props_vals[name] + elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index] + id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None) + value = self.parser[id] + else: + try: + value = getattr(self, f"get_{name}")() + except AttributeError as e: + if property[BundleProperty.REQUIRED]: + raise ValueError( + f"unsupported property '{name}' is required in the bundle properties," + f"need to implement a method 'get_{name}' to provide the property." + ) from e + self._props_vals[name] = value + return value + + def _set_property(self, name: str, property: dict, value: Any) -> Any: + """ + With specified property name and information, set value for the expected property. + Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + self._set_props_vals[name] = value + self._is_initialized = False + + class ConfigWorkflow(BundleWorkflow): """ Specification for the config-based bundle workflow. @@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. - properties_path: the path to the JSON file of properties. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "train". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` @@ -324,7 +475,6 @@ def __init__( self.parser.read_config(f=config_file) if self.meta_file is not None: self.parser.read_meta(f=self.meta_file) - # the rest key-values in the _args are to override config content self.parser.update(pairs=override) self.init_id = init_id @@ -394,8 +544,23 @@ def check_properties(self) -> list[str] | None: ret.extend(wrong_props) return ret - def _run_expr(self, id: str, **kwargs: dict) -> Any: - return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None + def _run_expr(self, id: str, **kwargs: dict) -> list[Any]: + """ + Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored, + allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process. + """ + ret = [] + if id in self.parser: + # suppose all the expressions are in a list, run and reset the expressions + if isinstance(self.parser[id], list): + for i in range(len(self.parser[id])): + sub_id = f"{id}{ID_SEP_KEY}{i}" + ret.append(self.parser.get_parsed_content(sub_id, **kwargs)) + self.parser.ref_resolver.remove_resolved_content(sub_id) + else: + ret.append(self.parser.get_parsed_content(id, **kwargs)) + self.parser.ref_resolver.remove_resolved_content(id) + return ret def _get_prop_id(self, name: str, property: dict) -> Any: prop_id = property[BundlePropertyConfig.ID] diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index ac96b077bd..86e1b1d3ae 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -154,10 +154,12 @@ def __init__( ) self.input_size = input_size - def forward(self, x): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + attn_mask (torch.Tensor, optional): mask to apply to the attention matrix. + B x (s_dim_1 * ... * s_dim_n). Defaults to None. Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C @@ -176,7 +178,13 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + query=q, + key=k, + value=v, + attn_mask=attn_mask, + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -186,10 +194,16 @@ def forward(self, x): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: + if attn_mask is not None: + raise ValueError("Causal attention does not support attention masks.") att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) + attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1) + att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf")) + att_mat = att_mat.softmax(dim=-1) if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 05eb3b07ab..6f0da73e7b 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -90,8 +90,10 @@ def __init__( use_flash_attention=use_flash_attention, ) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 832135ad06..32b817d584 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -13,7 +13,6 @@ import itertools from collections.abc import Sequence -from typing import Final import numpy as np import torch @@ -51,8 +50,6 @@ class SwinUNETR(nn.Module): " """ - patch_size: Final[int] = 2 - @deprecated_arg( name="img_size", since="1.3", @@ -65,18 +62,24 @@ def __init__( img_size: Sequence[int] | int, in_channels: int, out_channels: int, + patch_size: int = 2, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (3, 6, 12, 24), + window_size: Sequence[int] | int = 7, + qkv_bias: bool = True, + mlp_ratio: float = 4.0, feature_size: int = 24, norm_name: tuple | str = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, dropout_path_rate: float = 0.0, normalize: bool = True, + norm_layer: type[LayerNorm] = nn.LayerNorm, + patch_norm: bool = True, use_checkpoint: bool = False, spatial_dims: int = 3, - downsample="merging", - use_v2=False, + downsample: str | nn.Module = "merging", + use_v2: bool = False, ) -> None: """ Args: @@ -86,14 +89,20 @@ def __init__( It will be removed in an upcoming version. in_channels: dimension of input channels. out_channels: dimension of output channels. + patch_size: size of the patch token. feature_size: dimension of network feature size. depths: number of layers in each stage. num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + mlp_ratio: ratio of mlp hidden dim to embedding dim. norm_name: feature normalization type and arguments. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. dropout_path_rate: drop path rate. normalize: normalize output intermediate features in each stage. + norm_layer: normalization layer. + patch_norm: whether to apply normalization to the patch embedding. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: number of spatial dims. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a @@ -116,13 +125,15 @@ def __init__( super().__init__() - img_size = ensure_tuple_rep(img_size, spatial_dims) - patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) - window_size = ensure_tuple_rep(7, spatial_dims) - if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") + self.patch_size = patch_size + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) + window_size = ensure_tuple_rep(window_size, spatial_dims) + self._check_input_size(img_size) if not (0 <= drop_rate <= 1): @@ -146,12 +157,13 @@ def __init__( patch_size=patch_sizes, depths=depths, num_heads=num_heads, - mlp_ratio=4.0, - qkv_bias=True, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dropout_path_rate, - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, + patch_norm=patch_norm, use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, diff --git a/monai/utils/module.py b/monai/utils/module.py index 1ad001fc87..d3f2ff09f2 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s current_ver_string: if None, the current system GPU CUDA compute capability will be used. Returns: - True if the current system GPU CUDA compute capability is greater than the specified version. + True if the current system GPU CUDA compute capability is greater than or equal to the specified version. """ if current_ver_string is None: cuda_available = torch.cuda.is_available() @@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: - return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore + return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) while len(parts) < 2: parts += ["0"] c_major, c_minor = parts[:2] c_mn = int(c_major), int(c_minor) mn = int(major), int(minor) - return c_mn >= mn + return c_mn > mn diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index b2c44c12c6..fcfc5b2951 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -13,7 +13,7 @@ import torch -from monai.bundle import BundleWorkflow +from monai.bundle import BundleWorkflow, PythonicWorkflow from monai.data import DataLoader, Dataset from monai.engines import SupervisedEvaluator from monai.inferers import SlidingWindowInferer @@ -26,8 +26,9 @@ LoadImaged, SaveImaged, ScaleIntensityd, + ScaleIntensityRanged, ) -from monai.utils import BundleProperty, set_determinism +from monai.utils import BundleProperty, CommonKeys, set_determinism class NonConfigWorkflow(BundleWorkflow): @@ -176,3 +177,62 @@ def _set_property(self, name, property, value): self._numpy_version = value elif property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + +class PythonicWorkflowImpl(PythonicWorkflow): + """ + Test class simulates the bundle workflow defined by Python script directly. + """ + + def __init__( + self, + workflow_type: str = "inference", + config_file: str | None = None, + properties_path: str | None = None, + meta_file: str | None = None, + ): + super().__init__( + workflow_type=workflow_type, properties_path=properties_path, config_file=config_file, meta_file=meta_file + ) + self.dataflow: dict = {} + + def initialize(self): + self._props_vals = {} + self._is_initialized = True + self.net = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(16, 32, 64, 128), + strides=(2, 2, 2), + num_res_units=2, + ).to(self.device) + preprocessing = Compose( + [ + EnsureChannelFirstd(keys=["image"]), + ScaleIntensityd(keys="image"), + ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), + ] + ) + self.dataset = Dataset(data=[self.dataflow], transform=preprocessing) + self.postprocessing = Compose([Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True)]) + + def run(self): + data = self.dataset[0] + inputs = data[CommonKeys.IMAGE].unsqueeze(0).to(self.device) + self.net.eval() + with torch.no_grad(): + data[CommonKeys.PRED] = self.inferer(inputs, self.net) + self.dataflow.update({CommonKeys.PRED: self.postprocessing(data)[CommonKeys.PRED]}) + + def finalize(self): + pass + + def get_bundle_root(self): + return "." + + def get_device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def get_inferer(self): + return SlidingWindowInferer(roi_size=self.parser.roi_size, sw_batch_size=1, overlap=0) diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 835c8e5c1d..27e1ee97a8 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -53,7 +53,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestTRTExport(unittest.TestCase): def setUp(self): diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 1727fcdf53..893b9dc991 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -13,6 +13,7 @@ import os import shutil +import sys import tempfile import unittest from copy import deepcopy @@ -22,12 +23,12 @@ import torch from parameterized import parameterized -from monai.bundle import ConfigWorkflow +from monai.bundle import ConfigWorkflow, create_workflow from monai.data import Dataset from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.networks.nets import UNet -from monai.transforms import Compose, LoadImage -from tests.nonconfig_workflow import NonConfigWorkflow +from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged +from tests.nonconfig_workflow import NonConfigWorkflow, PythonicWorkflowImpl TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] @@ -35,6 +36,8 @@ TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] +TEST_CASE_4 = [os.path.join(os.path.dirname(__file__), "testing_data", "responsive_inference.json")] + TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."] @@ -45,7 +48,9 @@ def setUp(self): self.expected_shape = (128, 128, 128) test_image = np.random.rand(*self.expected_shape) self.filename = os.path.join(self.data_dir, "image.nii") + self.filename1 = os.path.join(self.data_dir, "image1.nii") nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1) def tearDown(self): shutil.rmtree(self.data_dir) @@ -108,12 +113,42 @@ def test_inference_config(self, config_file): # test property path inferer = ConfigWorkflow( config_file=config_file, + workflow_type="infer", properties_path=os.path.join(os.path.dirname(__file__), "testing_data", "fl_infer_properties.json"), logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), **override, ) self._test_inferer(inferer) - self.assertEqual(inferer.workflow_type, None) + self.assertEqual(inferer.workflow_type, "infer") + + @parameterized.expand([TEST_CASE_4]) + def test_responsive_inference_config(self, config_file): + input_loader = LoadImaged(keys="image") + output_saver = SaveImaged(keys="pred", output_dir=self.data_dir, output_postfix="seg") + + # test standard MONAI model-zoo config workflow + inferer = ConfigWorkflow( + workflow_type="infer", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + ) + # FIXME: temp add the property for test, we should add it to some formal realtime infer properties + inferer.add_property(name="dataflow", required=True, config_id="dataflow") + + inferer.initialize() + inferer.dataflow.update(input_loader({"image": self.filename})) + inferer.run() + output_saver(inferer.dataflow) + self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image", "image_seg.nii.gz"))) + + # bundle is instantiated and idle, just change the input for next inference + inferer.dataflow.clear() + inferer.dataflow.update(input_loader({"image": self.filename1})) + inferer.run() + output_saver(inferer.dataflow) + self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image1", "image1_seg.nii.gz"))) + + inferer.finalize() @parameterized.expand([TEST_CASE_3]) def test_train_config(self, config_file): @@ -164,6 +199,72 @@ def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_erro with self.assertRaisesRegex(FileNotFoundError, expected_error): NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file) + def test_pythonic_workflow(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config_file = {"roi_size": (64, 64, 32)} + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + property_path = os.path.join(os.path.dirname(__file__), "testing_data", "python_workflow_properties.json") + workflow = PythonicWorkflowImpl( + workflow_type="infer", config_file=config_file, meta_file=meta_file, properties_path=property_path + ) + workflow.initialize() + # Load input data + input_loader = LoadImaged(keys="image") + workflow.dataflow.update(input_loader({"image": self.filename})) + self.assertEqual(workflow.bundle_root, ".") + self.assertEqual(workflow.device, device) + self.assertEqual(workflow.version, "0.1.0") + # check config override correctly + self.assertEqual(workflow.inferer.roi_size, (64, 64, 32)) + workflow.run() + # update input data and run again + workflow.dataflow.update(input_loader({"image": self.filename1})) + workflow.run() + pred = workflow.dataflow["pred"] + self.assertEqual(pred.shape[2:], self.expected_shape) + self.assertEqual(pred.meta["filename_or_obj"], self.filename1) + workflow.finalize() + + def test_create_pythonic_workflow(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config_file = {"roi_size": (64, 64, 32)} + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + property_path = os.path.join(os.path.dirname(__file__), "testing_data", "python_workflow_properties.json") + sys.path.append(os.path.dirname(__file__)) + workflow = create_workflow( + "nonconfig_workflow.PythonicWorkflowImpl", + workflow_type="infer", + config_file=config_file, + meta_file=meta_file, + properties_path=property_path, + ) + # Load input data + input_loader = LoadImaged(keys="image") + workflow.dataflow.update(input_loader({"image": self.filename})) + self.assertEqual(workflow.bundle_root, ".") + self.assertEqual(workflow.device, device) + self.assertEqual(workflow.version, "0.1.0") + # check config override correctly + self.assertEqual(workflow.inferer.roi_size, (64, 64, 32)) + + # check set property override correctly + workflow.inferer = SlidingWindowInferer(roi_size=config_file["roi_size"], sw_batch_size=1, overlap=0.5) + workflow.initialize() + self.assertEqual(workflow.inferer.overlap, 0.5) + + workflow.run() + # update input data and run again + workflow.dataflow.update(input_loader({"image": self.filename1})) + workflow.run() + pred = workflow.dataflow["pred"] + self.assertEqual(pred.shape[2:], self.expected_shape) + self.assertEqual(pred.meta["filename_or_obj"], self.filename1) + + # test add properties + workflow.add_property(name="net", required=True, desc="network for the training.") + self.assertIn("net", workflow.properties) + workflow.finalize() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 712d887c3b..a7b1edec3c 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -38,7 +38,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestConvertToTRT(unittest.TestCase): def setUp(self): diff --git a/tests/test_module_list.py b/tests/test_module_list.py index d21ba53b7c..833441cbca 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -58,13 +58,17 @@ def test_transform_api(self): continue with self.subTest(n=n): basename = n[:-1] # Transformd basename is Transform + + # remove aliases to check, do this before the assert below so that a failed assert does skip this + for postfix in ("D", "d", "Dict"): + remained.remove(f"{basename}{postfix}") + for docname in (f"{basename}", f"{basename}d"): if docname in to_exclude_docs: continue if (contents is not None) and f"`{docname}`" not in f"{contents}": self.assertTrue(False, f"please add `{docname}` to docs/source/transforms.rst") - for postfix in ("D", "d", "Dict"): - remained.remove(f"{basename}{postfix}") + self.assertFalse(remained) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 88919fd8b1..338f1bf840 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -122,6 +122,24 @@ def test_causal(self): # check upper triangular part of the attention matrix is zero assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + def test_masked_selfattention(self): + n = 64 + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True) + input_shape = (1, n, 128) + # generate a mask randomly with zeros and ones of shape (1, n) + mask = torch.randint(0, 2, (1, n)).bool() + block(torch.randn(input_shape), attn_mask=mask) + att_mat = block.att_mat.squeeze() + # ensure all masked columns are zeros + assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)])) + + def test_causal_and_mask(self): + with self.assertRaises(ValueError): + block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64) + inputs = torch.randn(2, 64, 128) + mask = torch.randint(0, 2, (2, 64)).bool() + block(inputs, attn_mask=mask) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index e1323c201f..f7779fec9b 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -50,7 +50,7 @@ def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: f @skip_if_quick @unittest.skipUnless(trt_imported, "tensorrt is required") @unittest.skipUnless(polygraphy_imported, "polygraphy is required") -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestTRTCompile(unittest.TestCase): def setUp(self): diff --git a/tests/test_version_after.py b/tests/test_version_after.py index 34a5054974..b6cb741382 100644 --- a/tests/test_version_after.py +++ b/tests/test_version_after.py @@ -38,7 +38,7 @@ TEST_CASES_SM = [ # (major, minor, sm, expected) - (6, 1, "6.1", False), + (6, 1, "6.1", True), (6, 1, "6.0", False), (6, 0, "8.6", True), (7, 0, "8", True), diff --git a/tests/testing_data/fl_infer_properties.json b/tests/testing_data/fl_infer_properties.json index 72e97cd2c6..6b40edd2ab 100644 --- a/tests/testing_data/fl_infer_properties.json +++ b/tests/testing_data/fl_infer_properties.json @@ -1,67 +1,76 @@ { - "bundle_root": { - "description": "root path of the bundle.", - "required": true, - "id": "bundle_root" + "infer": { + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "dataset_dir": { + "description": "directory path of the dataset.", + "required": true, + "id": "dataset_dir" + }, + "dataset": { + "description": "PyTorch dataset object for the inference / evaluation logic.", + "required": true, + "id": "dataset" + }, + "evaluator": { + "description": "inference / evaluation workflow engine.", + "required": true, + "id": "evaluator" + }, + "network_def": { + "description": "network module for the inference.", + "required": true, + "id": "network_def" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + }, + "dataset_data": { + "description": "data source for the inference / evaluation dataset.", + "required": false, + "id": "dataset::data", + "refer_id": null + }, + "handlers": { + "description": "event-handlers for the inference / evaluation logic.", + "required": false, + "id": "handlers", + "refer_id": "evaluator::val_handlers" + }, + "preprocessing": { + "description": "preprocessing for the input data.", + "required": false, + "id": "preprocessing", + "refer_id": "dataset::transform" + }, + "postprocessing": { + "description": "postprocessing for the model output data.", + "required": false, + "id": "postprocessing", + "refer_id": "evaluator::postprocessing" + }, + "key_metric": { + "description": "the key metric during evaluation.", + "required": false, + "id": "key_metric", + "refer_id": "evaluator::key_val_metric" + } }, - "device": { - "description": "target device to execute the bundle workflow.", - "required": true, - "id": "device" - }, - "dataset_dir": { - "description": "directory path of the dataset.", - "required": true, - "id": "dataset_dir" - }, - "dataset": { - "description": "PyTorch dataset object for the inference / evaluation logic.", - "required": true, - "id": "dataset" - }, - "evaluator": { - "description": "inference / evaluation workflow engine.", - "required": true, - "id": "evaluator" - }, - "network_def": { - "description": "network module for the inference.", - "required": true, - "id": "network_def" - }, - "inferer": { - "description": "MONAI Inferer object to execute the model computation in inference.", - "required": true, - "id": "inferer" - }, - "dataset_data": { - "description": "data source for the inference / evaluation dataset.", - "required": false, - "id": "dataset::data", - "refer_id": null - }, - "handlers": { - "description": "event-handlers for the inference / evaluation logic.", - "required": false, - "id": "handlers", - "refer_id": "evaluator::val_handlers" - }, - "preprocessing": { - "description": "preprocessing for the input data.", - "required": false, - "id": "preprocessing", - "refer_id": "dataset::transform" - }, - "postprocessing": { - "description": "postprocessing for the model output data.", - "required": false, - "id": "postprocessing", - "refer_id": "evaluator::postprocessing" - }, - "key_metric": { - "description": "the key metric during evaluation.", - "required": false, - "id": "key_metric", - "refer_id": "evaluator::key_val_metric" + "meta": { + "version": { + "description": "version of the inference configuration.", + "required": true, + "id": "_meta_::version" + } } } diff --git a/tests/testing_data/python_workflow_properties.json b/tests/testing_data/python_workflow_properties.json new file mode 100644 index 0000000000..cd4295839a --- /dev/null +++ b/tests/testing_data/python_workflow_properties.json @@ -0,0 +1,26 @@ +{ + "infer": { + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + } + }, + "meta": { + "version": { + "description": "version of the inference configuration.", + "required": true, + "id": "_meta_::version" + } + } +} diff --git a/tests/testing_data/responsive_inference.json b/tests/testing_data/responsive_inference.json new file mode 100644 index 0000000000..16d953d38e --- /dev/null +++ b/tests/testing_data/responsive_inference.json @@ -0,0 +1,101 @@ +{ + "imports": [ + "$from collections import defaultdict" + ], + "bundle_root": "will override", + "device": "$torch.device('cpu')", + "network_def": { + "_target_": "UNet", + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 2, + "channels": [ + 2, + 2, + 4, + 8, + 4 + ], + "strides": [ + 2, + 2, + 2, + 2 + ], + "num_res_units": 2, + "norm": "batch" + }, + "network": "$@network_def.to(@device)", + "dataflow": "$defaultdict()", + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "EnsureChannelFirstd", + "keys": "image" + }, + { + "_target_": "ScaleIntensityd", + "keys": "image" + }, + { + "_target_": "RandRotated", + "_disabled_": true, + "keys": "image" + } + ] + }, + "dataset": { + "_target_": "Dataset", + "data": [ + "@dataflow" + ], + "transform": "@preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 0 + }, + "inferer": { + "_target_": "SlidingWindowInferer", + "roi_size": [ + 64, + 64, + 32 + ], + "sw_batch_size": 4, + "overlap": 0.25 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + }, + { + "_target_": "AsDiscreted", + "keys": "pred", + "argmax": true + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@dataloader", + "network": "@network", + "inferer": "@inferer", + "postprocessing": "@postprocessing", + "amp": false, + "epoch_length": 1 + }, + "run": [ + "$@evaluator.run()", + "$@dataflow.update(@evaluator.state.output[0])" + ] +}