Skip to content

Commit

Permalink
Save before remove weight algo choice from audio reactive node
Browse files Browse the repository at this point in the history
  • Loading branch information
yvann-ba committed Sep 5, 2024
1 parent e2d51e0 commit 4bb61c6
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 15 deletions.
6 changes: 3 additions & 3 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from .nodes.audio.AudioAnalysis_YVANN import AudioAnalysis_YVANN
from .nodes.audio.Audio_Reactive_IPAdapter_YVANN import Audio_Reactive_IPAdapter_YVANN
from .nodes.audio.AudioAnalysis_Advanced_YVANN import AudioAnalysis_Advanced_YVANN
from .nodes.audio.AudioFrequencyAnalysis_YVANN import AudioFrequencyAnalysis_YVANN

NODE_CLASS_MAPPINGS = {
"Audio Analysis | YVANN": AudioAnalysis_YVANN,
"Audio Reactive IPAdapter | YVANN": Audio_Reactive_IPAdapter_YVANN,
"Audio Analysis Advanced | YVANN": AudioAnalysis_Advanced_YVANN,
"Audio Frequency Analysis | YVANN": AudioFrequencyAnalysis_YVANN,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"AudioAnalysis_YVANN": "Audio Analysis | YVANN",
"Audio Reactive IPAdapter | YVANN": "Audio Reactive IPAdapter | YVANN",
"AudioAnalysis_Advanced_YVANN": "Audio Analysis Advanced | YVANN",
"AudioFrequencyAnalysis_YVANN": "Audio Frequency Analysis | YVANN",
}
Expand Down
28 changes: 16 additions & 12 deletions nodes/audio/AudioAnalysis_YVANN.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from PIL import Image
import librosa

class AudioAnalysis_YVANN:
class Audio_Reactive_IPAdapter_YVANN:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video_frames": ("IMAGE",),
"audio": ("AUDIO",),
"frame_rate": ("FLOAT", {"default": 30, "min": 0.1, "max": 120, "step": 0.1}),
"frame_rate": ("FLOAT",),
"weight_algorithm": (["rms_energy", "amplitude_envelope", "spectral_centroid", "onset_detection", "chroma_features"], {"default": "rms_energy"}),
"smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 1.0, "step": 0.01}),
}
}

RETURN_TYPES = ("AUDIO", "STRING", "AUDIO", "STRING", "AUDIO", "STRING", "AUDIO", "STRING", "IMAGE")
RETURN_NAMES = ("audio", "audio_weights_str", "drums_audio", "drums_weights_str", "vocals_audio", "vocals_weights_str", "bass_audio", "bass_weights_str", "Visual Weights Graph")
RETURN_TYPES = ("AUDIO", "FLOAT", "AUDIO", "FLOAT", "AUDIO", "FLOAT", "AUDIO", "FLOAT", "IMAGE")
RETURN_NAMES = ("audio", "Audio Weights", "drums_audio", "Drums Weights", "vocals_audio", "Vocals Weights", "bass_audio", "Bass Weights", "Visual Weights Graph")
FUNCTION = "process_audio"

def download_and_load_model(self):
Expand Down Expand Up @@ -106,12 +106,17 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot
total_samples = waveform.shape[-1]
samples_per_frame = total_samples // num_frames

# Create isolated audio objects for each target
# Création des isolated_audio avec normalisation
isolated_audio = {}
target_indices = {'drums': 1, 'vocals': 0, 'bass': 2}
for target, index in target_indices.items():
target_waveform = estimates[:, index, :, :] # Shape: (1, 2, num_samples)

# Normalisation du volume
max_val = torch.max(torch.abs(target_waveform))
if max_val > 0:
target_waveform = target_waveform / max_val

