Skip to content

Commit

Permalink
Add Help PopUp thanks to Ryanontheinside Kjnodes and mtb, add mask fe…
Browse files Browse the repository at this point in the history
…atures to all audio nodes
  • Loading branch information
yvann-ba committed Sep 6, 2024
1 parent 4bb61c6 commit 173ae74
Show file tree
Hide file tree
Showing 16 changed files with 1,265 additions and 193 deletions.
84 changes: 74 additions & 10 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,81 @@
from .nodes.audio.Audio_Reactive_IPAdapter_YVANN import Audio_Reactive_IPAdapter_YVANN
from .nodes.audio.AudioAnalysis_Advanced_YVANN import AudioAnalysis_Advanced_YVANN
from .nodes.audio.AudioFrequencyAnalysis_YVANN import AudioFrequencyAnalysis_YVANN
from .node_configs import CombinedMeta
from collections import OrderedDict

# credit to RyanOnTheInside, KJNodes, MTB, Akatz, their works helped me a lot


#allows for central management and inheritance of class variables for help documentation
class Yvann(metaclass=CombinedMeta):
@classmethod
def get_description(cls):

display_name = NODE_DISPLAY_NAME_MAPPINGS.get(cls.__name__, cls.__name__)
footer = "For more information, visit [RyanOnTheInside GitHub](https://github.com/ryanontheinside).\n\n"
footer += "For tutorials and example workflows visit [RyanOnTheInside Civitai](https://civitai.com/user/ryanontheinside).\n\n"
display_name = display_name.replace(" | Yvann", "")

desc = f"# {display_name}\n\n"

if hasattr(cls, 'DESCRIPTION'):
desc += f"{cls.DESCRIPTION}\n\n{footer}"
return desc

if hasattr(cls, 'TOP_DESCRIPTION'):
desc += f"### {cls.TOP_DESCRIPTION}\n\n"

if hasattr(cls, "BASE_DESCRIPTION"):
desc += cls.BASE_DESCRIPTION + "\n\n"

additional_info = OrderedDict()
for c in cls.mro()[::-1]:
if hasattr(c, 'ADDITIONAL_INFO'):
info = c.ADDITIONAL_INFO.strip()

additional_info[c.__name__] = info

if additional_info:
desc += "\n\n".join(additional_info.values()) + "\n\n"

if hasattr(cls, 'BOTTOM_DESCRIPTION'):
desc += f"{cls.BOTTOM_DESCRIPTION}\n\n"

desc += footer
return desc

from .nodes.audio.Audio_Drums_Analysis_Yvann import Audio_Drums_Analysis_Yvann
from .nodes.audio.Audio_Vocals_Analysis_Yvann import Audio_Vocals_Analysis_Yvann
from .nodes.audio.Audio_Analysis_Yvann import Audio_Analysis_Yvann

NODE_CLASS_MAPPINGS = {
"Audio Reactive IPAdapter | YVANN": Audio_Reactive_IPAdapter_YVANN,
"Audio Analysis Advanced | YVANN": AudioAnalysis_Advanced_YVANN,
"Audio Frequency Analysis | YVANN": AudioFrequencyAnalysis_YVANN,
"Audio Drums Analysis | Yvann": Audio_Drums_Analysis_Yvann,
"Audio Vocals Analysis | Yvann": Audio_Vocals_Analysis_Yvann,
"Audio Analysis | Yvann": Audio_Analysis_Yvann,
}

WEB_DIRECTORY = "./web/js"

NODE_DISPLAY_NAME_MAPPINGS = {
"Audio Reactive IPAdapter | YVANN": "Audio Reactive IPAdapter | YVANN",
"AudioAnalysis_Advanced_YVANN": "Audio Analysis Advanced | YVANN",
"AudioFrequencyAnalysis_YVANN": "Audio Frequency Analysis | YVANN",
"Audio Drums Analysis | Yvann": "Audio Drums Analysis | Yvann",
"Audio Vocals Analysis | Yvann": "Audio Vocals Analysis | Yvann",
"Audio Analysis | Yvann": "Audio Analysis | Yvann",
}

__all__ = ['NODE_CLASS_MAPPINGS']
from aiohttp import web
from server import PromptServer
from pathlib import Path

if hasattr(PromptServer, "instance"):

# NOTE: we add an extra static path to avoid comfy mechanism
# that loads every script in web.

PromptServer.instance.app.add_routes(
[web.static("/yvann_web_async", (Path(__file__).parent.absolute() / "yvann_web_async").as_posix())]
)



for node_name, node_class in NODE_CLASS_MAPPINGS.items():
if hasattr(node_class, 'get_description'):
desc = node_class.get_description()
node_class.DESCRIPTION = desc
50 changes: 50 additions & 0 deletions node_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#NOTE: this abstraction allows for both the documentation to be centrally managed and inherited
from abc import ABCMeta
class NodeConfigMeta(type):
def __new__(cls, name, bases, attrs):
new_class = super().__new__(cls, name, bases, attrs)
if name in NODE_CONFIGS:
for key, value in NODE_CONFIGS[name].items():
setattr(new_class, key, value)
return new_class

class CombinedMeta(NodeConfigMeta, ABCMeta):
pass

def add_node_config(node_name, config):
NODE_CONFIGS[node_name] = config

NODE_CONFIGS = {}

add_node_config("Audio_Analysis_Yvann", {
"BASE_DESCRIPTION": """
##Parameters
- `video_frames`: Input video frames to be processed.
- `audio`: Input audio to be processed.
- `frame_rate`: Frame rate of the video.
- `smoothing_factor`: Smoothing factor for the audio analysis
- `global_intensity`: Global intensity for the audio analysis
"""
})

