-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_unet3d.py
103 lines (85 loc) · 2.94 KB
/
test_unet3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import jax
import jax.numpy as jnp
from flax import linen as nn
import numpy as np
from flaxdiff.models.video_unet import FlaxUNet3DConditionModel
import matplotlib.pyplot as plt
import os
# Force CPU backend for testing
os.environ['JAX_PLATFORMS'] = 'cpu'
def test_unet3d_model():
"""
Test the FlaxUNet3DConditionModel with a simple random input
and visualize the output.
"""
# Set random seed for reproducibility
rng = jax.random.PRNGKey(42)
# Define model parameters
model = FlaxUNet3DConditionModel(
sample_size=32, # Small sample for testing
in_channels=4,
out_channels=4,
down_block_types=(
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
up_block_types=(
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
),
block_out_channels=(32, 64, 64, 64), # Smaller channels for testing
layers_per_block=1,
cross_attention_dim=64,
attention_head_dim=8,
dropout=0.0,
dtype=jnp.float32
)
# Create dummy inputs
batch_size = 1
num_frames = 4
sample = jax.random.normal(
rng,
shape=(batch_size, num_frames, 32, 32, 4),
dtype=jnp.float32
)
timestep = jnp.array([0], dtype=jnp.int32)
# Create dummy text embeddings
rng, text_key = jax.random.split(rng)
encoder_hidden_states = jax.random.normal(
text_key,
shape=(batch_size, 77, 64), # 77 is standard for CLIP text tokens
dtype=jnp.float32
)
# Initialize the model
rng, init_key = jax.random.split(rng)
params = model.init(init_key, sample, timestep, encoder_hidden_states)
# Print model summary
param_count = sum(p.size for p in jax.tree_util.tree_leaves(params))
print(f"Model initialized with {param_count:,} parameters")
# Run a forward pass
output = model.apply(params, sample, timestep, encoder_hidden_states)
# Print output shape
print(f"Input shape: {sample.shape}")
print(f"Output shape: {output['sample'].shape}")
# Visualize a sample frame from both input and output
frame_idx = 0
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Show input (only first 3 channels for visualization)
axes[0].imshow(sample[0, frame_idx, :, :, :3])
axes[0].set_title("Input Sample (Frame 0)")
axes[0].axis('off')
# Show output (only first 3 channels for visualization)
axes[1].imshow(output['sample'][0, frame_idx, :, :, :3])
axes[1].set_title("Model Output (Frame 0)")
axes[1].axis('off')
plt.tight_layout()
plt.savefig('unet3d_test_output.png')
plt.close()
print(f"Visualization saved to 'unet3d_test_output.png'")
return model, params
if __name__ == "__main__":
test_unet3d_model()