-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Help PopUp thanks to Ryanontheinside Kjnodes and mtb, add mask fe…
…atures to all audio nodes
- Loading branch information
Showing
16 changed files
with
1,265 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,81 @@ | ||
from .nodes.audio.Audio_Reactive_IPAdapter_YVANN import Audio_Reactive_IPAdapter_YVANN | ||
from .nodes.audio.AudioAnalysis_Advanced_YVANN import AudioAnalysis_Advanced_YVANN | ||
from .nodes.audio.AudioFrequencyAnalysis_YVANN import AudioFrequencyAnalysis_YVANN | ||
from .node_configs import CombinedMeta | ||
from collections import OrderedDict | ||
|
||
# credit to RyanOnTheInside, KJNodes, MTB, Akatz, their works helped me a lot | ||
|
||
|
||
#allows for central management and inheritance of class variables for help documentation | ||
class Yvann(metaclass=CombinedMeta): | ||
@classmethod | ||
def get_description(cls): | ||
|
||
display_name = NODE_DISPLAY_NAME_MAPPINGS.get(cls.__name__, cls.__name__) | ||
footer = "For more information, visit [RyanOnTheInside GitHub](https://github.com/ryanontheinside).\n\n" | ||
footer += "For tutorials and example workflows visit [RyanOnTheInside Civitai](https://civitai.com/user/ryanontheinside).\n\n" | ||
display_name = display_name.replace(" | Yvann", "") | ||
|
||
desc = f"# {display_name}\n\n" | ||
|
||
if hasattr(cls, 'DESCRIPTION'): | ||
desc += f"{cls.DESCRIPTION}\n\n{footer}" | ||
return desc | ||
|
||
if hasattr(cls, 'TOP_DESCRIPTION'): | ||
desc += f"### {cls.TOP_DESCRIPTION}\n\n" | ||
|
||
if hasattr(cls, "BASE_DESCRIPTION"): | ||
desc += cls.BASE_DESCRIPTION + "\n\n" | ||
|
||
additional_info = OrderedDict() | ||
for c in cls.mro()[::-1]: | ||
if hasattr(c, 'ADDITIONAL_INFO'): | ||
info = c.ADDITIONAL_INFO.strip() | ||
|
||
additional_info[c.__name__] = info | ||
|
||
if additional_info: | ||
desc += "\n\n".join(additional_info.values()) + "\n\n" | ||
|
||
if hasattr(cls, 'BOTTOM_DESCRIPTION'): | ||
desc += f"{cls.BOTTOM_DESCRIPTION}\n\n" | ||
|
||
desc += footer | ||
return desc | ||
|
||
from .nodes.audio.Audio_Drums_Analysis_Yvann import Audio_Drums_Analysis_Yvann | ||
from .nodes.audio.Audio_Vocals_Analysis_Yvann import Audio_Vocals_Analysis_Yvann | ||
from .nodes.audio.Audio_Analysis_Yvann import Audio_Analysis_Yvann | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"Audio Reactive IPAdapter | YVANN": Audio_Reactive_IPAdapter_YVANN, | ||
"Audio Analysis Advanced | YVANN": AudioAnalysis_Advanced_YVANN, | ||
"Audio Frequency Analysis | YVANN": AudioFrequencyAnalysis_YVANN, | ||
"Audio Drums Analysis | Yvann": Audio_Drums_Analysis_Yvann, | ||
"Audio Vocals Analysis | Yvann": Audio_Vocals_Analysis_Yvann, | ||
"Audio Analysis | Yvann": Audio_Analysis_Yvann, | ||
} | ||
|
||
WEB_DIRECTORY = "./web/js" | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"Audio Reactive IPAdapter | YVANN": "Audio Reactive IPAdapter | YVANN", | ||
"AudioAnalysis_Advanced_YVANN": "Audio Analysis Advanced | YVANN", | ||
"AudioFrequencyAnalysis_YVANN": "Audio Frequency Analysis | YVANN", | ||
"Audio Drums Analysis | Yvann": "Audio Drums Analysis | Yvann", | ||
"Audio Vocals Analysis | Yvann": "Audio Vocals Analysis | Yvann", | ||
"Audio Analysis | Yvann": "Audio Analysis | Yvann", | ||
} | ||
|
||
__all__ = ['NODE_CLASS_MAPPINGS'] | ||
from aiohttp import web | ||
from server import PromptServer | ||
from pathlib import Path | ||
|
||
if hasattr(PromptServer, "instance"): | ||
|
||
# NOTE: we add an extra static path to avoid comfy mechanism | ||
# that loads every script in web. | ||
|
||
PromptServer.instance.app.add_routes( | ||
[web.static("/yvann_web_async", (Path(__file__).parent.absolute() / "yvann_web_async").as_posix())] | ||
) | ||
|
||
|
||
|
||
for node_name, node_class in NODE_CLASS_MAPPINGS.items(): | ||
if hasattr(node_class, 'get_description'): | ||
desc = node_class.get_description() | ||
node_class.DESCRIPTION = desc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#NOTE: this abstraction allows for both the documentation to be centrally managed and inherited | ||
from abc import ABCMeta | ||
class NodeConfigMeta(type): | ||
def __new__(cls, name, bases, attrs): | ||
new_class = super().__new__(cls, name, bases, attrs) | ||
if name in NODE_CONFIGS: | ||
for key, value in NODE_CONFIGS[name].items(): | ||
setattr(new_class, key, value) | ||
return new_class | ||
|
||
class CombinedMeta(NodeConfigMeta, ABCMeta): | ||
pass | ||
|
||
def add_node_config(node_name, config): | ||
NODE_CONFIGS[node_name] = config | ||
|
||
NODE_CONFIGS = {} | ||
|
||
add_node_config("Audio_Analysis_Yvann", { | ||
"BASE_DESCRIPTION": """ | ||
##Parameters | ||
- `video_frames`: Input video frames to be processed. | ||
- `audio`: Input audio to be processed. | ||
- `frame_rate`: Frame rate of the video. | ||
- `smoothing_factor`: Smoothing factor for the audio analysis | ||
- `global_intensity`: Global intensity for the audio analysis | ||
""" | ||
}) | ||
|
||
add_node_config("Audio_Drums_Analysis_Yvann", { | ||
"BASE_DESCRIPTION": """ | ||
##Parameters | ||
- `video_frames`: Input video frames to be processed. | ||
- `audio`: Input audio to be processed. | ||
- `frame_rate`: Frame rate of the video. | ||
- `smoothing_factor`: Smoothing factor for the audio analysis | ||
- `global_intensity`: Global intensity for the audio analysis | ||
""" | ||
}) | ||
|
||
add_node_config("Audio_Vocals_Analysis_Yvann", { | ||
"BASE_DESCRIPTION": """ | ||
##Parameters | ||
- `video_frames`: Input video frames to be processed. | ||
- `audio`: Input audio to be processed. | ||
- `frame_rate`: Frame rate of the video. | ||
- `smoothing_factor`: Smoothing factor for the audio analysis | ||
- `global_intensity`: Global intensity for the audio analysis | ||
""" | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
import os | ||
import folder_paths | ||
import matplotlib.pyplot as plt | ||
import tempfile | ||
import numpy as np | ||
from PIL import Image | ||
import librosa | ||
from nodes import SaveImage | ||
import pandas as pd | ||
from ... import Yvann | ||
|
||
class AudioNodeBase(Yvann): | ||
CATEGORY= "👁️ Yvann Nodes/Audio" | ||
|
||
class Audio_Analysis_Yvann(AudioNodeBase): | ||
@classmethod | ||
def INPUT_TYPES(cls): | ||
return { | ||
"required": { | ||
"video_frames": ("IMAGE",), | ||
"audio": ("AUDIO",), | ||
"frame_rate": ("INT", {"default": 30, "min": 1, "max": 60, "step": 1}), | ||
"smoothing_factor": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), | ||
"global_intensity": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}), | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("AUDIO", "FLOAT", "MASK", "IMAGE") | ||
RETURN_NAMES = ("Audio", "Audio Weights", "Audio Masks", "Weights Graph") | ||
FUNCTION = "process_audio" | ||
|
||
def _get_audio_frame(self, waveform, i, samples_per_frame): | ||
start = i * samples_per_frame | ||
end = start + samples_per_frame | ||
return waveform[..., start:end].cpu().numpy().squeeze() | ||
|
||
def _rms_energy(self, waveform, num_frames, samples_per_frame): | ||
return np.array([np.sqrt(np.mean(self._get_audio_frame(waveform, i, samples_per_frame)**2)) for i in range(num_frames)]) | ||
|
||
def _smooth_weights(self, weights, smoothing_factor): | ||
if smoothing_factor <= 0.01: | ||
return weights | ||
kernel_size = max(3, int(smoothing_factor * 50)) # Ensure minimum kernel size of 3 | ||
kernel = np.ones(kernel_size) / kernel_size | ||
return np.convolve(weights, kernel, mode='same') | ||
|
||
def _normalize_weights(self, weights): | ||
min_val, max_val = np.min(weights), np.max(weights) | ||
if max_val > min_val: | ||
return (weights - min_val) / (max_val - min_val) | ||
else: | ||
return np.zeros_like(weights) | ||
|
||
def adjust_weights(self, weights, global_intensity): | ||
factor = 1 + (global_intensity * 0.5) | ||
adjusted_weights = np.maximum(weights * factor, 0) | ||
adjusted_weights = np.round(adjusted_weights, 3) | ||
return adjusted_weights | ||
|
||
def generate_masks(self, input_values, width, height): | ||
# Ensure input_values is a list | ||
if isinstance(input_values, (float, int)): | ||
input_values = [input_values] | ||
elif isinstance(input_values, pd.Series): | ||
input_values = input_values.tolist() | ||
elif isinstance(input_values, list) and all(isinstance(item, list) for item in input_values): | ||
input_values = [item for sublist in input_values for item in sublist] | ||
|
||
# Generate a batch of masks based on the input_values | ||
masks = [] | ||
for value in input_values: | ||
# Assuming value is a float between 0 and 1 representing the mask's intensity | ||
mask = torch.ones((height, width), dtype=torch.float32) * value | ||
masks.append(mask) | ||
masks_out = torch.stack(masks, dim=0) | ||
|
||
return masks_out | ||
|
||
def process_audio(self, audio, video_frames, frame_rate, smoothing_factor, global_intensity): | ||
waveform = audio['waveform'] | ||
sample_rate = audio['sample_rate'] | ||
|
||
num_frames, height, width, _ = video_frames.shape | ||
|
||
if waveform.dim() == 3: | ||
waveform = waveform.squeeze(0) | ||
if waveform.dim() == 1: | ||
waveform = waveform.unsqueeze(0) # Add channel dimension if mono | ||
if waveform.shape[0] != 2: | ||
waveform = waveform.repeat(2, 1) # Duplicate mono to stereo if necessary | ||
|
||
total_samples = waveform.shape[-1] | ||
samples_per_frame = total_samples // num_frames | ||
|
||
processed_audio = { | ||
'waveform': waveform, | ||
'sample_rate': sample_rate, | ||
'frame_rate': frame_rate | ||
} | ||
|
||
audio_weights = self._rms_energy(waveform, num_frames, samples_per_frame) | ||
audio_weights = self._smooth_weights(audio_weights, max(0.01, smoothing_factor)) | ||
audio_weights = self._normalize_weights(audio_weights) | ||
audio_weights = [round(float(weight), 3) for weight in audio_weights] | ||
|
||
audio_weights = self.adjust_weights(np.array(audio_weights), global_intensity) | ||
|
||
# Plot the weights | ||
frames = list(range(1, num_frames + 1)) | ||
plt.figure(figsize=(10, 6)) | ||
plt.plot(frames, audio_weights, label='Audio Weights', color='blue') | ||
plt.xlabel('Frame Number') | ||
plt.ylabel('Normalized Weights') | ||
plt.title('Normalized Weights for Audio (RMS Energy)') | ||
plt.legend() | ||
plt.grid(True) | ||
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile: | ||
plt.savefig(tmpfile, format='png') | ||
tmpfile_path = tmpfile.name | ||
plt.close() | ||
|
||
weights_graph = Image.open(tmpfile_path).convert("RGB") | ||
weights_graph = np.array(weights_graph) | ||
weights_graph = torch.tensor(weights_graph).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | ||
weights_graph = weights_graph.permute(0, 2, 3, 1) | ||
|
||
# Generate masks from audio weights | ||
audio_masks = self.generate_masks(audio_weights, width, height) | ||
|
||
return ( | ||
processed_audio, | ||
audio_weights.tolist(), | ||
audio_masks, | ||
weights_graph | ||
) |
Oops, something went wrong.