From 4bb61c64ea872fa8d1b3f90a69e3141031ed9b9b Mon Sep 17 00:00:00 2001 From: Yvann Date: Thu, 5 Sep 2024 22:44:32 +0200 Subject: [PATCH] Save before remove weight algo choice from audio reactive node --- __init__.py | 6 +- nodes/audio/AudioAnalysis_YVANN.py | 28 +-- nodes/audio/Audio_Reactive_IPAdapter_YVANN.py | 183 ++++++++++++++++++ 3 files changed, 202 insertions(+), 15 deletions(-) create mode 100644 nodes/audio/Audio_Reactive_IPAdapter_YVANN.py diff --git a/__init__.py b/__init__.py index 1aa6424..41944b4 100644 --- a/__init__.py +++ b/__init__.py @@ -1,15 +1,15 @@ -from .nodes.audio.AudioAnalysis_YVANN import AudioAnalysis_YVANN +from .nodes.audio.Audio_Reactive_IPAdapter_YVANN import Audio_Reactive_IPAdapter_YVANN from .nodes.audio.AudioAnalysis_Advanced_YVANN import AudioAnalysis_Advanced_YVANN from .nodes.audio.AudioFrequencyAnalysis_YVANN import AudioFrequencyAnalysis_YVANN NODE_CLASS_MAPPINGS = { - "Audio Analysis | YVANN": AudioAnalysis_YVANN, + "Audio Reactive IPAdapter | YVANN": Audio_Reactive_IPAdapter_YVANN, "Audio Analysis Advanced | YVANN": AudioAnalysis_Advanced_YVANN, "Audio Frequency Analysis | YVANN": AudioFrequencyAnalysis_YVANN, } NODE_DISPLAY_NAME_MAPPINGS = { - "AudioAnalysis_YVANN": "Audio Analysis | YVANN", + "Audio Reactive IPAdapter | YVANN": "Audio Reactive IPAdapter | YVANN", "AudioAnalysis_Advanced_YVANN": "Audio Analysis Advanced | YVANN", "AudioFrequencyAnalysis_YVANN": "Audio Frequency Analysis | YVANN", } diff --git a/nodes/audio/AudioAnalysis_YVANN.py b/nodes/audio/AudioAnalysis_YVANN.py index c4786c5..b07b1b5 100644 --- a/nodes/audio/AudioAnalysis_YVANN.py +++ b/nodes/audio/AudioAnalysis_YVANN.py @@ -7,21 +7,21 @@ from PIL import Image import librosa -class AudioAnalysis_YVANN: +class Audio_Reactive_IPAdapter_YVANN: @classmethod def INPUT_TYPES(cls): return { "required": { "video_frames": ("IMAGE",), "audio": ("AUDIO",), - "frame_rate": ("FLOAT", {"default": 30, "min": 0.1, "max": 120, "step": 0.1}), + "frame_rate": ("FLOAT",), "weight_algorithm": (["rms_energy", "amplitude_envelope", "spectral_centroid", "onset_detection", "chroma_features"], {"default": "rms_energy"}), "smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 1.0, "step": 0.01}), } } - RETURN_TYPES = ("AUDIO", "STRING", "AUDIO", "STRING", "AUDIO", "STRING", "AUDIO", "STRING", "IMAGE") - RETURN_NAMES = ("audio", "audio_weights_str", "drums_audio", "drums_weights_str", "vocals_audio", "vocals_weights_str", "bass_audio", "bass_weights_str", "Visual Weights Graph") + RETURN_TYPES = ("AUDIO", "FLOAT", "AUDIO", "FLOAT", "AUDIO", "FLOAT", "AUDIO", "FLOAT", "IMAGE") + RETURN_NAMES = ("audio", "Audio Weights", "drums_audio", "Drums Weights", "vocals_audio", "Vocals Weights", "bass_audio", "Bass Weights", "Visual Weights Graph") FUNCTION = "process_audio" def download_and_load_model(self): @@ -106,12 +106,17 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot total_samples = waveform.shape[-1] samples_per_frame = total_samples // num_frames - # Create isolated audio objects for each target + # Création des isolated_audio avec normalisation isolated_audio = {} target_indices = {'drums': 1, 'vocals': 0, 'bass': 2} for target, index in target_indices.items(): target_waveform = estimates[:, index, :, :] # Shape: (1, 2, num_samples) + # Normalisation du volume + max_val = torch.max(torch.abs(target_waveform)) + if max_val > 0: + target_waveform = target_waveform / max_val + isolated_audio[target] = { 'waveform': target_waveform.cpu(), # Move back to CPU 'sample_rate': sample_rate, @@ -129,11 +134,10 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot audio_weights = self._smooth_weights(audio_weights, smoothing_factor) audio_weights = self._normalize_weights(audio_weights) - audio_weights_str = [f"{i}:({float(weight):.2f})" for i, weight in enumerate(audio_weights)] + audio_weights = [round(float(weight), 3) for weight in audio_weights] # Calculate and normalize weights for each isolated audio target target_weights = {} - target_weights_str = {} for target, index in target_indices.items(): target_waveform = isolated_audio[target]['waveform'].squeeze(0) if weight_algorithm in ['spectral_centroid', 'onset_detection', 'chroma_features']: @@ -145,7 +149,7 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot target_weights[target] = self._smooth_weights(target_weights[target], smoothing_factor) target_weights[target] = self._normalize_weights(target_weights[target]) - target_weights_str[target] = [f"{i}:({float(weight):.2f})" for i, weight in enumerate(target_weights[target])] + target_weights[target] = [round(float(weight), 3) for weight in target_weights[target]] # Plot the weights frames = list(range(1, num_frames + 1)) @@ -176,12 +180,12 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot return ( audio, - ",\n".join(audio_weights_str), + audio_weights, isolated_audio['drums'], - ",\n".join(target_weights_str['drums']), + target_weights['drums'], isolated_audio['vocals'], - ",\n".join(target_weights_str['vocals']), + target_weights['vocals'], isolated_audio['bass'], - ",\n".join(target_weights_str['bass']), + target_weights['bass'], weights_graph ) diff --git a/nodes/audio/Audio_Reactive_IPAdapter_YVANN.py b/nodes/audio/Audio_Reactive_IPAdapter_YVANN.py new file mode 100644 index 0000000..5918a72 --- /dev/null +++ b/nodes/audio/Audio_Reactive_IPAdapter_YVANN.py @@ -0,0 +1,183 @@ +import torch +import os +import folder_paths +import matplotlib.pyplot as plt +import tempfile +import numpy as np +from PIL import Image +import librosa +from nodes import SaveImage + +class Audio_Reactive_IPAdapter_YVANN: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "video_frames": ("IMAGE",), + "audio": ("AUDIO",), + "frame_rate": ("INT", {"default": 30.0, "min": 1.0, "max": 60.0, "step": 1.0}), + "weight_algorithm": (["rms_energy", "amplitude_envelope", "spectral_centroid", "onset_detection", "chroma_features"], {"default": "rms_energy"}), + "smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("FLOAT", "AUDIO", "FLOAT", "AUDIO", "FLOAT", "IMAGE") + RETURN_NAMES = ("Audio Weights", "Drums Audio", "Drums Weights", "Vocals Audio", "Vocals Weights", "Weights Graph") + FUNCTION = "process_audio" + + def download_and_load_model(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + download_path = os.path.join(folder_paths.models_dir, "openunmix") + os.makedirs(download_path, exist_ok=True) + + model_file = "umxl.pth" + model_path = os.path.join(download_path, model_file) + + if not os.path.exists(model_path): + print("Downloading umxhq model...") + separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxl', device='cpu') + torch.save(separator.state_dict(), model_path) + print(f"Model saved to: {model_path}") + else: + print(f"Loading model from: {model_path}") + separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxl', device='cpu') + separator.load_state_dict(torch.load(model_path, map_location='cpu')) + + separator = separator.to(device) + separator.eval() + + return separator + + def _get_audio_frame(self, waveform, i, samples_per_frame): + start = i * samples_per_frame + end = start + samples_per_frame + return waveform[..., start:end].cpu().numpy().squeeze() + + def _amplitude_envelope(self, waveform, num_frames, samples_per_frame): + return np.array([np.max(np.abs(self._get_audio_frame(waveform, i, samples_per_frame))) for i in range(num_frames)]) + + def _rms_energy(self, waveform, num_frames, samples_per_frame): + return np.array([np.sqrt(np.mean(self._get_audio_frame(waveform, i, samples_per_frame)**2)) for i in range(num_frames)]) + + def _spectral_centroid(self, waveform, num_frames, samples_per_frame, sample_rate): + return np.array([np.mean(librosa.feature.spectral_centroid(y=self._get_audio_frame(waveform, i, samples_per_frame), sr=sample_rate)[0]) for i in range(num_frames)]) + + def _onset_detection(self, waveform, num_frames, samples_per_frame, sample_rate): + return np.array([np.mean(librosa.onset.onset_strength(y=self._get_audio_frame(waveform, i, samples_per_frame), sr=sample_rate)) for i in range(num_frames)]) + + def _chroma_features(self, waveform, num_frames, samples_per_frame, sample_rate): + return np.array([np.mean(librosa.feature.chroma_stft(y=self._get_audio_frame(waveform, i, samples_per_frame), sr=sample_rate)) for i in range(num_frames)]) + + def _smooth_weights(self, weights, smoothing_factor): + kernel_size = max(3, int(smoothing_factor * 50)) # Ensure minimum kernel size of 3 + kernel = np.ones(kernel_size) / kernel_size + return np.convolve(weights, kernel, mode='same') + + def _normalize_weights(self, weights): + min_val, max_val = np.min(weights), np.max(weights) + if max_val > min_val: + return (weights - min_val) / (max_val - min_val) + else: + return np.zeros_like(weights) + + def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoothing_factor): + model = self.download_and_load_model() + + waveform = audio['waveform'] + sample_rate = audio['sample_rate'] + + num_frames, height, width, _ = video_frames.shape + + if waveform.dim() == 3: + waveform = waveform.squeeze(0) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) # Add channel dimension if mono + if waveform.shape[0] != 2: + waveform = waveform.repeat(2, 1) # Duplicate mono to stereo if necessary + + waveform = waveform.unsqueeze(0) + + # Determine the device + device = next(model.parameters()).device + waveform = waveform.to(device) + + estimates = model(waveform) + + # Compute normalized audio weights for each frame + total_samples = waveform.shape[-1] + samples_per_frame = total_samples // num_frames + + # Create isolated audio objects for each target + isolated_audio = {} + target_indices = {'drums': 1, 'vocals': 0} + for target, index in target_indices.items(): + target_waveform = estimates[:, index, :, :] # Shape: (1, 2, num_samples) + + isolated_audio[target] = { + 'waveform': target_waveform.cpu(), # Move back to CPU + 'sample_rate': sample_rate, + 'frame_rate': frame_rate + } + + # Apply the selected weight algorithm + weight_function = getattr(self, f"_{weight_algorithm}") + if weight_algorithm in ['spectral_centroid', 'onset_detection', 'chroma_features']: + audio_weights = weight_function(waveform.squeeze(0), num_frames, samples_per_frame, sample_rate) + else: + audio_weights = weight_function(waveform.squeeze(0), num_frames, samples_per_frame) + + # Apply smoothing to audio weights + audio_weights = self._smooth_weights(audio_weights, smoothing_factor) + audio_weights = self._normalize_weights(audio_weights) + + audio_weights = [round(float(weight), 3) for weight in audio_weights] + + # Calculate and normalize weights for each isolated audio target + target_weights = {} + for target, index in target_indices.items(): + target_waveform = isolated_audio[target]['waveform'].squeeze(0) + if weight_algorithm in ['spectral_centroid', 'onset_detection', 'chroma_features']: + target_weights[target] = weight_function(target_waveform, num_frames, samples_per_frame, sample_rate) + else: + target_weights[target] = weight_function(target_waveform, num_frames, samples_per_frame) + + # Apply smoothing to target weights + target_weights[target] = self._smooth_weights(target_weights[target], smoothing_factor) + target_weights[target] = self._normalize_weights(target_weights[target]) + + target_weights[target] = [round(float(weight), 3) for weight in target_weights[target]] + + # Plot the weights + frames = list(range(1, num_frames + 1)) + plt.figure(figsize=(10, 6)) + plt.plot(frames, audio_weights, label='Audio Weights', color='black') + plt.plot(frames, target_weights['drums'], label='Drums Weights', color='red') + plt.plot(frames, target_weights['vocals'], label='Vocals Weights', color='green') + plt.xlabel('Frame Number') + plt.ylabel('Normalized Weights') + plt.title(f'Normalized Weights for Audio Components ({weight_algorithm})') + plt.legend() + plt.grid(True) + + # Save the plot to a temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile: + plt.savefig(tmpfile, format='png') + tmpfile_path = tmpfile.name + plt.close() + + # Load the image from the temporary file and convert to tensor + weights_graph = Image.open(tmpfile_path).convert("RGB") + weights_graph = np.array(weights_graph) + weights_graph = torch.tensor(weights_graph).permute(2, 0, 1).unsqueeze(0).float() / 255.0 + + # Ensure the tensor has the correct shape [B, H, W, C] + weights_graph = weights_graph.permute(0, 2, 3, 1) + + return ( + audio_weights, + isolated_audio['drums'], + target_weights['drums'], + isolated_audio['vocals'], + target_weights['vocals'], + weights_graph + )