Skip to content

Commit

Permalink
SD3 support.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 12, 2024
1 parent 7504509 commit 975ed7a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui_tensorrt"
description = "TensorRT Node for ComfyUI\nThis node enables the best performance on NVIDIA RTX™ Graphics Cards (GPUs) for Stable Diffusion by leveraging NVIDIA TensorRT."
version = "0.1.0"
version = "0.1.1"
license = "LICENSE"
dependencies = [
"tensorrt>=10.0.1",
Expand Down
19 changes: 13 additions & 6 deletions tensorrt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,17 @@ def _convert(
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
unet = model.model.diffusion_model

if "context_dim" in model.model.model_config.unet_config:
context_dim = model.model.model_config.unet_config.get("context_dim", None)
context_len = 77
context_len_min = context_len

if context_dim is None: #SD3
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
if context_embedder_config is not None:
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77

if context_dim is not None:
input_names = ["x", "timesteps", "context"]
output_names = ["h"]

Expand All @@ -170,7 +180,6 @@ def _convert(
}

transformer_options = model.model_options['transformer_options'].copy()
context_len = 77
if model.model.model_config.unet_config.get(
"use_temporal_resblock", False
): # SVD
Expand All @@ -194,7 +203,7 @@ def forward(self, x, timesteps, context, y):
svd_unet.unet = unet
svd_unet.transformer_options = transformer_options
unet = svd_unet
context_len = 1
context_len_min = context_len = 1
else:
class UNET(torch.nn.Module):
def forward(self, x, timesteps, context, y=None):
Expand All @@ -212,12 +221,10 @@ def forward(self, x, timesteps, context, y=None):

input_channels = model.model.model_config.unet_config.get("in_channels")

context_dim = model.model.model_config.unet_config.get("context_dim")

inputs_shapes_min = (
(batch_size_min, input_channels, height_min // 8, width_min // 8),
(batch_size_min,),
(batch_size_min, context_len * context_min, context_dim),
(batch_size_min, context_len_min * context_min, context_dim),
)
inputs_shapes_opt = (
(batch_size_opt, input_channels, height_opt // 8, width_opt // 8),
Expand Down
8 changes: 6 additions & 2 deletions tensorrt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class TensorRTLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd"], ),
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3"], ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
Expand Down Expand Up @@ -141,7 +141,11 @@ def load_unet(self, unet_name, model_type):
elif model_type == "svd":
conf = comfy.supported_models.SVD_img2vid({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
model = conf.get_model({})
elif model_type == "sd3":
conf = comfy.supported_models.SD3({})
conf.unet_config["disable_unet_model_creation"] = True
model = conf.get_model({})
model.diffusion_model = unet
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting

Expand Down

0 comments on commit 975ed7a

Please sign in to comment.