isolated_audio[target] = {
'waveform': target_waveform.cpu(), # Move back to CPU
'sample_rate': sample_rate,
Expand All @@ -129,11 +134,10 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot
audio_weights = self._smooth_weights(audio_weights, smoothing_factor)
audio_weights = self._normalize_weights(audio_weights)

audio_weights_str = [f"{i}:({float(weight):.2f})" for i, weight in enumerate(audio_weights)]
audio_weights = [round(float(weight), 3) for weight in audio_weights]

# Calculate and normalize weights for each isolated audio target
target_weights = {}
target_weights_str = {}
for target, index in target_indices.items():
target_waveform = isolated_audio[target]['waveform'].squeeze(0)
if weight_algorithm in ['spectral_centroid', 'onset_detection', 'chroma_features']:
Expand All @@ -145,7 +149,7 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot
target_weights[target] = self._smooth_weights(target_weights[target], smoothing_factor)
target_weights[target] = self._normalize_weights(target_weights[target])

target_weights_str[target] = [f"{i}:({float(weight):.2f})" for i, weight in enumerate(target_weights[target])]
target_weights[target] = [round(float(weight), 3) for weight in target_weights[target]]

# Plot the weights
frames = list(range(1, num_frames + 1))
Expand Down Expand Up @@ -176,12 +180,12 @@ def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoot

return (
audio,
",\n".join(audio_weights_str),
audio_weights,
isolated_audio['drums'],
",\n".join(target_weights_str['drums']),
target_weights['drums'],
isolated_audio['vocals'],
",\n".join(target_weights_str['vocals']),
target_weights['vocals'],
isolated_audio['bass'],
",\n".join(target_weights_str['bass']),
target_weights['bass'],
weights_graph
)
183 changes: 183 additions & 0 deletions nodes/audio/Audio_Reactive_IPAdapter_YVANN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import torch
import os
import folder_paths
import matplotlib.pyplot as plt
import tempfile
import numpy as np
from PIL import Image
import librosa
from nodes import SaveImage

class Audio_Reactive_IPAdapter_YVANN:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video_frames": ("IMAGE",),
"audio": ("AUDIO",),
"frame_rate": ("INT", {"default": 30.0, "min": 1.0, "max": 60.0, "step": 1.0}),
"weight_algorithm": (["rms_energy", "amplitude_envelope", "spectral_centroid", "onset_detection", "chroma_features"], {"default": "rms_energy"}),
"smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 1.0, "step": 0.01}),
}
}

RETURN_TYPES = ("FLOAT", "AUDIO", "FLOAT", "AUDIO", "FLOAT", "IMAGE")
RETURN_NAMES = ("Audio Weights", "Drums Audio", "Drums Weights", "Vocals Audio", "Vocals Weights", "Weights Graph")
FUNCTION = "process_audio"

def download_and_load_model(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
download_path = os.path.join(folder_paths.models_dir, "openunmix")
os.makedirs(download_path, exist_ok=True)

model_file = "umxl.pth"
model_path = os.path.join(download_path, model_file)

if not os.path.exists(model_path):
print("Downloading umxhq model...")
separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxl', device='cpu')
torch.save(separator.state_dict(), model_path)
print(f"Model saved to: {model_path}")
else:
print(f"Loading model from: {model_path}")
separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxl', device='cpu')
separator.load_state_dict(torch.load(model_path, map_location='cpu'))

separator = separator.to(device)
separator.eval()

return separator

def _get_audio_frame(self, waveform, i, samples_per_frame):
start = i * samples_per_frame
end = start + samples_per_frame
return waveform[..., start:end].cpu().numpy().squeeze()

def _amplitude_envelope(self, waveform, num_frames, samples_per_frame):
return np.array([np.max(np.abs(self._get_audio_frame(waveform, i, samples_per_frame))) for i in range(num_frames)])

def _rms_energy(self, waveform, num_frames, samples_per_frame):
return np.array([np.sqrt(np.mean(self._get_audio_frame(waveform, i, samples_per_frame)**2)) for i in range(num_frames)])

def _spectral_centroid(self, waveform, num_frames, samples_per_frame, sample_rate):
return np.array([np.mean(librosa.feature.spectral_centroid(y=self._get_audio_frame(waveform, i, samples_per_frame), sr=sample_rate)[0]) for i in range(num_frames)])

def _onset_detection(self, waveform, num_frames, samples_per_frame, sample_rate):
return np.array([np.mean(librosa.onset.onset_strength(y=self._get_audio_frame(waveform, i, samples_per_frame), sr=sample_rate)) for i in range(num_frames)])

def _chroma_features(self, waveform, num_frames, samples_per_frame, sample_rate):
return np.array([np.mean(librosa.feature.chroma_stft(y=self._get_audio_frame(waveform, i, samples_per_frame), sr=sample_rate)) for i in range(num_frames)])

def _smooth_weights(self, weights, smoothing_factor):
kernel_size = max(3, int(smoothing_factor * 50)) # Ensure minimum kernel size of 3
kernel = np.ones(kernel_size) / kernel_size
return np.convolve(weights, kernel, mode='same')

def _normalize_weights(self, weights):
min_val, max_val = np.min(weights), np.max(weights)
if max_val > min_val:
return (weights - min_val) / (max_val - min_val)
else:
return np.zeros_like(weights)

def process_audio(self, audio, video_frames, frame_rate, weight_algorithm, smoothing_factor):
model = self.download_and_load_model()

waveform = audio['waveform']
sample_rate = audio['sample_rate']

num_frames, height, width, _ = video_frames.shape

if waveform.dim() == 3:
waveform = waveform.squeeze(0)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0) # Add channel dimension if mono
if waveform.shape[0] != 2:
waveform = waveform.repeat(2, 1) # Duplicate mono to stereo if necessary