add_node_config("Audio_Drums_Analysis_Yvann", {
"BASE_DESCRIPTION": """
##Parameters
- `video_frames`: Input video frames to be processed.
- `audio`: Input audio to be processed.
- `frame_rate`: Frame rate of the video.
- `smoothing_factor`: Smoothing factor for the audio analysis
- `global_intensity`: Global intensity for the audio analysis
"""
})

add_node_config("Audio_Vocals_Analysis_Yvann", {
"BASE_DESCRIPTION": """
##Parameters
- `video_frames`: Input video frames to be processed.
- `audio`: Input audio to be processed.
- `frame_rate`: Frame rate of the video.
- `smoothing_factor`: Smoothing factor for the audio analysis
- `global_intensity`: Global intensity for the audio analysis
"""
})
137 changes: 137 additions & 0 deletions nodes/audio/Audio_Analysis_Yvann.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
import os
import folder_paths
import matplotlib.pyplot as plt
import tempfile
import numpy as np
from PIL import Image
import librosa
from nodes import SaveImage
import pandas as pd
from ... import Yvann

class AudioNodeBase(Yvann):
CATEGORY= "👁️ Yvann Nodes/Audio"

class Audio_Analysis_Yvann(AudioNodeBase):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video_frames": ("IMAGE",),
"audio": ("AUDIO",),
"frame_rate": ("INT", {"default": 30, "min": 1, "max": 60, "step": 1}),
"smoothing_factor": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
"global_intensity": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
}
}

RETURN_TYPES = ("AUDIO", "FLOAT", "MASK", "IMAGE")
RETURN_NAMES = ("Audio", "Audio Weights", "Audio Masks", "Weights Graph")
FUNCTION = "process_audio"

def _get_audio_frame(self, waveform, i, samples_per_frame):
start = i * samples_per_frame
end = start + samples_per_frame
return waveform[..., start:end].cpu().numpy().squeeze()

def _rms_energy(self, waveform, num_frames, samples_per_frame):
return np.array([np.sqrt(np.mean(self._get_audio_frame(waveform, i, samples_per_frame)**2)) for i in range(num_frames)])

def _smooth_weights(self, weights, smoothing_factor):
if smoothing_factor <= 0.01:
return weights
kernel_size = max(3, int(smoothing_factor * 50)) # Ensure minimum kernel size of 3
kernel = np.ones(kernel_size) / kernel_size
return np.convolve(weights, kernel, mode='same')

def _normalize_weights(self, weights):
min_val, max_val = np.min(weights), np.max(weights)
if max_val > min_val:
return (weights - min_val) / (max_val - min_val)
else:
return np.zeros_like(weights)

def adjust_weights(self, weights, global_intensity):
factor = 1 + (global_intensity * 0.5)
adjusted_weights = np.maximum(weights * factor, 0)
adjusted_weights = np.round(adjusted_weights, 3)
return adjusted_weights

def generate_masks(self, input_values, width, height):
# Ensure input_values is a list
if isinstance(input_values, (float, int)):
input_values = [input_values]
elif isinstance(input_values, pd.Series):
input_values = input_values.tolist()
elif isinstance(input_values, list) and all(isinstance(item, list) for item in input_values):
input_values = [item for sublist in input_values for item in sublist]

# Generate a batch of masks based on the input_values
masks = []
for value in input_values:
# Assuming value is a float between 0 and 1 representing the mask's intensity
mask = torch.ones((height, width), dtype=torch.float32) * value
masks.append(mask)
masks_out = torch.stack(masks, dim=0)

return masks_out

def process_audio(self, audio, video_frames, frame_rate, smoothing_factor, global_intensity):
waveform = audio['waveform']
sample_rate = audio['sample_rate']

num_frames, height, width, _ = video_frames.shape

if waveform.dim() == 3:
waveform = waveform.squeeze(0)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # Add channel dimension if mono
if waveform.shape[0] != 2:
waveform = waveform.repeat(2, 1) # Duplicate mono to stereo if necessary

total_samples = waveform.shape[-1]
samples_per_frame = total_samples // num_frames

processed_audio = {
'waveform': waveform,
'sample_rate': sample_rate,
'frame_rate': frame_rate
}

audio_weights = self._rms_energy(waveform, num_frames, samples_per_frame)
audio_weights = self._smooth_weights(audio_weights, max(0.01, smoothing_factor))
audio_weights = self._normalize_weights(audio_weights)
audio_weights = [round(float(weight), 3) for weight in audio_weights]

audio_weights = self.adjust_weights(np.array(audio_weights), global_intensity)

# Plot the weights
frames = list(range(1, num_frames + 1))
plt.figure(figsize=(10, 6))
plt.plot(frames, audio_weights, label='Audio Weights', color='blue')
plt.xlabel('Frame Number')
plt.ylabel('Normalized Weights')
plt.title('Normalized Weights for Audio (RMS Energy)')
plt.legend()
plt.grid(True)

with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile:
plt.savefig(tmpfile, format='png')
tmpfile_path = tmpfile.name
plt.close()

weights_graph = Image.open(tmpfile_path).convert("RGB")
weights_graph = np.array(weights_graph)
weights_graph = torch.tensor(weights_graph).permute(2, 0, 1).unsqueeze(0).float() / 255.0
weights_graph = weights_graph.permute(0, 2, 3, 1)

# Generate masks from audio weights
audio_masks = self.generate_masks(audio_weights, width, height)

return (
processed_audio,
audio_weights.tolist(),
audio_masks,
weights_graph
)
Loading

0 comments on commit 173ae74

Please sign in to comment.