waveform = waveform.unsqueeze(0)

# Determine the device
device = next(model.parameters()).device
waveform = waveform.to(device)

estimates = model(waveform)

# Compute normalized audio weights for each frame
total_samples = waveform.shape[-1]
samples_per_frame = total_samples // num_frames

# Create isolated audio objects for each target
isolated_audio = {}
target_indices = {'drums': 1, 'vocals': 0}
for target, index in target_indices.items():
target_waveform = estimates[:, index, :, :] # Shape: (1, 2, num_samples)

isolated_audio[target] = {
'waveform': target_waveform.cpu(), # Move back to CPU
'sample_rate': sample_rate,
'frame_rate': frame_rate
}

# Apply the selected weight algorithm
weight_function = getattr(self, f"_{weight_algorithm}")
if weight_algorithm in ['spectral_centroid', 'onset_detection', 'chroma_features']:
audio_weights = weight_function(waveform.squeeze(0), num_frames, samples_per_frame, sample_rate)
else:
audio_weights = weight_function(waveform.squeeze(0), num_frames, samples_per_frame)

# Apply smoothing to audio weights
audio_weights = self._smooth_weights(audio_weights, smoothing_factor)
audio_weights = self._normalize_weights(audio_weights)

audio_weights = [round(float(weight), 3) for weight in audio_weights]

# Calculate and normalize weights for each isolated audio target
target_weights = {}
for target, index in target_indices.items():
target_waveform = isolated_audio[target]['waveform'].squeeze(0)
if weight_algorithm in ['spectral_centroid', 'onset_detection', 'chroma_features']:
target_weights[target] = weight_function(target_waveform, num_frames, samples_per_frame, sample_rate)
else:
target_weights[target] = weight_function(target_waveform, num_frames, samples_per_frame)

# Apply smoothing to target weights
target_weights[target] = self._smooth_weights(target_weights[target], smoothing_factor)
target_weights[target] = self._normalize_weights(target_weights[target])

target_weights[target] = [round(float(weight), 3) for weight in target_weights[target]]

# Plot the weights
frames = list(range(1, num_frames + 1))
plt.figure(figsize=(10, 6))
plt.plot(frames, audio_weights, label='Audio Weights', color='black')
plt.plot(frames, target_weights['drums'], label='Drums Weights', color='red')
plt.plot(frames, target_weights['vocals'], label='Vocals Weights', color='green')
plt.xlabel('Frame Number')
plt.ylabel('Normalized Weights')
plt.title(f'Normalized Weights for Audio Components ({weight_algorithm})')
plt.legend()
plt.grid(True)

# Save the plot to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile:
plt.savefig(tmpfile, format='png')
tmpfile_path = tmpfile.name
plt.close()

# Load the image from the temporary file and convert to tensor
weights_graph = Image.open(tmpfile_path).convert("RGB")
weights_graph = np.array(weights_graph)
weights_graph = torch.tensor(weights_graph).permute(2, 0, 1).unsqueeze(0).float() / 255.0

# Ensure the tensor has the correct shape [B, H, W, C]
weights_graph = weights_graph.permute(0, 2, 3, 1)

return (
audio_weights,
isolated_audio['drums'],
target_weights['drums'],
isolated_audio['vocals'],
target_weights['vocals'],
weights_graph
)

0 comments on commit 4bb61c6

Please sign in to comment.