diff --git a/stable_warpfusion.ipynb b/stable_warpfusion.ipynb
new file mode 100644
index 00000000..177423ca
--- /dev/null
+++ b/stable_warpfusion.ipynb
@@ -0,0 +1,13393 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TitleTop"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "# This is a beta of WarpFusion.\n",
+ "#### May produce meh results and not be very stable.\n",
+ "[Local install guide](https://github.com/Sxela/WarpFusion/blob/main/README.md).\n",
+ "Use venv guide for this colab.\n",
+ "\n",
+ "\\\n",
+ "Tutorials by WarpFusion users:\n",
+ "- 05.05.2023, v0.10 [Video to AI Animation Tutorial For Beginners: Stable WarpFusion + Controlnet | MDMZ](https://youtu.be/HkM-7wxtkGA)\n",
+ "- 11.05.2023, v0.11 [How to use Stable Warp Fusion](https://www.youtube.com/watch?v=FxRTEILPCQQ)\n",
+ "\n",
+ "- 13.05.2023, v0.8 [Warp Fusion Local Install Guide (v0.8.6) with Diffusion Demonstration](https://www.youtube.com/watch?v=wqXy_r_9qw8)\n",
+ "\n",
+ "- 14.05.2023, v0.12 [Warp Fusion Alpha Masking Tutorial | Covers Both Auto-Masking and Custom Masking](https://www.youtube.com/watch?v=VMF7L0czyIg)\n",
+ "\n",
+ "- 23.05.2023, v0.12 [STABLE WARPFUSION TUTORIAL - Colab Pro & Local Install](https://www.youtube.com/watch?v=m8xaPnaooyg)\n",
+ "\n",
+ "- 15.06.2023, v0.13 [AI Animation out of Your Video: Stable Warpfusion Guide (Google Colab & Local Intallation)](https://www.youtube.com/watch?v=-B7WtxAAXLg)\n",
+ "\n",
+ "- 17.06.2023, v0.14 [Stable Warpfusion Tutorial: Turn Your Video to an AI Animation](https://www.youtube.com/watch?v=tUHCtQaBWCw)\n",
+ "\n",
+ "- 21.06.2023, v0.14 [Avoiding Common Problems with Stable Warpfusion](https://www.youtube.com/watch?v=GH420ol2sCw)\n",
+ "\n",
+ "- 21.06.2023, v0.15 [Warp Fusion: Step by Step Tutorial](https://www.youtube.com/watch?v=0AT8esyY0Fw)\n",
+ "\n",
+ "- 04.07.2023, v0.15 [Intense AI Video Maker (Stable WarpFusion Tutorial)](https://www.youtube.com/watch?v=mVze7REhjCI&ab_channel=MattWolfe)\n",
+ "\n",
+ "- 15.08.2023, v0.17 [BEST Laptop for AI ( SDXL & Stable Warpfusion ) ft. RTX 4090 - Make AI Art FREE and FAST!](https://www.youtube.com/watch?v=SM0Mxmhfj7A)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\\\n",
+ "Kudos to my [patreon](https://www.patreon.com/sxela) supporters: \\\n",
+ "**Ced Pakusevskij**, **John Haugeland**, **Fernando Magalhaes**, **Zlata Ponirovskaya**, **Francisco Bucknor**, **Territory Technical**, **Noah Miller**, **Chris Boyle**, **Russ Gilbert**, **Ronny Khalil**, **Michael Carychao**, **Justin Aharoni**, **melt.immersive**, **Chris The Wizard**, **Diego Chavez**, **Willis Hsieh**, **Romuprod**, **HogWorld**, **Diesellord**, **PatchesFlows**, **Hueman Instrument**, **Sam Schrag**, **Laurent Plique**, **noviy maggg**, **Wayne Ellis**, **Richard Brooker**, **Aixsponza**, **Frank Hegedus**, **Shai Ankori**, **Pedro Henrique**, **TaijiNinja**, **Ryry**, **pierre zandrowicz**, **Christian Nielsen**, **Remi**, **Above The Void**, **Kytra@Midnight**, **Justin Bortnick aka RizzleRoc**, **Daniel Bedingfield**, **Marc Donahue**, **Digital Grails**, **Oleg Pashkovsky**, **Russ Creech**, **Kyle Trewartha**, **Ainoise**, **Omar Karim**, **Aether Elf**, **Ken Hill**, **DreamMachineAI**, **Kevin Doan**, **Prime Duo**, **Gili Ben Shahar**, **Ivan Morales**, **Jana Spacelight**, **Hassan Ragab**, **Jack**, **JB**, **JustMaier**, **Marcus**, **Diego Mellado Alarcon**, **Justin Mayers**, **James Stewart**, **ענבר לוי**, **Tyler Bernabe**, **EgorGa**, **Тот Самый Santa**, **levelbevel**, **Luis Gerardo Díaz**, **kc**, **Sofia Riccheri**, **James Gerde**, **Pete Puskas**, **Accounts PS**, **Mike Rowley**, **Hector Gomez Espinosa**, **Jorge Aguiñiga**, **gary chan**, **Aidan Poole**, **Renz**, **mike waller**, **EdwinAi**, **Itay Schiff**, **Eran Mahalu**, **Nate Denton**, **Mohamed Afifi**, **Thomas**, **jdevant**, **patrick van der vliet**, **Parshant Loungani**, **MainlyHighDotCom**, **Filip Fajt**, **WRAP Party**, **RABBITATTITUDE**, **Xclusive Connex**, **Trent Hunter**, **travis miller**, **Randy Gardner**, **Александра**, **Christopher Reid**, **Akshay Rustagi**, **Chris Kimling**, **Gregory Brodzik**, **Antioch Sanders**, **DJ Kre**, **Jenasis**, **Jeff Glazer**, **Ivan Navi**, **DeadTube**, **William Fraley**, **Edward Hickman**, **Junior**, **Michael Spadoni**, **Machenna56k**, **Keith Lambert**, **Tony Rizko**, **Dulce Baerga**, **Daniel**, **Ali Kadhim**, **Liliana Jaén Russo**, **Starhaven**, **Jonas Aarmo**, **Fredrik**, **Tommie Geraedts**, **VPM**, **H. Mateo Prio Sanchez**, **Prime Devation**, **Daniel Sheryshev**, **motionCA2023**, **HOOD GAMER**, **Matt Reardon**, **Rui Li**, **James McDowell**, **john doe**, **ed.enai**, **waveliterature**, **Mirrorsvn**, **Somebody that i used to know**, **ThurberMingus**, **Mohammad Alkloub**, **Sergiy Dovgal**, **sam**, **Wayne Davis**, **Ann Petrosyan**, **Roman Degtiarev**, **Steven Van Loon**, **Alex Royal**, **Noel Bangne**, **Maximus Powers**, **Martin W**, **Thomas Folk**, **Max Tran**, **hagai assouline**, **Yaho omail**, **David Cazal**, **Аnna Baum**, **Christian Acuna**, **Jules FilmsFX**, **fo Ti**, **Rafael Spanopoulos**, **Caíque Ribeiro Barboza**, **benjamin west**, **Daniel Stremel**, **The Artist**, **Ornecq Anthony**, **Stephane Bhiri**, **Alex from Directia.fr**, **Random Like U**, **Jordan Nering**, **Mazlum Batan**, **c0ffymachyne**, **D2GW**, **wcngai**, **Se7enS**, **kinsam lin**, **Bunseng Chuor**, **LX Cui**, **Ian MacKay**, **Molotov Cocktail**, **EDDY CHEN**, **Raghu Ram**, **She Devil Films**, **levan**, **Domm Dynamite**, **Ryan Guiterman**, **Kirill Shevyakov**, **Riker Gold**, **Manuel Radl**, **Adam Ivie**, **InfinityDeltaX**, **szj**, **John Walters**, **Hennesii**, **martinej Hoccus**, **반석 김**, **Iván Guillén**, **James**, **PENT**, **Alonso**, **Johnny B Good**, **JoKr**, **Андрей Корниенко**, **you xi**, **Sam Hains**, **ULTRA VFX**, **Kevin Sasso**, **qingfu Qin**, **indievish**, **Mark Irvine**, **mehdi Nejad**, **Victor Krachinov**, **Andy**, **Marco**, **WinkJP**, **Rovonn Russell**, **Cainz Juss**, **Quentin Uriel**, **Peter**, **Surog**, **philipp**, **Benjámin Madarász**, **James H**, **Tony Barnhill**, **Hanish Keloth**, **Directed by stro**, **Ramesh Kumar N**, **Jae Kim**, **Anthony Seychelles**, **侠 大**, **Onoff Hentai**, **Visual Science**, **Anup prabhakar**, **ENFERMATIKO**, **Kilo Boozer**, **Edward Sarker**, **Salar**, **Jeff Valle**, **Dull Haven**, **Pheo**, **abhinav tayal**, **Prabhakaran Raj**, **Thomas Mebane**, **sahar sobhani**, **Mimic The Mad**, **David Rutgos**, **Vincent Shirayuki**, **ALMAANY PRODUCTION**, **Nikola Serafimovski**, **Cere Toon**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " and all my patreon supporters for their endless support and constructive feedback!\\\n",
+ "Here's the current [public warp](https://colab.research.google.com/github/Sxela/DiscoDiffusion-Warp/blob/main/Disco_Diffusion_v5_2_Warp.ipynb) for videos with openai diffusion model\n",
+ "\n",
+ "# WarpFusion v0.24 by [Alex Spirin](https://twitter.com/devdef)\n",
+ "\n",
+ "\n",
+ "This version improves video init. You can now generate optical flow maps from input videos, and use those to:\n",
+ "- warp init frames for consistent style\n",
+ "- warp processed frames for less noise in final video\n",
+ "\n",
+ "\n",
+ "\n",
+ "##Init warping\n",
+ "####[Explanation video](https://www.youtube.com/watch?v=ZuPBDRjwtu0)\n",
+ "The feature works like this: we take the 1st frame, diffuse it as usual as an image input with fixed skip steps. Then we warp in with its flow map into the 2nd frame and blend it with the original raw video 2nd frame. This way we get the style from heavily stylized 1st frame (warped accordingly) and content from 2nd frame (to reduce warping artifacts and prevent overexposure)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------\n",
+ "\n",
+ "This is a variation of the awesome [DiscoDiffusion colab](https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb#scrollTo=Changelog)\n",
+ "\n",
+ "If you like what I'm doing you can\n",
+ "- follow me on [twitter](https://twitter.com/devdef)\n",
+ "- tip me on [patreon](https://www.patreon.com/sxela)\n",
+ "\n",
+ "\n",
+ "Thank you for being awesome!\n",
+ "\n",
+ "--------------------------------------\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------\n",
+ "\n",
+ "This notebook was based on DiscoDiffusion (though it's not much like it anymore)\\\n",
+ "To learn more about DiscoDiffusion, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kDKwhb8xiKwu"
+ },
+ "source": [
+ "# Changelog, credits & license"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WrxXo2FVivvi"
+ },
+ "source": [
+ "### Changelog"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1ort2i_yiD51"
+ },
+ "source": [
+ "**v0.24**\n",
+ "20.10.2023\n",
+ "- fix pytorch dependencies error\n",
+ "- fix zoe depth error\n",
+ "- move installers to github repo\n",
+ "\n",
+ "16.10.2023\n",
+ "- fix pillow errors (UnidentifiedImageError: cannot identify image file)\n",
+ "- fix timm import error (isDirectory error)\n",
+ "- deprecate v2_depth model (use depth controlnet instead)\n",
+ "\n",
+ "14.10.2023\n",
+ "- fix xformers version\n",
+ "- fix flow preview error for less than 10 frames\n",
+ "\n",
+ "6.10.2023\n",
+ "- fix controlnet preview (next_frame error)\n",
+ "- fix dwpose 'final_boxes' error for frames with no people\n",
+ "- move width_height to video init cell to avoid people forgetting to run it to update width_height\n",
+ "\n",
+ "26.09.2023\n",
+ "- add FreeU Hack from https://huggingface.co/papers/2309.11497\n",
+ "- add option to apply FreeU before or after controlnet outputs\n",
+ "- add inpaint-softedge and temporal-depth controlnet models\n",
+ "- auto-download inpaint-softedge and temporal-depth checkpoints\n",
+ "- fix sd21 lineart model not working\n",
+ "- refactor get_controlnet_annotations a bit\n",
+ "- add inpaint-softedge and temporal-depth controlnet preprocessors\n",
+ "\n",
+ "**v0.23**\\\n",
+ "19.09.2023\n",
+ "- fix gui for non controlnet-mode\n",
+ "- add deflicker for win\n",
+ "- add experimental deflicker from https://video.stackexchange.com/questions/23384/remove-flickering-due-to-artificial-light-with-ffmpeg\n",
+ "- fix linear and none blend modes for video export\n",
+ "+ detect init_video fps and pass down to video export with respect to nth frame\n",
+ "+ do not reload already loaded controlnets\n",
+ "\n",
+ "18.09.2023\n",
+ "+ rename upscaler model path variable\n",
+ "+ make mask use image folder correctly as a mask source\n",
+ "\n",
+ "17.09.2023\n",
+ "+ add dw pose estimator from https://github.com/IDEA-Research/DWPose\n",
+ "+ add onnxruntime-gpu install, (update env for dw_pose)\n",
+ "+ add dw_pose model downloader\n",
+ "+ add controlnet preview, kudos to #kytr.ai, idea - https://discord.com/channels/973802253204996116/1124468917671309352\n",
+ "+ add temporalnet sdxl - v1 (3-channel)\n",
+ "+ add prores thanks to #sozeditit https://discord.com/channels/973802253204996116/1149027955998195742\n",
+ "+ make width_height accept 1 number to resize frame to that size keeping aspect ratio\n",
+ "+ add cc masked template for content-aware scheduling\n",
+ "+ add reverse frames extraction\n",
+ "+ move looped image to video init cell as video source mode\n",
+ "\n",
+ "+ fix settings not being loaded via button\n",
+ "+ fix bug when cc_masked_diffusion == 0\n",
+ "+ add message on missing audio during video export / mute exception\n",
+ "+ go back to root dir after running rife\n",
+ "\n",
+ "**v0.22**\\\n",
+ "13.09.2023\n",
+ "+ fix \"TypeError: eval() arg 1...\" error when loading non-existent settings on the initial run\n",
+ "\n",
+ "11.09.2023\n",
+ "+ add error message for model version mismatch\n",
+ "+ download dummypkl automatically\n",
+ "+ fix venv install real-esrgan model folder not being created\n",
+ "\n",
+ "8.09.2023\n",
+ "+ fix samtrack site-packages url\n",
+ "+ fix samtrack missing groundingdino config\n",
+ "\n",
+ "7.09.2023\n",
+ "+ make samtrack save separate bg mask\n",
+ "\n",
+ "6.09.2023\n",
+ "+ fix rife imports\n",
+ "+ fix samtrack imports\n",
+ "+ fix samtrack not saving video\n",
+ "\n",
+ "5.09.2023\n",
+ "+ add rife\n",
+ "+ fix samtrack imports\n",
+ "+ fix rife imports\n",
+ "\n",
+ "3.09.2023\n",
+ "+ fix samtrack local install for windows\n",
+ "+ fix samtrack incorrect frame indexing if starting not from 1st frame\n",
+ "+ fix schedules not loading\n",
+ "\n",
+ "30.08.2023\n",
+ "+ fix ---> 81 os.chdir(f'{root_dir}/Real-ESRGAN') file not found error thanks to Leandro Dreger\n",
+ "\n",
+ "29.08.2023\n",
+ "+ hide \"warp_mode\",\"use_patchmatch_inpaiting\",\"warp_num_k\",\"warp_forward\",\"sat_scale\" from gui as deprecated\n",
+ "+ clean up gui settings setters/getters\n",
+ "+ fix contronet not updating in gui sometimes\n",
+ "\n",
+ "**v0.21**\\\n",
+ "3.09.2023\n",
+ "+ add dummy model init for sdxl - won't download unnecessary stuff\n",
+ "\n",
+ "28.08.2023\n",
+ "+ add v1 qr controlnet\n",
+ "+ add v2 contronets: qr, depth, scribble, openpose, normalbae, lineart, softedge, canny, seg\n",
+ "+ upcast custom controlnet sources to RGB from Grayscale\n",
+ "+ add v2_768 control_multi mode\n",
+ "\n",
+ "**v0.20**\\\n",
+ "26.08.2023\n",
+ "+ temporarily disable reference for sdxl\n",
+ "\n",
+ "24.08.2023\n",
+ "+ fix ModuleNotFound: safetensors error\n",
+ "+ fix cv2.error ssize.empty error for face controlnet\n",
+ "\n",
+ "23.08.2023\n",
+ "+ fix clip import error\n",
+ "\n",
+ "22.08.2023\n",
+ "+ add controlnet tile\n",
+ "+ remove single controlnets from model versions\n",
+ "+ fix guidance for controlnet_multi (now work with much higher clamp_max)\n",
+ "+ fix instructpix2pix not working in newer versions (kudos to #stabbyrobot)\n",
+ "+ fix AttributeError: 'OpenAIWrapper' object has no attribute 'get_dtype' error\n",
+ "\n",
+ "21.08.2023\n",
+ "+ add control-lora loader from [ComfyUI](https://github.com/comfyanonymous/ComfyUI)\n",
+ "+ add stability ai control-loras: depth, softedge, canny\n",
+ "+ refactor controlnet code a bit\n",
+ "+ fix tiled vae for sdxl\n",
+ "+ stop on black frames, print comprehensive error message\n",
+ "\n",
+ "20.08.2023\n",
+ "+ add cell execution check (kudos to #soze)\n",
+ "+ add skip diffusion switch to generate video only (kudos to #soze)\n",
+ "+ add error msg when creating video from 0 frames\n",
+ "\n",
+ "19.08.2023\n",
+ "+ fix rec noise for control multi sdxl mode\n",
+ "+ fix control mode\n",
+ "+ fix control_multi annotator folder error\n",
+ "\n",
+ "18.08.2023\n",
+ "+ add sdxl diffusers controlnet loader from [ComfyUI](https://github.com/comfyanonymous/ComfyUI)\n",
+ "+ add sdxl controlnets\n",
+ "+ save annotators to controlnet folder\n",
+ "+ hide sdxl model load spam\n",
+ "+ fix sdxl tiled vae errors (still gives black output on sdxl with vanilla vae)\n",
+ "+ fix cc_masked_diffusion not loaded from gui\n",
+ "\n",
+ "vae)\n",
+ "\n",
+ "**v0.19**\\\n",
+ "14.08.2023\n",
+ "+ fix beep error\n",
+ "+ bring back init scale, fix deflicker init scale error thanks to #rebirthai\n",
+ "+ make cc_masked_diffusion a schedule\n",
+ "\n",
+ "13.08.2023\n",
+ "\n",
+ "+ add extra per-controlnet settings: source, mode, resolution, preprocess\n",
+ "+ add global and per-controlnet settings to gui\n",
+ "+ add beep sounds by Infinitevibes\n",
+ "+ add normalize controlnet weights toggle\n",
+ "\n",
+ "**v0.18**\\\n",
+ "10.08.2023\n",
+ "+ refactor lora support\n",
+ "+ add other lora-like models support from automatic1111\n",
+ "+ fix loras not being unloaded correctly\n",
+ "\n",
+ "9.08.2023\n",
+ "+ add sdxl lora support\n",
+ "+ fix load settings file = -1 not getting latest file\n",
+ "\n",
+ "5.08.2023\n",
+ "+ add weighted keywords support\n",
+ "+ clear gpu vram on render interrupt\n",
+ "+ offload openpose models\n",
+ "\n",
+ "4.08.2023\n",
+ "+ add sdxl_refiner support (same limitations apply)\n",
+ "\n",
+ "3.08.2023\n",
+ "+ add sdxl_base support\n",
+ "+ temporarily disable loras/lycoris/init_scale for sdxl\n",
+ "+ bring back xformers :D\n",
+ "\n",
+ "**v0.17**\\\n",
+ "17.03.2023\n",
+ "+ fix SAMTrack ckpt folder error ty to laplaceswzz\n",
+ "+ fix SAMTrack video export error\n",
+ "\n",
+ "16.07.2023\n",
+ "+ add lycoris\n",
+ "+ add lycoris/lora selector\n",
+ "\n",
+ "10.07.2023\n",
+ "+ fix settings_path by #sozeditit\n",
+ "\n",
+ "8.7.2023\n",
+ "+ add SAMTrack from https://github.com/z-x-yang/Segment-and-Track-Anything\n",
+ "+ fix bug with 2+ masks\n",
+ "+ print error when reloading gui fails\n",
+ "+ renamed default_settings_path to settings_path and load_default_settings to load_settings_from_file\n",
+ "\n",
+ "\n",
+ "**v0.16**\\\n",
+ "5.7.2023\n",
+ "+ fix torchmetrics version thx to tomatoslasher\n",
+ "\n",
+ "27.6.2023\n",
+ "+ fix reference controlnet not working with multiprompt\n",
+ "\n",
+ "26.6.2023\n",
+ "+ fix consistency error between identical frames\n",
+ "+ asd ffmpeg deflicker option to video export (dfl postifx)\n",
+ "+ export video with inv postfix for inverted mask video\n",
+ "\n",
+ "25.6.2023\n",
+ "+ add deflicker sd_batch_size,\n",
+ "normalize_prompt_weights,\n",
+ "mask_paths,\n",
+ "deflicker_scale,\n",
+ "deflicker_latent_scale to gui/saved settings\n",
+ "\n",
+ "22.6.2023\n",
+ "+ fix compare settings not working in new run\n",
+ "\n",
+ "21.6.2023\n",
+ "+ add universal frame loader\n",
+ "+ add masked prompts support for controlnet_multi-internal mode\n",
+ "+ add masked prompts for other modes\n",
+ "+ fix undefined mask error\n",
+ "\n",
+ "20.6.2023\n",
+ "+ add support for different prompt number per frame\n",
+ "+ add prompt weight blending between frames\n",
+ "\n",
+ "16.6.2023\n",
+ "+ add prompt weight parser/splitter\n",
+ "+ update lora parser to support multiple prompts. if duplicate loras are used in more than 1 prompt, last prompt lora weights will be used\n",
+ "+ unload unused loras\n",
+ "\n",
+ "10.6.2023\n",
+ "+ add multiple prompts support\n",
+ "+ add max batch size\n",
+ "+ add prompt weights\n",
+ "\n",
+ "**v0.15**\\\n",
+ "12.6.2023\n",
+ "+ fix audio not being saved for relative init video path ty to @louis.jeck\n",
+ "\n",
+ "7.6.2023\n",
+ "+ add pattern replacement (filtering) for prompt\n",
+ "+ fix constant bitrate error during video export causing noise in high-res videos\n",
+ "+ fix full-screen consistency mask error (also bumped missed consistency dilation to 2)\n",
+ "\n",
+ "5.6.2023\n",
+ "+ add alpha masked diffusion\n",
+ "+ add inverse alpha mask diffusion\n",
+ "+ save settings to exif\n",
+ "+ backup existing settings on resume run\n",
+ "+ load settings from png exif\n",
+ "+ add beep\n",
+ "+ move consistency mask dilation to render settings\n",
+ "\n",
+ "**v0.14**\n",
+ "\n",
+ "2.06.2023\n",
+ "+ fix torch v2 install\n",
+ "\n",
+ "1.06.2023\n",
+ "+ fix oom errors during upscaler (offload more, retry failed frame 1 time)\n",
+ "\n",
+ "31.05.2023\n",
+ "+ fix flow preview generation out of range exception\n",
+ "+ fix realesrgan not found error\n",
+ "+ fix upscale ratio not being int error\n",
+ "+ fix black screen when using tiled vae (because of divisible by 8)\n",
+ "\n",
+ "30.05.2023\n",
+ "+ fix markupsafe install error on local install\n",
+ "+ add realesrgan upscaler for video export\n",
+ "+ save upscaled video under different name\n",
+ "\n",
+ "29.05.2023\n",
+ "+ save controlnet debug with 6 digit frame numbers\n",
+ "+ fix torch not found message with installed torch\n",
+ "+ allow output size as multiple of 8 (add hack from auto1111 to controlnets)\n",
+ "\n",
+ "27.05.2023\n",
+ "+ add option to keep audio from the init video\n",
+ "\n",
+ "26.05.2023\n",
+ "+ extract videoFrames to folder named after video metadata\n",
+ "+ extract from to folder named after flow source metadata and resolution\n",
+ "+ auto re-create flow on video/resolution change\n",
+ "\n",
+ "25.05.2023\n",
+ "+ add deflicker losses (the effect needs to be tested)\n",
+ "\n",
+ "24.05.2023\n",
+ "+ add safetensors support for vae\n",
+ "+ replace blend_code and normalize_code for start_code with code_randomness\n",
+ "+ fix control_inpainting_mask=None mode error\n",
+ "+ rename mask_callback to masked_diffusion\n",
+ "+ hide extra settings for tiled vae\n",
+ "+ set controlnet weight to 1 when turning it on\n",
+ "+ safeguard controlnet model dir from empty value\n",
+ "\n",
+ "24.05.2023\n",
+ "+ fix reference infinite recursion error\n",
+ "+ fix prompt schedules not working with \"0\"-like keys\n",
+ "\n",
+ "22.05.2023\n",
+ "+ remove torch downgrade for colab\n",
+ "+ remove xformers for torch v2/colab\n",
+ "+ add sdp attention from AUTOMATIC1111 to replace xformers (for torch v2)\n",
+ "+ fix outer not defined error for reference\n",
+ "\n",
+ "21.05.2023\n",
+ "+ skip flow preview generation if it fails\n",
+ "+ downgrade to torch v1.13 for colab hosted env\n",
+ "\n",
+ "18.05.2023\n",
+ "+ add reference controlner (attention injection)\n",
+ "+ add reference mode and source\n",
+ "\n",
+ "17.05.2023\n",
+ "+ save schedules to settings before applying templates\n",
+ "+ add gui options to load settings, keep state on rerun/load from previous cells\n",
+ "+ fix schedules not kept on gui rerun\n",
+ "+ rename depth_source to cond_image_src to reflect it's actual purpose\n",
+ "\n",
+ "13.05.2023\n",
+ "+ auto skip install for our docker env\n",
+ "+ fix AttributeError: 'Block' object has no attribute 'drop_path' error for depth controlnet\n",
+ "\n",
+ "9.05.2023\n",
+ "+ swtich to AGPL license\n",
+ "+ downgrade torch v2 to 2.0.0 to fix install errors\n",
+ "+ add alternative consistency algo (also faster)\n",
+ "\n",
+ "8.05.2023\n",
+ "+ make torch v2 install optional\n",
+ "+ make installation skippable for consecutive runs\n",
+ "\n",
+ "6.05.2023\n",
+ "+ fix flow preview not being shown\n",
+ "+ fix prompt schedules not working in some cases\n",
+ "+ fix captions not updating in rec prompt\n",
+ "\n",
+ "5.05.2023\n",
+ "+ fix controlnet inpainting model None and cond_video modes\n",
+ "+ clean some discodiffusion legacy code (it's been a year :D)\n",
+ "+ add controlnet default main model (v1.5)\n",
+ "\n",
+ "29.04.2023\n",
+ "+ add controlnet multimodel options to gui (thanks to Gateway#5208)\n",
+ "+ tidy up the colab interface a bit\n",
+ "+ fix dependency errors for uformer\n",
+ "+ fix lineart/anime lineart errors\n",
+ "\n",
+ "27.04.2023\n",
+ "+ add controlnet v1.1 annotator options to gui\n",
+ "+ fix controlnet cond_image for cond_video with no preprocess\n",
+ "+ fix colormatch stylized frame not working with frame range\n",
+ "\n",
+ "26.04.2023\n",
+ "+ add consistency controls to video export cell\n",
+ "+ add rec noise save\\load for compatible settings\n",
+ "+ fix local install\n",
+ "+ save rec noise cache to recNoiseCache folder in project dir\n",
+ "\n",
+ "24.04.2023\n",
+ "+ add shuffle controlnet\n",
+ "+ add shuffle controlnet sources\n",
+ "\n",
+ "22.04.2023\n",
+ "+ add ip2p, lineart, lineart anime controlnet\n",
+ "+ bring together most of the installs (faster install, even faster restart & run all), only tested on colab\n",
+ "\n",
+ "21.04.2023\n",
+ "+ fix zoe depth model producing black frames with autocast on\n",
+ "\n",
+ "20.04.2023\n",
+ "+ add tiled vae\n",
+ "+ switch to controlnet v1.1 repo\n",
+ "+ update controlnet model urls and filenames to v1.1 and new naming convention\n",
+ "+ update existing controlnet modes to v1.1\n",
+ "\n",
+ "8.04.2023\n",
+ "+ add lora support thanks to brbbbq :D\n",
+ "+ add loras schedule parsing from prompts\n",
+ "+ add custom path for loras\n",
+ "+ add custom path from embeddings thanks to brbbbq\n",
+ "+ add torch built-in raft implementation\n",
+ "+ fix threads error on video export\n",
+ "+ disable guidance for lora\n",
+ "+ add compile option for raft @ torch v2 (a100 only)\n",
+ "+ force torch downgrade for T4 GPU on colab\n",
+ "+ add faces controlnet from https://huggingface.co/CrucibleAI/ControlNetMediaPipeFace\n",
+ "+ make gui not reset on run cell (there is still javascript delay before input is saved)\n",
+ "+ add custom download folder for controlnets\n",
+ "\n",
+ "3.04.2023\n",
+ "+ add rec steps % option\n",
+ "+ add masked guidance toggle to gui\n",
+ "+ add masked diffusion toggle to gui\n",
+ "+ add softclamp to gui\n",
+ "+ add temporalnet settings to gui\n",
+ "+ add controlnet annotator settings to gui\n",
+ "+ hide sat_scale (causes black screen)\n",
+ "+ hide inpainting model-specific settings\n",
+ "+ hide instructpix2pix-scpecific settings\n",
+ "+ add rec noise to gui\n",
+ "+ add predicted noise mode (reconstruction / rec) from AUTO111 and https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736\n",
+ "+ add prompt schedule for rec\n",
+ "+ add cfg scale for rec\n",
+ "+ add captions support to rec prompt\n",
+ "+ add source selector for rec noise\n",
+ "+ add temporalnet source selector (init/stylized)\n",
+ "+ skip temporalnet for 1st frame\n",
+ "+ add v1/v2 support for rec noise\n",
+ "+ add single controlnet support for rec noise\n",
+ "+ add multi controlnet to rec noise\n",
+ "\n",
+ "29.03.2023\n",
+ "- add TemporalNet from https://huggingface.co/CiaraRowles/TemporalNet\n",
+ "\n",
+ "17.03.2023\n",
+ "- fix resume_run not working properly\n",
+ "\n",
+ "08.03.2023\n",
+ "- fix PIL error in colab\n",
+ "- auto pull a fresh ControlNet thanks to Jonas.Klesen#8793\n",
+ "\n",
+ "07.03.2023\n",
+ "- add multicontrolnet\n",
+ "- add multicontrolnet autoloader\n",
+ "- add multicontrolnet weight, start/end steps, internal\\external mode\n",
+ "- add multicontrolnet/annotator cpu offload mode\n",
+ "- add empty negative image condition\n",
+ "- add softcap image range scaler\n",
+ "- add no_half_vae mode\n",
+ "- cast contolnet to fp16\n",
+ "\n",
+ "28.02.2023\n",
+ "- add separate base model for controlnet support\n",
+ "- add smaller controlnet support\n",
+ "- add invert mask for masked guidance\n",
+ "\n",
+ "27.02.2023\n",
+ "- fix frame_range starting not from zero not working\n",
+ "- add option to offload model before decoder stage\n",
+ "- add fix noise option for latent guidance\n",
+ "- add masked diffusion callback\n",
+ "- add noise, noise scale, fixed noise to masked diffusion\n",
+ "- add masked latent guidance\n",
+ "- add controlnet_preprocessing switch to allow raw input\n",
+ "- fix sampler being locked to euler\n",
+ "\n",
+ "26.02.2023\n",
+ "- fix prompts not working for loaded settings\n",
+ "\n",
+ "24.02.2023\n",
+ "- fix load settings not working for filepath\n",
+ "- fix norm colormatch error\n",
+ "- fix warp latent mode error\n",
+ "\n",
+ "\n",
+ "21.02.2023\n",
+ "- fix image_resolution error for controlnet models\n",
+ "- fix controlnet models not downloading (file not found error)\n",
+ "- fix settings not loading with -1 and empty batch folder\n",
+ "\n",
+ "18.02.2023\n",
+ "- add ControlNet models from https://github.com/lllyasviel/ControlNet\n",
+ "- add ControlNet downloads from https://colab.research.google.com/drive/1VRrDqT6xeETfMsfqYuCGhwdxcC2kLd2P\n",
+ "- add settings for ControlNet: canny filter ranges, detection size for depth/norm and other models\n",
+ "- add vae ckpt load for non-ControlNet models\n",
+ "- add selection by number to compare settings cell\n",
+ "- add noise to guiding image (init scale, latent scale)\n",
+ "- add noise resolution\n",
+ "- add guidance function selection for init scale\n",
+ "- add fixed seed option (completely fixed seed, not like fixed code)\n",
+ "\n",
+ "\n",
+ "14.02.2023\n",
+ "- add instruct pix2pix from https://github.com/timothybrooks/instruct-pix2pix\n",
+ "- add image_scale_schedule to support instruct pix2pix\n",
+ "- add frame_range to render a selected range of extracted frames only\n",
+ "- add load settings by run number\n",
+ "- add model cpu-gpu offload to free some vram\n",
+ "- fix promts not being loaded from saved settings\n",
+ "- fix xformers cell hanging on Overwrite user query\n",
+ "- fix sampler not being loaded\n",
+ "- fix description_tooltip=turbo_frame_skips_steps error\n",
+ "- fix -1 settings not loading in empty folder\n",
+ "- fix colormatch offset mode first frame not found error\n",
+ "\n",
+ "\n",
+ "12.02.2023\n",
+ "- fix colormatch first frame error\n",
+ "- fix raft_model not found error when generate flow cell is run after the cell with warp_towards_init\n",
+ "\n",
+ "10.02.2023\n",
+ "- fix ANSI encoding error\n",
+ "- fix videoFramesCaptions error when captions were off\n",
+ "- fix caption keyframes \"nonetype is not iterable\" error\n",
+ "\n",
+ "9.02.2023\n",
+ "- fix blip config path\n",
+ "- shift caption frame number\n",
+ "\n",
+ "8.02.2023\n",
+ "- add separate color video / image as colormatch source\n",
+ "- add color video to gui options\n",
+ "- fix init frame / stylized frame colormatch with offset 0 error\n",
+ "- save settings to settings folder\n",
+ "- fix batchnum not inferred correctly due to moved settings files\n",
+ "\n",
+ "7.02.2023\n",
+ "- add conditioning source video for depth, inpainting models\n",
+ "- add conditioning source to gui\n",
+ "- add automatic caption generation\n",
+ "- add caption syntax to prompts\n",
+ "- convert loaded setting keys to int (\"0\" -> 0)\n",
+ "\n",
+ "2.02.2023\n",
+ "- fix xformers install in colab A100\n",
+ "- fix load default settings not working\n",
+ "- fix mask_clip not loading from settings\n",
+ "\n",
+ "26.01.2022\n",
+ "- fix error saving settings with content aware schedules\n",
+ "- fix legacy normalize latent = first latent error\n",
+ "\n",
+ "25.01.2023\n",
+ "- add ffmpeg instal for windows\n",
+ "- add torch install/reinstall options for windows\n",
+ "- add xformers install/reinstall for windows\n",
+ "- disable shutdown runtime for non-colab env\n",
+ "12.01.2023\n",
+ "- add embeddings, prompt attention weights from AUTOMATIC1111 webui repo\n",
+ "- bring back colab gui for compatibility\n",
+ "- add default settings option\n",
+ "- add frame difference manual override for content-aware scheduling\n",
+ "- fix content-aware schedules\n",
+ "- fix PDF color transfer being used by default with LAB selected\n",
+ "- fix xformers URL for colab\n",
+ "- remove PatchMatch to fix jacinle/travis error on colab\n",
+ "- print settings on error while saving\n",
+ "- reload embeddings in Diffuse! cell\n",
+ "- fix pkl not saveable after loading embeddings\n",
+ "- fix xformers install (no need to downgrade torch thanks to TheFastBen)\n",
+ "\n",
+ "27.12.2023\n",
+ "- add v1 runwayml inpainting model support\n",
+ "- add inpainting mask source\n",
+ "- add inpainting mask strength\n",
+ "\n",
+ "\n",
+ "23.12.2022\n",
+ "- add samplers\n",
+ "- add v2-768-v support\n",
+ "- add mask clipping\n",
+ "- add tooltips to gui settings\n",
+ "- fix consistency map generation error (thanks to maurerflower)\n",
+ "- fix colab crashing on low vram env during model loading\n",
+ "- fix xformers install on colab\n",
+ "\n",
+ "17.12.2022\n",
+ "- add first beta gui\n",
+ "- remove settings included in gui from notebook\n",
+ "- add fix for loading pkl models on a100 that were saved not on a100\n",
+ "- fix gpu variable bug on local env\n",
+ "\n",
+ "13.12.2022\n",
+ "- downgrade torch without restart\n",
+ "\n",
+ "7.12.2022\n",
+ "- add v2-depth support\n",
+ "- add v2 support\n",
+ "- add v1 backwards compatibility\n",
+ "- add model selector\n",
+ "- add depth source option: prev frame or raw frame\n",
+ "- add fix for TIFF bug\n",
+ "- add torch downgrade for colab xformers\n",
+ "- add forward patch-warping option (beta, no consistency support yet)\n",
+ "- add *.mov video export thanks to cerspense#3301\n",
+ "\n",
+ "5.12.2022\n",
+ "- add force_os for xformers setup\n",
+ "- add load ckpt onto gpu option\n",
+ "\n",
+ "2.12.2022\n",
+ "- add colormatch turbo frames toggle\n",
+ "- add colormatch before stylizing toggle\n",
+ "- add faster flow generation (up to x4 depending on disk bandwidth)\n",
+ "- add faster flow-blended video export (up to x10 depending on disk bandwidth)\n",
+ "- add 10 evenly spaced frames' previews for flow and consistency maps\n",
+ "- add warning for missing ffmpeg on windows\n",
+ "- fix installation not working after being interrupted\n",
+ "- fix xformers install for A*000 series cards.\n",
+ "- fix error during RAFT init for non 3.7 python envs\n",
+ "- fix settings comparing typo\n",
+ "\n",
+ "24.11.2022\n",
+ "- fix int() casting error for flow remapping\n",
+ "- remove int() casting for frames' schedule\n",
+ "- turn off inconsistent areas' color matching (was calculated even when off)\n",
+ "- fix settings' comparison\n",
+ "\n",
+ "23.11.2022\n",
+ "- fix writefile for non-colab interface\n",
+ "- add xformers install for linux/windows\n",
+ "\n",
+ "20.11.2022\n",
+ "- add patchmatch inpainting for inconsistent areas\n",
+ "- add warp towards init (thanks to [Zippika](https://twitter.com/AlexanderRedde3) from [deforum](https://github.com/deforum/stable-diffusion) team\n",
+ "- add grad with respect to denoised latent, not input (4x faster) (thanks to EnzymeZoo from [deforum](https://github.com/deforum/stable-diffusion) team\n",
+ "- add init/frame scale towards real frame option (thanks to [Zippika](https://twitter.com/AlexanderRedde3) from [deforum](https://github.com/deforum/stable-diffusion) team\n",
+ "- add json schedules\n",
+ "- add settings comparison (thanks to brbbbq)\n",
+ "- save output videos to a separate video folder (thanks to Colton)\n",
+ "- fix xformers not loading until restart\n",
+ "\n",
+ "14.11.2022\n",
+ "- add xformers for colab\n",
+ "- add latent init blending\n",
+ "- fix init scale loss to use 1/2 sized images\n",
+ "- add verbose mode\n",
+ "- fix frame correction for non-existent reference frames\n",
+ "- fix user-defined latent stats to support 4 channels (4d)\n",
+ "- fix start code to use 4d norm\n",
+ "- track latent stats across all frames\n",
+ "- print latent norm average stats\n",
+ "\n",
+ "11.11.2022\n",
+ "- add latent warp mode\n",
+ "- add consistency support for latent warp mode\n",
+ "- add masking support for latent warp mode\n",
+ "- add normalize_latent modes: init_frame, init_frame_offset, stylized_frame, stylized_frame_offset\n",
+ "- add normalize latent offset setting\n",
+ "\n",
+ "4.11.2022\n",
+ "- add normalize_latent modes: off, first_latent, user_defined\n",
+ "- add normalize_latent user preset std and mean settings\n",
+ "- add latent_norm_4d setting for per-channel latent normalization (was off in legacy colabs)\n",
+ "- add colormatch_frame modes: off, init_frame, init_frame_offset, stylized_frame, stylized_frame_offset\n",
+ "- add color match algorithm selection: LAB, PDF, mean (LAB was the default in legacy colabs)\n",
+ "- add color match offset setting\n",
+ "- add color match regrain flag\n",
+ "- add color match strength\n",
+ "\n",
+ "30.10.2022\n",
+ "- add cfg_scale schedule\n",
+ "- add option to apply schedule templates to peak difference frames only\n",
+ "- add flow multiplier (for glitches)\n",
+ "- add flow remapping (for even more glitches)\n",
+ "- add inverse mask\n",
+ "- fix masking in turbo mode (hopefully)\n",
+ "- fix deleting videoframes not working in some cases\n",
+ "\n",
+ "26.10.2022\n",
+ "- add negative prompts\n",
+ "- move google drive init cell higher\n",
+ "\n",
+ "22.10.2022\n",
+ "- add background mask support\n",
+ "- add background mask extraction from video (using https://github.com/Sxela/RobustVideoMattingCLI)\n",
+ "- add separate mask options during render and video creation\n",
+ "\n",
+ "21.10.2022\n",
+ "- add match first frame color toggle\n",
+ "- add match first frame latent option\n",
+ "- add karras noise + ramp up options\n",
+ "\n",
+ "11.10.2022\n",
+ "- add frame difference analysis\n",
+ "- make preview persistent\n",
+ "- fix some bugs with images not being sent\n",
+ "\n",
+ "9.10.2022\n",
+ "- add steps scheduling\n",
+ "- add init_scale scheduling\n",
+ "- add init_latent_scale scheduling\n",
+ "\n",
+ "8.10.2022\n",
+ "- add skip steps scheduling\n",
+ "- add flow_blend scheduling\n",
+ "\n",
+ "2.10.2022\n",
+ "- add auto session shutdown after run\n",
+ "- add awesome user-generated guide\n",
+ "\n",
+ "23.09.2022\n",
+ "- add channel mixing for consistency masks\n",
+ "- add multilayer consistency masks\n",
+ "- add jpeg-only consistency masks (weight less)\n",
+ "- add save as pickle option (model weight less, loads faster, uses less CPU RAM)\n",
+ "\n",
+ "18.09.2022\n",
+ "- add clip guidance (ViT-H/14, ViT-L/14, ViT-B/32)\n",
+ "- fix progress bar\n",
+ "- change output dir name to StableWarpFusion\n",
+ "\n",
+ "15.08.2022\n",
+ "- remove unnecessary inage resizes, that caused a feedback loop in a few frames, kudos to everyoneishappy#5351 @ Discord\n",
+ "\n",
+ "7.08.2022\n",
+ "- added vram usage fix, now supports up to 1536x1536 images on 16gig gpus (with init_scales and sat_scale off)\n",
+ "- added frame range (start-end frame) for video inits\n",
+ "- added pseudo-inpainting by diffusing only inconsistent areas\n",
+ "- fixed changing width height not working correctly\n",
+ "- removed LAMA inpainting to reduce load and installation bugs\n",
+ "- hiden intermediate saves (unusable for now)\n",
+ "- fixed multiple image operations being applied during intermediate previews (even though the previews were not shown)\n",
+ "- moved Stable model loading to a later stage to allow processings optical flow for larger frame sizes\n",
+ "- fixed videoframes being saved correctly without google drive\\locally\n",
+ "- fixed PIL module error for colab to work without restarting\n",
+ "- fix RAFT models download error x2\n",
+ "\n",
+ "2.09.2022\n",
+ "- Add Create a video from the init image\n",
+ "- Add Fixed start code toggle \\ blend setting\n",
+ "- Add Latent frame scale\n",
+ "- Fix prompt scheduling\n",
+ "- Return init scale \\ frames scale back to its original menus\n",
+ "- Hide all unused settings\n",
+ "\n",
+ "30.08.2022\n",
+ "- Add fixes to run locally\n",
+ "\n",
+ "25.08.2022\n",
+ "- use existing color matching to keep frames from going deep purple\n",
+ "- temporarily hide non-active stuff\n",
+ "- fix match_color_var typo\n",
+ "- fix model path interface\n",
+ "\n",
+ "- brought back LAMA inpainting\n",
+ "- fixed PIL error\n",
+ "\n",
+ "23.08.2022\n",
+ "- Add Stable Diffusion\n",
+ "\n",
+ "1.08.2022\n",
+ "- Add color matching from https://github.com/pengbo-learn/python-color-transfer (kudos to authors!)\n",
+ "- Add automatic brightness correction (thanks to @lowfuel and his awesome https://github.com/lowfuel/progrockdiffusion)\n",
+ "- Add early stopping\n",
+ "- Bump 4 leading zeros in frame names to 6. Fixes error for videos with more than 9999 frames\n",
+ "- Move LAMA and RAFT models to the models' folder\n",
+ "\n",
+ "09.07.2022\n",
+ "- Add inpainting from https://github.com/saic-mdal/lama\n",
+ "- Add init image for video mode\n",
+ "- Add separate video init for optical flows\n",
+ "- Fix leftover padding_mode definition\n",
+ "\n",
+ "28.06.2022\n",
+ "- Add input padding\n",
+ "- Add traceback print\n",
+ "- Add a (hopefully) self-explanatory warp settings form\n",
+ "\n",
+ "21.06.2022\n",
+ "- Add pythonic consistency check wrapper from [flow_tools](https://github.com/Sxela/flow_tools)\n",
+ "\n",
+ "15.06.2022\n",
+ "- Fix default prompt prohibiting prompt animation\n",
+ "\n",
+ "8.06.2022\n",
+ "- Fix path issue (thanks to Michael Carychao#0700)\n",
+ "- Add turbo-smooth settings to settings.txt\n",
+ "- Soften consistency clamping\n",
+ "\n",
+ "7.06.2022\n",
+ "- Add turbo-smooth\n",
+ "- Add consistency clipping for normal and turbo frames\n",
+ "- Add turbo frames skip steps\n",
+ "- Add disable consistency for turbo frames\n",
+ "\n",
+ "22.05.2022:\n",
+ "- Add saving frames and flow to google drive (suggested by Chris the Wizard#8082\n",
+ ")\n",
+ "- Add back a more stable version of consistency checking\n",
+ "\n",
+ "\n",
+ "11.05.2022:\n",
+ "- Add custom diffusion model support (more on training it [here](https://www.patreon.com/posts/generating-faces-66246423))\n",
+ "\n",
+ "16.04.2022:\n",
+ "- Use width_height size instead of input video size\n",
+ "- Bring back adabins and 2d/3d anim modes\n",
+ "- Install RAFT only when video input animation mode is selected\n",
+ "- Generate optical flow maps only for video input animation mode even with flow_warp unchecked, so you can still save an obtical flow blended video later\n",
+ "- Install AdaBins for 3d mode only (should do the same for midas)\n",
+ "- Add animation mode check to create video tab\n",
+ "15.04.2022: Init"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CreditsChTop"
+ },
+ "source": [
+ "### Credits ⬇️"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Credits"
+ },
+ "source": [
+ "#### Credits\n",
+ "\n",
+ "This notebook uses:\n",
+ "\n",
+ "[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by CompVis & StabilityAI\\\n",
+ "[K-diffusion wrapper](https://github.com/crowsonkb/k-diffusion) by Katherine Crowson\\\n",
+ "RAFT model by princeton-vl\\\n",
+ "Consistency Checking (legacy) from maua\\\n",
+ "Color correction from [pengbo-learn](https://github.com/pengbo-learn/python-color-transfer)\\\n",
+ "Auto brightness adjustment from [progrockdiffusion](https://github.com/lowfuel/progrockdiffusion)\n",
+ "\n",
+ "AUTOMATIC1111: weighted prompt keywords, lora, embeddings, attention hacks\\\n",
+ "Alt Img2img from AUTOMATIC1111: reconstructed noise\\\n",
+ "\n",
+ "ControlNet\\\n",
+ "Diffusers controlnet loader - [ComfyUI](https://github.com/comfyanonymous/ComfyUI)\\\n",
+ "TemporalNet\\\n",
+ "Controlnet Face\\\n",
+ "and lots of other controlnets (check model list)\n",
+ "BLIP\\\n",
+ "RobustVideoMatting (as external package)\\\n",
+ "CLIP\n",
+ "\n",
+ "[comfyanonymous](https://github.com/comfyanonymous/ComfyUI): controlnet loaders\n",
+ "\n",
+ "DiscoDiffusion legacy credits:\n",
+ "\n",
+ "Original notebook by [Somnai](https://twitter.com/Somnai_dreams), [Adam Letts](https://twitter.com/gandamu_ml) and lots of other awesome people!\n",
+ "\n",
+ "Turbo feature by [Chris Allen](https://twitter.com/zippy731)\n",
+ "\n",
+ "Improvements to ability to run on local systems, Windows support, and dependency installation by [HostsServer](https://twitter.com/HostsServer)\n",
+ "\n",
+ "Warp and custom model support by [Alex Spirin](https://twitter.com/devdef)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LicenseTop"
+ },
+ "source": [
+ "### License"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pxZe-Q8vIzBo"
+ },
+ "source": [
+ "This is the top-level license of this notebook.\n",
+ "AGPL is inherited from AUTOMATIC1111 code snippets used here.\n",
+ "You can find other licenses for code snippets or dependencies included below."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OOPUnaVKIgPw"
+ },
+ "source": [
+ " GNU AFFERO GENERAL PUBLIC LICENSE\n",
+ " Version 3, 19 November 2007\n",
+ "\n",
+ " \n",
+ " Copyright (c)\n",
+ " 2023 comfyanonymous\n",
+ " 2023 AUTOMATIC1111\n",
+ " 2023 Alex Spirin\n",
+ "\n",
+ " Copyright (C) 2007 Free Software Foundation, Inc. \n",
+ " Everyone is permitted to copy and distribute verbatim copies\n",
+ " of this license document, but changing it is not allowed.\n",
+ "\n",
+ " Preamble\n",
+ "\n",
+ " The GNU Affero General Public License is a free, copyleft license for\n",
+ "software and other kinds of works, specifically designed to ensure\n",
+ "cooperation with the community in the case of network server software.\n",
+ "\n",
+ " The licenses for most software and other practical works are designed\n",
+ "to take away your freedom to share and change the works. By contrast,\n",
+ "our General Public Licenses are intended to guarantee your freedom to\n",
+ "share and change all versions of a program--to make sure it remains free\n",
+ "software for all its users.\n",
+ "\n",
+ " When we speak of free software, we are referring to freedom, not\n",
+ "price. Our General Public Licenses are designed to make sure that you\n",
+ "have the freedom to distribute copies of free software (and charge for\n",
+ "them if you wish), that you receive source code or can get it if you\n",
+ "want it, that you can change the software or use pieces of it in new\n",
+ "free programs, and that you know you can do these things.\n",
+ "\n",
+ " Developers that use our General Public Licenses protect your rights\n",
+ "with two steps: (1) assert copyright on the software, and (2) offer\n",
+ "you this License which gives you legal permission to copy, distribute\n",
+ "and/or modify the software.\n",
+ "\n",
+ " A secondary benefit of defending all users' freedom is that\n",
+ "improvements made in alternate versions of the program, if they\n",
+ "receive widespread use, become available for other developers to\n",
+ "incorporate. Many developers of free software are heartened and\n",
+ "encouraged by the resulting cooperation. However, in the case of\n",
+ "software used on network servers, this result may fail to come about.\n",
+ "The GNU General Public License permits making a modified version and\n",
+ "letting the public access it on a server without ever releasing its\n",
+ "source code to the public.\n",
+ "\n",
+ " The GNU Affero General Public License is designed specifically to\n",
+ "ensure that, in such cases, the modified source code becomes available\n",
+ "to the community. It requires the operator of a network server to\n",
+ "provide the source code of the modified version running there to the\n",
+ "users of that server. Therefore, public use of a modified version, on\n",
+ "a publicly accessible server, gives the public access to the source\n",
+ "code of the modified version.\n",
+ "\n",
+ " An older license, called the Affero General Public License and\n",
+ "published by Affero, was designed to accomplish similar goals. This is\n",
+ "a different license, not a version of the Affero GPL, but Affero has\n",
+ "released a new version of the Affero GPL which permits relicensing under\n",
+ "this license.\n",
+ "\n",
+ " The precise terms and conditions for copying, distribution and\n",
+ "modification follow.\n",
+ "\n",
+ " TERMS AND CONDITIONS\n",
+ "\n",
+ " 0. Definitions.\n",
+ "\n",
+ " \"This License\" refers to version 3 of the GNU Affero General Public License.\n",
+ "\n",
+ " \"Copyright\" also means copyright-like laws that apply to other kinds of\n",
+ "works, such as semiconductor masks.\n",
+ "\n",
+ " \"The Program\" refers to any copyrightable work licensed under this\n",
+ "License. Each licensee is addressed as \"you\". \"Licensees\" and\n",
+ "\"recipients\" may be individuals or organizations.\n",
+ "\n",
+ " To \"modify\" a work means to copy from or adapt all or part of the work\n",
+ "in a fashion requiring copyright permission, other than the making of an\n",
+ "exact copy. The resulting work is called a \"modified version\" of the\n",
+ "earlier work or a work \"based on\" the earlier work.\n",
+ "\n",
+ " A \"covered work\" means either the unmodified Program or a work based\n",
+ "on the Program.\n",
+ "\n",
+ " To \"propagate\" a work means to do anything with it that, without\n",
+ "permission, would make you directly or secondarily liable for\n",
+ "infringement under applicable copyright law, except executing it on a\n",
+ "computer or modifying a private copy. Propagation includes copying,\n",
+ "distribution (with or without modification), making available to the\n",
+ "public, and in some countries other activities as well.\n",
+ "\n",
+ " To \"convey\" a work means any kind of propagation that enables other\n",
+ "parties to make or receive copies. Mere interaction with a user through\n",
+ "a computer network, with no transfer of a copy, is not conveying.\n",
+ "\n",
+ " An interactive user interface displays \"Appropriate Legal Notices\"\n",
+ "to the extent that it includes a convenient and prominently visible\n",
+ "feature that (1) displays an appropriate copyright notice, and (2)\n",
+ "tells the user that there is no warranty for the work (except to the\n",
+ "extent that warranties are provided), that licensees may convey the\n",
+ "work under this License, and how to view a copy of this License. If\n",
+ "the interface presents a list of user commands or options, such as a\n",
+ "menu, a prominent item in the list meets this criterion.\n",
+ "\n",
+ " 1. Source Code.\n",
+ "\n",
+ " The \"source code\" for a work means the preferred form of the work\n",
+ "for making modifications to it. \"Object code\" means any non-source\n",
+ "form of a work.\n",
+ "\n",
+ " A \"Standard Interface\" means an interface that either is an official\n",
+ "standard defined by a recognized standards body, or, in the case of\n",
+ "interfaces specified for a particular programming language, one that\n",
+ "is widely used among developers working in that language.\n",
+ "\n",
+ " The \"System Libraries\" of an executable work include anything, other\n",
+ "than the work as a whole, that (a) is included in the normal form of\n",
+ "packaging a Major Component, but which is not part of that Major\n",
+ "Component, and (b) serves only to enable use of the work with that\n",
+ "Major Component, or to implement a Standard Interface for which an\n",
+ "implementation is available to the public in source code form. A\n",
+ "\"Major Component\", in this context, means a major essential component\n",
+ "(kernel, window system, and so on) of the specific operating system\n",
+ "(if any) on which the executable work runs, or a compiler used to\n",
+ "produce the work, or an object code interpreter used to run it.\n",
+ "\n",
+ " The \"Corresponding Source\" for a work in object code form means all\n",
+ "the source code needed to generate, install, and (for an executable\n",
+ "work) run the object code and to modify the work, including scripts to\n",
+ "control those activities. However, it does not include the work's\n",
+ "System Libraries, or general-purpose tools or generally available free\n",
+ "programs which are used unmodified in performing those activities but\n",
+ "which are not part of the work. For example, Corresponding Source\n",
+ "includes interface definition files associated with source files for\n",
+ "the work, and the source code for shared libraries and dynamically\n",
+ "linked subprograms that the work is specifically designed to require,\n",
+ "such as by intimate data communication or control flow between those\n",
+ "subprograms and other parts of the work.\n",
+ "\n",
+ " The Corresponding Source need not include anything that users\n",
+ "can regenerate automatically from other parts of the Corresponding\n",
+ "Source.\n",
+ "\n",
+ " The Corresponding Source for a work in source code form is that\n",
+ "same work.\n",
+ "\n",
+ " 2. Basic Permissions.\n",
+ "\n",
+ " All rights granted under this License are granted for the term of\n",
+ "copyright on the Program, and are irrevocable provided the stated\n",
+ "conditions are met. This License explicitly affirms your unlimited\n",
+ "permission to run the unmodified Program. The output from running a\n",
+ "covered work is covered by this License only if the output, given its\n",
+ "content, constitutes a covered work. This License acknowledges your\n",
+ "rights of fair use or other equivalent, as provided by copyright law.\n",
+ "\n",
+ " You may make, run and propagate covered works that you do not\n",
+ "convey, without conditions so long as your license otherwise remains\n",
+ "in force. You may convey covered works to others for the sole purpose\n",
+ "of having them make modifications exclusively for you, or provide you\n",
+ "with facilities for running those works, provided that you comply with\n",
+ "the terms of this License in conveying all material for which you do\n",
+ "not control copyright. Those thus making or running the covered works\n",
+ "for you must do so exclusively on your behalf, under your direction\n",
+ "and control, on terms that prohibit them from making any copies of\n",
+ "your copyrighted material outside their relationship with you.\n",
+ "\n",
+ " Conveying under any other circumstances is permitted solely under\n",
+ "the conditions stated below. Sublicensing is not allowed; section 10\n",
+ "makes it unnecessary.\n",
+ "\n",
+ " 3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n",
+ "\n",
+ " No covered work shall be deemed part of an effective technological\n",
+ "measure under any applicable law fulfilling obligations under article\n",
+ "11 of the WIPO copyright treaty adopted on 20 December 1996, or\n",
+ "similar laws prohibiting or restricting circumvention of such\n",
+ "measures.\n",
+ "\n",
+ " When you convey a covered work, you waive any legal power to forbid\n",
+ "circumvention of technological measures to the extent such circumvention\n",
+ "is effected by exercising rights under this License with respect to\n",
+ "the covered work, and you disclaim any intention to limit operation or\n",
+ "modification of the work as a means of enforcing, against the work's\n",
+ "users, your or third parties' legal rights to forbid circumvention of\n",
+ "technological measures.\n",
+ "\n",
+ " 4. Conveying Verbatim Copies.\n",
+ "\n",
+ " You may convey verbatim copies of the Program's source code as you\n",
+ "receive it, in any medium, provided that you conspicuously and\n",
+ "appropriately publish on each copy an appropriate copyright notice;\n",
+ "keep intact all notices stating that this License and any\n",
+ "non-permissive terms added in accord with section 7 apply to the code;\n",
+ "keep intact all notices of the absence of any warranty; and give all\n",
+ "recipients a copy of this License along with the Program.\n",
+ "\n",
+ " You may charge any price or no price for each copy that you convey,\n",
+ "and you may offer support or warranty protection for a fee.\n",
+ "\n",
+ " 5. Conveying Modified Source Versions.\n",
+ "\n",
+ " You may convey a work based on the Program, or the modifications to\n",
+ "produce it from the Program, in the form of source code under the\n",
+ "terms of section 4, provided that you also meet all of these conditions:\n",
+ "\n",
+ " a) The work must carry prominent notices stating that you modified\n",
+ " it, and giving a relevant date.\n",
+ "\n",
+ " b) The work must carry prominent notices stating that it is\n",
+ " released under this License and any conditions added under section\n",
+ " 7. This requirement modifies the requirement in section 4 to\n",
+ " \"keep intact all notices\".\n",
+ "\n",
+ " c) You must license the entire work, as a whole, under this\n",
+ " License to anyone who comes into possession of a copy. This\n",
+ " License will therefore apply, along with any applicable section 7\n",
+ " additional terms, to the whole of the work, and all its parts,\n",
+ " regardless of how they are packaged. This License gives no\n",
+ " permission to license the work in any other way, but it does not\n",
+ " invalidate such permission if you have separately received it.\n",
+ "\n",
+ " d) If the work has interactive user interfaces, each must display\n",
+ " Appropriate Legal Notices; however, if the Program has interactive\n",
+ " interfaces that do not display Appropriate Legal Notices, your\n",
+ " work need not make them do so.\n",
+ "\n",
+ " A compilation of a covered work with other separate and independent\n",
+ "works, which are not by their nature extensions of the covered work,\n",
+ "and which are not combined with it such as to form a larger program,\n",
+ "in or on a volume of a storage or distribution medium, is called an\n",
+ "\"aggregate\" if the compilation and its resulting copyright are not\n",
+ "used to limit the access or legal rights of the compilation's users\n",
+ "beyond what the individual works permit. Inclusion of a covered work\n",
+ "in an aggregate does not cause this License to apply to the other\n",
+ "parts of the aggregate.\n",
+ "\n",
+ " 6. Conveying Non-Source Forms.\n",
+ "\n",
+ " You may convey a covered work in object code form under the terms\n",
+ "of sections 4 and 5, provided that you also convey the\n",
+ "machine-readable Corresponding Source under the terms of this License,\n",
+ "in one of these ways:\n",
+ "\n",
+ " a) Convey the object code in, or embodied in, a physical product\n",
+ " (including a physical distribution medium), accompanied by the\n",
+ " Corresponding Source fixed on a durable physical medium\n",
+ " customarily used for software interchange.\n",
+ "\n",
+ " b) Convey the object code in, or embodied in, a physical product\n",
+ " (including a physical distribution medium), accompanied by a\n",
+ " written offer, valid for at least three years and valid for as\n",
+ " long as you offer spare parts or customer support for that product\n",
+ " model, to give anyone who possesses the object code either (1) a\n",
+ " copy of the Corresponding Source for all the software in the\n",
+ " product that is covered by this License, on a durable physical\n",
+ " medium customarily used for software interchange, for a price no\n",
+ " more than your reasonable cost of physically performing this\n",
+ " conveying of source, or (2) access to copy the\n",
+ " Corresponding Source from a network server at no charge.\n",
+ "\n",
+ " c) Convey individual copies of the object code with a copy of the\n",
+ " written offer to provide the Corresponding Source. This\n",
+ " alternative is allowed only occasionally and noncommercially, and\n",
+ " only if you received the object code with such an offer, in accord\n",
+ " with subsection 6b.\n",
+ "\n",
+ " d) Convey the object code by offering access from a designated\n",
+ " place (gratis or for a charge), and offer equivalent access to the\n",
+ " Corresponding Source in the same way through the same place at no\n",
+ " further charge. You need not require recipients to copy the\n",
+ " Corresponding Source along with the object code. If the place to\n",
+ " copy the object code is a network server, the Corresponding Source\n",
+ " may be on a different server (operated by you or a third party)\n",
+ " that supports equivalent copying facilities, provided you maintain\n",
+ " clear directions next to the object code saying where to find the\n",
+ " Corresponding Source. Regardless of what server hosts the\n",
+ " Corresponding Source, you remain obligated to ensure that it is\n",
+ " available for as long as needed to satisfy these requirements.\n",
+ "\n",
+ " e) Convey the object code using peer-to-peer transmission, provided\n",
+ " you inform other peers where the object code and Corresponding\n",
+ " Source of the work are being offered to the general public at no\n",
+ " charge under subsection 6d.\n",
+ "\n",
+ " A separable portion of the object code, whose source code is excluded\n",
+ "from the Corresponding Source as a System Library, need not be\n",
+ "included in conveying the object code work.\n",
+ "\n",
+ " A \"User Product\" is either (1) a \"consumer product\", which means any\n",
+ "tangible personal property which is normally used for personal, family,\n",
+ "or household purposes, or (2) anything designed or sold for incorporation\n",
+ "into a dwelling. In determining whether a product is a consumer product,\n",
+ "doubtful cases shall be resolved in favor of coverage. For a particular\n",
+ "product received by a particular user, \"normally used\" refers to a\n",
+ "typical or common use of that class of product, regardless of the status\n",
+ "of the particular user or of the way in which the particular user\n",
+ "actually uses, or expects or is expected to use, the product. A product\n",
+ "is a consumer product regardless of whether the product has substantial\n",
+ "commercial, industrial or non-consumer uses, unless such uses represent\n",
+ "the only significant mode of use of the product.\n",
+ "\n",
+ " \"Installation Information\" for a User Product means any methods,\n",
+ "procedures, authorization keys, or other information required to install\n",
+ "and execute modified versions of a covered work in that User Product from\n",
+ "a modified version of its Corresponding Source. The information must\n",
+ "suffice to ensure that the continued functioning of the modified object\n",
+ "code is in no case prevented or interfered with solely because\n",
+ "modification has been made.\n",
+ "\n",
+ " If you convey an object code work under this section in, or with, or\n",
+ "specifically for use in, a User Product, and the conveying occurs as\n",
+ "part of a transaction in which the right of possession and use of the\n",
+ "User Product is transferred to the recipient in perpetuity or for a\n",
+ "fixed term (regardless of how the transaction is characterized), the\n",
+ "Corresponding Source conveyed under this section must be accompanied\n",
+ "by the Installation Information. But this requirement does not apply\n",
+ "if neither you nor any third party retains the ability to install\n",
+ "modified object code on the User Product (for example, the work has\n",
+ "been installed in ROM).\n",
+ "\n",
+ " The requirement to provide Installation Information does not include a\n",
+ "requirement to continue to provide support service, warranty, or updates\n",
+ "for a work that has been modified or installed by the recipient, or for\n",
+ "the User Product in which it has been modified or installed. Access to a\n",
+ "network may be denied when the modification itself materially and\n",
+ "adversely affects the operation of the network or violates the rules and\n",
+ "protocols for communication across the network.\n",
+ "\n",
+ " Corresponding Source conveyed, and Installation Information provided,\n",
+ "in accord with this section must be in a format that is publicly\n",
+ "documented (and with an implementation available to the public in\n",
+ "source code form), and must require no special password or key for\n",
+ "unpacking, reading or copying.\n",
+ "\n",
+ " 7. Additional Terms.\n",
+ "\n",
+ " \"Additional permissions\" are terms that supplement the terms of this\n",
+ "License by making exceptions from one or more of its conditions.\n",
+ "Additional permissions that are applicable to the entire Program shall\n",
+ "be treated as though they were included in this License, to the extent\n",
+ "that they are valid under applicable law. If additional permissions\n",
+ "apply only to part of the Program, that part may be used separately\n",
+ "under those permissions, but the entire Program remains governed by\n",
+ "this License without regard to the additional permissions.\n",
+ "\n",
+ " When you convey a copy of a covered work, you may at your option\n",
+ "remove any additional permissions from that copy, or from any part of\n",
+ "it. (Additional permissions may be written to require their own\n",
+ "removal in certain cases when you modify the work.) You may place\n",
+ "additional permissions on material, added by you to a covered work,\n",
+ "for which you have or can give appropriate copyright permission.\n",
+ "\n",
+ " Notwithstanding any other provision of this License, for material you\n",
+ "add to a covered work, you may (if authorized by the copyright holders of\n",
+ "that material) supplement the terms of this License with terms:\n",
+ "\n",
+ " a) Disclaiming warranty or limiting liability differently from the\n",
+ " terms of sections 15 and 16 of this License; or\n",
+ "\n",
+ " b) Requiring preservation of specified reasonable legal notices or\n",
+ " author attributions in that material or in the Appropriate Legal\n",
+ " Notices displayed by works containing it; or\n",
+ "\n",
+ " c) Prohibiting misrepresentation of the origin of that material, or\n",
+ " requiring that modified versions of such material be marked in\n",
+ " reasonable ways as different from the original version; or\n",
+ "\n",
+ " d) Limiting the use for publicity purposes of names of licensors or\n",
+ " authors of the material; or\n",
+ "\n",
+ " e) Declining to grant rights under trademark law for use of some\n",
+ " trade names, trademarks, or service marks; or\n",
+ "\n",
+ " f) Requiring indemnification of licensors and authors of that\n",
+ " material by anyone who conveys the material (or modified versions of\n",
+ " it) with contractual assumptions of liability to the recipient, for\n",
+ " any liability that these contractual assumptions directly impose on\n",
+ " those licensors and authors.\n",
+ "\n",
+ " All other non-permissive additional terms are considered \"further\n",
+ "restrictions\" within the meaning of section 10. If the Program as you\n",
+ "received it, or any part of it, contains a notice stating that it is\n",
+ "governed by this License along with a term that is a further\n",
+ "restriction, you may remove that term. If a license document contains\n",
+ "a further restriction but permits relicensing or conveying under this\n",
+ "License, you may add to a covered work material governed by the terms\n",
+ "of that license document, provided that the further restriction does\n",
+ "not survive such relicensing or conveying.\n",
+ "\n",
+ " If you add terms to a covered work in accord with this section, you\n",
+ "must place, in the relevant source files, a statement of the\n",
+ "additional terms that apply to those files, or a notice indicating\n",
+ "where to find the applicable terms.\n",
+ "\n",
+ " Additional terms, permissive or non-permissive, may be stated in the\n",
+ "form of a separately written license, or stated as exceptions;\n",
+ "the above requirements apply either way.\n",
+ "\n",
+ " 8. Termination.\n",
+ "\n",
+ " You may not propagate or modify a covered work except as expressly\n",
+ "provided under this License. Any attempt otherwise to propagate or\n",
+ "modify it is void, and will automatically terminate your rights under\n",
+ "this License (including any patent licenses granted under the third\n",
+ "paragraph of section 11).\n",
+ "\n",
+ " However, if you cease all violation of this License, then your\n",
+ "license from a particular copyright holder is reinstated (a)\n",
+ "provisionally, unless and until the copyright holder explicitly and\n",
+ "finally terminates your license, and (b) permanently, if the copyright\n",
+ "holder fails to notify you of the violation by some reasonable means\n",
+ "prior to 60 days after the cessation.\n",
+ "\n",
+ " Moreover, your license from a particular copyright holder is\n",
+ "reinstated permanently if the copyright holder notifies you of the\n",
+ "violation by some reasonable means, this is the first time you have\n",
+ "received notice of violation of this License (for any work) from that\n",
+ "copyright holder, and you cure the violation prior to 30 days after\n",
+ "your receipt of the notice.\n",
+ "\n",
+ " Termination of your rights under this section does not terminate the\n",
+ "licenses of parties who have received copies or rights from you under\n",
+ "this License. If your rights have been terminated and not permanently\n",
+ "reinstated, you do not qualify to receive new licenses for the same\n",
+ "material under section 10.\n",
+ "\n",
+ " 9. Acceptance Not Required for Having Copies.\n",
+ "\n",
+ " You are not required to accept this License in order to receive or\n",
+ "run a copy of the Program. Ancillary propagation of a covered work\n",
+ "occurring solely as a consequence of using peer-to-peer transmission\n",
+ "to receive a copy likewise does not require acceptance. However,\n",
+ "nothing other than this License grants you permission to propagate or\n",
+ "modify any covered work. These actions infringe copyright if you do\n",
+ "not accept this License. Therefore, by modifying or propagating a\n",
+ "covered work, you indicate your acceptance of this License to do so.\n",
+ "\n",
+ " 10. Automatic Licensing of Downstream Recipients.\n",
+ "\n",
+ " Each time you convey a covered work, the recipient automatically\n",
+ "receives a license from the original licensors, to run, modify and\n",
+ "propagate that work, subject to this License. You are not responsible\n",
+ "for enforcing compliance by third parties with this License.\n",
+ "\n",
+ " An \"entity transaction\" is a transaction transferring control of an\n",
+ "organization, or substantially all assets of one, or subdividing an\n",
+ "organization, or merging organizations. If propagation of a covered\n",
+ "work results from an entity transaction, each party to that\n",
+ "transaction who receives a copy of the work also receives whatever\n",
+ "licenses to the work the party's predecessor in interest had or could\n",
+ "give under the previous paragraph, plus a right to possession of the\n",
+ "Corresponding Source of the work from the predecessor in interest, if\n",
+ "the predecessor has it or can get it with reasonable efforts.\n",
+ "\n",
+ " You may not impose any further restrictions on the exercise of the\n",
+ "rights granted or affirmed under this License. For example, you may\n",
+ "not impose a license fee, royalty, or other charge for exercise of\n",
+ "rights granted under this License, and you may not initiate litigation\n",
+ "(including a cross-claim or counterclaim in a lawsuit) alleging that\n",
+ "any patent claim is infringed by making, using, selling, offering for\n",
+ "sale, or importing the Program or any portion of it.\n",
+ "\n",
+ " 11. Patents.\n",
+ "\n",
+ " A \"contributor\" is a copyright holder who authorizes use under this\n",
+ "License of the Program or a work on which the Program is based. The\n",
+ "work thus licensed is called the contributor's \"contributor version\".\n",
+ "\n",
+ " A contributor's \"essential patent claims\" are all patent claims\n",
+ "owned or controlled by the contributor, whether already acquired or\n",
+ "hereafter acquired, that would be infringed by some manner, permitted\n",
+ "by this License, of making, using, or selling its contributor version,\n",
+ "but do not include claims that would be infringed only as a\n",
+ "consequence of further modification of the contributor version. For\n",
+ "purposes of this definition, \"control\" includes the right to grant\n",
+ "patent sublicenses in a manner consistent with the requirements of\n",
+ "this License.\n",
+ "\n",
+ " Each contributor grants you a non-exclusive, worldwide, royalty-free\n",
+ "patent license under the contributor's essential patent claims, to\n",
+ "make, use, sell, offer for sale, import and otherwise run, modify and\n",
+ "propagate the contents of its contributor version.\n",
+ "\n",
+ " In the following three paragraphs, a \"patent license\" is any express\n",
+ "agreement or commitment, however denominated, not to enforce a patent\n",
+ "(such as an express permission to practice a patent or covenant not to\n",
+ "sue for patent infringement). To \"grant\" such a patent license to a\n",
+ "party means to make such an agreement or commitment not to enforce a\n",
+ "patent against the party.\n",
+ "\n",
+ " If you convey a covered work, knowingly relying on a patent license,\n",
+ "and the Corresponding Source of the work is not available for anyone\n",
+ "to copy, free of charge and under the terms of this License, through a\n",
+ "publicly available network server or other readily accessible means,\n",
+ "then you must either (1) cause the Corresponding Source to be so\n",
+ "available, or (2) arrange to deprive yourself of the benefit of the\n",
+ "patent license for this particular work, or (3) arrange, in a manner\n",
+ "consistent with the requirements of this License, to extend the patent\n",
+ "license to downstream recipients. \"Knowingly relying\" means you have\n",
+ "actual knowledge that, but for the patent license, your conveying the\n",
+ "covered work in a country, or your recipient's use of the covered work\n",
+ "in a country, would infringe one or more identifiable patents in that\n",
+ "country that you have reason to believe are valid.\n",
+ "\n",
+ " If, pursuant to or in connection with a single transaction or\n",
+ "arrangement, you convey, or propagate by procuring conveyance of, a\n",
+ "covered work, and grant a patent license to some of the parties\n",
+ "receiving the covered work authorizing them to use, propagate, modify\n",
+ "or convey a specific copy of the covered work, then the patent license\n",
+ "you grant is automatically extended to all recipients of the covered\n",
+ "work and works based on it.\n",
+ "\n",
+ " A patent license is \"discriminatory\" if it does not include within\n",
+ "the scope of its coverage, prohibits the exercise of, or is\n",
+ "conditioned on the non-exercise of one or more of the rights that are\n",
+ "specifically granted under this License. You may not convey a covered\n",
+ "work if you are a party to an arrangement with a third party that is\n",
+ "in the business of distributing software, under which you make payment\n",
+ "to the third party based on the extent of your activity of conveying\n",
+ "the work, and under which the third party grants, to any of the\n",
+ "parties who would receive the covered work from you, a discriminatory\n",
+ "patent license (a) in connection with copies of the covered work\n",
+ "conveyed by you (or copies made from those copies), or (b) primarily\n",
+ "for and in connection with specific products or compilations that\n",
+ "contain the covered work, unless you entered into that arrangement,\n",
+ "or that patent license was granted, prior to 28 March 2007.\n",
+ "\n",
+ " Nothing in this License shall be construed as excluding or limiting\n",
+ "any implied license or other defenses to infringement that may\n",
+ "otherwise be available to you under applicable patent law.\n",
+ "\n",
+ " 12. No Surrender of Others' Freedom.\n",
+ "\n",
+ " If conditions are imposed on you (whether by court order, agreement or\n",
+ "otherwise) that contradict the conditions of this License, they do not\n",
+ "excuse you from the conditions of this License. If you cannot convey a\n",
+ "covered work so as to satisfy simultaneously your obligations under this\n",
+ "License and any other pertinent obligations, then as a consequence you may\n",
+ "not convey it at all. For example, if you agree to terms that obligate you\n",
+ "to collect a royalty for further conveying from those to whom you convey\n",
+ "the Program, the only way you could satisfy both those terms and this\n",
+ "License would be to refrain entirely from conveying the Program.\n",
+ "\n",
+ " 13. Remote Network Interaction; Use with the GNU General Public License.\n",
+ "\n",
+ " Notwithstanding any other provision of this License, if you modify the\n",
+ "Program, your modified version must prominently offer all users\n",
+ "interacting with it remotely through a computer network (if your version\n",
+ "supports such interaction) an opportunity to receive the Corresponding\n",
+ "Source of your version by providing access to the Corresponding Source\n",
+ "from a network server at no charge, through some standard or customary\n",
+ "means of facilitating copying of software. This Corresponding Source\n",
+ "shall include the Corresponding Source for any work covered by version 3\n",
+ "of the GNU General Public License that is incorporated pursuant to the\n",
+ "following paragraph.\n",
+ "\n",
+ " Notwithstanding any other provision of this License, you have\n",
+ "permission to link or combine any covered work with a work licensed\n",
+ "under version 3 of the GNU General Public License into a single\n",
+ "combined work, and to convey the resulting work. The terms of this\n",
+ "License will continue to apply to the part which is the covered work,\n",
+ "but the work with which it is combined will remain governed by version\n",
+ "3 of the GNU General Public License.\n",
+ "\n",
+ " 14. Revised Versions of this License.\n",
+ "\n",
+ " The Free Software Foundation may publish revised and/or new versions of\n",
+ "the GNU Affero General Public License from time to time. Such new versions\n",
+ "will be similar in spirit to the present version, but may differ in detail to\n",
+ "address new problems or concerns.\n",
+ "\n",
+ " Each version is given a distinguishing version number. If the\n",
+ "Program specifies that a certain numbered version of the GNU Affero General\n",
+ "Public License \"or any later version\" applies to it, you have the\n",
+ "option of following the terms and conditions either of that numbered\n",
+ "version or of any later version published by the Free Software\n",
+ "Foundation. If the Program does not specify a version number of the\n",
+ "GNU Affero General Public License, you may choose any version ever published\n",
+ "by the Free Software Foundation.\n",
+ "\n",
+ " If the Program specifies that a proxy can decide which future\n",
+ "versions of the GNU Affero General Public License can be used, that proxy's\n",
+ "public statement of acceptance of a version permanently authorizes you\n",
+ "to choose that version for the Program.\n",
+ "\n",
+ " Later license versions may give you additional or different\n",
+ "permissions. However, no additional obligations are imposed on any\n",
+ "author or copyright holder as a result of your choosing to follow a\n",
+ "later version.\n",
+ "\n",
+ " 15. Disclaimer of Warranty.\n",
+ "\n",
+ " THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\n",
+ "APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\n",
+ "HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\n",
+ "OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\n",
+ "THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n",
+ "PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\n",
+ "IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\n",
+ "ALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n",
+ "\n",
+ " 16. Limitation of Liability.\n",
+ "\n",
+ " IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\n",
+ "WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\n",
+ "THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\n",
+ "GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\n",
+ "USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\n",
+ "DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\n",
+ "PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\n",
+ "EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\n",
+ "SUCH DAMAGES.\n",
+ "\n",
+ " 17. Interpretation of Sections 15 and 16.\n",
+ "\n",
+ " If the disclaimer of warranty and limitation of liability provided\n",
+ "above cannot be given local legal effect according to their terms,\n",
+ "reviewing courts shall apply local law that most closely approximates\n",
+ "an absolute waiver of all civil liability in connection with the\n",
+ "Program, unless a warranty or assumption of liability accompanies a\n",
+ "copy of the Program in return for a fee.\n",
+ "\n",
+ " END OF TERMS AND CONDITIONS\n",
+ "\n",
+ " How to Apply These Terms to Your New Programs\n",
+ "\n",
+ " If you develop a new program, and you want it to be of the greatest\n",
+ "possible use to the public, the best way to achieve this is to make it\n",
+ "free software which everyone can redistribute and change under these terms.\n",
+ "\n",
+ " To do so, attach the following notices to the program. It is safest\n",
+ "to attach them to the start of each source file to most effectively\n",
+ "state the exclusion of warranty; and each file should have at least\n",
+ "the \"copyright\" line and a pointer to where the full notice is found.\n",
+ "\n",
+ " \n",
+ " Copyright (C) \n",
+ "\n",
+ " This program is free software: you can redistribute it and/or modify\n",
+ " it under the terms of the GNU Affero General Public License as published by\n",
+ " the Free Software Foundation, either version 3 of the License, or\n",
+ " (at your option) any later version.\n",
+ "\n",
+ " This program is distributed in the hope that it will be useful,\n",
+ " but WITHOUT ANY WARRANTY; without even the implied warranty of\n",
+ " MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the\n",
+ " GNU Affero General Public License for more details.\n",
+ "\n",
+ " You should have received a copy of the GNU Affero General Public License\n",
+ " along with this program. If not, see .\n",
+ "\n",
+ "Also add information on how to contact you by electronic and paper mail.\n",
+ "\n",
+ " If your software can interact with users remotely through a computer\n",
+ "network, you should also make sure that it provides a way for users to\n",
+ "get its source. For example, if your program is a web application, its\n",
+ "interface could display a \"Source\" link that leads users to an archive\n",
+ "of the code. There are many ways you could offer source, and different\n",
+ "solutions will be better for different programs; see section 13 for the\n",
+ "specific requirements.\n",
+ "\n",
+ " You should also get your employer (if you work as a programmer) or school,\n",
+ "if any, to sign a \"copyright disclaimer\" for the program, if necessary.\n",
+ "For more information on this, and how to apply and follow the GNU AGPL, see\n",
+ "."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "License"
+ },
+ "source": [
+ "Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros\n",
+ "\n",
+ "Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n",
+ "\n",
+ "The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n",
+ "\n",
+ "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n",
+ "\n",
+ "Portions of code and models (such as pretrained checkpoints, which are fine-tuned starting from released Stable Diffusion checkpoints) are derived from the Stable Diffusion codebase (https://github.com/CompVis/stable-diffusion). Further restrictions may apply. Please consult the Stable Diffusion license `stable_diffusion/LICENSE`. Modified code is denoted as such in comments at the start of each file.\n",
+ "\n",
+ "\n",
+ "Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors\n",
+ "\n",
+ "CreativeML Open RAIL-M\n",
+ "dated August 22, 2022\n",
+ "\n",
+ "Section I: PREAMBLE\n",
+ "\n",
+ "Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.\n",
+ "\n",
+ "Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.\n",
+ "\n",
+ "In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.\n",
+ "\n",
+ "Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.\n",
+ "\n",
+ "This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.\n",
+ "\n",
+ "NOW THEREFORE, You and Licensor agree as follows:\n",
+ "\n",
+ "1. Definitions\n",
+ "\n",
+ "- \"License\" means the terms and conditions for use, reproduction, and Distribution as defined in this document.\n",
+ "- \"Data\" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.\n",
+ "- \"Output\" means the results of operating a Model as embodied in informational content resulting therefrom.\n",
+ "- \"Model\" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.\n",
+ "- \"Derivatives of the Model\" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.\n",
+ "- \"Complementary Material\" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.\n",
+ "- \"Distribution\" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.\n",
+ "- \"Licensor\" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.\n",
+ "- \"You\" (or \"Your\") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.\n",
+ "- \"Third Parties\" means individuals or legal entities that are not under common control with Licensor or You.\n",
+ "- \"Contribution\" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, \"submitted\" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as \"Not a Contribution.\"\n",
+ "- \"Contributor\" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.\n",
+ "\n",
+ "Section II: INTELLECTUAL PROPERTY RIGHTS\n",
+ "\n",
+ "Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.\n",
+ "\n",
+ "2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.\n",
+ "3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.\n",
+ "\n",
+ "Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION\n",
+ "\n",
+ "4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:\n",
+ "Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.\n",
+ "You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;\n",
+ "You must cause any modified files to carry prominent notices stating that You changed the files;\n",
+ "You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.\n",
+ "You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.\n",
+ "5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).\n",
+ "6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.\n",
+ "\n",
+ "Section IV: OTHER PROVISIONS\n",
+ "\n",
+ "7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.\n",
+ "8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.\n",
+ "9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.\n",
+ "10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.\n",
+ "11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.\n",
+ "12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.\n",
+ "\n",
+ "END OF TERMS AND CONDITIONS\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Attachment A\n",
+ "\n",
+ "Use Restrictions\n",
+ "\n",
+ "You agree not to use the Model or Derivatives of the Model:\n",
+ "- In any way that violates any applicable national, federal, state, local or international law or regulation;\n",
+ "- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;\n",
+ "- To generate or disseminate verifiably false information and/or content with the purpose of harming others;\n",
+ "- To generate or disseminate personal identifiable information that can be used to harm an individual;\n",
+ "- To defame, disparage or otherwise harass others;\n",
+ "- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;\n",
+ "- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;\n",
+ "- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;\n",
+ "- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;\n",
+ "- To provide medical advice and medical results interpretation;\n",
+ "- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).\n",
+ "\n",
+ "Licensed under the MIT License\n",
+ "\n",
+ "Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)\n",
+ "\n",
+ "Copyright (c) 2021 Maxwell Ingham\n",
+ "\n",
+ "Copyright (c) 2022 Adam Letts\n",
+ "\n",
+ "Copyright (c) 2022 Alex Spirin\n",
+ "\n",
+ "Copyright (c) 2022 lowfuel\n",
+ "\n",
+ "Copyright (c) 2021-2022 Katherine Crowson\n",
+ "\n",
+ "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
+ "of this software and associated documentation files (the \"Software\"), to deal\n",
+ "in the Software without restriction, including without limitation the rights\n",
+ "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
+ "copies of the Software, and to permit persons to whom the Software is\n",
+ "furnished to do so, subject to the following conditions:\n",
+ "\n",
+ "The above copyright notice and this permission notice shall be included in\n",
+ "all copies or substantial portions of the Software.\n",
+ "\n",
+ "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
+ "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
+ "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
+ "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
+ "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
+ "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
+ "THE SOFTWARE."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SetupTop"
+ },
+ "source": [
+ "# 1. Set Up"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "PrepFolders"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 1.1 Prepare Folders\n",
+ "import subprocess, os, sys, ipykernel\n",
+ "\n",
+ "#cell execution check thanks to #soze\n",
+ "\n",
+ "\n",
+ "executed_cells = {\n",
+ " 'prepare_folders':False,\n",
+ " 'install_pytorch':False,\n",
+ " 'install_sd_dependencies':False,\n",
+ " 'import_dependencies':False,\n",
+ " 'basic_settings':False,\n",
+ " 'animation_settings':False,\n",
+ " 'video_input_settings':False,\n",
+ " 'video_masking':False,\n",
+ " 'generate_optical_flow':False,\n",
+ " 'load_model':False,\n",
+ " 'tiled_vae':False,\n",
+ " 'save_loaded_model':False,\n",
+ " 'clip_guidance':False,\n",
+ " 'brightness_adjustment':False,\n",
+ " 'content_aware_scheduling':False,\n",
+ " 'plot_threshold_vs_frame_difference':False,\n",
+ " 'create_schedules':False,\n",
+ " 'frame_captioning':False,\n",
+ " 'flow_and_turbo_settings':False,\n",
+ " 'consistency_maps_mixing':False,\n",
+ " 'seed_and_grad_settings':False,\n",
+ " 'prompts':False,\n",
+ " 'warp_turbo_smooth_settings':False,\n",
+ " 'video_mask_settings': False,\n",
+ " 'frame_correction':False,\n",
+ " 'main_settings':False,\n",
+ " 'advanced':False,\n",
+ " 'lora':False,\n",
+ " 'reference_controlnet':False,\n",
+ " 'GUI':False,\n",
+ " 'do_the_run':False\n",
+ "}\n",
+ "\n",
+ "executed_cells_errors = {\n",
+ " 'prepare_folders': '1.1 Prepare folders',\n",
+ " 'install_pytorch':'1.2 Install pytorch',\n",
+ " 'install_sd_dependencies':'1.3 Install SD Dependencies',\n",
+ " 'import_dependencies':'1.4 Import dependencies, define functions',\n",
+ " 'basic_settings':'2.Settings - Basic Settings',\n",
+ " 'animation_settings': '2.Settings - Animation Settings',\n",
+ " 'video_input_settings':'2.Settings - Video Input Settings',\n",
+ " 'video_masking':'2.Settings - Video Masking',\n",
+ " 'generate_optical_flow':'Optical map settings - Generate optical flow and consistency maps',\n",
+ " 'load_model':'Load up a stable. - define SD + K functions, load model',\n",
+ " 'tiled_vae':'Extra features - Tiled VAE',\n",
+ " 'save_loaded_model':'Extra features - Save loaded model',\n",
+ " 'clip_guidance':'CLIP guidance - CLIP guidance settings',\n",
+ " 'brightness_adjustment':'Automatic Brightness Adjustment',\n",
+ " 'content_aware_scheduling':'Content-aware scheduing - Content-aware scheduing',\n",
+ " 'plot_threshold_vs_frame_difference':'Content-aware scheduing - Plot threshold vs frame difference',\n",
+ " 'create_schedules':'Content-aware scheduing - Create schedules from frame difference',\n",
+ " 'frame_captioning':'Frame captioning - Generate captions for keyframes',\n",
+ " 'flow_and_turbo_settings':'Render settings - Non-gui - Flow and turbo settings',\n",
+ " 'consistency_maps_mixing':'Render settings - Non-gui - Consistency map mixing',\n",
+ " 'seed_and_grad_settings':'Render settings - Non-gui - Seed and grad Settings',\n",
+ " 'prompts':'Render settings - Non-gui - Prompts',\n",
+ " 'warp_turbo_smooth_settings':'Render settings - Non-gui - Warp Turbo Smooth Settings',\n",
+ " 'video_mask_settings':'Render settings - Non-gui - Video mask settings',\n",
+ " 'frame_correction':'Render settings - Non-gui - Frame correction',\n",
+ " 'main_settings':'Render settings - Non-gui - Main settings',\n",
+ " 'advanced':'Render settings - Non-gui - Advanced',\n",
+ " 'lora': 'LORA & embedding paths',\n",
+ " 'reference_controlnet': 'Reference controlnet (attention injection)',\n",
+ " 'GUI':'GUI',\n",
+ " 'do_the_run':'Diffuse! - Do the run'\n",
+ "\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def check_execution(cell_name):\n",
+ " for key in executed_cells.keys():\n",
+ " if key == cell_name:\n",
+ " #reached current cell successfully, exit\n",
+ " return\n",
+ " if executed_cells[key] == False:\n",
+ " raise RuntimeError(f'The {executed_cells_errors[key]} cell was not run successfully and must be executed to continue. \\\n",
+ "RUN ALL after starting runtime (CTRL-F9)');\n",
+ "\n",
+ "cell_name = 'prepare_folders'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "\n",
+ "def gitclone(url, recursive=False, dest=None, branch=None):\n",
+ " command = ['git', 'clone']\n",
+ " if branch is not None:\n",
+ " command.append(['-b', branch])\n",
+ " command.append(url)\n",
+ " if dest: command.append(dest)\n",
+ " if recursive: command.append('--recursive')\n",
+ "\n",
+ " res = subprocess.run(command, stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(res)\n",
+ "\n",
+ "\n",
+ "def pipi(modulestr):\n",
+ " res = subprocess.run(['python','-m','pip', '-q', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(res)\n",
+ "\n",
+ "def pipie(modulestr):\n",
+ " res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(res)\n",
+ "\n",
+ "def wget_p(url, outputdir):\n",
+ " res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(res)\n",
+ "\n",
+ "try:\n",
+ " from google.colab import drive\n",
+ " print(\"Google Colab detected. Using Google Drive.\")\n",
+ " is_colab = True\n",
+ " #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n",
+ " google_drive = True #@param {type:\"boolean\"}\n",
+ " #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
+ " save_models_to_google_drive = True #@param {type:\"boolean\"}\n",
+ "except:\n",
+ " is_colab = False\n",
+ " google_drive = False\n",
+ " save_models_to_google_drive = False\n",
+ " print(\"Google Colab not detected.\")\n",
+ "\n",
+ "if is_colab:\n",
+ " if google_drive is True:\n",
+ " drive.mount('/content/drive')\n",
+ " root_path = '/content/drive/MyDrive/AI/StableWarpFusion'\n",
+ " else:\n",
+ " root_path = '/content'\n",
+ "else:\n",
+ " root_path = os.getcwd()\n",
+ "\n",
+ "import os\n",
+ "def createPath(filepath):\n",
+ " os.makedirs(filepath, exist_ok=True)\n",
+ "\n",
+ "initDirPath = os.path.join(root_path,'init_images')\n",
+ "createPath(initDirPath)\n",
+ "outDirPath = os.path.join(root_path,'images_out')\n",
+ "createPath(outDirPath)\n",
+ "root_dir = os.getcwd()\n",
+ "\n",
+ "if is_colab:\n",
+ " root_dir = '/content/'\n",
+ " if google_drive and not save_models_to_google_drive or not google_drive:\n",
+ " model_path = '/content/models'\n",
+ " createPath(model_path)\n",
+ " if google_drive and save_models_to_google_drive:\n",
+ " model_path = f'{root_path}/models'\n",
+ " createPath(model_path)\n",
+ "else:\n",
+ " model_path = f'{root_path}/models'\n",
+ " createPath(model_path)\n",
+ "\n",
+ "#(c) Alex Spirin 2023\n",
+ "\n",
+ "class FrameDataset():\n",
+ " def __init__(self, source_path, outdir_prefix='', videoframes_root=''):\n",
+ " self.frame_paths = None\n",
+ " image_extenstions = ['jpeg', 'jpg', 'png', 'tiff', 'bmp', 'webp']\n",
+ "\n",
+ " if not os.path.exists(source_path):\n",
+ " if len(glob(source_path))>0:\n",
+ " self.frame_paths = sorted(glob(source_path))\n",
+ " else:\n",
+ " raise Exception(f'Frame source for {outdir_prefix} not found at {source_path}\\nPlease specify an existing source path.')\n",
+ " if os.path.exists(source_path):\n",
+ " if os.path.isfile(source_path):\n",
+ " if os.path.splitext(source_path)[1][1:].lower() in image_extenstions:\n",
+ " self.frame_paths = [source_path]\n",
+ " hash = generate_file_hash(source_path)[:10]\n",
+ " out_path = os.path.join(videoframes_root, outdir_prefix+'_'+hash)\n",
+ "\n",
+ " extractFrames(source_path, out_path,\n",
+ " nth_frame=1, start_frame=0, end_frame=999999999)\n",
+ " self.frame_paths = glob(os.path.join(out_path, '*.*'))\n",
+ " if len(self.frame_paths)<1:\n",
+ " raise Exception(f'Couldn`t extract frames from {source_path}\\nPlease specify an existing source path.')\n",
+ " elif os.path.isdir(source_path):\n",
+ " self.frame_paths = glob(os.path.join(source_path, '*.*'))\n",
+ " if len(self.frame_paths)<1:\n",
+ " raise Exception(f'Found 0 frames in {source_path}\\nPlease specify an existing source path.')\n",
+ " extensions = []\n",
+ " if self.frame_paths is not None:\n",
+ " for f in self.frame_paths:\n",
+ " ext = os.path.splitext(f)[1][1:]\n",
+ " if ext not in image_extenstions:\n",
+ " raise Exception(f'Found non-image file extension: {ext} in {source_path}. Please provide a folder with image files of the same extension, or specify a glob pattern.')\n",
+ " if not ext in extensions:\n",
+ " extensions+=[ext]\n",
+ " if len(extensions)>1:\n",
+ " raise Exception(f'Found multiple file extensions: {extensions} in {source_path}. Please provide a folder with image files of the same extension, or specify a glob pattern.')\n",
+ "\n",
+ " self.frame_paths = sorted(self.frame_paths)\n",
+ "\n",
+ " else: raise Exception(f'Frame source for {outdir_prefix} not found at {source_path}\\nPlease specify an existing source path.')\n",
+ " print(f'Found {len(self.frame_paths)} frames at {source_path}')\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " idx = min(idx, len(self.frame_paths)-1)\n",
+ " return self.frame_paths[idx]\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.frame_paths)\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "CptlIAdM9B1Y"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 1.2 Install pytorch\n",
+ "cell_name = 'install_pytorch'\n",
+ "check_execution(cell_name)\n",
+ "gpu = None\n",
+ "\n",
+ "!python -m pip -q install requests\n",
+ "import requests\n",
+ "installer_url = 'https://raw.githubusercontent.com/Sxela/WarpTools/main/installersw/warp_installer_024.py'\n",
+ "r = requests.get(installer_url, allow_redirects=True)\n",
+ "open('warp_installer.py', 'wb').write(r.content)\n",
+ "\n",
+ "import warp_installer\n",
+ "\n",
+ "force_os = 'off'\n",
+ "force_torch_reinstall = False #@param {'type':'boolean'}\n",
+ "force_xformers_reinstall = False\n",
+ "#@markdown Use v2 by default.\n",
+ "use_torch_v2 = True #@param {'type':'boolean'}\n",
+ "\n",
+ "import subprocess, sys\n",
+ "import os, platform\n",
+ "\n",
+ "simple_nvidia_smi_display = True\n",
+ "if simple_nvidia_smi_display:\n",
+ " #!nvidia-smi\n",
+ " nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(nvidiasmi_output)\n",
+ "else:\n",
+ " #!nvidia-smi -i 0 -e 0\n",
+ " nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(nvidiasmi_output)\n",
+ " nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print(nvidiasmi_ecc_note)\n",
+ "\n",
+ "\n",
+ "if force_torch_reinstall:\n",
+ " warp_installer.uninstall_pytorch(is_colab)\n",
+ "\n",
+ "if platform.system() != 'Linux' or force_os == 'Windows':\n",
+ " warp_installer.install_torch_windows(force_torch_reinstall, use_torch_v2)\n",
+ "\n",
+ "try:\n",
+ " if os.environ[\"IS_DOCKER\"] == \"1\":\n",
+ " print('Docker found. Skipping install.')\n",
+ "except:\n",
+ " os.environ[\"IS_DOCKER\"] = \"0\"\n",
+ "\n",
+ "if (is_colab or (platform.system() == 'Linux') or force_os == 'Linux') and os.environ[\"IS_DOCKER\"]==\"0\":\n",
+ " from subprocess import getoutput\n",
+ " from IPython.display import HTML\n",
+ " from IPython.display import clear_output\n",
+ " import time\n",
+ " #https://github.com/TheLastBen/fast-stable-diffusion\n",
+ " s = getoutput('nvidia-smi')\n",
+ " if 'T4' in s:\n",
+ " gpu = 'T4'\n",
+ " elif 'P100' in s:\n",
+ " gpu = 'P100'\n",
+ " elif 'V100' in s:\n",
+ " gpu = 'V100'\n",
+ " elif 'A100' in s:\n",
+ " gpu = 'A100'\n",
+ "\n",
+ " for g in ['A4000','A5000','A6000']:\n",
+ " if g in s:\n",
+ " gpu = 'A100'\n",
+ "\n",
+ " for g in ['2080','2070','2060']:\n",
+ " if g in s:\n",
+ " gpu = 'T4'\n",
+ " print(' DONE !')\n",
+ "\n",
+ "if is_colab:\n",
+ " warp_installer.install_torch_colab(force_torch_reinstall, use_torch_v2)\n",
+ "\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "wldaiwzcCbWy"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 1.3 Install SD Dependencies\n",
+ "from IPython.utils import io\n",
+ "import shutil, traceback\n",
+ "import pathlib, shutil, os, sys\n",
+ "cell_name = 'install_sd_dependencies'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "#@markdown Enable skip_install to avoid reinstalling dependencies after the initial setup.\n",
+ "skip_install = False #@param {'type':'boolean'}\n",
+ "os.makedirs('./embeddings', exist_ok=True)\n",
+ "\n",
+ "if os.environ[\"IS_DOCKER\"]==\"1\":\n",
+ " skip_install = True\n",
+ " print('Docker detected. Skipping install.')\n",
+ "\n",
+ "if not is_colab:\n",
+ " # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.\n",
+ " os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'\n",
+ "\n",
+ "PROJECT_DIR = os.path.abspath(os.getcwd())\n",
+ "USE_ADABINS = False\n",
+ "\n",
+ "if is_colab:\n",
+ " if google_drive is not True:\n",
+ " root_path = f'/content'\n",
+ " model_path = '/content/models'\n",
+ "else:\n",
+ " root_path = os.getcwd()\n",
+ " model_path = f'{root_path}/models'\n",
+ "\n",
+ "if skip_install:\n",
+ " # pass\n",
+ " warp_installer.pull_repos(is_colab)\n",
+ "else:\n",
+ " warp_installer.install_dependencies_colab(is_colab, root_dir)\n",
+ "\n",
+ "sys.path.append(f'{PROJECT_DIR}/BLIP')\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "InstallDeps"
+ },
+ "outputs": [],
+ "source": [
+ "#@title ### 1.4 Import dependencies, define functions\n",
+ "\n",
+ "\n",
+ "cell_name = 'import_dependencies'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "user_settings_keys = ['latent_scale_schedule', 'init_scale_schedule', 'steps_schedule', 'style_strength_schedule',\n",
+ " 'cfg_scale_schedule', 'flow_blend_schedule', 'image_scale_schedule', 'flow_override_map',\n",
+ " 'text_prompts', 'negative_prompts', 'prompt_patterns_sched', 'latent_fixed_mean',\n",
+ " 'latent_fixed_std', 'rec_prompts', 'cc_masked_diffusion_schedule', 'mask_paths','user_comment', 'blend_json_schedules', 'VERBOSE', 'use_background_mask', 'invert_mask', 'background',\n",
+ " 'background_source', 'mask_clip_low', 'mask_clip_high', 'turbo_mode', 'turbo_steps', 'colormatch_turbo',\n",
+ " 'turbo_frame_skips_steps', 'soften_consistency_mask_for_turbo_frames', 'flow_warp', 'apply_mask_after_warp',\n",
+ " 'warp_num_k', 'warp_forward', 'warp_strength', 'warp_mode', 'warp_towards_init', 'check_consistency',\n",
+ " 'missed_consistency_weight', 'overshoot_consistency_weight', 'edges_consistency_weight', 'consistency_blur',\n",
+ " 'consistency_dilate', 'padding_ratio', 'padding_mode', 'match_color_strength', 'soften_consistency_mask',\n",
+ " 'mask_result', 'use_patchmatch_inpaiting', 'cond_image_src', 'set_seed', 'clamp_grad', 'clamp_max', 'sat_scale',\n",
+ " 'init_grad', 'grad_denoised', 'blend_latent_to_init', 'fixed_code', 'code_randomness', 'dynamic_thresh',\n",
+ " 'sampler', 'use_karras_noise', 'inpainting_mask_weight', 'inverse_inpainting_mask', 'inpainting_mask_source',\n",
+ " 'normalize_latent', 'normalize_latent_offset', 'latent_norm_4d', 'colormatch_frame', 'color_match_frame_str',\n",
+ " 'colormatch_offset', 'colormatch_method', 'colormatch_regrain', 'colormatch_after',\n",
+ " 'fixed_seed', 'rec_cfg', 'rec_steps_pct', 'rec_randomness', 'use_predicted_noise', 'overwrite_rec_noise',\n",
+ " 'save_controlnet_annotations', 'control_sd15_openpose_hands_face', 'control_sd15_depth_detector',\n",
+ " 'control_sd15_softedge_detector', 'control_sd15_seg_detector', 'control_sd15_scribble_detector',\n",
+ " 'control_sd15_lineart_coarse', 'control_sd15_inpaint_mask_source', 'control_sd15_shuffle_source',\n",
+ " 'control_sd15_shuffle_1st_source', 'controlnet_multimodel', 'controlnet_mode', 'normalize_cn_weights',\n",
+ " 'controlnet_preprocess', 'detect_resolution', 'bg_threshold', 'low_threshold', 'high_threshold',\n",
+ " 'value_threshold', 'distance_threshold', 'temporalnet_source', 'temporalnet_skip_1st_frame',\n",
+ " 'controlnet_multimodel_mode', 'max_faces', 'do_softcap', 'softcap_thresh', 'softcap_q', 'masked_guidance',\n",
+ " 'alpha_masked_diffusion', 'invert_alpha_masked_diffusion', 'normalize_prompt_weights', 'sd_batch_size',\n",
+ " 'controlnet_low_vram', 'deflicker_scale', 'deflicker_latent_scale', 'pose_detector','apply_freeu_after_control','do_freeunet']\n",
+ "user_settings_eval_keys = ['latent_scale_schedule', 'init_scale_schedule', 'steps_schedule', 'style_strength_schedule',\n",
+ " 'cfg_scale_schedule', 'flow_blend_schedule', 'image_scale_schedule', 'flow_override_map',\n",
+ " 'text_prompts', 'negative_prompts', 'prompt_patterns_sched', 'latent_fixed_mean',\n",
+ " 'latent_fixed_std', 'rec_prompts', 'cc_masked_diffusion_schedule', 'mask_paths']\n",
+ "#init settings\n",
+ "user_settings = {} #init empty to check for missing keys\n",
+ "# user_settings = dict([(key,'') for key in user_settings_keys])\n",
+ "image_prompts = {}\n",
+ "\n",
+ "import os, random, torch\n",
+ "import numpy as np\n",
+ "\n",
+ "\n",
+ "def seed_everything(seed, deterministic=False):\n",
+ " print(f'Set global seed to {seed}')\n",
+ " random.seed(seed)\n",
+ " np.random.seed(seed)\n",
+ " torch.manual_seed(seed)\n",
+ " torch.cuda.manual_seed_all(seed)\n",
+ " os.environ['PYTHONHASHSEED'] = str(seed)\n",
+ " if deterministic:\n",
+ " torch.backends.cudnn.deterministic = True\n",
+ " torch.backends.cudnn.benchmark = False\n",
+ "\n",
+ "import torch\n",
+ "from dataclasses import dataclass\n",
+ "from functools import partial\n",
+ "import cv2\n",
+ "import pandas as pd\n",
+ "import gc\n",
+ "import io\n",
+ "import math\n",
+ "from IPython import display\n",
+ "import lpips\n",
+ "# !wget \"https://download.pytorch.org/models/vgg16-397923af.pth\" -O /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth\n",
+ "from PIL import Image, ImageOps, ImageDraw\n",
+ "import requests\n",
+ "from glob import glob\n",
+ "import json\n",
+ "from types import SimpleNamespace\n",
+ "from torch import nn\n",
+ "from torch.nn import functional as F\n",
+ "import torchvision.transforms as T\n",
+ "import torchvision.transforms.functional as TF\n",
+ "from tqdm.notebook import tqdm\n",
+ "# from CLIP import clip\n",
+ "# from resize_right import resize\n",
+ "# from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n",
+ "from datetime import datetime\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import random\n",
+ "from ipywidgets import Output\n",
+ "import hashlib\n",
+ "from functools import partial\n",
+ "if is_colab:\n",
+ " os.chdir('/content')\n",
+ " from google.colab import files\n",
+ "else:\n",
+ " os.chdir(f'{PROJECT_DIR}')\n",
+ "from IPython.display import Image as ipyimg\n",
+ "from numpy import asarray\n",
+ "from einops import rearrange, repeat\n",
+ "import torch, torchvision\n",
+ "import time\n",
+ "from omegaconf import OmegaConf\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
+ "\n",
+ "import torch\n",
+ "DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
+ "print('Using device:', DEVICE)\n",
+ "device = DEVICE # At least one of the modules expects this name..\n",
+ "\n",
+ "if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n",
+ " print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n",
+ " torch.backends.cudnn.enabled = False\n",
+ "elif torch.cuda.get_device_capability(DEVICE)[0] == 8: ## A100 fix thanks to Emad\n",
+ " print('Disabling CUDNN for Ada gpu', file=sys.stderr)\n",
+ " torch.backends.cudnn.enabled = False\n",
+ "\n",
+ "import open_clip\n",
+ "\n",
+ "#@title 1.5 Define necessary functions\n",
+ "\n",
+ "from typing import Mapping\n",
+ "\n",
+ "import mediapipe as mp\n",
+ "import numpy\n",
+ "from PIL import Image\n",
+ "\n",
+ "\n",
+ "def append_dims(x, n):\n",
+ " return x[(Ellipsis, *(None,) * (n - x.ndim))]\n",
+ "\n",
+ "def expand_to_planes(x, shape):\n",
+ " return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n",
+ "\n",
+ "def alpha_sigma_to_t(alpha, sigma):\n",
+ " return torch.atan2(sigma, alpha) * 2 / math.pi\n",
+ "\n",
+ "def t_to_alpha_sigma(t):\n",
+ " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
+ "\n",
+ "mp_drawing = mp.solutions.drawing_utils\n",
+ "mp_drawing_styles = mp.solutions.drawing_styles\n",
+ "mp_face_detection = mp.solutions.face_detection # Only for counting faces.\n",
+ "mp_face_mesh = mp.solutions.face_mesh\n",
+ "mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION\n",
+ "mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS\n",
+ "mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS\n",
+ "\n",
+ "DrawingSpec = mp.solutions.drawing_styles.DrawingSpec\n",
+ "PoseLandmark = mp.solutions.drawing_styles.PoseLandmark\n",
+ "\n",
+ "f_thick = 2\n",
+ "f_rad = 1\n",
+ "right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)\n",
+ "right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)\n",
+ "right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)\n",
+ "left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)\n",
+ "left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)\n",
+ "left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)\n",
+ "mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)\n",
+ "head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)\n",
+ "\n",
+ "# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.\n",
+ "face_connection_spec = {}\n",
+ "for edge in mp_face_mesh.FACEMESH_FACE_OVAL:\n",
+ " face_connection_spec[edge] = head_draw\n",
+ "for edge in mp_face_mesh.FACEMESH_LEFT_EYE:\n",
+ " face_connection_spec[edge] = left_eye_draw\n",
+ "for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:\n",
+ " face_connection_spec[edge] = left_eyebrow_draw\n",
+ "# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:\n",
+ "# face_connection_spec[edge] = left_iris_draw\n",
+ "for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:\n",
+ " face_connection_spec[edge] = right_eye_draw\n",
+ "for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:\n",
+ " face_connection_spec[edge] = right_eyebrow_draw\n",
+ "# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:\n",
+ "# face_connection_spec[edge] = right_iris_draw\n",
+ "for edge in mp_face_mesh.FACEMESH_LIPS:\n",
+ " face_connection_spec[edge] = mouth_draw\n",
+ "iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}\n",
+ "\n",
+ "\n",
+ "def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):\n",
+ " \"\"\"We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all\n",
+ " landmarks. Until our PR is merged into mediapipe, we need this separate method.\"\"\"\n",
+ " if len(image.shape) != 3:\n",
+ " raise ValueError(\"Input image must be H,W,C.\")\n",
+ " image_rows, image_cols, image_channels = image.shape\n",
+ " if image_channels != 3: # BGR channels\n",
+ " raise ValueError('Input image must contain three channel bgr data.')\n",
+ " for idx, landmark in enumerate(landmark_list.landmark):\n",
+ " if (\n",
+ " (landmark.HasField('visibility') and landmark.visibility < 0.9) or\n",
+ " (landmark.HasField('presence') and landmark.presence < 0.5)\n",
+ " ):\n",
+ " continue\n",
+ " if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:\n",
+ " continue\n",
+ " image_x = int(image_cols*landmark.x)\n",
+ " image_y = int(image_rows*landmark.y)\n",
+ " draw_color = None\n",
+ " if isinstance(drawing_spec, Mapping):\n",
+ " if drawing_spec.get(idx) is None:\n",
+ " continue\n",
+ " else:\n",
+ " draw_color = drawing_spec[idx].color\n",
+ " elif isinstance(drawing_spec, DrawingSpec):\n",
+ " draw_color = drawing_spec.color\n",
+ " image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color\n",
+ "\n",
+ "\n",
+ "def reverse_channels(image):\n",
+ " \"\"\"Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB.\"\"\"\n",
+ " # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.\n",
+ " # im[:,:,::[2,1,0]] would also work but makes a copy of the data.\n",
+ " return image[:, :, ::-1]\n",
+ "\n",
+ "\n",
+ "def generate_annotation(\n",
+ " input_image: Image.Image,\n",
+ " max_faces: int,\n",
+ " min_face_size_pixels: int = 0,\n",
+ " return_annotation_data: bool = False\n",
+ "):\n",
+ " \"\"\"\n",
+ " Find up to 'max_faces' inside the provided input image.\n",
+ " If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many\n",
+ " pixels in the image.\n",
+ " If return_annotation_data is TRUE (default: false) then in addition to returning the 'detected face' image, three\n",
+ " additional parameters will be returned: faces before filtering, faces after filtering, and an annotation image.\n",
+ " The faces_before_filtering return value is the number of faces detected in an image with no filtering.\n",
+ " faces_after_filtering is the number of faces remaining after filtering small faces.\n",
+ " :return:\n",
+ " If 'return_annotation_data==True', returns (numpy array, numpy array, int, int).\n",
+ " If 'return_annotation_data==False' (default), returns a numpy array.\n",
+ " \"\"\"\n",
+ " with mp_face_mesh.FaceMesh(\n",
+ " static_image_mode=True,\n",
+ " max_num_faces=max_faces,\n",
+ " refine_landmarks=True,\n",
+ " min_detection_confidence=0.5,\n",
+ " ) as facemesh:\n",
+ " img_rgb = numpy.asarray(input_image)\n",
+ " results = facemesh.process(img_rgb).multi_face_landmarks\n",
+ " if results is None:\n",
+ " return None\n",
+ " faces_found_before_filtering = len(results)\n",
+ "\n",
+ " # Filter faces that are too small\n",
+ " filtered_landmarks = []\n",
+ " for lm in results:\n",
+ " landmarks = lm.landmark\n",
+ " face_rect = [\n",
+ " landmarks[0].x,\n",
+ " landmarks[0].y,\n",
+ " landmarks[0].x,\n",
+ " landmarks[0].y,\n",
+ " ] # Left, up, right, down.\n",
+ " for i in range(len(landmarks)):\n",
+ " face_rect[0] = min(face_rect[0], landmarks[i].x)\n",
+ " face_rect[1] = min(face_rect[1], landmarks[i].y)\n",
+ " face_rect[2] = max(face_rect[2], landmarks[i].x)\n",
+ " face_rect[3] = max(face_rect[3], landmarks[i].y)\n",
+ " if min_face_size_pixels > 0:\n",
+ " face_width = abs(face_rect[2] - face_rect[0])\n",
+ " face_height = abs(face_rect[3] - face_rect[1])\n",
+ " face_width_pixels = face_width * input_image.size[0]\n",
+ " face_height_pixels = face_height * input_image.size[1]\n",
+ " face_size = min(face_width_pixels, face_height_pixels)\n",
+ " if face_size >= min_face_size_pixels:\n",
+ " filtered_landmarks.append(lm)\n",
+ " else:\n",
+ " filtered_landmarks.append(lm)\n",
+ "\n",
+ " faces_remaining_after_filtering = len(filtered_landmarks)\n",
+ "\n",
+ " # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.\n",
+ " empty = numpy.zeros_like(img_rgb)\n",
+ "\n",
+ " # Draw detected faces:\n",
+ " for face_landmarks in filtered_landmarks:\n",
+ " mp_drawing.draw_landmarks(\n",
+ " empty,\n",
+ " face_landmarks,\n",
+ " connections=face_connection_spec.keys(),\n",
+ " landmark_drawing_spec=None,\n",
+ " connection_drawing_spec=face_connection_spec\n",
+ " )\n",
+ " draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)\n",
+ "\n",
+ " # Flip BGR back to RGB.\n",
+ " empty = reverse_channels(empty)\n",
+ "\n",
+ " # We might have to generate a composite.\n",
+ " if return_annotation_data:\n",
+ " # Note that we're copying the input image AND flipping the channels so we can draw on top of it.\n",
+ " annotated = reverse_channels(numpy.asarray(input_image)).copy()\n",
+ " for face_landmarks in filtered_landmarks:\n",
+ " mp_drawing.draw_landmarks(\n",
+ " empty,\n",
+ " face_landmarks,\n",
+ " connections=face_connection_spec.keys(),\n",
+ " landmark_drawing_spec=None,\n",
+ " connection_drawing_spec=face_connection_spec\n",
+ " )\n",
+ " draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)\n",
+ " annotated = reverse_channels(annotated)\n",
+ "\n",
+ " if not return_annotation_data:\n",
+ " return empty\n",
+ " else:\n",
+ " return empty, annotated, faces_found_before_filtering, faces_remaining_after_filtering\n",
+ "\n",
+ "\n",
+ "\n",
+ "# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n",
+ "import PIL\n",
+ "\n",
+ "\n",
+ "def interp(t):\n",
+ " return 3 * t**2 - 2 * t ** 3\n",
+ "\n",
+ "def perlin(width, height, scale=10, device=None):\n",
+ " gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n",
+ " xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n",
+ " ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n",
+ " wx = 1 - interp(xs)\n",
+ " wy = 1 - interp(ys)\n",
+ " dots = 0\n",
+ " dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n",
+ " dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n",
+ " dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n",
+ " dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n",
+ " return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n",
+ "\n",
+ "def perlin_ms(octaves, width, height, grayscale, device=device):\n",
+ " out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n",
+ " # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n",
+ " for i in range(1 if grayscale else 3):\n",
+ " scale = 2 ** len(octaves)\n",
+ " oct_width = width\n",
+ " oct_height = height\n",
+ " for oct in octaves:\n",
+ " p = perlin(oct_width, oct_height, scale, device)\n",
+ " out_array[i] += p * oct\n",
+ " scale //= 2\n",
+ " oct_width *= 2\n",
+ " oct_height *= 2\n",
+ " return torch.cat(out_array)\n",
+ "\n",
+ "def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n",
+ " out = perlin_ms(octaves, width, height, grayscale)\n",
+ " if grayscale:\n",
+ " out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n",
+ " out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n",
+ " else:\n",
+ " out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n",
+ " out = TF.resize(size=(side_y, side_x), img=out)\n",
+ " out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n",
+ "\n",
+ " out = ImageOps.autocontrast(out)\n",
+ " return out\n",
+ "\n",
+ "def regen_perlin():\n",
+ " if perlin_mode == 'color':\n",
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
+ " elif perlin_mode == 'gray':\n",
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
+ " else:\n",
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
+ "\n",
+ " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
+ " del init2\n",
+ " return init.expand(batch_size, -1, -1, -1)\n",
+ "\n",
+ "def fetch(url_or_path):\n",
+ " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
+ " r = requests.get(url_or_path)\n",
+ " r.raise_for_status()\n",
+ " fd = io.BytesIO()\n",
+ " fd.write(r.content)\n",
+ " fd.seek(0)\n",
+ " return fd\n",
+ " return open(url_or_path, 'rb')\n",
+ "\n",
+ "def read_image_workaround(path):\n",
+ " \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n",
+ " this incompatibility to avoid colour inversions.\"\"\"\n",
+ " im_tmp = cv2.imread(path)\n",
+ " return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n",
+ "\n",
+ "def parse_prompt(prompt):\n",
+ " if prompt.startswith('http://') or prompt.startswith('https://'):\n",
+ " vals = prompt.rsplit(':', 2)\n",
+ " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n",
+ " else:\n",
+ " vals = prompt.rsplit(':', 1)\n",
+ " vals = vals + ['', '1'][len(vals):]\n",
+ " return vals[0], float(vals[1])\n",
+ "\n",
+ "def sinc(x):\n",
+ " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n",
+ "\n",
+ "def lanczos(x, a):\n",
+ " cond = torch.logical_and(-a < x, x < a)\n",
+ " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n",
+ " return out / out.sum()\n",
+ "\n",
+ "def ramp(ratio, width):\n",
+ " n = math.ceil(width / ratio + 1)\n",
+ " out = torch.empty([n])\n",
+ " cur = 0\n",
+ " for i in range(out.shape[0]):\n",
+ " out[i] = cur\n",
+ " cur += ratio\n",
+ " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n",
+ "\n",
+ "def resample(input, size, align_corners=True):\n",
+ " n, c, h, w = input.shape\n",
+ " dh, dw = size\n",
+ "\n",
+ " input = input.reshape([n * c, 1, h, w])\n",
+ "\n",
+ " if dh < h:\n",
+ " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n",
+ " pad_h = (kernel_h.shape[0] - 1) // 2\n",
+ " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n",
+ " input = F.conv2d(input, kernel_h[None, None, :, None])\n",
+ "\n",
+ " if dw < w:\n",
+ " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n",
+ " pad_w = (kernel_w.shape[0] - 1) // 2\n",
+ " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n",
+ " input = F.conv2d(input, kernel_w[None, None, None, :])\n",
+ "\n",
+ " input = input.reshape([n, c, h, w])\n",
+ " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n",
+ "\n",
+ "class MakeCutouts(nn.Module):\n",
+ " def __init__(self, cut_size, cutn, skip_augs=False):\n",
+ " super().__init__()\n",
+ " self.cut_size = cut_size\n",
+ " self.cutn = cutn\n",
+ " self.skip_augs = skip_augs\n",
+ " self.augs = T.Compose([\n",
+ " T.RandomHorizontalFlip(p=0.5),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomGrayscale(p=0.15),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
+ " ])\n",
+ "\n",
+ " def forward(self, input):\n",
+ " input = T.Pad(input.shape[2]//4, fill=0)(input)\n",
+ " sideY, sideX = input.shape[2:4]\n",
+ " max_size = min(sideX, sideY)\n",
+ "\n",
+ " cutouts = []\n",
+ " for ch in range(self.cutn):\n",
+ " if ch > self.cutn - self.cutn//4:\n",
+ " cutout = input.clone()\n",
+ " else:\n",
+ " size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n",
+ " offsetx = torch.randint(0, abs(sideX - size + 1), ())\n",
+ " offsety = torch.randint(0, abs(sideY - size + 1), ())\n",
+ " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
+ "\n",
+ " if not self.skip_augs:\n",
+ " cutout = self.augs(cutout)\n",
+ " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n",
+ " del cutout\n",
+ "\n",
+ " cutouts = torch.cat(cutouts, dim=0)\n",
+ " return cutouts\n",
+ "\n",
+ "cutout_debug = False\n",
+ "padargs = {}\n",
+ "\n",
+ "class MakeCutoutsDango(nn.Module):\n",
+ " def __init__(self, cut_size,\n",
+ " Overview=4,\n",
+ " InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.cut_size = cut_size\n",
+ " self.Overview = Overview\n",
+ " self.InnerCrop = InnerCrop\n",
+ " self.IC_Size_Pow = IC_Size_Pow\n",
+ " self.IC_Grey_P = IC_Grey_P\n",
+ " if args.animation_mode == 'None':\n",
+ " self.augs = T.Compose([\n",
+ " T.RandomHorizontalFlip(p=0.5),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomGrayscale(p=0.1),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
+ " ])\n",
+ " elif args.animation_mode == 'Video Input Legacy':\n",
+ " self.augs = T.Compose([\n",
+ " T.RandomHorizontalFlip(p=0.5),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomGrayscale(p=0.15),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
+ " ])\n",
+ " elif args.animation_mode == '2D' or args.animation_mode == 'Video Input':\n",
+ " self.augs = T.Compose([\n",
+ " T.RandomHorizontalFlip(p=0.4),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.RandomGrayscale(p=0.1),\n",
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
+ " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n",
+ " ])\n",
+ "\n",
+ "\n",
+ " def forward(self, input):\n",
+ " cutouts = []\n",
+ " gray = T.Grayscale(3)\n",
+ " sideY, sideX = input.shape[2:4]\n",
+ " max_size = min(sideX, sideY)\n",
+ " min_size = min(sideX, sideY, self.cut_size)\n",
+ " l_size = max(sideX, sideY)\n",
+ " output_shape = [1,3,self.cut_size,self.cut_size]\n",
+ " output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n",
+ " pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n",
+ " cutout = resize(pad_input, out_shape=output_shape)\n",
+ "\n",
+ " if self.Overview>0:\n",
+ " if self.Overview<=4:\n",
+ " if self.Overview>=1:\n",
+ " cutouts.append(cutout)\n",
+ " if self.Overview>=2:\n",
+ " cutouts.append(gray(cutout))\n",
+ " if self.Overview>=3:\n",
+ " cutouts.append(TF.hflip(cutout))\n",
+ " if self.Overview==4:\n",
+ " cutouts.append(gray(TF.hflip(cutout)))\n",
+ " else:\n",
+ " cutout = resize(pad_input, out_shape=output_shape)\n",
+ " for _ in range(self.Overview):\n",
+ " cutouts.append(cutout)\n",
+ "\n",
+ " if cutout_debug:\n",
+ " if is_colab:\n",
+ " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n",
+ " else:\n",
+ " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n",
+ "\n",
+ "\n",
+ " if self.InnerCrop >0:\n",
+ " for i in range(self.InnerCrop):\n",
+ " size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n",
+ " offsetx = torch.randint(0, sideX - size + 1, ())\n",
+ " offsety = torch.randint(0, sideY - size + 1, ())\n",
+ " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
+ " if i <= int(self.IC_Grey_P * self.InnerCrop):\n",
+ " cutout = gray(cutout)\n",
+ " cutout = resize(cutout, out_shape=output_shape)\n",
+ " cutouts.append(cutout)\n",
+ " if cutout_debug:\n",
+ " if is_colab:\n",
+ " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n",
+ " else:\n",
+ " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n",
+ " cutouts = torch.cat(cutouts)\n",
+ " if skip_augs is not True: cutouts=self.augs(cutouts)\n",
+ " return cutouts\n",
+ "\n",
+ "def spherical_dist_loss(x, y):\n",
+ " x = F.normalize(x, dim=-1)\n",
+ " y = F.normalize(y, dim=-1)\n",
+ " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n",
+ "\n",
+ "def tv_loss(input):\n",
+ " \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n",
+ " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n",
+ " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n",
+ " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n",
+ " return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n",
+ "\n",
+ "def get_image_from_lat(lat):\n",
+ " img = sd_model.decode_first_stage(lat.cuda())[0]\n",
+ " return TF.to_pil_image(img.add(1).div(2).clamp(0, 1))\n",
+ "\n",
+ "\n",
+ "def get_lat_from_pil(frame):\n",
+ " print(frame.shape, 'frame2pil.shape')\n",
+ " frame = np.array(frame)\n",
+ " frame = (frame/255.)[None,...].transpose(0, 3, 1, 2)\n",
+ " frame = 2*torch.from_numpy(frame).float().cuda()-1.\n",
+ " return sd_model.get_first_stage_encoding(sd_model.encode_first_stage(frame))\n",
+ "\n",
+ "\n",
+ "def range_loss(input):\n",
+ " return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n",
+ "\n",
+ "stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n",
+ "TRANSLATION_SCALE = 1.0/200.0\n",
+ "\n",
+ "def get_sched_from_json(frame_num, sched_json, blend=False):\n",
+ "\n",
+ " frame_num = int(frame_num)\n",
+ " frame_num = max(frame_num, 0)\n",
+ " sched_int = {}\n",
+ " for key in sched_json.keys():\n",
+ " sched_int[int(key)] = sched_json[key]\n",
+ " sched_json = sched_int\n",
+ " keys = sorted(list(sched_json.keys())); #print(keys)\n",
+ " if frame_num<0:\n",
+ " frame_num = max(keys)\n",
+ " try:\n",
+ " frame_num = min(frame_num,max(keys)) #clamp frame num to 0:max(keys) range\n",
+ " except:\n",
+ " pass\n",
+ "\n",
+ " # print('clamped frame num ', frame_num)\n",
+ " if frame_num in keys:\n",
+ " return sched_json[frame_num]; #print('frame in keys')\n",
+ " if frame_num not in keys:\n",
+ " for i in range(len(keys)-1):\n",
+ " k1 = keys[i]\n",
+ " k2 = keys[i+1]\n",
+ " if frame_num > k1 and frame_num < k2:\n",
+ " if not blend:\n",
+ " print('frame between keys, no blend')\n",
+ " return sched_json[k1]\n",
+ " if blend:\n",
+ " total_dist = k2-k1\n",
+ " dist_from_k1 = frame_num - k1\n",
+ " return sched_json[k1]*(1 - dist_from_k1/total_dist) + sched_json[k2]*(dist_from_k1/total_dist)\n",
+ " #else: print(f'frame {frame_num} not in {k1} {k2}')\n",
+ " return 0\n",
+ "\n",
+ "def get_scheduled_arg(frame_num, schedule):\n",
+ " if isinstance(schedule, list):\n",
+ " return schedule[frame_num] if frame_num 0:\n",
+ " arr = np.array(init_image_alpha)\n",
+ " if mask_clip_high < 255:\n",
+ " arr = np.where(arr 0:\n",
+ " arr = np.where(arr>mask_clip_low, arr, 0)\n",
+ " init_image_alpha = Image.fromarray(arr)\n",
+ "\n",
+ " if background == 'color':\n",
+ " bg = Image.new('RGB', size, background_source)\n",
+ " if background == 'image':\n",
+ " bg = Image.open(background_source).convert('RGB').resize(size)\n",
+ " if background == 'init_video':\n",
+ " bg = Image.open(f'{videoFramesFolder}/{frame_num+1:06}.jpg').resize(size)\n",
+ " # init_image.putalpha(init_image_alpha)\n",
+ " if warp_mode == 'use_image':\n",
+ " bg.paste(init_image, (0,0), init_image_alpha)\n",
+ " if warp_mode == 'use_latent':\n",
+ " #convert bg to latent\n",
+ "\n",
+ " bg = np.array(bg)\n",
+ " bg = (bg/255.)[None,...].transpose(0, 3, 1, 2)\n",
+ " bg = 2*torch.from_numpy(bg).float().cuda()-1.\n",
+ " bg = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(bg))\n",
+ " bg = bg.cpu().numpy()#[0].transpose(1,2,0)\n",
+ " init_image_alpha = np.array(init_image_alpha)[::8,::8][None, None, ...]\n",
+ " init_image_alpha = np.repeat(init_image_alpha, 4, axis = 1)/255\n",
+ " print(bg.shape, init_image.shape, init_image_alpha.shape, init_image_alpha.max(), init_image_alpha.min())\n",
+ " bg = init_image*init_image_alpha + bg*(1-init_image_alpha)\n",
+ " return bg\n",
+ "\n",
+ "def softcap(arr, thresh=0.8, q=0.95):\n",
+ " cap = torch.quantile(abs(arr).float(), q)\n",
+ " printf('q -----', torch.quantile(abs(arr).float(), torch.Tensor([0.25,0.5,0.75,0.9,0.95,0.99,1]).cuda()))\n",
+ " cap_ratio = (1-thresh)/(cap-thresh)\n",
+ " arr = torch.where(arr>thresh, thresh+(arr-thresh)*cap_ratio, arr)\n",
+ " arr = torch.where(arr<-thresh, -thresh+(arr+thresh)*cap_ratio, arr)\n",
+ " return arr\n",
+ "\n",
+ "def do_run():\n",
+ " seed = args.seed\n",
+ " print(range(args.start_frame, args.max_frames))\n",
+ " if args.animation_mode != \"None\":\n",
+ " batchBar = tqdm(total=args.max_frames, desc =\"Frames\")\n",
+ "\n",
+ " # if (args.animation_mode == 'Video Input') and (args.midas_weight > 0.0):\n",
+ " # midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model(args.midas_depth_model)\n",
+ " for frame_num in range(args.start_frame, args.max_frames):\n",
+ " try:\n",
+ " sd_model.cpu()\n",
+ " sd_model.model.cpu()\n",
+ " sd_model.cond_stage_model.cpu()\n",
+ " sd_model.first_stage_model.cpu()\n",
+ " if 'control' in model_version:\n",
+ " for key in loaded_controlnets.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ " except: pass\n",
+ " try:\n",
+ " apply_openpose.body_estimation.model.cpu()\n",
+ " apply_openpose.hand_estimation.model.cpu()\n",
+ " apply_openpose.face_estimation.model.cpu()\n",
+ " except: pass\n",
+ " try:\n",
+ " sd_model.model.diffusion_model.cpu()\n",
+ " except: pass\n",
+ " try:\n",
+ " apply_softedge.netNetwork.cpu()\n",
+ " except: pass\n",
+ " try:\n",
+ " apply_normal.netNetwork.cpu()\n",
+ " except: pass\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " if stop_on_next_loop:\n",
+ " break\n",
+ "\n",
+ " # display.clear_output(wait=True)\n",
+ "\n",
+ " # Print Frame progress if animation mode is on\n",
+ " if args.animation_mode != \"None\":\n",
+ " display.display(batchBar.container)\n",
+ " batchBar.n = frame_num\n",
+ " batchBar.update(1)\n",
+ " batchBar.refresh()\n",
+ " # display.display(batchBar.container)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " # Inits if not video frames\n",
+ " if args.animation_mode != \"Video Input Legacy\":\n",
+ " if args.init_image == '':\n",
+ " init_image = None\n",
+ " else:\n",
+ " init_image = args.init_image\n",
+ " init_scale = get_scheduled_arg(frame_num, init_scale_schedule)\n",
+ " # init_scale = args.init_scale\n",
+ " steps = int(get_scheduled_arg(frame_num, steps_schedule))\n",
+ " style_strength = get_scheduled_arg(frame_num, style_strength_schedule)\n",
+ " skip_steps = int(steps-steps*style_strength)\n",
+ " # skip_steps = args.skip_steps\n",
+ "\n",
+ " if args.animation_mode == 'Video Input':\n",
+ " if frame_num == args.start_frame:\n",
+ " steps = int(get_scheduled_arg(frame_num, steps_schedule))\n",
+ " style_strength = get_scheduled_arg(frame_num, style_strength_schedule)\n",
+ " skip_steps = int(steps-steps*style_strength)\n",
+ " # skip_steps = args.skip_steps\n",
+ "\n",
+ " # init_scale = args.init_scale\n",
+ " init_scale = get_scheduled_arg(frame_num, init_scale_schedule)\n",
+ " # init_latent_scale = args.init_latent_scale\n",
+ " init_latent_scale = get_scheduled_arg(frame_num, latent_scale_schedule)\n",
+ " init_image = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " if use_background_mask:\n",
+ " init_image_pil = Image.open(init_image)\n",
+ " init_image_pil = apply_mask(init_image_pil, frame_num, background, background_source, invert_mask)\n",
+ " init_image_pil.save(f'init_alpha_{frame_num}.png')\n",
+ " init_image = f'init_alpha_{frame_num}.png'\n",
+ " if (args.init_image != '') and args.init_image is not None:\n",
+ " init_image = args.init_image\n",
+ " if use_background_mask:\n",
+ " init_image_pil = Image.open(init_image)\n",
+ " init_image_pil = apply_mask(init_image_pil, frame_num, background, background_source, invert_mask)\n",
+ " init_image_pil.save(f'init_alpha_{frame_num}.png')\n",
+ " init_image = f'init_alpha_{frame_num}.png'\n",
+ " if VERBOSE:print('init image', args.init_image)\n",
+ " if frame_num > 0 and frame_num != frame_range[0]:\n",
+ " # print(frame_num)\n",
+ "\n",
+ " first_frame_source = batchFolder+f\"/{batch_name}({batchNum})_{args.start_frame:06}.png\"\n",
+ " if os.path.exists(first_frame_source):\n",
+ " first_frame = Image.open(first_frame_source)\n",
+ " else:\n",
+ " first_frame_source = batchFolder+f\"/{batch_name}({batchNum})_{args.start_frame-1:06}.png\"\n",
+ " first_frame = Image.open(first_frame_source)\n",
+ "\n",
+ "\n",
+ " # print(frame_num)\n",
+ "\n",
+ " # first_frame = Image.open(batchFolder+f\"/{batch_name}({batchNum})_{args.start_frame:06}.png\")\n",
+ " # first_frame_source = batchFolder+f\"/{batch_name}({batchNum})_{args.start_frame:06}.png\"\n",
+ " if not fixed_seed:\n",
+ " seed += 1\n",
+ " if resume_run and frame_num == start_frame:\n",
+ " print('if resume_run and frame_num == start_frame')\n",
+ " img_filepath = batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:06}.png\"\n",
+ " if turbo_mode and frame_num > turbo_preroll:\n",
+ " shutil.copyfile(img_filepath, 'oldFrameScaled.png')\n",
+ " else:\n",
+ " shutil.copyfile(img_filepath, 'prevFrame.png')\n",
+ " else:\n",
+ " # img_filepath = '/content/prevFrame.png' if is_colab else 'prevFrame.png'\n",
+ " img_filepath = 'prevFrame.png'\n",
+ "\n",
+ " next_step_pil = do_3d_step(img_filepath, frame_num, forward_clip=forward_weights_clip)\n",
+ " if warp_mode == 'use_image':\n",
+ " next_step_pil.save('prevFrameScaled.png')\n",
+ " else:\n",
+ " # init_image = 'prevFrameScaled_lat.pt'\n",
+ " # next_step_pil.save('prevFrameScaled.png')\n",
+ " torch.save(next_step_pil, 'prevFrameScaled_lat.pt')\n",
+ "\n",
+ " steps = int(get_scheduled_arg(frame_num, steps_schedule))\n",
+ " style_strength = get_scheduled_arg(frame_num, style_strength_schedule)\n",
+ " skip_steps = int(steps-steps*style_strength)\n",
+ " # skip_steps = args.calc_frames_skip_steps\n",
+ "\n",
+ " ### Turbo mode - skip some diffusions, use 3d morph for clarity and to save time\n",
+ " if turbo_mode:\n",
+ " if frame_num == turbo_preroll: #start tracking oldframe\n",
+ " if warp_mode == 'use_image':\n",
+ " next_step_pil.save('oldFrameScaled.png')#stash for later blending\n",
+ " if warp_mode == 'use_latent':\n",
+ " # lat_from_img = get_lat/_from_pil(next_step_pil)\n",
+ " torch.save(next_step_pil, 'oldFrameScaled_lat.pt')\n",
+ " elif frame_num > turbo_preroll:\n",
+ " #set up 2 warped image sequences, old & new, to blend toward new diff image\n",
+ " if warp_mode == 'use_image':\n",
+ " old_frame = do_3d_step('oldFrameScaled.png', frame_num, forward_clip=forward_weights_clip_turbo_step)\n",
+ " old_frame.save('oldFrameScaled.png')\n",
+ " if warp_mode == 'use_latent':\n",
+ " old_frame = do_3d_step('oldFrameScaled.png', frame_num, forward_clip=forward_weights_clip_turbo_step)\n",
+ "\n",
+ " # lat_from_img = get_lat_from_pil(old_frame)\n",
+ " torch.save(old_frame, 'oldFrameScaled_lat.pt')\n",
+ " if frame_num % int(turbo_steps) != 0:\n",
+ " print('turbo skip this frame: skipping clip diffusion steps')\n",
+ " filename = f'{args.batch_name}({args.batchNum})_{frame_num:06}.png'\n",
+ " blend_factor = ((frame_num % int(turbo_steps))+1)/int(turbo_steps)\n",
+ " print('turbo skip this frame: skipping clip diffusion steps and saving blended frame')\n",
+ " if warp_mode == 'use_image':\n",
+ " newWarpedImg = cv2.imread('prevFrameScaled.png')#this is already updated..\n",
+ " oldWarpedImg = cv2.imread('oldFrameScaled.png')\n",
+ " blendedImage = cv2.addWeighted(newWarpedImg, blend_factor, oldWarpedImg,1-blend_factor, 0.0)\n",
+ " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
+ " next_step_pil.save(f'{img_filepath}') # save it also as prev_frame to feed next iteration\n",
+ " if warp_mode == 'use_latent':\n",
+ " newWarpedImg = torch.load('prevFrameScaled_lat.pt')#this is already updated..\n",
+ " oldWarpedImg = torch.load('oldFrameScaled_lat.pt')\n",
+ " blendedImage = newWarpedImg*(blend_factor)+oldWarpedImg*(1-blend_factor)\n",
+ " blendedImage = get_image_from_lat(blendedImage).save(f'{batchFolder}/{filename}')\n",
+ " torch.save(next_step_pil,f'{img_filepath[:-4]}_lat.pt')\n",
+ "\n",
+ "\n",
+ " if turbo_frame_skips_steps is not None:\n",
+ " if warp_mode == 'use_image':\n",
+ " oldWarpedImg = cv2.imread('prevFrameScaled.png')\n",
+ " cv2.imwrite(f'oldFrameScaled.png',oldWarpedImg)#swap in for blending later\n",
+ " print('clip/diff this frame - generate clip diff image')\n",
+ " if warp_mode == 'use_latent':\n",
+ " oldWarpedImg = torch.load('prevFrameScaled_lat.pt')\n",
+ " torch.save(oldWarpedImg, f'oldFrameScaled_lat.pt',)#swap in for blending later\n",
+ " skip_steps = math.floor(steps * turbo_frame_skips_steps)\n",
+ " else: continue\n",
+ " else:\n",
+ " #if not a skip frame, will run diffusion and need to blend.\n",
+ " if warp_mode == 'use_image':\n",
+ " oldWarpedImg = cv2.imread('prevFrameScaled.png')\n",
+ " cv2.imwrite(f'oldFrameScaled.png',oldWarpedImg)#swap in for blending later\n",
+ " print('clip/diff this frame - generate clip diff image')\n",
+ " if warp_mode == 'use_latent':\n",
+ " oldWarpedImg = torch.load('prevFrameScaled_lat.pt')\n",
+ " torch.save(oldWarpedImg, f'oldFrameScaled_lat.pt',)#swap in for blending later\n",
+ " # oldWarpedImg = cv2.imread('prevFrameScaled.png')\n",
+ " # cv2.imwrite(f'oldFrameScaled.png',oldWarpedImg)#swap in for blending later\n",
+ " print('clip/diff this frame - generate clip diff image')\n",
+ " if warp_mode == 'use_image':\n",
+ " init_image = 'prevFrameScaled.png'\n",
+ " else:\n",
+ " init_image = 'prevFrameScaled_lat.pt'\n",
+ " if use_background_mask:\n",
+ " if warp_mode == 'use_latent':\n",
+ " # pass\n",
+ " latent = apply_mask(latent.cpu(), frame_num, background, background_source, invert_mask, warp_mode)#.save(init_image)\n",
+ "\n",
+ " if warp_mode == 'use_image':\n",
+ " apply_mask(Image.open(init_image), frame_num, background, background_source, invert_mask).save(init_image)\n",
+ " # init_scale = args.frames_scale\n",
+ " init_scale = get_scheduled_arg(frame_num, init_scale_schedule)\n",
+ " # init_latent_scale = args.frames_latent_scale\n",
+ " init_latent_scale = get_scheduled_arg(frame_num, latent_scale_schedule)\n",
+ "\n",
+ "\n",
+ " loss_values = []\n",
+ "\n",
+ " if seed is not None:\n",
+ " np.random.seed(seed)\n",
+ " random.seed(seed)\n",
+ " torch.manual_seed(seed)\n",
+ " torch.cuda.manual_seed_all(seed)\n",
+ " torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ " target_embeds, weights = [], []\n",
+ "\n",
+ " if args.prompts_series is not None and frame_num >= len(args.prompts_series):\n",
+ " # frame_prompt = args.prompts_series[-1]\n",
+ " frame_prompt = get_sched_from_json(frame_num, args.prompts_series, blend=False)\n",
+ " elif args.prompts_series is not None:\n",
+ " # frame_prompt = args.prompts_series[frame_num]\n",
+ " frame_prompt = get_sched_from_json(frame_num, args.prompts_series, blend=False)\n",
+ " else:\n",
+ " frame_prompt = []\n",
+ "\n",
+ " if VERBOSE:print(args.image_prompts_series)\n",
+ " if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):\n",
+ " image_prompt = get_sched_from_json(frame_num, args.image_prompts_series, blend=False)\n",
+ " elif args.image_prompts_series is not None:\n",
+ " image_prompt = get_sched_from_json(frame_num, args.image_prompts_series, blend=False)\n",
+ " else:\n",
+ " image_prompt = []\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " init = None\n",
+ "\n",
+ "\n",
+ "\n",
+ " image_display = Output()\n",
+ " for i in range(args.n_batches):\n",
+ " if args.animation_mode == 'None':\n",
+ " display.clear_output(wait=True)\n",
+ " batchBar = tqdm(range(args.n_batches), desc =\"Batches\")\n",
+ " batchBar.n = i\n",
+ " batchBar.refresh()\n",
+ " print('')\n",
+ " display.display(image_display)\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " steps = int(get_scheduled_arg(frame_num, steps_schedule))\n",
+ " style_strength = get_scheduled_arg(frame_num, style_strength_schedule)\n",
+ " skip_steps = int(steps-steps*style_strength)\n",
+ "\n",
+ "\n",
+ " if perlin_init:\n",
+ " init = regen_perlin()\n",
+ "\n",
+ " consistency_mask = None\n",
+ " if (check_consistency or (model_version == 'v1_inpainting') or ('control_sd15_inpaint' in controlnet_multimodel.keys())) and frame_num>0:\n",
+ " frame1_path = f'{videoFramesFolder}/{frame_num:06}.jpg'\n",
+ " if reverse_cc_order:\n",
+ " weights_path = f\"{flo_folder}/{frame1_path.split('/')[-1]}-21_cc.jpg\"\n",
+ " else:\n",
+ " weights_path = f\"{flo_folder}/{frame1_path.split('/')[-1]}_12-21_cc.jpg\"\n",
+ " consistency_mask = load_cc(weights_path, blur=consistency_blur, dilate=consistency_dilate)\n",
+ "\n",
+ " if diffusion_model == 'stable_diffusion':\n",
+ " if VERBOSE: print(args.side_x, args.side_y, init_image)\n",
+ " # init = Image.open(fetch(init_image)).convert('RGB')\n",
+ "\n",
+ " # init = init.resize((args.side_x, args.side_y), Image.LANCZOS)\n",
+ " # init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n",
+ " # text_prompt = copy.copy(args.prompts_series[frame_num])\n",
+ " text_prompt = copy.copy(get_sched_from_json(frame_num, args.prompts_series, blend=False))\n",
+ " if VERBOSE:print(f'Frame {frame_num} Prompt: {text_prompt}')\n",
+ " text_prompt = [re.sub('\\<(.*?)\\>', '', o).strip(' ') for o in text_prompt] #remove loras from prompt\n",
+ " text_prompt = [re.sub(\":\\s*([\\d.]+)\\s*$\", '', o).strip(' ') for o in text_prompt] #remove weights from prompt\n",
+ " used_loras, used_loras_weights = get_loras_weights_for_frame(frame_num, new_prompt_loras)\n",
+ " frame_prompt_weights = get_sched_from_json(frame_num, prompt_weights, blend=blend_json_schedules)\n",
+ "\n",
+ " if VERBOSE:\n",
+ " print('used_loras, used_loras_weights', used_loras, used_loras_weights)\n",
+ " print('prompt weights, frame_prompt_weights', prompt_weights , frame_prompt_weights)\n",
+ " # used_loras_weights = [o for o in used_loras_weights if o is not None else 0.]\n",
+ " # if use_lycoris:\n",
+ " # # load_lycos(used_loras, used_loras_weights, used_loras_weights)\n",
+ " # else:\n",
+ " # pass\n",
+ " # load_loras(used_loras, used_loras_weights)\n",
+ "\n",
+ " load_networks(names=used_loras, te_multipliers=used_loras_weights, unet_multipliers=used_loras_weights, dyn_dims=[None]*len(used_loras), sd_model=sd_model)\n",
+ " # load_networks(['pixel-art-xl-v1.1'], [1], [1], sd_model=sd_model)\n",
+ " # else:\n",
+ " # loaded_networks.clear()\n",
+ " caption = get_caption(frame_num)\n",
+ " if caption:\n",
+ " # print('args.prompt_series',args.prompts_series[frame_num])\n",
+ " for i in range(len(text_prompt)):\n",
+ " if '{caption}' in text_prompt[i]:\n",
+ " print('Replacing ', '{caption}', 'with ', caption)\n",
+ " text_prompt[0] = text_prompt[i].replace('{caption}', caption)\n",
+ " prompt_patterns = get_sched_from_json(frame_num, prompt_patterns_sched, blend=False)\n",
+ " if prompt_patterns:\n",
+ " for key in prompt_patterns.keys():\n",
+ " for i in range(len(text_prompt)):\n",
+ " if key in text_prompt[i]:\n",
+ " print('Replacing ', key, 'with ', prompt_patterns[key])\n",
+ " text_prompt[i] = text_prompt[i].replace(key, prompt_patterns[key])\n",
+ "\n",
+ " if args.neg_prompts_series is not None:\n",
+ " neg_prompt = get_sched_from_json(frame_num, args.neg_prompts_series, blend=False)\n",
+ " else:\n",
+ " neg_prompt = copy.copy(text_prompt)\n",
+ "\n",
+ " if VERBOSE:print(f'Frame {frame_num} neg_prompt: {neg_prompt}')\n",
+ " if args.rec_prompts_series is not None:\n",
+ " rec_prompt = copy.copy(get_sched_from_json(frame_num, args.rec_prompts_series, blend=False))\n",
+ " if caption and '{caption}' in rec_prompt[0]:\n",
+ " print('Replacing ', '{caption}', 'with ', caption)\n",
+ " rec_prompt[0] = rec_prompt[0].replace('{caption}', caption)\n",
+ " else:\n",
+ " rec_prompt = copy.copy(text_prompt)\n",
+ " if VERBOSE:print(f'Frame {rec_prompt} rec_prompt: {rec_prompt}')\n",
+ "\n",
+ " if VERBOSE:\n",
+ " print(neg_prompt, 'neg_prompt')\n",
+ " print('init_scale pre sd run', init_scale)\n",
+ " # init_latent_scale = args.init_latent_scale\n",
+ " # if frame_num>0:\n",
+ " # init_latent_scale = args.frames_latent_scale\n",
+ " steps = int(get_scheduled_arg(frame_num, steps_schedule))\n",
+ " init_scale = get_scheduled_arg(frame_num, init_scale_schedule)\n",
+ " init_latent_scale = get_scheduled_arg(frame_num, latent_scale_schedule)\n",
+ " style_strength = get_scheduled_arg(frame_num, style_strength_schedule)\n",
+ " skip_steps = int(steps-steps*style_strength)\n",
+ " cfg_scale = get_scheduled_arg(frame_num, cfg_scale_schedule)\n",
+ " image_scale = get_scheduled_arg(frame_num, image_scale_schedule)\n",
+ " cc_masked_diffusion = get_scheduled_arg(frame_num, cc_masked_diffusion_schedule)\n",
+ " if VERBOSE:printf('skip_steps b4 run_sd: ', skip_steps)\n",
+ "\n",
+ " deflicker_src = {\n",
+ " 'processed1':f'{batchFolder}/{args.batch_name}({args.batchNum})_{frame_num-1:06}.png',\n",
+ " 'raw1': f'{videoFramesFolder}/{frame_num:06}.jpg',\n",
+ " 'raw2': f'{videoFramesFolder}/{frame_num+1:06}.jpg',\n",
+ " }\n",
+ "\n",
+ " init_grad_img = None\n",
+ " if init_grad: init_grad_img = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " #setup depth source\n",
+ " if cond_image_src == 'init':\n",
+ " cond_image = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " if cond_image_src == 'stylized':\n",
+ " cond_image = init_image\n",
+ " if cond_image_src == 'cond_video':\n",
+ " cond_image = f'{condVideoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ "\n",
+ " ref_image = None\n",
+ " if reference_source == 'init':\n",
+ " ref_image = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " if reference_source == 'stylized':\n",
+ " ref_image = init_image\n",
+ " if reference_source == 'prev_frame':\n",
+ " ref_image = f'{batchFolder}/{args.batch_name}({args.batchNum})_{frame_num-1:06}.png'\n",
+ " if reference_source == 'color_video':\n",
+ " if os.path.exists(f'{colorVideoFramesFolder}/{frame_num+1:06}.jpg'):\n",
+ " ref_image = f'{colorVideoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " elif os.path.exists(f'{colorVideoFramesFolder}/{1:06}.jpg'):\n",
+ " ref_image = f'{colorVideoFramesFolder}/{1:06}.jpg'\n",
+ " else:\n",
+ " raise Exception(\"Reference mode specified with no color video or image. Please specify color video or disable the shuffle model\")\n",
+ "\n",
+ "\n",
+ " #setup shuffle\n",
+ " shuffle_source = None\n",
+ " if 'control_sd15_shuffle' in controlnet_multimodel.keys():\n",
+ " if control_sd15_shuffle_source == 'color_video':\n",
+ " if os.path.exists(f'{colorVideoFramesFolder}/{frame_num+1:06}.jpg'):\n",
+ " shuffle_source = f'{colorVideoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " elif os.path.exists(f'{colorVideoFramesFolder}/{1:06}.jpg'):\n",
+ " shuffle_source = f'{colorVideoFramesFolder}/{1:06}.jpg'\n",
+ " else:\n",
+ " raise Exception(\"Shuffle controlnet specified with no color video or image. Please specify color video or disable the shuffle model\")\n",
+ " elif control_sd15_shuffle_source == 'init':\n",
+ " shuffle_source = init_image\n",
+ " elif control_sd15_shuffle_source == 'first_frame':\n",
+ " shuffle_source = f'{batchFolder}/{args.batch_name}({args.batchNum})_{0:06}.png'\n",
+ " elif control_sd15_shuffle_source == 'prev_frame':\n",
+ " shuffle_source = f'{batchFolder}/{args.batch_name}({args.batchNum})_{frame_num-1:06}.png'\n",
+ " if not os.path.exists(shuffle_source):\n",
+ " if control_sd15_shuffle_1st_source == 'init':\n",
+ " shuffle_source = init_image\n",
+ " elif control_sd15_shuffle_1st_source == None:\n",
+ " shuffle_source = None\n",
+ " elif control_sd15_shuffle_1st_source == 'color_video':\n",
+ " if os.path.exists(f'{colorVideoFramesFolder}/{frame_num+1:06}.jpg'):\n",
+ " shuffle_source = f'{colorVideoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " elif os.path.exists(f'{colorVideoFramesFolder}/{1:06}.jpg'):\n",
+ " shuffle_source = f'{colorVideoFramesFolder}/{1:06}.jpg'\n",
+ " else:\n",
+ " raise Exception(\"Shuffle controlnet specified with no color video or image. Please specify color video or disable the shuffle model\")\n",
+ " print('Shuffle source ',shuffle_source)\n",
+ "\n",
+ "\n",
+ " prev_frame = f'{videoFramesFolder}/{frame_num:06}.jpg'\n",
+ " next_frame = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " #setup temporal source\n",
+ " if temporalnet_source =='init':\n",
+ " prev_frame = f'{videoFramesFolder}/{frame_num:06}.jpg'\n",
+ " if temporalnet_source == 'stylized':\n",
+ " prev_frame = f'{batchFolder}/{args.batch_name}({args.batchNum})_{frame_num-1:06}.png'\n",
+ " if temporalnet_source == 'cond_video':\n",
+ " prev_frame = f'{condVideoFramesFolder}/{frame_num:06}.jpg'\n",
+ " if not os.path.exists(prev_frame):\n",
+ " if temporalnet_skip_1st_frame:\n",
+ " print('prev_frame not found, replacing 1st videoframe init')\n",
+ " prev_frame = None\n",
+ " else:\n",
+ " prev_frame = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ "\n",
+ " #setup rec noise source\n",
+ " if rec_source == 'stylized':\n",
+ " rec_frame = init_image\n",
+ " elif rec_source == 'init':\n",
+ " rec_frame = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ "\n",
+ "\n",
+ " #setup masks for inpainting model\n",
+ " if model_version == 'v1_inpainting':\n",
+ " if inpainting_mask_source == 'consistency_mask':\n",
+ " cond_image = consistency_mask\n",
+ " if inpainting_mask_source in ['none', None,'', 'None', 'off']:\n",
+ " cond_image = None\n",
+ " if inpainting_mask_source == 'cond_video': cond_image = f'{condVideoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " # print('cond_image0',cond_image)\n",
+ "\n",
+ " #setup masks for controlnet inpainting model\n",
+ " control_inpainting_mask = None\n",
+ " if 'control_sd15_inpaint' in controlnet_multimodel.keys() or 'control_sd15_inpaint_softedge' in controlnet_multimodel.keys():\n",
+ " if control_sd15_inpaint_mask_source == 'consistency_mask':\n",
+ " control_inpainting_mask = consistency_mask\n",
+ " if control_sd15_inpaint_mask_source in ['none', None,'', 'None', 'off']:\n",
+ " # control_inpainting_mask = None\n",
+ " control_inpainting_mask = np.ones((args.side_y,args.side_x,3))\n",
+ " if control_sd15_inpaint_mask_source == 'cond_video':\n",
+ " control_inpainting_mask = f'{condVideoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " control_inpainting_mask = np.array(PIL.Image.open(control_inpainting_mask))\n",
+ " # print('cond_image0',cond_image)\n",
+ "\n",
+ " np_alpha = None\n",
+ " if alpha_masked_diffusion and frame_num>args.start_frame:\n",
+ " if VERBOSE: print('Using alpha masked diffusion')\n",
+ " print(f'{videoFramesAlpha}/{frame_num+1:06}.jpg')\n",
+ " if videoFramesAlpha == videoFramesFolder or not os.path.exists(f'{videoFramesAlpha}/{frame_num+1:06}.jpg'):\n",
+ " raise Exception('You have enabled alpha_masked_diffusion without providing an alpha mask source. Please go to mask cell and specify a masked video init or extract a mask from init video.')\n",
+ "\n",
+ " init_image_alpha = Image.open(f'{videoFramesAlpha}/{frame_num+1:06}.jpg').resize((args.side_x,args.side_y)).convert('L')\n",
+ " np_alpha = np.array(init_image_alpha)/255.\n",
+ "\n",
+ " mask_current_frame_many = None\n",
+ " if mask_frames_many is not None:\n",
+ " mask_current_frame_many = [torch.from_numpy(np.array(PIL.Image.open(o[frame_num]).resize((args.side_x,args.side_y)).convert('L'))/255.)[None,...].float() for o in mask_frames_many]\n",
+ " mask_current_frame_many.insert(0, torch.ones_like(mask_current_frame_many[0]))\n",
+ " assert len(mask_current_frame_many) == len(text_prompt), 'mask number doesn`t match prompt number'\n",
+ "\n",
+ " mask_current_frame_many = torch.stack(mask_current_frame_many).repeat((1,4,1,1))\n",
+ " # mask_current_frame_many = torch.where(mask_current_frame_many>0.5, 1., 0.).float()\n",
+ " controlnet_sources = {}\n",
+ " if controlnet_multimodel != {}:\n",
+ " controlnet_sources = get_control_source_images(frame_num, controlnet_multimodel_inferred, stylized_image=init_image)\n",
+ " elif 'control_' in model_version:\n",
+ " controlnet_sources[model_version] = cond_image\n",
+ " controlnet_sources['next_frame'] = next_frame\n",
+ "\n",
+ " # try:\n",
+ " sample, latent, depth_img = run_sd(args, init_image=init_image, skip_timesteps=skip_steps, H=args.side_y,\n",
+ " W=args.side_x, text_prompt=text_prompt, neg_prompt=neg_prompt, steps=steps,\n",
+ " seed=seed, init_scale = init_scale, init_latent_scale=init_latent_scale, cond_image=cond_image,\n",
+ " cfg_scale=cfg_scale, image_scale = image_scale, cond_fn=None,\n",
+ " init_grad_img=init_grad_img, consistency_mask=consistency_mask,\n",
+ " frame_num=frame_num, deflicker_src=deflicker_src, prev_frame=prev_frame,\n",
+ " rec_prompt=rec_prompt, rec_frame=rec_frame,control_inpainting_mask=control_inpainting_mask, shuffle_source=shuffle_source,\n",
+ " ref_image=ref_image, alpha_mask=np_alpha, prompt_weights=frame_prompt_weights,\n",
+ " mask_current_frame_many=mask_current_frame_many, controlnet_sources=controlnet_sources, cc_masked_diffusion =cc_masked_diffusion )\n",
+ " # except:\n",
+ " # traceback.print_exc()\n",
+ " # sys.exit()\n",
+ "\n",
+ " settings_json = save_settings(skip_save=True)\n",
+ " settings_exif = json2exif(settings_json)\n",
+ "\n",
+ "\n",
+ "\n",
+ " # depth_img.save(f'{root_dir}/depth_{frame_num}.png')\n",
+ " filename = f'{args.batch_name}({args.batchNum})_{frame_num:06}.png'\n",
+ " # if warp_mode == 'use_raw':torch.save(sample,f'{batchFolder}/{filename[:-4]}_raw.pt')\n",
+ " if warp_mode == 'use_latent':\n",
+ " torch.save(latent,f'{batchFolder}/{filename[:-4]}_lat.pt')\n",
+ " samples = sample*(steps-skip_steps)\n",
+ " samples = [{\"pred_xstart\": sample} for sample in samples]\n",
+ " # for j, sample in enumerate(samples):\n",
+ " # print(j, sample[\"pred_xstart\"].size)\n",
+ " # raise Exception\n",
+ " if VERBOSE: print(sample[0][0].shape)\n",
+ " image = sample[0][0]\n",
+ " if do_softcap:\n",
+ " image = softcap(image, thresh=softcap_thresh, q=softcap_q)\n",
+ " image = image.add(1).div(2).clamp(0, 1)\n",
+ " image = TF.to_pil_image(image)\n",
+ " if warp_towards_init != 'off' and frame_num!=0:\n",
+ " if warp_towards_init == 'init':\n",
+ " warp_init_filename = f'{videoFramesFolder}/{frame_num+1:06}.jpg'\n",
+ " else:\n",
+ " warp_init_filename = init_image\n",
+ " print('warping towards init')\n",
+ " init_pil = Image.open(warp_init_filename)\n",
+ " image = warp_towards_init_fn(image, init_pil)\n",
+ "\n",
+ " display.clear_output(wait=True)\n",
+ " fit(image, display_size).save('progress.png', exif=settings_exif)\n",
+ " display.display(display.Image('progress.png'))\n",
+ "\n",
+ " if mask_result and check_consistency and frame_num>0:\n",
+ "\n",
+ " if VERBOSE:print('imitating inpaint')\n",
+ " frame1_path = f'{videoFramesFolder}/{frame_num:06}.jpg'\n",
+ " weights_path = f\"{flo_folder}/{frame1_path.split('/')[-1]}-21_cc.jpg\"\n",
+ " consistency_mask = load_cc(weights_path, blur=consistency_blur, dilate=consistency_dilate)\n",
+ "\n",
+ " consistency_mask = cv2.GaussianBlur(consistency_mask,\n",
+ " (diffuse_inpaint_mask_blur,diffuse_inpaint_mask_blur),cv2.BORDER_DEFAULT)\n",
+ " if diffuse_inpaint_mask_thresh<1:\n",
+ " consistency_mask = np.where(consistency_mask args.start_frame) or ('color_video' in normalize_latent):\n",
+ " global first_latent\n",
+ " global first_latent_source\n",
+ "\n",
+ " if 'frame' in normalize_latent:\n",
+ " def img2latent(img_path):\n",
+ " frame2 = Image.open(img_path)\n",
+ " frame2pil = frame2.convert('RGB').resize(image.size,warp_interp)\n",
+ " frame2pil = np.array(frame2pil)\n",
+ " frame2pil = (frame2pil/255.)[None,...].transpose(0, 3, 1, 2)\n",
+ " frame2pil = 2*torch.from_numpy(frame2pil).float().cuda()-1.\n",
+ " frame2pil = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(frame2pil))\n",
+ " return frame2pil\n",
+ "\n",
+ " try:\n",
+ " if VERBOSE:print('Matching latent to:')\n",
+ " filename = get_frame_from_color_mode(normalize_latent, normalize_latent_offset, frame_num)\n",
+ " match_latent = img2latent(filename)\n",
+ " first_latent = match_latent\n",
+ " first_latent_source = filename\n",
+ " # print(first_latent_source, first_latent)\n",
+ " except:\n",
+ " if VERBOSE:print(traceback.format_exc())\n",
+ " print(f'Frame with offset/position {normalize_latent_offset} not found')\n",
+ " if 'init' in normalize_latent:\n",
+ " try:\n",
+ " filename = f'{videoFramesFolder}/{0:06}.jpg'\n",
+ " match_latent = img2latent(filename)\n",
+ " first_latent = match_latent\n",
+ " first_latent_source = filename\n",
+ " except: pass\n",
+ " print(f'Color matching the 1st frame.')\n",
+ "\n",
+ " if colormatch_frame != 'off' and colormatch_after:\n",
+ " if not turbo_mode & (frame_num % int(turbo_steps) != 0) or colormatch_turbo:\n",
+ " try:\n",
+ " print('Matching color to:')\n",
+ " filename = get_frame_from_color_mode(colormatch_frame, colormatch_offset, frame_num)\n",
+ " match_frame = Image.open(filename)\n",
+ " first_frame = match_frame\n",
+ " first_frame_source = filename\n",
+ "\n",
+ " except:\n",
+ " print(f'Frame with offset/position {colormatch_offset} not found')\n",
+ " if 'init' in colormatch_frame:\n",
+ " try:\n",
+ " filename = f'{videoFramesFolder}/{1:06}.jpg'\n",
+ " match_frame = Image.open(filename)\n",
+ " first_frame = match_frame\n",
+ " first_frame_source = filename\n",
+ " except: pass\n",
+ " print(f'Color matching the 1st frame.')\n",
+ " print('Colormatch source - ', first_frame_source)\n",
+ " image = Image.fromarray(match_color_var(first_frame,\n",
+ " image, opacity=color_match_frame_str, f=colormatch_method_fn,\n",
+ " regrain=colormatch_regrain))\n",
+ "\n",
+ " if frame_num == args.start_frame:\n",
+ " settings_json = save_settings()\n",
+ " if args.animation_mode != \"None\":\n",
+ " # sys.exit(os.getcwd(), 'cwd')\n",
+ " if warp_mode == 'use_image':\n",
+ " image.save('prevFrame.png', exif=settings_exif)\n",
+ " else:\n",
+ " torch.save(latent, 'prevFrame_lat.pt')\n",
+ " filename = f'{args.batch_name}({args.batchNum})_{frame_num:06}.png'\n",
+ " image.save(f'{batchFolder}/{filename}', exif=settings_exif)\n",
+ " # np.save(latent, f'{batchFolder}/{filename[:-4]}.npy')\n",
+ " if args.animation_mode == 'Video Input':\n",
+ " # If turbo, save a blended image\n",
+ " if turbo_mode and frame_num > args.start_frame:\n",
+ " # Mix new image with prevFrameScaled\n",
+ " blend_factor = (1)/int(turbo_steps)\n",
+ " if warp_mode == 'use_image':\n",
+ " newFrame = cv2.imread('prevFrame.png') # This is already updated..\n",
+ " prev_frame_warped = cv2.imread('prevFrameScaled.png')\n",
+ " blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n",
+ " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
+ " if warp_mode == 'use_latent':\n",
+ " newFrame = torch.load('prevFrame_lat.pt').cuda()\n",
+ " prev_frame_warped = torch.load('prevFrameScaled_lat.pt').cuda()\n",
+ " blendedImage = newFrame*(blend_factor)+prev_frame_warped*(1-blend_factor)\n",
+ " blendedImage = get_image_from_lat(blendedImage)\n",
+ " blendedImage.save(f'{batchFolder}/{filename}', exif=settings_exif)\n",
+ "\n",
+ " else:\n",
+ " image.save(f'{batchFolder}/{filename}', exif=settings_exif)\n",
+ " image.save('prevFrameScaled.png', exif=settings_exif)\n",
+ "\n",
+ " plt.plot(np.array(loss_values), 'r')\n",
+ " batchBar.close()\n",
+ "\n",
+ "def save_settings(skip_save=False):\n",
+ " settings_out = batchFolder+f\"/settings\"\n",
+ " os.makedirs(settings_out, exist_ok=True)\n",
+ " setting_list = {\n",
+ " 'text_prompts': text_prompts,\n",
+ " 'user_comment':user_comment,\n",
+ " 'image_prompts': image_prompts,\n",
+ " 'range_scale': range_scale,\n",
+ " 'sat_scale': sat_scale,\n",
+ " 'max_frames': max_frames,\n",
+ " 'interp_spline': interp_spline,\n",
+ " 'init_image': init_image,\n",
+ " 'clamp_grad': clamp_grad,\n",
+ " 'clamp_max': clamp_max,\n",
+ " 'seed': seed,\n",
+ " 'width': width_height[0],\n",
+ " 'height': width_height[1],\n",
+ " 'diffusion_model': diffusion_model,\n",
+ " 'diffusion_steps': diffusion_steps,\n",
+ " 'max_frames': max_frames,\n",
+ " 'video_init_path':video_init_path,\n",
+ " 'extract_nth_frame':extract_nth_frame,\n",
+ " 'flow_video_init_path':flow_video_init_path,\n",
+ " 'flow_extract_nth_frame':flow_extract_nth_frame,\n",
+ " 'video_init_seed_continuity': video_init_seed_continuity,\n",
+ " 'turbo_mode':turbo_mode,\n",
+ " 'turbo_steps':turbo_steps,\n",
+ " 'turbo_preroll':turbo_preroll,\n",
+ " 'flow_warp':flow_warp,\n",
+ " 'check_consistency':check_consistency,\n",
+ " 'turbo_frame_skips_steps' : turbo_frame_skips_steps,\n",
+ " 'forward_weights_clip' : forward_weights_clip,\n",
+ " 'forward_weights_clip_turbo_step' : forward_weights_clip_turbo_step,\n",
+ " 'padding_ratio':padding_ratio,\n",
+ " 'padding_mode':padding_mode,\n",
+ " 'consistency_blur':consistency_blur,\n",
+ " 'inpaint_blend':inpaint_blend,\n",
+ " 'match_color_strength':match_color_strength,\n",
+ " 'high_brightness_threshold':high_brightness_threshold,\n",
+ " 'high_brightness_adjust_ratio':high_brightness_adjust_ratio,\n",
+ " 'low_brightness_threshold':low_brightness_threshold,\n",
+ " 'low_brightness_adjust_ratio':low_brightness_adjust_ratio,\n",
+ " 'stop_early': stop_early,\n",
+ " 'high_brightness_adjust_fix_amount': high_brightness_adjust_fix_amount,\n",
+ " 'low_brightness_adjust_fix_amount': low_brightness_adjust_fix_amount,\n",
+ " 'max_brightness_threshold':max_brightness_threshold,\n",
+ " 'min_brightness_threshold':min_brightness_threshold,\n",
+ " 'enable_adjust_brightness':enable_adjust_brightness,\n",
+ " 'dynamic_thresh':dynamic_thresh,\n",
+ " 'warp_interp':warp_interp,\n",
+ " 'fixed_code':fixed_code,\n",
+ " 'code_randomness':code_randomness,\n",
+ " # 'normalize_code': normalize_code,\n",
+ " 'mask_result':mask_result,\n",
+ " 'reverse_cc_order':reverse_cc_order,\n",
+ " 'flow_lq':flow_lq,\n",
+ " 'use_predicted_noise':use_predicted_noise,\n",
+ " 'clip_guidance_scale':clip_guidance_scale,\n",
+ " 'clip_type':clip_type,\n",
+ " 'clip_pretrain':clip_pretrain,\n",
+ " 'missed_consistency_weight':missed_consistency_weight,\n",
+ " 'overshoot_consistency_weight':overshoot_consistency_weight,\n",
+ " 'edges_consistency_weight':edges_consistency_weight,\n",
+ " 'style_strength_schedule':style_strength_schedule_bkup,\n",
+ " 'flow_blend_schedule':flow_blend_schedule_bkup,\n",
+ " 'steps_schedule':steps_schedule_bkup,\n",
+ " 'init_scale_schedule':init_scale_schedule_bkup,\n",
+ " 'latent_scale_schedule':latent_scale_schedule_bkup,\n",
+ " 'latent_scale_template': latent_scale_template,\n",
+ " 'init_scale_template':init_scale_template,\n",
+ " 'steps_template':steps_template,\n",
+ " 'style_strength_template':style_strength_template,\n",
+ " 'flow_blend_template':flow_blend_template,\n",
+ " 'cc_masked_template':cc_masked_template,\n",
+ " 'make_schedules':make_schedules,\n",
+ " 'normalize_latent':normalize_latent,\n",
+ " 'normalize_latent_offset':normalize_latent_offset,\n",
+ " 'colormatch_frame':colormatch_frame,\n",
+ " 'use_karras_noise':use_karras_noise,\n",
+ " 'end_karras_ramp_early':end_karras_ramp_early,\n",
+ " 'use_background_mask':use_background_mask,\n",
+ " 'apply_mask_after_warp':apply_mask_after_warp,\n",
+ " 'background':background,\n",
+ " 'background_source':background_source,\n",
+ " 'mask_source':mask_source,\n",
+ " 'extract_background_mask':extract_background_mask,\n",
+ " 'mask_video_path':mask_video_path,\n",
+ " 'negative_prompts':negative_prompts,\n",
+ " 'invert_mask':invert_mask,\n",
+ " 'warp_strength': warp_strength,\n",
+ " 'flow_override_map':flow_override_map,\n",
+ " 'cfg_scale_schedule':cfg_scale_schedule_bkup,\n",
+ " 'respect_sched':respect_sched,\n",
+ " 'color_match_frame_str':color_match_frame_str,\n",
+ " 'colormatch_offset':colormatch_offset,\n",
+ " 'latent_fixed_mean':latent_fixed_mean,\n",
+ " 'latent_fixed_std':latent_fixed_std,\n",
+ " 'colormatch_method':colormatch_method,\n",
+ " 'colormatch_regrain':colormatch_regrain,\n",
+ " 'warp_mode':warp_mode,\n",
+ " 'use_patchmatch_inpaiting':use_patchmatch_inpaiting,\n",
+ " 'blend_latent_to_init':blend_latent_to_init,\n",
+ " 'warp_towards_init':warp_towards_init,\n",
+ " 'init_grad':init_grad,\n",
+ " 'grad_denoised':grad_denoised,\n",
+ " 'colormatch_after':colormatch_after,\n",
+ " 'colormatch_turbo':colormatch_turbo,\n",
+ " 'model_version':model_version,\n",
+ " 'cond_image_src':cond_image_src,\n",
+ " 'warp_num_k':warp_num_k,\n",
+ " 'warp_forward':warp_forward,\n",
+ " 'sampler':sampler.__name__,\n",
+ " 'mask_clip':(mask_clip_low, mask_clip_high),\n",
+ " 'inpainting_mask_weight':inpainting_mask_weight ,\n",
+ " 'inverse_inpainting_mask':inverse_inpainting_mask,\n",
+ " 'mask_source':mask_source,\n",
+ " 'model_path':model_path,\n",
+ " 'diff_override':diff_override,\n",
+ " 'image_scale_schedule':image_scale_schedule_bkup,\n",
+ " 'image_scale_template':image_scale_template,\n",
+ " 'frame_range': frame_range,\n",
+ " 'detect_resolution' :detect_resolution,\n",
+ " 'bg_threshold':bg_threshold,\n",
+ " 'diffuse_inpaint_mask_blur':diffuse_inpaint_mask_blur,\n",
+ " 'diffuse_inpaint_mask_thresh':diffuse_inpaint_mask_thresh,\n",
+ " 'add_noise_to_latent':add_noise_to_latent,\n",
+ " 'noise_upscale_ratio':noise_upscale_ratio,\n",
+ " 'fixed_seed':fixed_seed,\n",
+ " 'init_latent_fn':init_latent_fn.__name__,\n",
+ " 'value_threshold':value_threshold,\n",
+ " 'distance_threshold':distance_threshold,\n",
+ " 'masked_guidance':masked_guidance,\n",
+ " 'cc_masked_diffusion_schedule':cc_masked_diffusion_schedule_bkup,\n",
+ " 'alpha_masked_diffusion':alpha_masked_diffusion,\n",
+ " 'inverse_mask_order':inverse_mask_order,\n",
+ " 'invert_alpha_masked_diffusion':invert_alpha_masked_diffusion,\n",
+ " 'quantize':quantize,\n",
+ " 'cb_noise_upscale_ratio':cb_noise_upscale_ratio,\n",
+ " 'cb_add_noise_to_latent':cb_add_noise_to_latent,\n",
+ " 'cb_use_start_code':cb_use_start_code,\n",
+ " 'cb_fixed_code':cb_fixed_code,\n",
+ " 'cb_norm_latent':cb_norm_latent,\n",
+ " 'guidance_use_start_code':guidance_use_start_code,\n",
+ " 'offload_model':offload_model,\n",
+ " 'controlnet_preprocess':controlnet_preprocess,\n",
+ " 'small_controlnet_model_path':small_controlnet_model_path,\n",
+ " 'use_scale':use_scale,\n",
+ " 'g_invert_mask':g_invert_mask,\n",
+ " 'controlnet_multimodel':json.dumps(controlnet_multimodel),\n",
+ " 'img_zero_uncond':img_zero_uncond,\n",
+ " 'do_softcap':do_softcap,\n",
+ " 'softcap_thresh':softcap_thresh,\n",
+ " 'softcap_q':softcap_q,\n",
+ " 'deflicker_latent_scale':deflicker_latent_scale,\n",
+ " 'deflicker_scale':deflicker_scale,\n",
+ " 'controlnet_multimodel_mode':controlnet_multimodel_mode,\n",
+ " 'no_half_vae':no_half_vae,\n",
+ " 'temporalnet_source':temporalnet_source,\n",
+ " 'temporalnet_skip_1st_frame':temporalnet_skip_1st_frame,\n",
+ " 'rec_randomness':rec_randomness,\n",
+ " 'rec_source':rec_source,\n",
+ " 'rec_cfg':rec_cfg,\n",
+ " 'rec_prompts':rec_prompts,\n",
+ " 'inpainting_mask_source':inpainting_mask_source,\n",
+ " 'rec_steps_pct':rec_steps_pct,\n",
+ " 'max_faces': max_faces,\n",
+ " 'num_flow_updates':num_flow_updates,\n",
+ " 'pose_detector':pose_detector,\n",
+ " 'control_sd15_openpose_hands_face':control_sd15_openpose_hands_face,\n",
+ " 'control_sd15_depth_detector':control_sd15_depth_detector,\n",
+ " 'control_sd15_softedge_detector':control_sd15_softedge_detector,\n",
+ " 'control_sd15_seg_detector':control_sd15_seg_detector,\n",
+ " 'control_sd15_scribble_detector':control_sd15_scribble_detector,\n",
+ " 'control_sd15_lineart_coarse':control_sd15_lineart_coarse,\n",
+ " 'control_sd15_inpaint_mask_source':control_sd15_inpaint_mask_source,\n",
+ " 'control_sd15_shuffle_source':control_sd15_shuffle_source,\n",
+ " 'control_sd15_shuffle_1st_source':control_sd15_shuffle_1st_source,\n",
+ " 'overwrite_rec_noise':overwrite_rec_noise,\n",
+ " 'use_legacy_cc':use_legacy_cc,\n",
+ " 'missed_consistency_dilation':missed_consistency_dilation,\n",
+ " 'edge_consistency_width':edge_consistency_width,\n",
+ " 'use_reference':use_reference,\n",
+ " 'reference_weight':reference_weight,\n",
+ " 'reference_source':reference_source,\n",
+ " 'reference_mode':reference_mode,\n",
+ " 'use_legacy_fixed_code':use_legacy_fixed_code,\n",
+ " 'consistency_dilate':consistency_dilate,\n",
+ " 'prompt_patterns_sched':prompt_patterns_sched,\n",
+ " 'sd_batch_size':sd_batch_size,\n",
+ " 'normalize_prompt_weights':normalize_prompt_weights,\n",
+ " 'controlnet_low_vram':controlnet_low_vram,\n",
+ " 'mask_paths':mask_paths,\n",
+ " 'controlnet_mode':controlnet_mode,\n",
+ " 'normalize_cn_weights':normalize_cn_weights,\n",
+ " 'apply_freeu_after_control':apply_freeu_after_control,\n",
+ " 'do_freeunet':do_freeunet\n",
+ "\n",
+ " }\n",
+ " if not skip_save:\n",
+ " try:\n",
+ " settings_fname = f\"{settings_out}/{batch_name}({batchNum})_settings.txt\"\n",
+ " if os.path.exists(settings_fname):\n",
+ " s_meta = os.path.getmtime(settings_fname)\n",
+ " os.rename(settings_fname,settings_fname[:-4]+str(s_meta)+'.txt' )\n",
+ " with open(settings_fname, \"w+\") as f: #save settings\n",
+ " json.dump(setting_list, f, ensure_ascii=False, indent=4)\n",
+ " except Exception as e:\n",
+ " print(e)\n",
+ " print('Settings:', setting_list)\n",
+ " return setting_list\n",
+ "\n",
+ "#@title 1.6 init main sd run function, cond_fn, color matching for SD\n",
+ "init_latent = None\n",
+ "target_embed = None\n",
+ "\n",
+ "\n",
+ "import hashlib\n",
+ "import os\n",
+ "# import datetime\n",
+ "\n",
+ "# (c) Alex Spirin 2023\n",
+ "# We use input file hashes to automate video extraction\n",
+ "#\n",
+ "def generate_file_hash(input_file):\n",
+ " # Get file name and metadata\n",
+ " file_name = os.path.basename(input_file)\n",
+ " file_size = os.path.getsize(input_file)\n",
+ " creation_time = os.path.getctime(input_file)\n",
+ "\n",
+ " # Generate hash\n",
+ " hasher = hashlib.sha256()\n",
+ " hasher.update(file_name.encode('utf-8'))\n",
+ " hasher.update(str(file_size).encode('utf-8'))\n",
+ " hasher.update(str(creation_time).encode('utf-8'))\n",
+ " file_hash = hasher.hexdigest()\n",
+ "\n",
+ " return file_hash\n",
+ "\n",
+ "def get_frame_from_path_start_end_nth(video_path:str , num_frame:int, start:int=0, end:int=0, nth:int=1) -> Image:\n",
+ " assert os.path.exists(video_path), f\"Video path or frame folder not found at {video_path}. Please specify the correct path.\"\n",
+ " num_frame = max(0, num_frame)\n",
+ " start = max(0, start)\n",
+ " nth = max(1,nth)\n",
+ " if os.path.isdir(video_path):\n",
+ " frame_list = []\n",
+ " image_extensions = ['jpg','png','tiff','jpeg','JPEG','bmp']\n",
+ " for image_extension in image_extensions:\n",
+ " flist = glob.glob(os.path.join(video_path, f'*.{image_extension}'))\n",
+ " if len(flist)>0:\n",
+ " frame_list = flist\n",
+ " break\n",
+ " assert len(frame_list) != 0, f'No frames with {\", \".join(image_extensions)} extensions found in folder {video_path}. Please specify the correct path.'\n",
+ " if end == 0: end = len(frame_list)\n",
+ " frame_list = frame_list[start:end:nth]\n",
+ " num_frame = min(num_frame, len(frame_list))\n",
+ " return PIL.Image.open(frame_list[num_frame])\n",
+ "\n",
+ " elif os.path.isfile(video_path):\n",
+ " video = cv2.VideoCapture(video_path)\n",
+ " if not video.isOpened():\n",
+ " video.release()\n",
+ " raise Exception(f\"Error opening video file {video_path}. Please specify the correct path.\")\n",
+ " total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))\n",
+ " if end == 0: end = total_frames\n",
+ " num_frame = min(num_frame, total_frames)\n",
+ " frame_range = list(range(start,end,nth))\n",
+ " frame_number = frame_range[num_frame]\n",
+ " video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)\n",
+ " ret, frame = video.read()\n",
+ " if not ret:\n",
+ " video.release()\n",
+ " raise Exception(f\"Error reading frame {frame_number} from file {video_path}.\")\n",
+ " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
+ " image = Image.fromarray(frame)\n",
+ " video.release()\n",
+ " return image\n",
+ "\n",
+ "\n",
+ "import PIL\n",
+ "try:\n",
+ " import Image\n",
+ "except:\n",
+ " from PIL import Image\n",
+ "\n",
+ "mask_result = False\n",
+ "early_stop = 0\n",
+ "inpainting_stop = 0\n",
+ "warp_interp = Image.BILINEAR\n",
+ "\n",
+ "#init SD\n",
+ "from glob import glob\n",
+ "import argparse, os, sys\n",
+ "import PIL\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "from omegaconf import OmegaConf\n",
+ "from PIL import Image\n",
+ "from tqdm.auto import tqdm, trange\n",
+ "from itertools import islice\n",
+ "from einops import rearrange, repeat\n",
+ "from torchvision.utils import make_grid\n",
+ "from torch import autocast\n",
+ "from contextlib import nullcontext\n",
+ "import time\n",
+ "# from pytorch_lightning import seed_everything\n",
+ "\n",
+ "os.chdir(f\"{root_dir}/stablediffusion\")\n",
+ "from ldm.util import instantiate_from_config\n",
+ "from ldm.models.diffusion.ddim import DDIMSampler\n",
+ "from ldm.models.diffusion.plms import PLMSSampler\n",
+ "from ldm.modules.distributions.distributions import DiagonalGaussianDistribution\n",
+ "os.chdir(f\"{root_dir}\")\n",
+ "\n",
+ "\n",
+ "\n",
+ "def extract_into_tensor(a, t, x_shape):\n",
+ " b, *_ = t.shape\n",
+ " out = a.gather(-1, t)\n",
+ " return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n",
+ "\n",
+ "from kornia import augmentation as KA\n",
+ "aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border')\n",
+ "from torch.nn import functional as F\n",
+ "\n",
+ "from torch.cuda.amp import GradScaler\n",
+ "\n",
+ "def sd_cond_fn(x, t, denoised, init_image_sd, init_latent, init_scale,\n",
+ " init_latent_scale, target_embed, consistency_mask, guidance_start_code=None,\n",
+ " deflicker_fn=None, deflicker_lat_fn=None, deflicker_src=None, fft_fn=None, fft_latent_fn=None,\n",
+ " **kwargs):\n",
+ " if use_scale: scaler = GradScaler()\n",
+ " with torch.cuda.amp.autocast():\n",
+ " # print('denoised.shape')\n",
+ " # print(denoised.shape)\n",
+ " global add_noise_to_latent\n",
+ "\n",
+ " # init_latent_scale, init_scale, clip_guidance_scale, target_embed, init_latent, clamp_grad, clamp_max,\n",
+ " # **kwargs):\n",
+ " # global init_latent_scale\n",
+ " # global init_scale\n",
+ " global clip_guidance_scale\n",
+ " # global target_embed\n",
+ " # print(target_embed.shape)\n",
+ " global clamp_grad\n",
+ " global clamp_max\n",
+ " loss = 0.\n",
+ " if grad_denoised:\n",
+ " x = denoised\n",
+ " # denoised = x\n",
+ "\n",
+ " # print('grad denoised')\n",
+ " grad = torch.zeros_like(x)\n",
+ "\n",
+ " processed1 = deflicker_src['processed1']\n",
+ " if add_noise_to_latent:\n",
+ " if t != 0:\n",
+ " if guidance_use_start_code and guidance_start_code is not None:\n",
+ " noise = guidance_start_code\n",
+ " else:\n",
+ " noise = torch.randn_like(x)\n",
+ " noise = noise * t\n",
+ " if noise_upscale_ratio > 1:\n",
+ " noise = noise[::noise_upscale_ratio,::noise_upscale_ratio,:]\n",
+ " noise = torch.nn.functional.interpolate(noise, x.shape[2:],\n",
+ " mode='bilinear')\n",
+ " init_latent = init_latent + noise\n",
+ " if deflicker_lat_fn:\n",
+ " processed1 = deflicker_src['processed1'] + noise\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " if sat_scale>0 or init_scale>0 or clip_guidance_scale>0 or deflicker_scale>0 or fft_scale>0:\n",
+ " with torch.autocast('cuda'):\n",
+ " denoised_small = denoised[:,:,::2,::2]\n",
+ " denoised_img = model_wrap_cfg.inner_model.inner_model.differentiable_decode_first_stage(denoised_small)\n",
+ "\n",
+ " if clip_guidance_scale>0:\n",
+ " #compare text clip embeds with denoised image embeds\n",
+ " # denoised_img = model_wrap_cfg.inner_model.inner_model.differentiable_decode_first_stage(denoised);# print(denoised.requires_grad)\n",
+ " # print('d b',denoised.std(), denoised.mean())\n",
+ " denoised_img = denoised_img[0].add(1).div(2)\n",
+ " denoised_img = normalize(denoised_img)\n",
+ " denoised_t = denoised_img.cuda()[None,...]\n",
+ " # print('d a',denoised_t.std(), denoised_t.mean())\n",
+ " image_embed = get_image_embed(denoised_t)\n",
+ "\n",
+ " # image_embed = get_image_embed(denoised.add(1).div(2))\n",
+ " loss = spherical_dist_loss(image_embed, target_embed).sum() * clip_guidance_scale\n",
+ "\n",
+ " if masked_guidance:\n",
+ " if consistency_mask is None:\n",
+ " consistency_mask = torch.ones_like(denoised)\n",
+ " # consistency_mask = consistency_mask.permute(2,0,1)[None,...]\n",
+ " # print('consistency_mask.shape, denoised.shape')\n",
+ " # print(consistency_mask.shape, denoised.shape)\n",
+ "\n",
+ " consistency_mask = torch.nn.functional.interpolate(consistency_mask, denoised.shape[2:],\n",
+ " mode='bilinear')\n",
+ " if g_invert_mask: consistency_mask = 1-consistency_mask\n",
+ "\n",
+ " if init_latent_scale>0:\n",
+ "\n",
+ " #compare init image latent with denoised latent\n",
+ " # print(denoised.shape, init_latent.shape)\n",
+ "\n",
+ " loss += init_latent_fn(denoised, init_latent).sum() * init_latent_scale\n",
+ "\n",
+ " if fft_scale>0 and fft_fn is not None:\n",
+ " loss += fft_fn(image1=denoised_img).sum() * fft_scale\n",
+ "\n",
+ " if fft_latent_scale>0 and fft_latent_fn is not None:\n",
+ " loss += fft_latent_fn(image1=denoised).sum() * fft_latent_scale\n",
+ "\n",
+ " if sat_scale>0:\n",
+ " loss += torch.abs(denoised_img - denoised_img.clamp(min=-1,max=1)).mean()\n",
+ "\n",
+ " if init_scale>0:\n",
+ " #compare init image with denoised latent image via lpips\n",
+ " # print('init_image_sd', init_image_sd)\n",
+ "\n",
+ " loss += lpips_model(denoised_img, init_image_sd[:,:,::2,::2]).sum() * init_scale\n",
+ "\n",
+ " if deflicker_scale>0 and deflicker_fn is not None:\n",
+ " # print('deflicker_fn(denoised_img).sum() * deflicker_scale',deflicker_fn(denoised_img).sum() * deflicker_scale)\n",
+ " loss += deflicker_fn(processed2=denoised_img).sum() * deflicker_scale\n",
+ " print('deflicker ', loss)\n",
+ "\n",
+ " if deflicker_latent_scale>0 and deflicker_lat_fn is not None:\n",
+ "\n",
+ " loss += deflicker_lat_fn(processed2=denoised, processed1=processed1).sum() * deflicker_latent_scale\n",
+ " print('deflicker lat', loss)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " # print('loss', loss)\n",
+ " if loss!=0. :\n",
+ " if use_scale:\n",
+ " scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),\n",
+ " inputs=x)\n",
+ " inv_scale = 1./scaler.get_scale()\n",
+ " grad_params = [p * inv_scale for p in scaled_grad_params]\n",
+ " grad = -grad_params[0]\n",
+ " # scaler.update()\n",
+ " else:\n",
+ " grad = -torch.autograd.grad(loss, x)[0]\n",
+ " if masked_guidance:\n",
+ " grad = grad*consistency_mask\n",
+ " if torch.isnan(grad).any():\n",
+ " print('got NaN grad')\n",
+ " return torch.zeros_like(x)\n",
+ " if VERBOSE:printf('loss, grad',loss, grad.max(), grad.mean(), grad.std(), denoised.mean(), denoised.std())\n",
+ " if clamp_grad:\n",
+ " magnitude = grad.square().mean().sqrt()\n",
+ " return grad * magnitude.clamp(max=clamp_max) / magnitude\n",
+ "\n",
+ " return grad\n",
+ "\n",
+ "import cv2\n",
+ "\n",
+ "%cd \"{root_dir}/python-color-transfer\"\n",
+ "from python_color_transfer.color_transfer import ColorTransfer, Regrain\n",
+ "%cd \"{root_path}/\"\n",
+ "\n",
+ "PT = ColorTransfer()\n",
+ "\n",
+ "def match_color_var(stylized_img, raw_img, opacity=1., f=PT.pdf_transfer, regrain=False):\n",
+ " img_arr_ref = cv2.cvtColor(np.array(stylized_img).round().astype('uint8'),cv2.COLOR_RGB2BGR)\n",
+ " img_arr_in = cv2.cvtColor(np.array(raw_img).round().astype('uint8'),cv2.COLOR_RGB2BGR)\n",
+ " img_arr_ref = cv2.resize(img_arr_ref, (img_arr_in.shape[1], img_arr_in.shape[0]), interpolation=cv2.INTER_CUBIC )\n",
+ "\n",
+ " # img_arr_in = cv2.resize(img_arr_in, (img_arr_ref.shape[1], img_arr_ref.shape[0]), interpolation=cv2.INTER_CUBIC )\n",
+ " img_arr_col = f(img_arr_in=img_arr_in, img_arr_ref=img_arr_ref)\n",
+ " if regrain: img_arr_col = RG.regrain (img_arr_in=img_arr_col, img_arr_col=img_arr_ref)\n",
+ " img_arr_col = img_arr_col*opacity+img_arr_in*(1-opacity)\n",
+ " img_arr_reg = cv2.cvtColor(img_arr_col.round().astype('uint8'),cv2.COLOR_BGR2RGB)\n",
+ "\n",
+ " return img_arr_reg\n",
+ "\n",
+ "#https://gist.githubusercontent.com/trygvebw/c71334dd127d537a15e9d59790f7f5e1/raw/ed0bed6abaf75c0f1b270cf6996de3e07cbafc81/find_noise.py\n",
+ "\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "# import k_diffusion as K\n",
+ "\n",
+ "from PIL import Image\n",
+ "from torch import autocast\n",
+ "from einops import rearrange, repeat\n",
+ "\n",
+ "def pil_img_to_torch(pil_img, half=False):\n",
+ " image = np.array(pil_img).astype(np.float32) / 255.0\n",
+ " image = rearrange(torch.from_numpy(image), 'h w c -> c h w')\n",
+ " if half:\n",
+ " image = image\n",
+ " return (2.0 * image - 1.0).unsqueeze(0)\n",
+ "\n",
+ "def pil_img_to_latent(model, img, batch_size=1, device='cuda', half=True):\n",
+ " init_image = pil_img_to_torch(img, half=half).to(device)\n",
+ " init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n",
+ " if half:\n",
+ " return model.get_first_stage_encoding(model.encode_first_stage(init_image))\n",
+ " return model.get_first_stage_encoding(model.encode_first_stage(init_image))\n",
+ "\n",
+ "import torch\n",
+ "\n",
+ "\n",
+ "def find_noise_for_image(model, x, prompt, steps, cond_scale=0.0, verbose=False, normalize=True):\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " with autocast('cuda'):\n",
+ " uncond = model.get_learned_conditioning([''])\n",
+ " cond = model.get_learned_conditioning([prompt])\n",
+ "\n",
+ " s_in = x.new_ones([x.shape[0]])\n",
+ " dnw = K.external.CompVisDenoiser(model)\n",
+ " sigmas = dnw.get_sigmas(steps).flip(0)\n",
+ "\n",
+ " if verbose:\n",
+ " print(sigmas)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " with autocast('cuda'):\n",
+ " for i in trange(1, len(sigmas)):\n",
+ " x_in = torch.cat([x] * 2)\n",
+ " sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)\n",
+ " cond_in = torch.cat([uncond, cond])\n",
+ "\n",
+ " c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]\n",
+ "\n",
+ " if i == 1:\n",
+ " t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))\n",
+ " else:\n",
+ " t = dnw.sigma_to_t(sigma_in)\n",
+ "\n",
+ " eps = model.apply_model(x_in * c_in, t, cond=cond_in)\n",
+ " denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)\n",
+ "\n",
+ " denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale\n",
+ "\n",
+ " if i == 1:\n",
+ " d = (x - denoised) / (2 * sigmas[i])\n",
+ " else:\n",
+ " d = (x - denoised) / sigmas[i - 1]\n",
+ "\n",
+ " dt = sigmas[i] - sigmas[i - 1]\n",
+ " x = x + d * dt\n",
+ " print(x.shape)\n",
+ " if normalize:\n",
+ " return (x / x.std()) * sigmas[-1]\n",
+ " else:\n",
+ " return x\n",
+ "\n",
+ "# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736\n",
+ "#todo add batching for >2 cond size\n",
+ "import hashlib\n",
+ "def find_noise_for_image_sigma_adjustment(init_latent, prompt, image_conditioning, cfg_scale, steps, frame_num):\n",
+ " rec_noise_setting_list = {\n",
+ " 'init_image': init_image,\n",
+ " 'seed': seed,\n",
+ " 'width': width_height[0],\n",
+ " 'height': width_height[1],\n",
+ " 'diffusion_model': diffusion_model,\n",
+ " 'diffusion_steps': diffusion_steps,\n",
+ " 'video_init_path':video_init_path,\n",
+ " 'extract_nth_frame':extract_nth_frame,\n",
+ " 'flow_video_init_path':flow_video_init_path,\n",
+ " 'flow_extract_nth_frame':flow_extract_nth_frame,\n",
+ " 'video_init_seed_continuity': video_init_seed_continuity,\n",
+ " 'turbo_mode':turbo_mode,\n",
+ " 'turbo_steps':turbo_steps,\n",
+ " 'turbo_preroll':turbo_preroll,\n",
+ " 'flow_warp':flow_warp,\n",
+ " 'check_consistency':check_consistency,\n",
+ " 'turbo_frame_skips_steps' : turbo_frame_skips_steps,\n",
+ " 'forward_weights_clip' : forward_weights_clip,\n",
+ " 'forward_weights_clip_turbo_step' : forward_weights_clip_turbo_step,\n",
+ " 'padding_ratio':padding_ratio,\n",
+ " 'padding_mode':padding_mode,\n",
+ " 'consistency_blur':consistency_blur,\n",
+ " 'inpaint_blend':inpaint_blend,\n",
+ " 'match_color_strength':match_color_strength,\n",
+ " 'high_brightness_threshold':high_brightness_threshold,\n",
+ " 'high_brightness_adjust_ratio':high_brightness_adjust_ratio,\n",
+ " 'low_brightness_threshold':low_brightness_threshold,\n",
+ " 'low_brightness_adjust_ratio':low_brightness_adjust_ratio,\n",
+ " 'high_brightness_adjust_fix_amount': high_brightness_adjust_fix_amount,\n",
+ " 'low_brightness_adjust_fix_amount': low_brightness_adjust_fix_amount,\n",
+ " 'max_brightness_threshold':max_brightness_threshold,\n",
+ " 'min_brightness_threshold':min_brightness_threshold,\n",
+ " 'enable_adjust_brightness':enable_adjust_brightness,\n",
+ " 'dynamic_thresh':dynamic_thresh,\n",
+ " 'warp_interp':warp_interp,\n",
+ " 'reverse_cc_order':reverse_cc_order,\n",
+ " 'flow_lq':flow_lq,\n",
+ " 'use_predicted_noise':use_predicted_noise,\n",
+ " 'clip_guidance_scale':clip_guidance_scale,\n",
+ " 'clip_type':clip_type,\n",
+ " 'clip_pretrain':clip_pretrain,\n",
+ " 'missed_consistency_weight':missed_consistency_weight,\n",
+ " 'overshoot_consistency_weight':overshoot_consistency_weight,\n",
+ " 'edges_consistency_weight':edges_consistency_weight,\n",
+ " 'flow_blend_schedule':flow_blend_schedule,\n",
+ " 'steps_schedule':steps_schedule,\n",
+ " 'latent_scale_schedule':latent_scale_schedule,\n",
+ " 'flow_blend_template':flow_blend_template,\n",
+ " 'cc_masked_template':cc_masked_template,\n",
+ " 'make_schedules':make_schedules,\n",
+ " 'normalize_latent':normalize_latent,\n",
+ " 'normalize_latent_offset':normalize_latent_offset,\n",
+ " 'colormatch_frame':colormatch_frame,\n",
+ " 'use_karras_noise':use_karras_noise,\n",
+ " 'end_karras_ramp_early':end_karras_ramp_early,\n",
+ " 'use_background_mask':use_background_mask,\n",
+ " 'apply_mask_after_warp':apply_mask_after_warp,\n",
+ " 'background':background,\n",
+ " 'background_source':background_source,\n",
+ " 'mask_source':mask_source,\n",
+ " 'extract_background_mask':extract_background_mask,\n",
+ " 'mask_video_path':mask_video_path,\n",
+ " 'invert_mask':invert_mask,\n",
+ " 'warp_strength': warp_strength,\n",
+ " 'flow_override_map':flow_override_map,\n",
+ " 'respect_sched':respect_sched,\n",
+ " 'color_match_frame_str':color_match_frame_str,\n",
+ " 'colormatch_offset':colormatch_offset,\n",
+ " 'latent_fixed_mean':latent_fixed_mean,\n",
+ " 'latent_fixed_std':latent_fixed_std,\n",
+ " 'colormatch_method':colormatch_method,\n",
+ " 'colormatch_regrain':colormatch_regrain,\n",
+ " 'warp_mode':warp_mode,\n",
+ " 'use_patchmatch_inpaiting':use_patchmatch_inpaiting,\n",
+ " 'blend_latent_to_init':blend_latent_to_init,\n",
+ " 'warp_towards_init':warp_towards_init,\n",
+ " 'init_grad':init_grad,\n",
+ " 'grad_denoised':grad_denoised,\n",
+ " 'colormatch_after':colormatch_after,\n",
+ " 'colormatch_turbo':colormatch_turbo,\n",
+ " 'model_version':model_version,\n",
+ " 'cond_image_src':cond_image_src,\n",
+ " 'warp_num_k':warp_num_k,\n",
+ " 'warp_forward':warp_forward,\n",
+ " 'sampler':sampler.__name__,\n",
+ " 'mask_clip':(mask_clip_low, mask_clip_high),\n",
+ " 'inpainting_mask_weight':inpainting_mask_weight ,\n",
+ " 'inverse_inpainting_mask':inverse_inpainting_mask,\n",
+ " 'mask_source':mask_source,\n",
+ " 'model_path':model_path,\n",
+ " 'diff_override':diff_override,\n",
+ " 'image_scale_schedule':image_scale_schedule,\n",
+ " 'image_scale_template':image_scale_template,\n",
+ " 'detect_resolution' :detect_resolution,\n",
+ " 'bg_threshold':bg_threshold,\n",
+ " 'diffuse_inpaint_mask_blur':diffuse_inpaint_mask_blur,\n",
+ " 'diffuse_inpaint_mask_thresh':diffuse_inpaint_mask_thresh,\n",
+ " 'add_noise_to_latent':add_noise_to_latent,\n",
+ " 'noise_upscale_ratio':noise_upscale_ratio,\n",
+ " 'fixed_seed':fixed_seed,\n",
+ " 'init_latent_fn':init_latent_fn.__name__,\n",
+ " 'value_threshold':value_threshold,\n",
+ " 'distance_threshold':distance_threshold,\n",
+ " 'masked_guidance':masked_guidance,\n",
+ " 'cc_masked_diffusion_schedule':cc_masked_diffusion_schedule,\n",
+ " 'alpha_masked_diffusion':alpha_masked_diffusion,\n",
+ " 'inverse_mask_order':inverse_mask_order,\n",
+ " 'invert_alpha_masked_diffusion':invert_alpha_masked_diffusion,\n",
+ " 'quantize':quantize,\n",
+ " 'cb_noise_upscale_ratio':cb_noise_upscale_ratio,\n",
+ " 'cb_add_noise_to_latent':cb_add_noise_to_latent,\n",
+ " 'cb_use_start_code':cb_use_start_code,\n",
+ " 'cb_fixed_code':cb_fixed_code,\n",
+ " 'cb_norm_latent':cb_norm_latent,\n",
+ " 'guidance_use_start_code':guidance_use_start_code,\n",
+ " 'controlnet_preprocess':controlnet_preprocess,\n",
+ " 'small_controlnet_model_path':small_controlnet_model_path,\n",
+ " 'use_scale':use_scale,\n",
+ " 'g_invert_mask':g_invert_mask,\n",
+ " 'controlnet_multimodel':json.dumps(controlnet_multimodel),\n",
+ " 'img_zero_uncond':img_zero_uncond,\n",
+ " 'do_softcap':do_softcap,\n",
+ " 'softcap_thresh':softcap_thresh,\n",
+ " 'softcap_q':softcap_q,\n",
+ " 'deflicker_latent_scale':deflicker_latent_scale,\n",
+ " 'deflicker_scale':deflicker_scale,\n",
+ " 'controlnet_multimodel_mode':controlnet_multimodel_mode,\n",
+ " 'no_half_vae':no_half_vae,\n",
+ " 'temporalnet_source':temporalnet_source,\n",
+ " 'temporalnet_skip_1st_frame':temporalnet_skip_1st_frame,\n",
+ " 'rec_randomness':rec_randomness,\n",
+ " 'rec_source':rec_source,\n",
+ " 'rec_cfg':rec_cfg,\n",
+ " 'rec_prompts':rec_prompts,\n",
+ " 'inpainting_mask_source':inpainting_mask_source,\n",
+ " 'rec_steps_pct':rec_steps_pct,\n",
+ " 'max_faces': max_faces,\n",
+ " 'num_flow_updates':num_flow_updates,\n",
+ " 'pose_detector':pose_detector,\n",
+ " 'control_sd15_openpose_hands_face':control_sd15_openpose_hands_face,\n",
+ " 'control_sd15_depth_detector':control_sd15_openpose_hands_face,\n",
+ " 'control_sd15_softedge_detector':control_sd15_softedge_detector,\n",
+ " 'control_sd15_seg_detector':control_sd15_seg_detector,\n",
+ " 'control_sd15_scribble_detector':control_sd15_scribble_detector,\n",
+ " 'control_sd15_lineart_coarse':control_sd15_lineart_coarse,\n",
+ " 'control_sd15_inpaint_mask_source':control_sd15_inpaint_mask_source,\n",
+ " 'control_sd15_shuffle_source':control_sd15_shuffle_source,\n",
+ " 'control_sd15_shuffle_1st_source':control_sd15_shuffle_1st_source,\n",
+ " 'consistency_dilate':consistency_dilate,\n",
+ " 'apply_freeu_after_control':apply_freeu_after_control,\n",
+ " 'do_freeunet':do_freeunet\n",
+ " }\n",
+ " settings_hash = hashlib.sha256(json.dumps(rec_noise_setting_list).encode('utf-8')).hexdigest()[:16]\n",
+ " filepath = f'{recNoiseCacheFolder}/{settings_hash}_{frame_num:06}.pt'\n",
+ " if os.path.exists(filepath) and not overwrite_rec_noise:\n",
+ " print(filepath)\n",
+ " noise = torch.load(filepath)\n",
+ " print('loading existing noise')\n",
+ " return noise\n",
+ " steps = int(copy.copy(steps)*rec_steps_pct)\n",
+ "\n",
+ " cfg_scale=rec_cfg\n",
+ " if 'sdxl' in model_version:\n",
+ " cond = sd_model.get_learned_conditioning(prompt)\n",
+ " uncond = sd_model.get_learned_conditioning([''])\n",
+ " else:\n",
+ " cond = prompt_parser.get_learned_conditioning(sd_model, prompt, steps)\n",
+ " uncond = prompt_parser.get_learned_conditioning(sd_model, [''], steps)\n",
+ " cond = prompt_parser.reconstruct_cond_batch(cond, 0)\n",
+ " uncond = prompt_parser.reconstruct_cond_batch(uncond, 0)\n",
+ "\n",
+ " x = init_latent\n",
+ "\n",
+ " s_in = x.new_ones([x.shape[0]])\n",
+ " if sd_model.parameterization == \"v\" or model_version == 'control_multi_v2_768':\n",
+ " dnw = K.external.CompVisVDenoiser(sd_model)\n",
+ " skip = 1\n",
+ " else:\n",
+ " dnw = K.external.CompVisDenoiser(sd_model)\n",
+ " skip = 0\n",
+ " sigmas = dnw.get_sigmas(steps).flip(0)\n",
+ "\n",
+ " if 'sdxl' in model_version:\n",
+ " vector = cond['vector']\n",
+ " uc_vector = uncond['vector']\n",
+ " y = vector_in = torch.cat([uc_vector, vector])\n",
+ " cond = cond['crossattn']\n",
+ " uncond = uncond['crossattn']\n",
+ " sd_model.conditioner.vector_in = vector_in\n",
+ "\n",
+ " if cond.shape[1]>77:\n",
+ " cond = cond[:,:77,:]\n",
+ " print('Prompt length > 77 detected. Shorten your prompt or split into multiple prompts.')\n",
+ " uncond = uncond[:,:77,:]\n",
+ " for i in trange(1, len(sigmas)):\n",
+ "\n",
+ "\n",
+ " x_in = torch.cat([x] * 2)\n",
+ " sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)\n",
+ " cond_in = torch.cat([uncond, cond])\n",
+ "\n",
+ "\n",
+ " # image_conditioning = torch.cat([image_conditioning] * 2)\n",
+ " # cond_in = {\"c_concat\": [image_conditioning], \"c_crossattn\": [cond_in]}\n",
+ " if model_version == 'control_multi' and controlnet_multimodel_mode == 'external':\n",
+ " raise Exception(\"Predicted noise not supported for external mode. Please turn predicted noise off or use internal mode.\")\n",
+ " if image_conditioning is not None:\n",
+ " if 'control_multi' not in model_version:\n",
+ " if model_version in ['sdxl_base', 'sdxl_refiner']:\n",
+ " sd_model.conditioner.vector_in = vector_in[i*batch_size:(i+1)*batch_size]\n",
+ " if img_zero_uncond:\n",
+ " img_in = torch.cat([torch.zeros_like(image_conditioning),\n",
+ " image_conditioning])\n",
+ " else:\n",
+ " img_in = torch.cat([image_conditioning]*2)\n",
+ " cond_in={\"c_crossattn\": [cond_in],'c_concat': [img_in]}\n",
+ "\n",
+ " if 'control_multi' in model_version and controlnet_multimodel_mode != 'external':\n",
+ " img_in = {}\n",
+ " for key in image_conditioning.keys():\n",
+ " img_in[key] = torch.cat([torch.zeros_like(image_conditioning[key]),\n",
+ " image_conditioning[key]]) if img_zero_uncond else torch.cat([image_conditioning[key]]*2)\n",
+ "\n",
+ " cond_in = {\"c_crossattn\": [cond_in], 'c_concat': img_in,\n",
+ " 'controlnet_multimodel':controlnet_multimodel_inferred,\n",
+ " 'loaded_controlnets':loaded_controlnets}\n",
+ " if 'sdxl' in model_version:\n",
+ " cond_in['y'] = y\n",
+ "\n",
+ "\n",
+ " c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]\n",
+ "\n",
+ " if i == 1:\n",
+ " t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))\n",
+ " else:\n",
+ " t = dnw.sigma_to_t(sigma_in)\n",
+ "\n",
+ " eps = sd_model.apply_model(x_in * c_in, t, cond=cond_in)\n",
+ " denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)\n",
+ "\n",
+ " denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale\n",
+ "\n",
+ " if i == 1:\n",
+ " d = (x - denoised) / (2 * sigmas[i])\n",
+ " else:\n",
+ " d = (x - denoised) / sigmas[i - 1]\n",
+ "\n",
+ " dt = sigmas[i] - sigmas[i - 1]\n",
+ " x = x + d * dt\n",
+ "\n",
+ "\n",
+ "\n",
+ " # This shouldn't be necessary, but solved some VRAM issues\n",
+ " del x_in, sigma_in, cond_in, c_out, c_in, t,\n",
+ " del eps, denoised_uncond, denoised_cond, denoised, d, dt\n",
+ "\n",
+ "\n",
+ " # return (x / x.std()) * sigmas[-1]\n",
+ " x = x / sigmas[-1]\n",
+ " torch.save(x, filepath)\n",
+ " return x# / sigmas[-1]\n",
+ "\n",
+ "#karras noise\n",
+ "#https://github.com/Birch-san/stable-diffusion/blob/693c8a336aa3453d30ce403f48eb545689a679e5/scripts/txt2img_fork.py#L62-L81\n",
+ "sys.path.append('./k-diffusion')\n",
+ "\n",
+ "def get_premature_sigma_min(\n",
+ " steps: int,\n",
+ " sigma_max: float,\n",
+ " sigma_min_nominal: float,\n",
+ " rho: float\n",
+ " ) -> float:\n",
+ " min_inv_rho = sigma_min_nominal ** (1 / rho)\n",
+ " max_inv_rho = sigma_max ** (1 / rho)\n",
+ " ramp = (steps-2) * 1/(steps-1)\n",
+ " sigma_min = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho\n",
+ " return sigma_min\n",
+ "\n",
+ "import contextlib\n",
+ "none_context = contextlib.nullcontext()\n",
+ "\n",
+ "def masked_callback(args, callback_steps, masks, init_latent, start_code):\n",
+ " # print('callback_step', callback_step)\n",
+ " # print('masks callback shape',[o.shape for o in masks])\n",
+ " init_latent = init_latent.clone()\n",
+ " # print(args['i'])\n",
+ " masks = [m[:,0:1,...] for m in masks]\n",
+ " # print(args['x'].shape)\n",
+ " final_mask = None #create a combined mask for this step\n",
+ " for (mask, callback_step) in zip(masks, callback_steps):\n",
+ "\n",
+ " if args['i'] <= callback_step:\n",
+ " mask = torch.nn.functional.interpolate(mask, args['x'].shape[2:],\n",
+ " mode='bilinear')\n",
+ " if final_mask is None: final_mask = mask\n",
+ " else: final_mask = final_mask*mask\n",
+ "\n",
+ " mask = final_mask\n",
+ "\n",
+ " if mask is not None:\n",
+ " # PIL.Image.fromarray(np.repeat(mask.clone().cpu().numpy()[0,0,...][...,None],3, axis=2).astype('uint8')*255).save(f'{root_dir}/{args[\"i\"]}.jpg')\n",
+ " if cb_use_start_code:\n",
+ " noise = start_code\n",
+ " else:\n",
+ " noise = torch.randn_like(args['x'])\n",
+ " noise = noise*args['sigma']\n",
+ " if cb_noise_upscale_ratio > 1:\n",
+ " noise = noise[::noise_upscale_ratio,::noise_upscale_ratio,:]\n",
+ " noise = torch.nn.functional.interpolate(noise, args['x'].shape[2:],\n",
+ " mode='bilinear')\n",
+ " # mask = torch.nn.functional.interpolate(mask, args['x'].shape[2:],\n",
+ " # mode='bilinear')\n",
+ " if VERBOSE: print('Applying callback at step ', args['i'])\n",
+ " if cb_add_noise_to_latent:\n",
+ " init_latent = init_latent+noise\n",
+ " if cb_norm_latent:\n",
+ " noise = init_latent\n",
+ " noise2 = args['x']\n",
+ " n_mean = noise2.mean(dim=(2,3),keepdim=True)\n",
+ " n_std = noise2.std(dim=(2,3),keepdim=True)\n",
+ " n2_mean = noise.mean(dim=(2,3),keepdim=True)\n",
+ " noise = noise - (n2_mean-n_mean)\n",
+ " n2_std = noise.std(dim=(2,3),keepdim=True)\n",
+ " noise = noise/(n2_std/n_std)\n",
+ " init_latent = noise\n",
+ "\n",
+ " args['x'] = args['x']*(1-mask) + (init_latent)*mask #ok\n",
+ " # args['x'] = args['x']*(mask) + (init_latent)*(1-mask) #test reverse\n",
+ " # return args['x']\n",
+ "\n",
+ " return args['x']\n",
+ "\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "def high_frequency_loss(image1, image2):\n",
+ " \"\"\"\n",
+ " Compute the loss that penalizes high-frequency differences between images\n",
+ " while ignoring low-frequency differences.\n",
+ "\n",
+ " Args:\n",
+ " image1 (torch.Tensor): First input image tensor of shape (batch_size, channels, height, width).\n",
+ " image2 (torch.Tensor): Second input image tensor of shape (batch_size, channels, height, width).\n",
+ "\n",
+ " Returns:\n",
+ " torch.Tensor: Loss value.\n",
+ " \"\"\"\n",
+ "\n",
+ " # Compute the Fourier transforms of the images\n",
+ " image1_fft = torch.fft.fft2(image1)\n",
+ " image2_fft = torch.fft.fft2(image2)\n",
+ "\n",
+ " # Compute the magnitudes of the Fourier transforms\n",
+ " image1_mag = torch.abs(image1_fft)\n",
+ " image2_mag = torch.abs(image2_fft)\n",
+ "\n",
+ " # Compute the high-frequency difference between the magnitudes\n",
+ " high_freq_diff = image1_mag - image2_mag\n",
+ " print('image1.dtype, image2.dtype',image1.dtype, image2.dtype)\n",
+ " # Define a low-pass filter to remove low-frequency components\n",
+ " filter = torch.tensor([[1, 1, 1],\n",
+ " [1, 0, 1],\n",
+ " [1, 1, 1]], dtype=image1.dtype, device=image1.device).unsqueeze(0).unsqueeze(0).repeat(1,image1.shape[1],1,1)\n",
+ " filter = filter / torch.sum(filter)\n",
+ "\n",
+ " # Apply the low-pass filter to the high-frequency difference\n",
+ " print('high_freq_diff, filter',high_freq_diff.dtype, filter.dtype)\n",
+ " with torch.autocast('cuda'):\n",
+ " low_freq_diff = F.conv2d(high_freq_diff, filter, padding=1)\n",
+ "\n",
+ " # Compute the mean squared error between the low-frequency difference and zero\n",
+ " loss = torch.mean(low_freq_diff**2)\n",
+ "\n",
+ " return loss\n",
+ "\n",
+ "pred_noise = None\n",
+ "def run_sd(opt, init_image, skip_timesteps, H, W, text_prompt, neg_prompt, steps, seed,\n",
+ " init_scale, init_latent_scale, cond_image, cfg_scale, image_scale,\n",
+ " cond_fn=None, init_grad_img=None, consistency_mask=None, frame_num=0,\n",
+ " deflicker_src=None, prev_frame=None, rec_prompt=None, rec_frame=None,\n",
+ " control_inpainting_mask=None, shuffle_source=None, ref_image=None, alpha_mask=None,\n",
+ " prompt_weights=None, mask_current_frame_many=None, controlnet_sources={}, cc_masked_diffusion=[0]):\n",
+ "\n",
+ " # sampler = sample_euler\n",
+ "\n",
+ " # if model_version in ['sdxl_base', 'sdxl_refiner']:\n",
+ " # print('Disabling init_scale for sdxl')\n",
+ " # init_scale = 0\n",
+ "\n",
+ "\n",
+ " seed_everything(seed)\n",
+ " sd_model.cuda()\n",
+ " sd_model.model.cuda()\n",
+ " sd_model.cond_stage_model.cuda()\n",
+ " sd_model.cuda()\n",
+ " sd_model.first_stage_model.cuda()\n",
+ " model_wrap.inner_model.cuda()\n",
+ " model_wrap.cuda()\n",
+ " model_wrap_cfg.cuda()\n",
+ " model_wrap_cfg.inner_model.cuda()\n",
+ " # global cfg_scale\n",
+ " if VERBOSE:\n",
+ " print('seed', 'clip_guidance_scale', 'init_scale', 'init_latent_scale', 'clamp_grad', 'clamp_max',\n",
+ " 'init_image', 'skip_timesteps', 'cfg_scale')\n",
+ " print(seed, clip_guidance_scale, init_scale, init_latent_scale, clamp_grad,\n",
+ " clamp_max, init_image, skip_timesteps, cfg_scale)\n",
+ " global start_code, inpainting_mask_weight, inverse_inpainting_mask, start_code_cb, guidance_start_code\n",
+ " global pred_noise, controlnet_preprocess\n",
+ " # global frame_num\n",
+ " global normalize_latent\n",
+ " global first_latent\n",
+ " global first_latent_source\n",
+ " global use_karras_noise\n",
+ " global end_karras_ramp_early\n",
+ " global latent_fixed_norm\n",
+ " global latent_norm_4d\n",
+ " global latent_fixed_mean\n",
+ " global latent_fixed_std\n",
+ " global n_mean_avg\n",
+ " global n_std_avg\n",
+ " global reference_latent\n",
+ "\n",
+ " batch_size = num_samples = 1\n",
+ " scale = cfg_scale\n",
+ "\n",
+ " C = 4 #4\n",
+ " f = 8 #8\n",
+ " H = H\n",
+ " W = W\n",
+ " if VERBOSE:print(W, H, 'WH')\n",
+ " prompt = text_prompt[0]\n",
+ "\n",
+ "\n",
+ " neg_prompt = neg_prompt[0]\n",
+ " ddim_steps = steps\n",
+ "\n",
+ " # init_latent_scale = 0. #20\n",
+ " prompt_clip = prompt\n",
+ "\n",
+ "\n",
+ " assert prompt is not None\n",
+ " prompts = text_prompt\n",
+ "\n",
+ " if VERBOSE:print('prompts', prompts, text_prompt)\n",
+ "\n",
+ " precision_scope = autocast\n",
+ "\n",
+ " t_enc = ddim_steps-skip_timesteps\n",
+ "\n",
+ " if init_image is not None:\n",
+ " if isinstance(init_image, str):\n",
+ " if not init_image.endswith('_lat.pt'):\n",
+ " with torch.no_grad():\n",
+ " with torch.autocast('cuda'):\n",
+ " init_image_sd = load_img_sd(init_image, size=(W,H)).cuda()\n",
+ " if no_half_vae:\n",
+ " sd_model.first_stage_model.float()\n",
+ " init_image_sd = init_image_sd.float()\n",
+ " init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(init_image_sd))\n",
+ " x0 = init_latent\n",
+ " if init_image.endswith('_lat.pt'):\n",
+ " init_latent = torch.load(init_image).cuda()\n",
+ " init_image_sd = None\n",
+ " x0 = init_latent\n",
+ "\n",
+ " reference_latent = None\n",
+ " if ref_image is not None and reference_active:\n",
+ " if os.path.exists(ref_image):\n",
+ " with torch.no_grad(), torch.cuda.amp.autocast():\n",
+ " reference_img = load_img_sd(ref_image, size=(W,H)).cuda()\n",
+ " reference_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(reference_img))\n",
+ " else:\n",
+ " print('Failed to load reference image')\n",
+ " ref_image = None\n",
+ "\n",
+ "\n",
+ "\n",
+ " if use_predicted_noise:\n",
+ " if rec_frame is not None:\n",
+ " with torch.cuda.amp.autocast():\n",
+ " rec_frame_img = load_img_sd(rec_frame, size=(W,H)).cuda()\n",
+ " rec_frame_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(rec_frame_img))\n",
+ "\n",
+ " if init_grad_img is not None:\n",
+ " print('Replacing init image for cond fn')\n",
+ " init_image_sd = load_img_sd(init_grad_img, size=(W,H)).cuda()\n",
+ "\n",
+ " if blend_latent_to_init > 0. and first_latent is not None:\n",
+ " print('Blending to latent ', first_latent_source)\n",
+ " x0 = x0*(1-blend_latent_to_init) + blend_latent_to_init*first_latent\n",
+ " if normalize_latent!='off' and first_latent is not None:\n",
+ " if VERBOSE:\n",
+ " print('norm to 1st latent')\n",
+ " print('latent source - ', first_latent_source)\n",
+ " # noise2 - target\n",
+ " # noise - modified\n",
+ "\n",
+ " if latent_norm_4d:\n",
+ " n_mean = first_latent.mean(dim=(2,3),keepdim=True)\n",
+ " n_std = first_latent.std(dim=(2,3),keepdim=True)\n",
+ " else:\n",
+ " n_mean = first_latent.mean()\n",
+ " n_std = first_latent.std()\n",
+ "\n",
+ " if n_mean_avg is None and n_std_avg is None:\n",
+ " n_mean_avg = n_mean.clone().detach().cpu().numpy()[0,:,0,0]\n",
+ " n_std_avg = n_std.clone().detach().cpu().numpy()[0,:,0,0]\n",
+ " else:\n",
+ " n_mean_avg = n_mean_avg*n_smooth+(1-n_smooth)*n_mean.clone().detach().cpu().numpy()[0,:,0,0]\n",
+ " n_std_avg = n_std_avg*n_smooth+(1-n_smooth)*n_std.clone().detach().cpu().numpy()[0,:,0,0]\n",
+ "\n",
+ " if VERBOSE:\n",
+ " print('n_stats_avg (mean, std): ', n_mean_avg, n_std_avg)\n",
+ " if normalize_latent=='user_defined':\n",
+ " n_mean = latent_fixed_mean\n",
+ " if isinstance(n_mean, list) and len(n_mean)==4: n_mean = np.array(n_mean)[None,:, None, None]\n",
+ " n_std = latent_fixed_std\n",
+ " if isinstance(n_std, list) and len(n_std)==4: n_std = np.array(n_std)[None,:, None, None]\n",
+ " if latent_norm_4d: n2_mean = x0.mean(dim=(2,3),keepdim=True)\n",
+ " else: n2_mean = x0.mean()\n",
+ " x0 = x0 - (n2_mean-n_mean)\n",
+ " if latent_norm_4d: n2_std = x0.std(dim=(2,3),keepdim=True)\n",
+ " else: n2_std = x0.std()\n",
+ " x0 = x0/(n2_std/n_std)\n",
+ "\n",
+ " if clip_guidance_scale>0:\n",
+ " # text_features = clip_model.encode_text(text)\n",
+ " target_embed = F.normalize(clip_model.encode_text(open_clip.tokenize(prompt_clip).cuda()).float())\n",
+ " else:\n",
+ " target_embed = None\n",
+ "\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " with torch.cuda.amp.autocast():\n",
+ " with precision_scope(\"cuda\"):\n",
+ " scope = none_context if model_version == 'v1_inpainting' else sd_model.ema_scope()\n",
+ " with scope:\n",
+ " tic = time.time()\n",
+ " all_samples = []\n",
+ " uc = None\n",
+ " if True:\n",
+ " if scale != 1.0:\n",
+ " if 'sdxl' in model_version:\n",
+ " uc = sd_model.get_learned_conditioning([neg_prompt])\n",
+ " else:\n",
+ " uc = prompt_parser.get_learned_conditioning(sd_model, [neg_prompt], ddim_steps)\n",
+ "\n",
+ " if isinstance(prompts, tuple):\n",
+ " prompts = list(prompts)\n",
+ " if 'sdxl' in model_version:\n",
+ " c = sd_model.get_learned_conditioning(prompts)\n",
+ " else:\n",
+ " c = prompt_parser.get_learned_conditioning(sd_model, prompts, ddim_steps)\n",
+ "\n",
+ " shape = [C, H // f, W // f]\n",
+ " if use_karras_noise:\n",
+ "\n",
+ " rho = 7.\n",
+ " # 14.6146\n",
+ " sigma_max=model_wrap.sigmas[-1].item()\n",
+ " sigma_min_nominal=model_wrap.sigmas[0].item()\n",
+ " # get the \"sigma before sigma_min\" from a slightly longer ramp\n",
+ " # https://github.com/crowsonkb/k-diffusion/pull/23#issuecomment-1234872495\n",
+ " premature_sigma_min = get_premature_sigma_min(\n",
+ " steps=steps+1,\n",
+ " sigma_max=sigma_max,\n",
+ " sigma_min_nominal=sigma_min_nominal,\n",
+ " rho=rho\n",
+ " )\n",
+ " sigmas = K.sampling.get_sigmas_karras(\n",
+ " n=steps,\n",
+ " sigma_min=premature_sigma_min if end_karras_ramp_early else sigma_min_nominal,\n",
+ " sigma_max=sigma_max,\n",
+ " rho=rho,\n",
+ " device='cuda',\n",
+ " ).float()\n",
+ " else:\n",
+ " sigmas = model_wrap.get_sigmas(ddim_steps).float()\n",
+ " alpha_mask_t = None\n",
+ " if alpha_mask is not None and init_image is not None:\n",
+ " print('alpha_mask.shape', alpha_mask.shape)\n",
+ " alpha_mask_t = torch.from_numpy(alpha_mask).float().to(init_latent.device)[None,None,...][:,0:1,...]\n",
+ " consistency_mask_t = None\n",
+ " if consistency_mask is not None and init_image is not None:\n",
+ " consistency_mask_t = torch.from_numpy(consistency_mask).float().to(init_latent.device).permute(2,0,1)[None,...][:,0:1,...]\n",
+ " if guidance_use_start_code:\n",
+ " guidance_start_code = torch.randn_like(init_latent)\n",
+ "\n",
+ " deflicker_fn = deflicker_lat_fn = fft_fn = fft_latent_fn = None\n",
+ " if frame_num > args.start_frame:\n",
+ " def absdiff(a,b):\n",
+ " return abs(a-b)\n",
+ " for key in deflicker_src.keys():\n",
+ " deflicker_src[key] = load_img_sd(deflicker_src[key], size=(W,H)).cuda()\n",
+ " deflicker_fn = partial(deflicker_loss, processed1=deflicker_src['processed1'][:,:,::2,::2],\n",
+ " raw1=deflicker_src['raw1'][:,:,::2,::2], raw2=deflicker_src['raw2'][:,:,::2,::2], criterion1= absdiff, criterion2=lpips_model)\n",
+ " fft_fn = partial(high_frequency_loss, image2=init_image_sd)\n",
+ " for key in deflicker_src.keys():\n",
+ " deflicker_src[key] = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(deflicker_src[key]))\n",
+ " deflicker_lat_fn = partial(deflicker_loss,\n",
+ " raw1=deflicker_src['raw1'], raw2=deflicker_src['raw2'], criterion1= absdiff, criterion2=rmse)\n",
+ " fft_latent_fn = partial(high_frequency_loss, image2=init_latent)\n",
+ " cond_fn_partial = partial(sd_cond_fn, init_image_sd=init_image_sd,\n",
+ " init_latent=init_latent,\n",
+ " init_scale=init_scale,\n",
+ " init_latent_scale=init_latent_scale,\n",
+ " target_embed=target_embed,\n",
+ " consistency_mask = consistency_mask_t,\n",
+ " start_code = guidance_start_code,\n",
+ " deflicker_fn = deflicker_fn, deflicker_lat_fn=deflicker_lat_fn,\n",
+ " deflicker_src=deflicker_src, fft_fn=fft_fn, fft_latent_fn=fft_latent_fn\n",
+ " )\n",
+ " callback_partial = None\n",
+ " if cc_masked_diffusion > 0 and consistency_mask is not None or alpha_masked_diffusion and alpha_mask is not None:\n",
+ " if cb_fixed_code:\n",
+ " if start_code_cb is None:\n",
+ " if VERBOSE:print('init start code')\n",
+ " start_code_cb = torch.randn_like(x0)\n",
+ " else:\n",
+ " start_code_cb = torch.randn_like(x0)\n",
+ " # start_code = torch.randn_like(x0)\n",
+ " callback_steps = []\n",
+ " callback_masks = []\n",
+ " if (cc_masked_diffusion > 0) and (consistency_mask is not None):\n",
+ " callback_masks.append(consistency_mask_t)\n",
+ " callback_steps.append(int((ddim_steps-skip_timesteps)*cc_masked_diffusion))\n",
+ " if alpha_masked_diffusion and alpha_mask is not None:\n",
+ " if invert_alpha_masked_diffusion:\n",
+ " alpha_mask_t = 1.-alpha_mask_t\n",
+ " callback_masks.append(alpha_mask_t)\n",
+ " callback_steps.append(int((ddim_steps-skip_timesteps)*alpha_masked_diffusion))\n",
+ " if inverse_mask_order:\n",
+ " callback_masks.reverse()\n",
+ " callback_steps.reverse()\n",
+ "\n",
+ "\n",
+ " if VERBOSE: print('callback steps', callback_steps)\n",
+ " callback_partial = partial(masked_callback,\n",
+ " callback_steps=callback_steps,\n",
+ " masks=callback_masks,\n",
+ " init_latent=init_latent, start_code=start_code_cb)\n",
+ " if new_prompt_loras == {}:\n",
+ " # only use cond fn when loras are off\n",
+ " model_fn = make_cond_model_fn(model_wrap_cfg, cond_fn_partial)\n",
+ " # model_fn = make_static_thresh_model_fn(model_fn, dynamic_thresh)\n",
+ " else:\n",
+ " model_fn = model_wrap_cfg\n",
+ "\n",
+ " model_fn = make_static_thresh_model_fn(model_fn, dynamic_thresh)\n",
+ " depth_img = None\n",
+ " depth_cond = None\n",
+ " if 'control_' in model_version:\n",
+ " input_image = np.array(Image.open(cond_image).resize(size=(W,H))); #print(type(input_image), 'input_image', input_image.shape)\n",
+ "\n",
+ "\n",
+ " if 'control_multi' in model_version:\n",
+ " if offload_model and not controlnet_low_vram:\n",
+ " for key in loaded_controlnets.keys():\n",
+ " loaded_controlnets[key].cuda()\n",
+ "\n",
+ " models = list(controlnet_multimodel.keys()); print(models)\n",
+ " else: models = model_version\n",
+ "\n",
+ "\n",
+ " if 'control_' in model_version:\n",
+ "\n",
+ " controlnet_sources['control_inpainting_mask'] = control_inpainting_mask\n",
+ " controlnet_sources['shuffle_source'] = shuffle_source\n",
+ " controlnet_sources['prev_frame'] = prev_frame\n",
+ " controlnet_sources['init_image'] = init_image\n",
+ " init_image = np.array(Image.open(controlnet_sources['init_image']).convert('RGB').resize(size=(W,H)))\n",
+ " detected_maps, models = get_controlnet_annotations(model_version, W, H, models, controlnet_sources)\n",
+ "\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " if VERBOSE: print('Postprocessing cond maps')\n",
+ " def postprocess_map(detected_map):\n",
+ " control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0\n",
+ " control = torch.stack([control for _ in range(num_samples)], dim=0)\n",
+ " depth_cond = einops.rearrange(control, 'b h w c -> b c h w').clone()\n",
+ " # if VERBOSE: print('depth_cond', depth_cond.min(), depth_cond.max(), depth_cond.mean(), depth_cond.std(), depth_cond.shape)\n",
+ " return depth_cond\n",
+ "\n",
+ " if 'control_multi' in model_version:\n",
+ " print('init shape', init_latent.shape, H,W)\n",
+ " for m in models:\n",
+ " if save_controlnet_annotations:\n",
+ " PIL.Image.fromarray(detected_maps[m].astype('uint8')).save(f'{controlnetDebugFolder}/{args.batch_name}({args.batchNum})_{m}_{frame_num:06}.jpg', quality=95)\n",
+ " detected_maps[m] = postprocess_map(detected_maps[m])\n",
+ " if VERBOSE: print('detected_maps[m].shape', m, detected_maps[m].shape)\n",
+ "\n",
+ " depth_cond = detected_maps\n",
+ " else: depth_cond = postprocess_map(detected_maps[model_version])\n",
+ "\n",
+ "\n",
+ " if model_version == 'v1_instructpix2pix':\n",
+ " if isinstance(cond_image, str):\n",
+ " print('Got img cond: ', cond_image)\n",
+ " with torch.no_grad():\n",
+ " with torch.cuda.amp.autocast():\n",
+ " input_image = Image.open(cond_image).resize(size=(W,H))\n",
+ " input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1\n",
+ " input_image = rearrange(input_image, \"h w c -> 1 c h w\").to(sd_model.device)\n",
+ " depth_cond = sd_model.encode_first_stage(input_image).mode()\n",
+ "\n",
+ " if model_version == 'v1_inpainting':\n",
+ " print('using inpainting')\n",
+ " if cond_image is not None:\n",
+ " if inverse_inpainting_mask: cond_image = 1 - cond_image\n",
+ " cond_image = Image.fromarray((cond_image*255).astype('uint8'))\n",
+ "\n",
+ " batch = make_batch_sd(Image.open(init_image).resize((W,H)) , cond_image, txt=prompt, device=device, num_samples=1, inpainting_mask_weight=inpainting_mask_weight)\n",
+ " c_cat = list()\n",
+ " for ck in sd_model.concat_keys:\n",
+ " cc = batch[ck].float()\n",
+ " if ck != sd_model.masked_image_key:\n",
+ "\n",
+ " cc = torch.nn.functional.interpolate(cc, scale_factor=1/8)\n",
+ " else:\n",
+ " cc = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(cc))\n",
+ " c_cat.append(cc)\n",
+ " depth_cond = torch.cat(c_cat, dim=1)\n",
+ " # print('depth cond', depth_cond)\n",
+ " if mask_current_frame_many is not None:\n",
+ " mask_current_frame_many = torch.nn.functional.interpolate(mask_current_frame_many, x0.shape[2:])\n",
+ " extra_args = {'cond': c, 'uncond': uc, 'cond_scale': scale,\n",
+ " 'image_cond':depth_cond, 'prompt_weights':prompt_weights,\n",
+ " 'prompt_masks':mask_current_frame_many}\n",
+ " if model_version == 'v1_instructpix2pix':\n",
+ " extra_args['image_scale'] = image_scale\n",
+ " # extra_args['cond'] = sd_model.get_learned_conditioning(prompts)\n",
+ " # extra_args['uncond'] = sd_model.get_learned_conditioning([\"\"])\n",
+ " if skip_timesteps>0:\n",
+ " if offload_model:\n",
+ " sd_model.model.cuda()\n",
+ " sd_model.model.diffusion_model.cuda()\n",
+ " #using non-random start code\n",
+ " if fixed_code:\n",
+ " if start_code is None:\n",
+ " if VERBOSE:print('init start code')\n",
+ " start_code = torch.randn_like(x0)\n",
+ " if not use_legacy_fixed_code:\n",
+ " rand_code = torch.randn_like(x0)\n",
+ " combined_code = ((1 - code_randomness) * start_code + code_randomness * rand_code) / ((code_randomness**2 + (1-code_randomness)**2) ** 0.5)\n",
+ " noise = combined_code - (x0 / sigmas[0])\n",
+ " noise = noise * sigmas[ddim_steps - t_enc -1]\n",
+ "\n",
+ " #older version\n",
+ " if use_legacy_fixed_code:\n",
+ " normalize_code = True\n",
+ " if normalize_code:\n",
+ " noise2 = torch.randn_like(x0)* sigmas[ddim_steps - t_enc -1]\n",
+ " if latent_norm_4d: n_mean = noise2.mean(dim=(2,3),keepdim=True)\n",
+ " else: n_mean = noise2.mean()\n",
+ " if latent_norm_4d: n_std = noise2.std(dim=(2,3),keepdim=True)\n",
+ " else: n_std = noise2.std()\n",
+ "\n",
+ " noise = torch.randn_like(x0)\n",
+ " noise = (start_code*(1-code_randomness)+(code_randomness)*noise) * sigmas[ddim_steps - t_enc -1]\n",
+ " if normalize_code:\n",
+ " if latent_norm_4d: n2_mean = noise.mean(dim=(2,3),keepdim=True)\n",
+ " else: n2_mean = noise.mean()\n",
+ " noise = noise - (n2_mean-n_mean)\n",
+ " if latent_norm_4d: n2_std = noise.std(dim=(2,3),keepdim=True)\n",
+ " else: n2_std = noise.std()\n",
+ " noise = noise/(n2_std/n_std)\n",
+ "\n",
+ " else:\n",
+ " noise = torch.randn_like(x0) * sigmas[ddim_steps - t_enc -1] #correct one\n",
+ " if use_predicted_noise:\n",
+ " print('using predicted noise')\n",
+ " rand_noise = torch.randn_like(x0)\n",
+ " rec_noise = find_noise_for_image_sigma_adjustment(init_latent=rec_frame_latent, prompt=rec_prompt, image_conditioning=depth_cond, cfg_scale=scale, steps=ddim_steps, frame_num=frame_num)\n",
+ " combined_noise = ((1 - rec_randomness) * rec_noise + rec_randomness * rand_noise) / ((rec_randomness**2 + (1-rec_randomness)**2) ** 0.5)\n",
+ " noise = combined_noise - (x0 / sigmas[0])\n",
+ " noise = noise * sigmas[ddim_steps - t_enc -1]#faster collapse\n",
+ "\n",
+ " print('noise')\n",
+ " # noise = noise[::4,::4,:]\n",
+ " # noise = torch.nn.functional.interpolate(noise, scale_factor=4, mode='bilinear')\n",
+ " if t_enc != 0:\n",
+ " xi = x0 + noise\n",
+ " #printf('xi', xi.shape, xi.min().item(), xi.max().item(), xi.std().item(), xi.mean().item())\n",
+ " # print(xi.mean(), xi.std(), xi.min(), xi.max())\n",
+ " sigma_sched = sigmas[ddim_steps - t_enc - 1:]\n",
+ " # sigma_sched = sigmas[ddim_steps - t_enc:]\n",
+ " print('xi', xi.shape)\n",
+ " # with torch.autocast('cuda'):\n",
+ " # with torch.autocast('cuda', dtype=torch.float16):\n",
+ " samples_ddim = sampler(model_fn, xi, sigma_sched,\n",
+ " extra_args=extra_args, callback=callback_partial)\n",
+ " else:\n",
+ " samples_ddim = x0\n",
+ "\n",
+ " if offload_model:\n",
+ " sd_model.model.cpu()\n",
+ " sd_model.model.diffusion_model.cpu()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " else:\n",
+ " if offload_model:\n",
+ " sd_model.model.cuda()\n",
+ " sd_model.model.diffusion_model.cuda()\n",
+ " # if use_predicted_noise and frame_num>0:\n",
+ " if use_predicted_noise:\n",
+ " print('using predicted noise')\n",
+ " rand_noise = torch.randn_like(x0)\n",
+ " rec_noise = find_noise_for_image_sigma_adjustment(init_latent=rec_frame_latent, prompt=rec_prompt, image_conditioning=depth_cond, cfg_scale=scale, steps=ddim_steps, frame_num=frame_num)\n",
+ " combined_noise = ((1 - rec_randomness) * rec_noise + rec_randomness * rand_noise) / ((rec_randomness**2 + (1-rec_randomness)**2) ** 0.5)\n",
+ " x = combined_noise# - (x0 / sigmas[0])\n",
+ "\n",
+ " else: x = torch.randn([batch_size, *shape], device=device)\n",
+ " x = x * sigmas[0]\n",
+ " # with torch.autocast('cuda',dtype=torch.float16):\n",
+ " samples_ddim = sampler(model_fn, x, sigmas, extra_args=extra_args, callback=callback_partial)\n",
+ " if offload_model:\n",
+ " sd_model.model.cpu()\n",
+ " sd_model.model.diffusion_model.cpu()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " if first_latent is None:\n",
+ " if VERBOSE:print('setting 1st latent')\n",
+ " first_latent_source = 'samples ddim (1st frame output)'\n",
+ " first_latent = samples_ddim\n",
+ "\n",
+ " if offload_model:\n",
+ " sd_model.cond_stage_model.cpu()\n",
+ " if 'control_multi' in model_version:\n",
+ " for key in loaded_controlnets.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ "\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " if offload_model:\n",
+ " sd_model.first_stage_model.cuda()\n",
+ " if no_half_vae:\n",
+ " sd_model.first_stage_model.float()\n",
+ " x_samples_ddim = sd_model.decode_first_stage(samples_ddim.float())\n",
+ " else:\n",
+ " x_samples_ddim = sd_model.decode_first_stage(samples_ddim)\n",
+ " if offload_model:\n",
+ " sd_model.first_stage_model.cpu()\n",
+ " printf('x_samples_ddim', x_samples_ddim.min(), x_samples_ddim.max(), x_samples_ddim.std(), x_samples_ddim.mean())\n",
+ " scale_raw_sample = False\n",
+ " if scale_raw_sample:\n",
+ " m = x_samples_ddim.mean()\n",
+ " x_samples_ddim-=m;\n",
+ " r = (x_samples_ddim.max()-x_samples_ddim.min())/2\n",
+ "\n",
+ " x_samples_ddim/=r\n",
+ " x_samples_ddim+=m;\n",
+ " if VERBOSE:printf('x_samples_ddim scaled', x_samples_ddim.min(), x_samples_ddim.max(), x_samples_ddim.std(), x_samples_ddim.mean())\n",
+ "\n",
+ " assert not x_samples_ddim.isnan().any(), \"\"\"\n",
+ "Error: NaN encountered in VAE decode. You will get a black image.\n",
+ "\n",
+ "To avoid this you can try:\n",
+ "1) enabling no_half_vae in load model cell, then re-running it.\n",
+ "2) disabling tiled vae and re-running tiled vae cell\n",
+ "3) If you are using SDXL, you can try keeping no_half_vae off,\n",
+ "then downloading and using this vae checkpoint as your external vae_ckpt: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl_vae.safetensors\"\"\"\n",
+ "\n",
+ " all_samples.append(x_samples_ddim)\n",
+ " return all_samples, samples_ddim, depth_img\n",
+ "\n",
+ "def get_batch(keys, value_dict, N, device=\"cuda\"):\n",
+ " # Hardcoded demo setups; might undergo some changes in the future\n",
+ "\n",
+ " batch = {}\n",
+ " batch_uc = {}\n",
+ " for key in keys:\n",
+ " if key == \"txt\":\n",
+ " if len(value_dict[\"prompt\"]) != N[0]:\n",
+ " batch[\"txt\"] = (\n",
+ " np.repeat([value_dict[\"prompt\"]], repeats=math.prod(N))\n",
+ " .reshape(N)\n",
+ " .tolist()\n",
+ " )\n",
+ " else: batch[\"txt\"] = value_dict[\"prompt\"]\n",
+ " batch_uc[\"txt\"] = (\n",
+ " np.repeat([value_dict[\"negative_prompt\"]], repeats=math.prod(N))\n",
+ " .reshape(N)\n",
+ " .tolist()\n",
+ " )\n",
+ " elif key == \"original_size_as_tuple\":\n",
+ " batch[\"original_size_as_tuple\"] = (\n",
+ " torch.tensor([value_dict[\"orig_height\"], value_dict[\"orig_width\"]])\n",
+ " .to(device)\n",
+ " .repeat(*N, 1)\n",
+ " )\n",
+ " elif key == \"crop_coords_top_left\":\n",
+ " batch[\"crop_coords_top_left\"] = (\n",
+ " torch.tensor(\n",
+ " [value_dict[\"crop_coords_top\"], value_dict[\"crop_coords_left\"]]\n",
+ " )\n",
+ " .to(device)\n",
+ " .repeat(*N, 1)\n",
+ " )\n",
+ " elif key == \"aesthetic_score\":\n",
+ " batch[\"aesthetic_score\"] = (\n",
+ " torch.tensor([value_dict[\"aesthetic_score\"]]).to(device).repeat(*N, 1)\n",
+ " )\n",
+ " batch_uc[\"aesthetic_score\"] = (\n",
+ " torch.tensor([value_dict[\"negative_aesthetic_score\"]])\n",
+ " .to(device)\n",
+ " .repeat(*N, 1)\n",
+ " )\n",
+ "\n",
+ " elif key == \"target_size_as_tuple\":\n",
+ " batch[\"target_size_as_tuple\"] = (\n",
+ " torch.tensor([value_dict[\"target_height\"], value_dict[\"target_width\"]])\n",
+ " .to(device)\n",
+ " .repeat(*N, 1)\n",
+ " )\n",
+ " else:\n",
+ " batch[key] = value_dict[key]\n",
+ "\n",
+ " for key in batch.keys():\n",
+ " if key not in batch_uc and isinstance(batch[key], torch.Tensor):\n",
+ " batch_uc[key] = torch.clone(batch[key])\n",
+ " return batch, batch_uc\n",
+ "\n",
+ "def get_unique_embedder_keys_from_conditioner(conditioner):\n",
+ " return list(set([x.input_key for x in conditioner.embedders]))\n",
+ "\n",
+ "diffusion_model = \"stable_diffusion\"\n",
+ "diffusion_sampling_mode = 'ddim'\n",
+ "\n",
+ "normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
+ "lpips_model = lpips.LPIPS(net='vgg').to(device)\n",
+ "\n",
+ "#todo\n",
+ "#offload face model\n",
+ "#offload canny, mlsd\n",
+ "def get_controlnet_annotations(model_version, W, H, models, controlnet_sources):\n",
+ " detected_maps = {}\n",
+ " #controlnet sources have image paths\n",
+ " prev_frame = controlnet_sources['prev_frame']\n",
+ " if prev_frame is None:\n",
+ " controlnet_sources.pop('next_frame')\n",
+ " elif not os.path.exists(controlnet_sources['next_frame']):\n",
+ " if 'control_sd15_temporal_depth' in controlnet_multimodel_inferred.keys():\n",
+ " controlnet_sources['next_frame'] = controlnet_sources['control_sd15_temporal_depth']\n",
+ "\n",
+ " init_image = controlnet_sources['init_image']\n",
+ " init_image = np.array(Image.open(controlnet_sources['init_image']).convert('RGB').resize(size=(W,H)))\n",
+ " control_inpainting_mask = controlnet_sources['control_inpainting_mask']\n",
+ "\n",
+ " #todo: check that input images are hwc3 and int8, because loading grayscale images may return hw and 0-1 float images\n",
+ " shuffle_source = controlnet_sources['shuffle_source']\n",
+ " models_out = copy.deepcopy(models)\n",
+ "\n",
+ " controlnet_sources_pil = dict([(o,np.array(Image.open(controlnet_sources[o]).convert('RGB').resize(size=(W,H)))) for o in models])\n",
+ "\n",
+ " models_to_preprocess = [o for o in models if controlnet_multimodel_inferred[o]['preprocess']]\n",
+ "\n",
+ " for control_key in models:\n",
+ " if control_key in [\"control_sd15_inpaint_softedge\"]:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model: apply_softedge.netNetwork.cuda()\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_softedge(resize_image(input_image, detect_resolution))\n",
+ " detected_map = HWC3(detected_map)\n",
+ " softedge_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " # detected_maps[control_key] = detected_map\n",
+ " if offload_model: apply_softedge.netNetwork.cpu()\n",
+ "\n",
+ " if control_inpainting_mask is None:\n",
+ " if VERBOSE: print(f'skipping {control_key} as control_inpainting_mask is None')\n",
+ " models_out = [o for o in models_out if o != control_key]\n",
+ " if VERBOSE: print('models after removing temp', models_out)\n",
+ " else:\n",
+ " print('Applying inpaint'\n",
+ " )\n",
+ " control_inpainting_mask *= 255\n",
+ " control_inpainting_mask = 255 - control_inpainting_mask\n",
+ " if VERBOSE: print('control_inpainting_mask',control_inpainting_mask.shape,\n",
+ " control_inpainting_mask.min(), control_inpainting_mask.max())\n",
+ " if VERBOSE: print('control_inpainting_mask', (control_inpainting_mask[...,0] == control_inpainting_mask[...,0]).mean())\n",
+ " img = init_image #use prev warped frame\n",
+ " h, w, C = img.shape\n",
+ " #contolnet inpaint mask - H, W, 0-255 np array\n",
+ " detected_mask = cv2.resize(control_inpainting_mask[:, :, 0], (w, h), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_map = img.astype(np.float32).copy()\n",
+ " detected_map[detected_mask > 127] = -255.0 # use -1 as inpaint value\n",
+ " detected_map = np.where(detected_map == -255, -1*softedge_map, detected_map)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in ['control_sd15_temporal_depth']:\n",
+ " if prev_frame is not None:\n",
+ " #no detect resolution\n",
+ " #no preprocessign option\n",
+ " #source options - prev raw, prev stylized\n",
+ " if offload_model:\n",
+ " apply_depth.model.cuda()\n",
+ "\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model: apply_depth.model.cuda()\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = detected_map\n",
+ " else:\n",
+ " input_image = HWC3(np.array(input_image)); #print(type(input_image))\n",
+ " # Image.fromarray(input_image.astype('uint8')).save('./test.jpg')\n",
+ " input_image = resize_image(input_image, detect_resolution); #print((input_image.dtype), input_image.shape, input_image.size)\n",
+ " with torch.cuda.amp.autocast(False), torch.no_grad():\n",
+ " detected_map = apply_depth(input_image)\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if 'next_frame' in controlnet_sources_pil.keys():\n",
+ " next_frame = controlnet_sources_pil['next_frame']\n",
+ " input_image = HWC3(np.array(next_frame)); #print(type(input_image))\n",
+ "\n",
+ " input_image = resize_image(input_image, detect_resolution); #print((input_image.dtype), input_image.shape, input_image.size)\n",
+ " with torch.cuda.amp.autocast(False), torch.no_grad():\n",
+ " detected_map = apply_depth(input_image)\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_maps[control_key][...,2] = detected_map[...,0]\n",
+ "\n",
+ " if offload_model: apply_depth.model.cpu()\n",
+ "\n",
+ " detected_map = np.array(Image.open(prev_frame).resize(size=(W,H)).convert('L')); #print(type(input_image), 'input_image', input_image.shape)\n",
+ " detected_maps[control_key][...,0]= detected_map\n",
+ " # Image.fromarray(input_image.astype('uint8')).save('./temp_test.jpg')\n",
+ " else:\n",
+ " if VERBOSE: print('skipping temporalnet as prev_frame is None')\n",
+ " models_out = [o for o in models_out if o != control_key]\n",
+ " if VERBOSE: print('models after removing temp', models_out)\n",
+ "\n",
+ " if control_key in ['control_sd15_temporalnet', 'control_sdxl_temporalnet_v1']:\n",
+ " #no detect resolution\n",
+ " #no preprocessign option\n",
+ " #source options - prev raw, prev stylized\n",
+ " if prev_frame is not None:\n",
+ " detected_map = np.array(Image.open(prev_frame).resize(size=(W,H))); #print(type(input_image), 'input_image', input_image.shape)\n",
+ " detected_maps[control_key] = detected_map\n",
+ " else:\n",
+ " if VERBOSE: print('skipping temporalnet as prev_frame is None')\n",
+ " models_out = [o for o in models_out if o != control_key]\n",
+ " if VERBOSE: print('models after removing temp', models_out)\n",
+ "\n",
+ " if control_key == 'control_sd15_face':\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ "\n",
+ " else:\n",
+ " input_image = resize_image(input_image,\n",
+ " detect_resolution)\n",
+ " detected_map = generate_annotation(input_image, max_faces)\n",
+ " if detected_map is not None:\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_maps[control_key] = detected_map\n",
+ " else:\n",
+ " if VERBOSE: print('No faces detected')\n",
+ " models_out = [o for o in models_out if o != control_key ]\n",
+ " if VERBOSE: print('models after removing face', models_out)\n",
+ "\n",
+ " if control_key == 'control_sd15_normal':\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model: apply_depth.model.cuda()\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image[:, :, ::-1]\n",
+ " else:\n",
+ " input_image = HWC3(np.array(input_image)); print(type(input_image))\n",
+ " input_image = resize_image(input_image, detect_resolution); print((input_image.dtype))\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " _,detected_map = apply_depth(input_image, bg_th=bg_threshold)\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)[:, :, ::-1]\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if offload_model: apply_depth.model.cpu()\n",
+ "\n",
+ " if control_key in ['control_sd15_normalbae',\"control_sd21_normalbae\"]:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model: apply_normal.model.cuda()\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ "\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image[:, :, ::-1]\n",
+ " else:\n",
+ " input_image = HWC3(np.array(input_image)); print(type(input_image))\n",
+ " input_image = resize_image(input_image, detect_resolution); print((input_image.dtype))\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_normal(input_image)\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)[:, :, ::-1]\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if offload_model: apply_normal.model.cpu()\n",
+ "\n",
+ " if control_key in [\"control_sd21_depth\",'control_sd15_depth','control_sdxl_depth',\n",
+ " 'control_sdxl_lora_128_depth','control_sdxl_lora_256_depth']:\n",
+ " if offload_model:\n",
+ " apply_depth.model.cuda()\n",
+ "\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model: apply_depth.model.cuda()\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ "\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(np.array(input_image))\n",
+ " # Image.fromarray(input_image.astype('uint8')).save('./test.jpg')\n",
+ " input_image = resize_image(input_image, detect_resolution)\n",
+ "\n",
+ " if control_sd15_depth_detector == 'Midas':\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map,_ = apply_depth(input_image)\n",
+ " if control_sd15_depth_detector == 'Zoe':\n",
+ " with torch.cuda.amp.autocast(False), torch.no_grad():\n",
+ " detected_map = apply_depth(input_image)\n",
+ " #print('dectected map depth',detected_map.shape, detected_map.min(), detected_map.max(), detected_map.mean(), detected_map.std(), )\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if offload_model: apply_depth.model.cpu()\n",
+ "\n",
+ " if control_key in [\"control_sd21_canny\",'control_sd15_canny','control_sdxl_canny', 'control_sdxl_lora_128_canny','control_sdxl_lora_256_canny']:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ "\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " detected_map = apply_canny(resize_image(input_image, detect_resolution), low_threshold, high_threshold)\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd21_softedge\",'control_sd15_softedge','control_sdxl_softedge', 'control_sdxl_lora_128_softedge', 'control_sdxl_lora_256_softedge']:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model: apply_softedge.netNetwork.cuda()\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_softedge(resize_image(input_image, detect_resolution))\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if offload_model: apply_softedge.netNetwork.cpu()\n",
+ "\n",
+ " if control_key == 'control_sd15_mlsd':\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold)\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd21_openpose\",'control_sd15_openpose','control_sdxl_openpose']:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " if offload_model:\n",
+ " if pose_detector == 'openpose':\n",
+ " apply_openpose.body_estimation.model.cuda()\n",
+ " apply_openpose.hand_estimation.model.cuda()\n",
+ " apply_openpose.face_estimation.model.cuda()\n",
+ "\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " resized_img = resize_image(input_image,\n",
+ " detect_resolution)\n",
+ " try:\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " if pose_detector == 'openpose':\n",
+ " detected_map = apply_openpose(resized_img, hand_and_face=control_sd15_openpose_hands_face)\n",
+ " elif pose_detector == 'dw_pose':\n",
+ " detected_map = apply_openpose(resized_img)\n",
+ " except:\n",
+ " detected_map = np.zeros_like(resized_img)\n",
+ "\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if offload_model:\n",
+ " if pose_detector == 'openpose':\n",
+ " apply_openpose.body_estimation.model.cpu()\n",
+ " apply_openpose.hand_estimation.model.cpu()\n",
+ " apply_openpose.face_estimation.model.cpu()\n",
+ "\n",
+ " if control_key in ['control_sd15_scribble',\"control_sd21_scribble\"]:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ "\n",
+ " if offload_model: apply_scribble.netNetwork.cuda()\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_scribble(resize_image(input_image, detect_resolution))\n",
+ "\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_map = nms(detected_map, 127, 3.0)\n",
+ " detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)\n",
+ " detected_map[detected_map > 4] = 255\n",
+ " detected_map[detected_map < 255] = 0\n",
+ " detected_maps[control_key] = detected_map\n",
+ " if offload_model: apply_scribble.netNetwork.cpu()\n",
+ "\n",
+ " if control_key in [\"control_sd21_seg\", \"control_sd15_seg\",'control_sdxl_seg']:\n",
+ "\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_seg(resize_image(input_image, detect_resolution))\n",
+ "\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd21_lineart\", \"control_sd15_lineart\"]:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred['control_sd15_lineart'][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_lineart(resize_image(input_image, detect_resolution), coarse=control_sd15_lineart_coarse)\n",
+ "\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd15_lineart_anime\"]:\n",
+ " #has detect res\n",
+ " #has preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if not controlnet_multimodel_inferred[control_key][\"preprocess\"]:\n",
+ " detected_maps[control_key] = input_image\n",
+ " else:\n",
+ " input_image = HWC3(input_image)\n",
+ " with torch.cuda.amp.autocast(True), torch.no_grad():\n",
+ " detected_map = apply_lineart_anime(resize_image(input_image, detect_resolution))\n",
+ "\n",
+ " detected_map = HWC3(detected_map)\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd15_ip2p\"]:\n",
+ " #no detect res\n",
+ " #no preprocess option\n",
+ " #ip2p has no separate detect resolution\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " input_image = HWC3(input_image)\n",
+ " detected_map = input_image.copy()\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd15_tile\",\"control_sd15_qr\", \"control_sd21_qr\"]:\n",
+ " #no detect res\n",
+ " #no preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " input_image = HWC3(input_image)\n",
+ " detected_map = input_image.copy()\n",
+ " detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd15_shuffle\"]:\n",
+ " #shuffle has no separate detect resolution\n",
+ " #no preprocess option\n",
+ " shuffle_image = np.array(Image.open(shuffle_source))\n",
+ " shuffle_image = HWC3(shuffle_image)\n",
+ " shuffle_image = cv2.resize(shuffle_image, (W, H), interpolation=cv2.INTER_NEAREST)\n",
+ "\n",
+ " dH, dW, dC = shuffle_image.shape\n",
+ " detected_map = apply_shuffle(shuffle_image, w=dW, h=dH, f=256)\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ " if control_key in [\"control_sd15_inpaint\"]:\n",
+ " #defaults to init image (stylized prev frame)\n",
+ " #inpaint has no separate detect resolution\n",
+ " #no preprocess option\n",
+ " input_image = controlnet_sources_pil[control_key]\n",
+ " detect_resolution = controlnet_multimodel_inferred[control_key][\"detect_resolution\"]\n",
+ " if control_inpainting_mask is None:\n",
+ " if VERBOSE: print('skipping control_sd15_inpaint as control_inpainting_mask is None')\n",
+ " models_out = [o for o in models_out if o != control_key]\n",
+ " if VERBOSE: print('models after removing temp', models_out)\n",
+ " else:\n",
+ " control_inpainting_mask *= 255\n",
+ " control_inpainting_mask = 255 - control_inpainting_mask\n",
+ " if VERBOSE: print('control_inpainting_mask',control_inpainting_mask.shape,\n",
+ " control_inpainting_mask.min(), control_inpainting_mask.max())\n",
+ " if VERBOSE: print('control_inpainting_mask', (control_inpainting_mask[...,0] == control_inpainting_mask[...,0]).mean())\n",
+ " img = input_image\n",
+ " h, w, C = img.shape\n",
+ " #contolnet inpaint mask - H, W, 0-255 np array\n",
+ " detected_mask = cv2.resize(control_inpainting_mask[:, :, 0], (w, h), interpolation=cv2.INTER_LINEAR)\n",
+ " detected_map = img.astype(np.float32).copy()\n",
+ " detected_map[detected_mask > 127] = -255.0 # use -1 as inpaint value\n",
+ " detected_maps[control_key] = detected_map\n",
+ "\n",
+ "\n",
+ "\n",
+ " return detected_maps, models_out\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SettingsTop"
+ },
+ "source": [
+ "# 2. Settings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "BasicSettings"
+ },
+ "outputs": [],
+ "source": [
+ "#@markdown ####**Basic Settings:**\n",
+ "\n",
+ "cell_name = 'basic_settings'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "batch_name = 'stable_warpfusion_0.24.0' #@param{type: 'string'}\n",
+ "steps = 50\n",
+ "##@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}\n",
+ "# stop_early = 0 #@param{type: 'number'}\n",
+ "stop_early = 0\n",
+ "stop_early = min(steps-1,stop_early)\n",
+ "\n",
+ "\n",
+ "clip_guidance_scale = 0 #\n",
+ "tv_scale = 0\n",
+ "range_scale = 0\n",
+ "cutn_batches = 4\n",
+ "skip_augs = False\n",
+ "\n",
+ "#@markdown ---\n",
+ "\n",
+ "#@markdown ####**Init Settings:**\n",
+ "init_image = \"\" #@param{type: 'string'}\n",
+ "init_scale = 0\n",
+ "##@param{type: 'integer'}\n",
+ "skip_steps = 25\n",
+ "##@param{type: 'integer'}\n",
+ "##@markdown *Make sure you set skip_steps to ~50% of your steps if you want to use an init image.\\\n",
+ "##@markdown A good init_scale for Stable Diffusion is 0*\n",
+ "\n",
+ "\n",
+ "#Update Model Settings\n",
+ "timestep_respacing = f'ddim{steps}'\n",
+ "diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
+ "\n",
+ "\n",
+ "#Make folder for batch\n",
+ "batchFolder = f'{outDirPath}/{batch_name}'\n",
+ "createPath(batchFolder)\n",
+ "\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Bsax1iBhv4bt"
+ },
+ "outputs": [],
+ "source": [
+ "#@title ##Video Input Settings:\n",
+ "cell_name = 'animation_settings'\n",
+ "check_execution(cell_name)\n",
+ "executed_cells[cell_name] = True\n",
+ "#@markdown ###**Output Size Settings**\n",
+ "#@markdown Specify desired output size here [width,height] or use a single number to resize the frame keeping aspect ratio.\\\n",
+ "#@markdown Don't forget to rerun all steps after changing the width height (including forcing optical flow generation)\n",
+ "width_height = 1280#@param{type: 'raw'}\n",
+ "#Get corrected sizes\n",
+ "#@markdown Make sure the resolution is divisible by that number. The Default 64 is the most stable.\n",
+ "\n",
+ "force_multiple_of = \"64\" #@param [8,64]\n",
+ "force_multiple_of = int(force_multiple_of)\n",
+ "if isinstance(width_height, list):\n",
+ " width_height = [int(o) for o in width_height]\n",
+ " side_x = (width_height[0]//force_multiple_of)*force_multiple_of;\n",
+ " side_y = (width_height[1]//force_multiple_of)*force_multiple_of;\n",
+ " if side_x != width_height[0] or side_y != width_height[1]:\n",
+ " print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of {force_multiple_of}.')\n",
+ " width_height = (side_x, side_y)\n",
+ "else:\n",
+ " width_height = int(width_height)\n",
+ "\n",
+ "\n",
+ "\n",
+ "cell_name = 'video_input_settings'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "animation_mode = 'Video Input'\n",
+ "import os, platform\n",
+ "if platform.system() != 'Linux' and not os.path.exists(\"ffmpeg.exe\"):\n",
+ " print(\"Warning! ffmpeg.exe not found. Please download ffmpeg and place it in current working dir.\")\n",
+ "\n",
+ "#@markdown ###**Video Input Settings**\n",
+ "#@markdown ---\n",
+ "\n",
+ "video_source = 'video_init' #@param ['video_init', 'looped_init_image']\n",
+ "\n",
+ "#@markdown Use video_init to process your video file.\\\n",
+ "#@markdown If you don't have a video file, you can looped_init_image to create a looping video from single init_image\\\n",
+ "#@markdown Use this if you just want to test settings. This will create a small video (1 sec = 24 frames)\\\n",
+ "#@markdown This way you will be able to iterate faster without the need to process flow maps for a long final video before even getting to testing prompts.\n",
+ "looped_video_duration_sec = 2 #@param {'type':'number'}\n",
+ "\n",
+ "video_init_path = \"/content/drive/MyDrive/vids/init/y2mate.com - Jennifer Connelly HOT 90s GIRLS_1080p.mp4\" #@param {type: 'string'}\n",
+ "\n",
+ "if video_source=='looped_init_image':\n",
+ " actual_size = Image.open(init_image).size\n",
+ " if isinstance(width_height, int):\n",
+ " width_height = fit_size(actual_size, width_height)\n",
+ "\n",
+ " force_multiple_of = int(force_multiple_of)\n",
+ " side_x = (width_height[0]//force_multiple_of)*force_multiple_of;\n",
+ " side_y = (width_height[1]//force_multiple_of)*force_multiple_of;\n",
+ " if side_x != width_height[0] or side_y != width_height[1]:\n",
+ " print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of {force_multiple_of}.')\n",
+ " width_height = (side_x, side_y)\n",
+ " subprocess.run(['ffmpeg', '-loop', '1', '-i', init_image, '-c:v', 'libx264', '-t', str(looped_video_duration_sec), '-pix_fmt',\n",
+ " 'yuv420p', '-vf', f'scale={side_x}:{side_y}', f\"{root_dir}/out.mp4\", '-y'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " print('Video saved to ', f\"{root_dir}/out.mp4\")\n",
+ " video_init_path = f\"{root_dir}/out.mp4\"\n",
+ "\n",
+ "extract_nth_frame = 1#@param {type: 'number'}\n",
+ "reverse = False #@param {type: 'boolean'}\n",
+ "no_vsync = True #@param {type: 'boolean'}\n",
+ "#@markdown *Specify frame range. end_frame=0 means fill the end of video*\n",
+ "start_frame = 0#@param {type: 'number'}\n",
+ "end_frame = 0#@param {type: 'number'}\n",
+ "end_frame_orig = end_frame\n",
+ "if end_frame<=0 or end_frame==None: end_frame = 99999999999999999999999999999\n",
+ "#@markdown ####**Separate guiding video** (optical flow source):\n",
+ "#@markdown Leave blank to use the first video.\n",
+ "flow_video_init_path = \"\" #@param {type: 'string'}\n",
+ "flow_extract_nth_frame = 1#@param {type: 'number'}\n",
+ "if flow_video_init_path == '':\n",
+ " flow_video_init_path = None\n",
+ "#@markdown ####**Image Conditioning Video Source**:\n",
+ "#@markdown Used together with image-conditioned models, like controlnet, depth, or inpainting model.\n",
+ "#@markdown You can use your own video as depth mask or as inpaiting mask.\n",
+ "cond_video_path = \"\" #@param {type: 'string'}\n",
+ "cond_extract_nth_frame = 1#@param {type: 'number'}\n",
+ "if cond_video_path == '':\n",
+ " cond_video_path = None\n",
+ "\n",
+ "#@markdown ####**Colormatching Video Source**:\n",
+ "#@markdown Used as colormatching source. Specify image or video.\n",
+ "color_video_path = \"\" #@param {type: 'string'}\n",
+ "color_extract_nth_frame = 1#@param {type: 'number'}\n",
+ "if color_video_path == '':\n",
+ " color_video_path = None\n",
+ "#@markdown Enable to store frames, flow maps, alpha maps on drive\n",
+ "store_frames_on_google_drive = False #@param {type: 'boolean'}\n",
+ "video_init_seed_continuity = False\n",
+ "\n",
+ "def extractFrames(video_path, output_path, nth_frame, start_frame, end_frame):\n",
+ " createPath(output_path)\n",
+ " print(f\"Exporting Video Frames (1 every {nth_frame})...\")\n",
+ " try:\n",
+ " for f in [o.replace('\\\\','/') for o in glob(output_path+'/*.jpg')]:\n",
+ " pathlib.Path(f).unlink()\n",
+ " except:\n",
+ " print('error deleting frame ', f)\n",
+ " vf = f'select=between(n\\\\,{start_frame}\\\\,{end_frame}) , select=not(mod(n\\\\,{nth_frame}))'\n",
+ " if reverse: vf+=',reverse'\n",
+ " if no_vsync: vsync='0'\n",
+ " else: vsync = 'vfr'\n",
+ " if os.path.exists(video_path):\n",
+ " try:\n",
+ " subprocess.run(['ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}',\n",
+ " '-vsync', vsync, '-q:v', '2', '-loglevel', 'error', '-stats',\n",
+ " f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " except:\n",
+ " subprocess.run(['ffmpeg.exe', '-i', f'{video_path}', '-vf', f'{vf}',\n",
+ " '-vsync', vsync, '-q:v', '2', '-loglevel', 'error', '-stats',\n",
+ " f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ "\n",
+ " else:\n",
+ " sys.exit(f'\\nERROR!\\n\\nVideo not found: {video_path}.\\nPlease check your video path.\\n')\n",
+ "\n",
+ "import cv2\n",
+ "def get_fps(video_init_path):\n",
+ " if os.path.exists(video_init_path):\n",
+ " if os.path.isfile(video_init_path):\n",
+ " cap = cv2.VideoCapture(video_init_path)\n",
+ " fps = cap.get(cv2.CAP_PROP_FPS)\n",
+ " cap.release()\n",
+ " return fps\n",
+ " return -1\n",
+ "\n",
+ "if animation_mode == 'Video Input':\n",
+ " detected_fps = get_fps(video_init_path)\n",
+ " postfix = f'{generate_file_hash(video_init_path)[:10]}-{detected_fps:.6}_{start_frame}_{end_frame_orig}_{extract_nth_frame}'\n",
+ " print(f'Detected video fps of {detected_fps:.6}. With extract_nth_frame={extract_nth_frame} the suggested export fps would be {detected_fps/extract_nth_frame:.6}.')\n",
+ " if flow_video_init_path:\n",
+ " flow_postfix = f'{generate_file_hash(flow_video_init_path)[:10]}_{flow_extract_nth_frame}'\n",
+ " if store_frames_on_google_drive: #suggested by Chris the Wizard#8082 at discord\n",
+ " videoFramesFolder = f'{batchFolder}/videoFrames/{postfix}'\n",
+ " flowVideoFramesFolder = f'{batchFolder}/flowVideoFrames/{flow_postfix}' if flow_video_init_path else videoFramesFolder\n",
+ " condVideoFramesFolder = f'{batchFolder}/condVideoFrames'\n",
+ " colorVideoFramesFolder = f'{batchFolder}/colorVideoFrames'\n",
+ " controlnetDebugFolder = f'{batchFolder}/controlnetDebug'\n",
+ " recNoiseCacheFolder = f'{batchFolder}/recNoiseCache'\n",
+ "\n",
+ " else:\n",
+ " videoFramesFolder = f'{root_dir}/videoFrames/{postfix}'\n",
+ " flowVideoFramesFolder = f'{root_dir}/flowVideoFrames/{flow_postfix}' if flow_video_init_path else videoFramesFolder\n",
+ " condVideoFramesFolder = f'{root_dir}/condVideoFrames'\n",
+ " colorVideoFramesFolder = f'{root_dir}/colorVideoFrames'\n",
+ " controlnetDebugFolder = f'{root_dir}/controlnetDebug'\n",
+ " recNoiseCacheFolder = f'{root_dir}/recNoiseCache'\n",
+ "\n",
+ " if not is_colab:\n",
+ " videoFramesFolder = f'{batchFolder}/videoFrames/{postfix}'\n",
+ " flowVideoFramesFolder = f'{batchFolder}/flowVideoFrames/{flow_postfix}' if flow_video_init_path else videoFramesFolder\n",
+ " condVideoFramesFolder = f'{batchFolder}/condVideoFrames'\n",
+ " colorVideoFramesFolder = f'{batchFolder}/colorVideoFrames'\n",
+ " controlnetDebugFolder = f'{batchFolder}/controlnetDebug'\n",
+ " recNoiseCacheFolder = f'{batchFolder}/recNoiseCache'\n",
+ "\n",
+ " os.makedirs(controlnetDebugFolder, exist_ok=True)\n",
+ " os.makedirs(recNoiseCacheFolder, exist_ok=True)\n",
+ "\n",
+ " extractFrames(video_init_path, videoFramesFolder, extract_nth_frame, start_frame, end_frame)\n",
+ " if flow_video_init_path:\n",
+ " print(flow_video_init_path, flowVideoFramesFolder, flow_extract_nth_frame)\n",
+ " extractFrames(flow_video_init_path, flowVideoFramesFolder, flow_extract_nth_frame, start_frame, end_frame)\n",
+ "\n",
+ " if cond_video_path:\n",
+ " print(cond_video_path, condVideoFramesFolder, cond_extract_nth_frame)\n",
+ " extractFrames(cond_video_path, condVideoFramesFolder, cond_extract_nth_frame, start_frame, end_frame)\n",
+ "\n",
+ " if color_video_path:\n",
+ " try:\n",
+ " os.makedirs(colorVideoFramesFolder, exist_ok=True)\n",
+ " Image.open(color_video_path).save(os.path.join(colorVideoFramesFolder,'000001.jpg'))\n",
+ " except:\n",
+ " print(color_video_path, colorVideoFramesFolder, color_extract_nth_frame)\n",
+ " extractFrames(color_video_path, colorVideoFramesFolder, color_extract_nth_frame, start_frame, end_frame)\n",
+ "\n",
+ "def fit_size(size,maxsize=512):\n",
+ " maxdim = max(size)\n",
+ " ratio = maxsize/maxdim\n",
+ " x,y = size\n",
+ " size = (int(x*ratio)),(int(y*ratio))\n",
+ " return size\n",
+ "\n",
+ "actual_size = Image.open(sorted(glob(videoFramesFolder+'/*.*'))[0]).size\n",
+ "if isinstance(width_height, int):\n",
+ " width_height = fit_size(actual_size, width_height)\n",
+ "\n",
+ "force_multiple_of = int(force_multiple_of)\n",
+ "side_x = (width_height[0]//force_multiple_of)*force_multiple_of;\n",
+ "side_y = (width_height[1]//force_multiple_of)*force_multiple_of;\n",
+ "if side_x != width_height[0] or side_y != width_height[1]:\n",
+ " print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of {force_multiple_of}.')\n",
+ "width_height = (side_x, side_y)\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "gZrXG3Vpfijs"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Video Masking\n",
+ "\n",
+ "cell_name = 'video_masking'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "#@markdown Generate background mask from your init video or use a video as a mask\n",
+ "mask_source = 'init_video' #@param ['init_video','mask_video']\n",
+ "#@markdown Check to rotoscope the video and create a mask from it. If unchecked, the raw monochrome video will be used as a mask.\n",
+ "extract_background_mask = False #@param {'type':'boolean'}\n",
+ "#@markdown Specify path to a mask video for mask_video mode.\n",
+ "mask_video_path = '' #@param {'type':'string'}\n",
+ "if extract_background_mask:\n",
+ " os.chdir(root_dir)\n",
+ " !python -m pip -q install av pims\n",
+ " gitclone('https://github.com/Sxela/RobustVideoMattingCLI')\n",
+ " if mask_source == 'init_video':\n",
+ " videoFramesAlpha = videoFramesFolder+'Alpha'\n",
+ " createPath(videoFramesAlpha)\n",
+ " !python \"{root_dir}/RobustVideoMattingCLI/rvm_cli.py\" --input_path \"{videoFramesFolder}\" --output_alpha \"{root_dir}/alpha.mp4\"\n",
+ " extractFrames(f\"{root_dir}/alpha.mp4\", f\"{videoFramesAlpha}\", 1, 0, 999999999)\n",
+ " if mask_source == 'mask_video':\n",
+ " videoFramesAlpha = videoFramesFolder+'Alpha'\n",
+ " createPath(videoFramesAlpha)\n",
+ " maskVideoFrames = videoFramesFolder+'Mask'\n",
+ " createPath(maskVideoFrames)\n",
+ " extractFrames(mask_video_path, f\"{maskVideoFrames}\", extract_nth_frame, start_frame, end_frame)\n",
+ " !python \"{root_dir}/RobustVideoMattingCLI/rvm_cli.py\" --input_path \"{maskVideoFrames}\" --output_alpha \"{root_dir}/alpha.mp4\"\n",
+ " extractFrames(f\"{root_dir}/alpha.mp4\", f\"{videoFramesAlpha}\", 1, 0, 999999999)\n",
+ "else:\n",
+ " if mask_source == 'init_video':\n",
+ " videoFramesAlpha = videoFramesFolder\n",
+ " if mask_source == 'mask_video':\n",
+ " videoFramesAlpha = videoFramesFolder+'Alpha'\n",
+ " createPath(videoFramesAlpha)\n",
+ " extractFrames(mask_video_path, f\"{videoFramesAlpha}\", extract_nth_frame, start_frame, end_frame)\n",
+ " #extract video\n",
+ "\n",
+ "\n",
+ "executed_cells[cell_name] = True\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ycrPKG1G3hY0"
+ },
+ "source": [
+ "# Optical map settings\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "-SkN-otqgT_Q"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Generate optical flow and consistency maps\n",
+ "#@markdown Run once per init video and width_height setting.\n",
+ "#if you're running locally, just restart this runtime, no need to edit PIL files.\n",
+ "\n",
+ "cell_name = 'generate_optical_flow'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "flow_warp = True\n",
+ "check_consistency = True\n",
+ "force_flow_generation = False #@param {type:'boolean'}\n",
+ "\n",
+ "use_legacy_cc = False #@param{'type':'boolean'}\n",
+ "\n",
+ "#@title Setup Optical Flow\n",
+ "##@markdown Run once per session. Doesn't download again if model path exists.\n",
+ "##@markdown Use force download to reload raft models if needed\n",
+ "force_download = False #\\@param {type:'boolean'}\n",
+ "# import wget\n",
+ "import zipfile, shutil\n",
+ "\n",
+ "if (os.path.exists(f'{root_dir}/raft')) and force_download:\n",
+ " try:\n",
+ " shutil.rmtree(f'{root_dir}/raft')\n",
+ " except:\n",
+ " print('error deleting existing RAFT model')\n",
+ "if (not (os.path.exists(f'{root_dir}/raft'))) or force_download:\n",
+ " os.chdir(root_dir)\n",
+ " gitclone('https://github.com/Sxela/WarpFusion')\n",
+ "else:\n",
+ " os.chdir(root_dir)\n",
+ " os.chdir('WarpFusion')\n",
+ " !git pull\n",
+ " os.chdir(root_dir)\n",
+ "\n",
+ "try:\n",
+ " from python_color_transfer.color_transfer import ColorTransfer, Regrain\n",
+ "except:\n",
+ " os.chdir(root_dir)\n",
+ " gitclone('https://github.com/pengbo-learn/python-color-transfer')\n",
+ "\n",
+ "os.chdir(root_dir)\n",
+ "sys.path.append('./python-color-transfer')\n",
+ "\n",
+ "if animation_mode == 'Video Input':\n",
+ " os.chdir(root_dir)\n",
+ " gitclone('https://github.com/Sxela/flow_tools')\n",
+ "\n",
+ "#@title Define color matching and brightness adjustment\n",
+ "os.chdir(f\"{root_dir}/python-color-transfer\")\n",
+ "from python_color_transfer.color_transfer import ColorTransfer, Regrain\n",
+ "os.chdir(root_path)\n",
+ "\n",
+ "PT = ColorTransfer()\n",
+ "RG = Regrain()\n",
+ "\n",
+ "def match_color(stylized_img, raw_img, opacity=1.):\n",
+ " if opacity > 0:\n",
+ " img_arr_ref = cv2.cvtColor(np.array(stylized_img).round().astype('uint8'),cv2.COLOR_RGB2BGR)\n",
+ " img_arr_in = cv2.cvtColor(np.array(raw_img).round().astype('uint8'),cv2.COLOR_RGB2BGR)\n",
+ " # img_arr_in = cv2.resize(img_arr_in, (img_arr_ref.shape[1], img_arr_ref.shape[0]), interpolation=cv2.INTER_CUBIC )\n",
+ " img_arr_col = PT.pdf_transfer(img_arr_in=img_arr_in, img_arr_ref=img_arr_ref)\n",
+ " img_arr_reg = RG.regrain (img_arr_in=img_arr_col, img_arr_col=img_arr_ref)\n",
+ " img_arr_reg = img_arr_reg*opacity+img_arr_in*(1-opacity)\n",
+ " img_arr_reg = cv2.cvtColor(img_arr_reg.round().astype('uint8'),cv2.COLOR_BGR2RGB)\n",
+ " return img_arr_reg\n",
+ " else: return raw_img\n",
+ "\n",
+ "from PIL import Image, ImageOps, ImageStat, ImageEnhance\n",
+ "\n",
+ "def get_stats(image):\n",
+ " stat = ImageStat.Stat(image)\n",
+ " brightness = sum(stat.mean) / len(stat.mean)\n",
+ " contrast = sum(stat.stddev) / len(stat.stddev)\n",
+ " return brightness, contrast\n",
+ "\n",
+ "#implemetation taken from https://github.com/lowfuel/progrockdiffusion\n",
+ "\n",
+ "def adjust_brightness(image):\n",
+ "\n",
+ " brightness, contrast = get_stats(image)\n",
+ " if brightness > high_brightness_threshold:\n",
+ " print(\" Brightness over threshold. Compensating!\")\n",
+ " filter = ImageEnhance.Brightness(image)\n",
+ " image = filter.enhance(high_brightness_adjust_ratio)\n",
+ " image = np.array(image)\n",
+ " image = np.where(image>high_brightness_threshold, image-high_brightness_adjust_fix_amount, image).clip(0,255).round().astype('uint8')\n",
+ " image = Image.fromarray(image)\n",
+ " if brightness < low_brightness_threshold:\n",
+ " print(\" Brightness below threshold. Compensating!\")\n",
+ " filter = ImageEnhance.Brightness(image)\n",
+ " image = filter.enhance(low_brightness_adjust_ratio)\n",
+ " image = np.array(image)\n",
+ " image = np.where(imagemax_brightness_threshold, image-high_brightness_adjust_fix_amount, image).clip(0,255).round().astype('uint8')\n",
+ " image = np.where(image BGR instead of RGB\n",
+ " ch_idx = 2-i if convert_to_bgr else i\n",
+ " flow_image[:,:,ch_idx] = np.floor(255 * col)\n",
+ " return flow_image\n",
+ "\n",
+ "\n",
+ "def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):\n",
+ " \"\"\"\n",
+ " Expects a two dimensional flow image of shape.\n",
+ " Args:\n",
+ " flow_uv (np.ndarray): Flow UV image of shape [H,W,2]\n",
+ " clip_flow (float, optional): Clip maximum of flow values. Defaults to None.\n",
+ " convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n",
+ " Returns:\n",
+ " np.ndarray: Flow visualization image of shape [H,W,3]\n",
+ " \"\"\"\n",
+ " assert flow_uv.ndim == 3, 'input flow must have three dimensions'\n",
+ " assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'\n",
+ " if clip_flow is not None:\n",
+ " flow_uv = np.clip(flow_uv, 0, clip_flow)\n",
+ " u = flow_uv[:,:,0]\n",
+ " v = flow_uv[:,:,1]\n",
+ " rad = np.sqrt(np.square(u) + np.square(v))\n",
+ " rad_max = np.max(rad)\n",
+ " epsilon = 1e-5\n",
+ " u = u / (rad_max + epsilon)\n",
+ " v = v / (rad_max + epsilon)\n",
+ " return flow_uv_to_colors(u, v, convert_to_bgr)\n",
+ "\n",
+ "\n",
+ "from torch import Tensor\n",
+ "\n",
+ "# if True:\n",
+ "if animation_mode == 'Video Input':\n",
+ " in_path = videoFramesFolder if not flow_video_init_path else flowVideoFramesFolder\n",
+ " flo_folder = in_path+'_out_flo_fwd'\n",
+ " #the main idea comes from neural-style-tf frame warping with optical flow maps\n",
+ " #https://github.com/cysmith/neural-style-tf\n",
+ " # path = f'{root_dir}/RAFT/core'\n",
+ " # import sys\n",
+ " # sys.path.append(f'{root_dir}/RAFT/core')\n",
+ " # %cd {path}\n",
+ "\n",
+ " # from utils.utils import InputPadder\n",
+ "\n",
+ " class InputPadder:\n",
+ " \"\"\" Pads images such that dimensions are divisible by 8 \"\"\"\n",
+ " def __init__(self, dims, mode='sintel'):\n",
+ " self.ht, self.wd = dims[-2:]\n",
+ " pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8\n",
+ " pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8\n",
+ " if mode == 'sintel':\n",
+ " self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]\n",
+ " else:\n",
+ " self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]\n",
+ "\n",
+ " def pad(self, *inputs):\n",
+ " return [F.pad(x, self._pad, mode='replicate') for x in inputs]\n",
+ "\n",
+ " def unpad(self,x):\n",
+ " ht, wd = x.shape[-2:]\n",
+ " c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]\n",
+ " return x[..., c[0]:c[1], c[2]:c[3]]\n",
+ "\n",
+ " # from raft import RAFT\n",
+ " import numpy as np\n",
+ " import argparse, PIL, cv2\n",
+ " from PIL import Image\n",
+ " from tqdm.notebook import tqdm\n",
+ " from glob import glob\n",
+ " import torch\n",
+ " import scipy.ndimage\n",
+ "\n",
+ " args2 = argparse.Namespace()\n",
+ " args2.small = False\n",
+ " args2.mixed_precision = True\n",
+ "\n",
+ " TAG_CHAR = np.array([202021.25], np.float32)\n",
+ "\n",
+ " def writeFlow(filename,uv,v=None):\n",
+ " \"\"\"\n",
+ " https://github.com/NVIDIA/flownet2-pytorch/blob/master/utils/flow_utils.py\n",
+ " Copyright 2017 NVIDIA CORPORATION\n",
+ "\n",
+ " Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ " you may not use this file except in compliance with the License.\n",
+ " You may obtain a copy of the License at\n",
+ "\n",
+ " http://www.apache.org/licenses/LICENSE-2.0\n",
+ "\n",
+ " Unless required by applicable law or agreed to in writing, software\n",
+ " distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ " See the License for the specific language governing permissions and\n",
+ " limitations under the License.\n",
+ "\n",
+ " Write optical flow to file.\n",
+ "\n",
+ " If v is None, uv is assumed to contain both u and v channels,\n",
+ " stacked in depth.\n",
+ " Original code by Deqing Sun, adapted from Daniel Scharstein.\n",
+ " \"\"\"\n",
+ " nBands = 2\n",
+ "\n",
+ " if v is None:\n",
+ " assert(uv.ndim == 3)\n",
+ " assert(uv.shape[2] == 2)\n",
+ " u = uv[:,:,0]\n",
+ " v = uv[:,:,1]\n",
+ " else:\n",
+ " u = uv\n",
+ "\n",
+ " assert(u.shape == v.shape)\n",
+ " height,width = u.shape\n",
+ " f = open(filename,'wb')\n",
+ " # write the header\n",
+ " f.write(TAG_CHAR)\n",
+ " np.array(width).astype(np.int32).tofile(f)\n",
+ " np.array(height).astype(np.int32).tofile(f)\n",
+ " # arrange into matrix form\n",
+ " tmp = np.zeros((height, width*nBands))\n",
+ " tmp[:,np.arange(width)*2] = u\n",
+ " tmp[:,np.arange(width)*2 + 1] = v\n",
+ " tmp.astype(np.float32).tofile(f)\n",
+ " f.close()\n",
+ "\n",
+ " def load_cc(path, blur=2, dilate=0):\n",
+ " multilayer_weights = np.array(Image.open(path))/255\n",
+ " weights = np.ones_like(multilayer_weights[...,0])\n",
+ " weights*=multilayer_weights[...,0].clip(1-missed_consistency_weight,1)\n",
+ " weights*=multilayer_weights[...,1].clip(1-overshoot_consistency_weight,1)\n",
+ " weights*=multilayer_weights[...,2].clip(1-edges_consistency_weight,1)\n",
+ " weights = np.where(weights<0.5, 0, 1)\n",
+ " if dilate>0:\n",
+ " weights = (1-binary_dilation(1-weights, disk(dilate))).astype('uint8')\n",
+ " if blur>0: weights = scipy.ndimage.gaussian_filter(weights, [blur, blur])\n",
+ " weights = np.repeat(weights[...,None],3, axis=2)\n",
+ " # print('------------cc debug------', f'{controlnetDebugFolder}/{args.batch_name}({args.batchNum})_cc_mask.jpg')\n",
+ " PIL.Image.fromarray((weights*255).astype('uint8')).save(f'{controlnetDebugFolder}/{args.batch_name}({args.batchNum})_cc_mask.jpg', quality=95)\n",
+ " # assert False\n",
+ " if DEBUG: print('weight min max mean std', weights.shape, weights.min(), weights.max(), weights.mean(), weights.std())\n",
+ " return weights\n",
+ "\n",
+ "\n",
+ "\n",
+ " def load_img(img, size):\n",
+ " img = Image.open(img).convert('RGB').resize(size, warp_interp)\n",
+ " return torch.from_numpy(np.array(img)).permute(2,0,1).float()[None,...].cuda()\n",
+ "\n",
+ " def get_flow(frame1, frame2, model, iters=20, half=True):\n",
+ " # print(frame1.shape, frame2.shape)\n",
+ " padder = InputPadder(frame1.shape)\n",
+ " frame1, frame2 = padder.pad(frame1, frame2)\n",
+ " if half: frame1, frame2 = frame1, frame2\n",
+ " # print(frame1.shape, frame2.shape)\n",
+ " _, flow12 = model(frame1, frame2)\n",
+ " flow12 = flow12[0].permute(1, 2, 0).detach().cpu().numpy()\n",
+ "\n",
+ " return flow12\n",
+ "\n",
+ " def warp_flow(img, flow, mul=1.):\n",
+ " h, w = flow.shape[:2]\n",
+ " flow = flow.copy()\n",
+ " flow[:, :, 0] += np.arange(w)\n",
+ " flow[:, :, 1] += np.arange(h)[:, np.newaxis]\n",
+ " flow*=mul\n",
+ " res = cv2.remap(img, flow, None, cv2.INTER_LANCZOS4)\n",
+ "\n",
+ " return res\n",
+ "\n",
+ " def makeEven(_x):\n",
+ " return _x if (_x % 2 == 0) else _x+1\n",
+ "\n",
+ " def fit(img,maxsize=512):\n",
+ " maxdim = max(*img.size)\n",
+ " if maxdim>maxsize:\n",
+ " # if True:\n",
+ " ratio = maxsize/maxdim\n",
+ " x,y = img.size\n",
+ " size = (makeEven(int(x*ratio)),makeEven(int(y*ratio)))\n",
+ " img = img.resize(size, warp_interp)\n",
+ " return img\n",
+ "\n",
+ "\n",
+ " def warp(frame1, frame2, flo_path, blend=0.5, weights_path=None, forward_clip=0.,\n",
+ " pad_pct=0.1, padding_mode='reflect', inpaint_blend=0., video_mode=False, warp_mul=1.):\n",
+ " printf('blend warp', blend)\n",
+ "\n",
+ " if isinstance(flo_path, str):\n",
+ " flow21 = np.load(flo_path)\n",
+ " else: flow21 = flo_path\n",
+ " # print('loaded flow from ', flo_path, ' witch shape ', flow21.shape)\n",
+ " pad = int(max(flow21.shape)*pad_pct)\n",
+ " flow21 = np.pad(flow21, pad_width=((pad,pad),(pad,pad),(0,0)),mode='constant')\n",
+ " # print('frame1.size, frame2.size, padded flow21.shape')\n",
+ " # print(frame1.size, frame2.size, flow21.shape)\n",
+ "\n",
+ "\n",
+ " frame1pil = np.array(frame1.convert('RGB'))#.resize((flow21.shape[1]-pad*2,flow21.shape[0]-pad*2),warp_interp))\n",
+ " frame1pil = np.pad(frame1pil, pad_width=((pad,pad),(pad,pad),(0,0)),mode=padding_mode)\n",
+ " if video_mode:\n",
+ " warp_mul=1.\n",
+ " frame1_warped21 = warp_flow(frame1pil, flow21, warp_mul)\n",
+ " frame1_warped21 = frame1_warped21[pad:frame1_warped21.shape[0]-pad,pad:frame1_warped21.shape[1]-pad,:]\n",
+ "\n",
+ " frame2pil = np.array(frame2.convert('RGB').resize((flow21.shape[1]-pad*2,flow21.shape[0]-pad*2),warp_interp))\n",
+ " # if not video_mode: frame2pil = match_color(frame1_warped21, frame2pil, opacity=match_color_strength)\n",
+ " if weights_path:\n",
+ " forward_weights = load_cc(weights_path, blur=consistency_blur, dilate=consistency_dilate)\n",
+ " # print('forward_weights')\n",
+ " # print(forward_weights.shape)\n",
+ " if not video_mode and match_color_strength>0.: frame2pil = match_color(frame1_warped21, frame2pil, opacity=match_color_strength)\n",
+ "\n",
+ " forward_weights = forward_weights.clip(forward_clip,1.)\n",
+ " if use_patchmatch_inpaiting>0 and warp_mode == 'use_image':\n",
+ " if not is_colab: print('Patchmatch only working on colab/linux')\n",
+ " else: print('PatchMatch disabled.')\n",
+ " # if not video_mode and is_colab:\n",
+ " # print('patchmatching')\n",
+ " # # print(np.array(blended_w).shape, forward_weights[...,0][...,None].shape )\n",
+ " # patchmatch_mask = (forward_weights[...,0][...,None]*-255.+255).astype('uint8')\n",
+ " # frame2pil = np.array(frame2pil)*(1-use_patchmatch_inpaiting)+use_patchmatch_inpaiting*np.array(patch_match.inpaint(frame1_warped21, patchmatch_mask, patch_size=5))\n",
+ " # # blended_w = Image.fromarray(blended_w)\n",
+ " blended_w = frame2pil*(1-blend) + blend*(frame1_warped21*forward_weights+frame2pil*(1-forward_weights))\n",
+ " else:\n",
+ " if not video_mode and match_color_strength>0.: frame2pil = match_color(frame1_warped21, frame2pil, opacity=match_color_strength)\n",
+ " blended_w = frame2pil*(1-blend) + frame1_warped21*(blend)\n",
+ "\n",
+ "\n",
+ "\n",
+ " blended_w = Image.fromarray(blended_w.round().astype('uint8'))\n",
+ " # if use_patchmatch_inpaiting and warp_mode == 'use_image':\n",
+ " # print('patchmatching')\n",
+ " # print(np.array(blended_w).shape, forward_weights[...,0][...,None].shape )\n",
+ " # patchmatch_mask = (forward_weights[...,0][...,None]*-255.+255).astype('uint8')\n",
+ " # blended_w = patch_match.inpaint(blended_w, patchmatch_mask, patch_size=5)\n",
+ " # blended_w = Image.fromarray(blended_w)\n",
+ " if not video_mode:\n",
+ " if enable_adjust_brightness: blended_w = adjust_brightness(blended_w)\n",
+ " return blended_w\n",
+ "\n",
+ " def warp_lat(frame1, frame2, flo_path, blend=0.5, weights_path=None, forward_clip=0.,\n",
+ " pad_pct=0.1, padding_mode='reflect', inpaint_blend=0., video_mode=False, warp_mul=1.):\n",
+ " warp_downscaled = True\n",
+ " flow21 = np.load(flo_path)\n",
+ " pad = int(max(flow21.shape)*pad_pct)\n",
+ " if warp_downscaled:\n",
+ " flow21 = flow21.transpose(2,0,1)[None,...]\n",
+ " flow21 = torch.nn.functional.interpolate(torch.from_numpy(flow21).float(), scale_factor = 1/8, mode = 'bilinear')\n",
+ " flow21 = flow21.numpy()[0].transpose(1,2,0)/8\n",
+ " # flow21 = flow21[::8,::8,:]/8\n",
+ "\n",
+ " flow21 = np.pad(flow21, pad_width=((pad,pad),(pad,pad),(0,0)),mode='constant')\n",
+ "\n",
+ " if not warp_downscaled:\n",
+ " frame1 = torch.nn.functional.interpolate(frame1, scale_factor = 8)\n",
+ " frame1pil = frame1.cpu().numpy()[0].transpose(1,2,0)\n",
+ "\n",
+ " frame1pil = np.pad(frame1pil, pad_width=((pad,pad),(pad,pad),(0,0)),mode=padding_mode)\n",
+ " if video_mode:\n",
+ " warp_mul=1.\n",
+ " frame1_warped21 = warp_flow(frame1pil, flow21, warp_mul)\n",
+ " frame1_warped21 = frame1_warped21[pad:frame1_warped21.shape[0]-pad,pad:frame1_warped21.shape[1]-pad,:]\n",
+ " if not warp_downscaled:\n",
+ " frame2pil = frame2.convert('RGB').resize((flow21.shape[1]-pad*2,flow21.shape[0]-pad*2),warp_interp)\n",
+ " else:\n",
+ " frame2pil = frame2.convert('RGB').resize(((flow21.shape[1]-pad*2)*8,(flow21.shape[0]-pad*2)*8),warp_interp)\n",
+ " frame2pil = np.array(frame2pil)\n",
+ " frame2pil = (frame2pil/255.)[None,...].transpose(0, 3, 1, 2)\n",
+ " frame2pil = 2*torch.from_numpy(frame2pil).float().cuda()-1.\n",
+ " frame2pil = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(frame2pil))\n",
+ " if not warp_downscaled: frame2pil = torch.nn.functional.interpolate(frame2pil, scale_factor = 8)\n",
+ " frame2pil = frame2pil.cpu().numpy()[0].transpose(1,2,0)\n",
+ " # if not video_mode: frame2pil = match_color(frame1_warped21, frame2pil, opacity=match_color_strength)\n",
+ " if weights_path:\n",
+ " forward_weights = load_cc(weights_path, blur=consistency_blur, dilate=consistency_dilate)\n",
+ " print(forward_weights[...,:1].shape, 'forward_weights.shape')\n",
+ " forward_weights = np.repeat(forward_weights[...,:1],4, axis=-1)\n",
+ " # print('forward_weights')\n",
+ " # print(forward_weights.shape)\n",
+ " print('frame2pil.shape, frame1_warped21.shape, flow21.shape', frame2pil.shape, frame1_warped21.shape, flow21.shape)\n",
+ " forward_weights = forward_weights.clip(forward_clip,1.)\n",
+ " if warp_downscaled: forward_weights = forward_weights[::8,::8,:]; print(forward_weights.shape, 'forward_weights.shape')\n",
+ " blended_w = frame2pil*(1-blend) + blend*(frame1_warped21*forward_weights+frame2pil*(1-forward_weights))\n",
+ " else:\n",
+ " if not video_mode and not warp_mode == 'use_latent' and match_color_strength>0.: frame2pil = match_color(frame1_warped21, frame2pil, opacity=match_color_strength)\n",
+ " blended_w = frame2pil*(1-blend) + frame1_warped21*(blend)\n",
+ " blended_w = blended_w.transpose(2,0,1)[None,...]\n",
+ " blended_w = torch.from_numpy(blended_w).float()\n",
+ " if not warp_downscaled:\n",
+ " # blended_w = blended_w[::8,::8,:]\n",
+ " blended_w = torch.nn.functional.interpolate(blended_w, scale_factor = 1/8, mode='bilinear')\n",
+ "\n",
+ "\n",
+ " return blended_w# torch.nn.functional.interpolate(torch.from_numpy(blended_w), scale_factor = 1/8)\n",
+ "\n",
+ "\n",
+ " in_path = videoFramesFolder if not flow_video_init_path else flowVideoFramesFolder\n",
+ " flo_folder = in_path+'_out_flo_fwd'\n",
+ "\n",
+ " temp_flo = in_path+'_temp_flo'\n",
+ " flo_fwd_folder = in_path+'_out_flo_fwd'\n",
+ " flo_bck_folder = in_path+'_out_flo_bck'\n",
+ "\n",
+ " %cd {root_path}\n",
+ "\n",
+ "# (c) Alex Spirin 2023\n",
+ "\n",
+ "import cv2\n",
+ "\n",
+ "def extract_occlusion_mask(flow, threshold=10):\n",
+ " flow = flow.clone()[0].permute(1, 2, 0).detach().cpu().numpy()\n",
+ " h, w = flow.shape[:2]\n",
+ "\n",
+ " \"\"\"\n",
+ " Extract a mask containing all the points that have no origin in frame one.\n",
+ "\n",
+ " Parameters:\n",
+ " motion_vector (numpy.ndarray): A 2D array of motion vectors.\n",
+ " threshold (int): The threshold value for the magnitude of the motion vector.\n",
+ "\n",
+ " Returns:\n",
+ " numpy.ndarray: The occlusion mask.\n",
+ " \"\"\"\n",
+ " # Compute the magnitude of the motion vector.\n",
+ " mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])\n",
+ "\n",
+ " # Threshold the magnitude to identify occlusions.\n",
+ " occlusion_mask = (mag > threshold).astype(np.uint8)\n",
+ "\n",
+ " return occlusion_mask, mag\n",
+ "\n",
+ "import cv2\n",
+ "import numpy as np\n",
+ "\n",
+ "def edge_detector(image, threshold=0.5, edge_width=1):\n",
+ " \"\"\"\n",
+ " Detect edges in an image with adjustable edge width.\n",
+ "\n",
+ " Parameters:\n",
+ " image (numpy.ndarray): The input image.\n",
+ " edge_width (int): The width of the edges to detect.\n",
+ "\n",
+ " Returns:\n",
+ " numpy.ndarray: The edge image.\n",
+ " \"\"\"\n",
+ " # Convert the image to grayscale.\n",
+ " gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
+ "\n",
+ " # Compute the Sobel edge map.\n",
+ " sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=edge_width)\n",
+ " sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=edge_width)\n",
+ "\n",
+ " # Compute the edge magnitude.\n",
+ " mag = np.sqrt(sobelx ** 2 + sobely ** 2)\n",
+ "\n",
+ " # Normalize the magnitude to the range [0, 1].\n",
+ " mag = cv2.normalize(mag, None, 0, 1, cv2.NORM_MINMAX)\n",
+ "\n",
+ " # Threshold the magnitude to create a binary edge image.\n",
+ "\n",
+ " edge_image = (mag > threshold).astype(np.uint8) * 255\n",
+ "\n",
+ " return edge_image\n",
+ "\n",
+ "def get_unreliable(flow):\n",
+ " # Mask pixels that have no source and will be taken from frame1, to remove trails and ghosting.\n",
+ "\n",
+ " # flow = flow[0].cpu().numpy().transpose(1,2,0)\n",
+ "\n",
+ " # Calculate the coordinates of pixels in the new frame\n",
+ " h, w = flow.shape[:2]\n",
+ " x, y = np.meshgrid(np.arange(w), np.arange(h))\n",
+ " new_x = x + flow[..., 0]\n",
+ " new_y = y + flow[..., 1]\n",
+ "\n",
+ " # Create a mask for the valid pixels in the new frame\n",
+ " mask = (new_x >= 0) & (new_x < w) & (new_y >= 0) & (new_y < h)\n",
+ "\n",
+ " # Create the new frame by interpolating the pixel values using the calculated coordinates\n",
+ " new_frame = np.zeros((flow.shape[0], flow.shape[1], 3))*1.-1\n",
+ " new_frame[new_y[mask].astype(np.int32), new_x[mask].astype(np.int32)] = 255\n",
+ "\n",
+ " # Keep masked area, discard the image.\n",
+ " new_frame = new_frame==-1\n",
+ " return new_frame, mask\n",
+ "\n",
+ "from scipy.ndimage import binary_fill_holes\n",
+ "from skimage.morphology import disk, binary_erosion, binary_dilation, binary_opening, binary_closing\n",
+ "\n",
+ "import cv2\n",
+ "\n",
+ "def remove_small_holes(mask, min_size=50):\n",
+ " # Copy the input binary mask\n",
+ " result = mask.copy()\n",
+ "\n",
+ " # Find contours of connected components in the binary image\n",
+ " contours, hierarchy = cv2.findContours(result, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)\n",
+ "\n",
+ " # Iterate over each contour\n",
+ " for i in range(len(contours)):\n",
+ " # Compute the area of the i-th contour\n",
+ " area = cv2.contourArea(contours[i])\n",
+ "\n",
+ " # Check if the area of the i-th contour is smaller than min_size\n",
+ " if area < min_size:\n",
+ " # Draw a filled contour over the i-th contour region\n",
+ " cv2.drawContours(result, [contours[i]], 0, 255, -1, cv2.LINE_AA, hierarchy, 0)\n",
+ "\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "def filter_unreliable(mask, dilation=1):\n",
+ " img = 255-remove_small_holes((1-mask[...,0].astype('uint8'))*255, 200)\n",
+ " # img = binary_fill_holes(img)\n",
+ " img = binary_erosion(img, disk(1))\n",
+ " img = binary_dilation(img, disk(dilation))\n",
+ " return img\n",
+ "from torchvision.utils import flow_to_image as flow_to_image_torch\n",
+ "def make_cc_map(predicted_flows, predicted_flows_bwd, dilation=1, edge_width=11):\n",
+ "\n",
+ " flow_imgs = flow_to_image(predicted_flows_bwd)\n",
+ " edge = edge_detector(flow_imgs.astype('uint8'), threshold=0.1, edge_width=edge_width)\n",
+ " res, _ = get_unreliable(predicted_flows)\n",
+ " _, overshoot = get_unreliable(predicted_flows_bwd)\n",
+ " joint_mask = np.ones_like(res)*255\n",
+ " joint_mask[...,0] = 255-(filter_unreliable(res, dilation)*255)\n",
+ " joint_mask[...,1] = (overshoot*255)\n",
+ " joint_mask[...,2] = 255-edge\n",
+ "\n",
+ " return joint_mask\n",
+ "\n",
+ "\n",
+ "def hstack(images):\n",
+ " if isinstance(images[0], str):\n",
+ " images = [Image.open(image).convert('RGB') for image in images]\n",
+ " widths, heights = zip(*(i.size for i in images))\n",
+ " for image in images:\n",
+ " draw = ImageDraw.Draw(image)\n",
+ " draw.rectangle(((0, 00), (image.size[0], image.size[1])), outline=\"black\", width=3)\n",
+ " total_width = sum(widths)\n",
+ " max_height = max(heights)\n",
+ "\n",
+ " new_im = Image.new('RGB', (total_width, max_height))\n",
+ "\n",
+ " x_offset = 0\n",
+ " for im in images:\n",
+ " new_im.paste(im, (x_offset,0))\n",
+ " x_offset += im.size[0]\n",
+ " return new_im\n",
+ "\n",
+ "import locale\n",
+ "def getpreferredencoding(do_setlocale = True):\n",
+ " return \"UTF-8\"\n",
+ "if is_colab: locale.getpreferredencoding = getpreferredencoding\n",
+ "\n",
+ "def vstack(images):\n",
+ " if isinstance(next(iter(images)), str):\n",
+ " images = [Image.open(image).convert('RGB') for image in images]\n",
+ " widths, heights = zip(*(i.size for i in images))\n",
+ "\n",
+ " total_height = sum(heights)\n",
+ " max_width = max(widths)\n",
+ "\n",
+ " new_im = Image.new('RGB', (max_width, total_height))\n",
+ "\n",
+ " y_offset = 0\n",
+ " for im in images:\n",
+ " new_im.paste(im, (0, y_offset))\n",
+ " y_offset += im.size[1]\n",
+ " return new_im\n",
+ "\n",
+ "if is_colab:\n",
+ " for i in [7,8,9,10]:\n",
+ " try:\n",
+ " filedata = None\n",
+ " with open(f'/usr/local/lib/python3.{i}/dist-packages/PIL/TiffImagePlugin.py', 'r') as file :\n",
+ " filedata = file.read()\n",
+ " filedata = filedata.replace('(TiffTags.IFD, \"L\", \"long\"),', '#(TiffTags.IFD, \"L\", \"long\"),')\n",
+ " with open(f'/usr/local/lib/python3.{i}/dist-packages/PIL/TiffImagePlugin.py', 'w') as file :\n",
+ " file.write(filedata)\n",
+ " with open(f'/usr/local/lib/python3.7/dist-packages/PIL/TiffImagePlugin.py', 'w') as file :\n",
+ " file.write(filedata)\n",
+ " except:\n",
+ " pass\n",
+ " # print(f'Error writing /usr/local/lib/python3.{i}/dist-packages/PIL/TiffImagePlugin.py')\n",
+ "\n",
+ "class flowDataset():\n",
+ " def __init__(self, in_path, half=True, normalize=False):\n",
+ " frames = sorted(glob(in_path+'/*.*'));\n",
+ " assert len(frames)>2, f'WARNING!\\nCannot create flow maps: Found {len(frames)} frames extracted from your video input.\\nPlease check your video path.'\n",
+ " self.frames = frames\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.frames)-1\n",
+ "\n",
+ " def load_img(self, img, size):\n",
+ " img = Image.open(img).convert('RGB').resize(size, warp_interp)\n",
+ " return torch.from_numpy(np.array(img)).permute(2,0,1).float()[None,...]\n",
+ "\n",
+ " def __getitem__(self, i):\n",
+ " frame1, frame2 = self.frames[i], self.frames[i+1]\n",
+ " frame1 = self.load_img(frame1, width_height)\n",
+ " frame2 = self.load_img(frame2, width_height)\n",
+ " padder = InputPadder(frame1.shape)\n",
+ " frame1, frame2 = padder.pad(frame1, frame2)\n",
+ " batch = torch.cat([frame1, frame2])\n",
+ " if normalize:\n",
+ " batch = 2 * (batch / 255.0) - 1.0\n",
+ " return batch\n",
+ "\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "def save_preview(flow21, out_flow21_fn):\n",
+ " try:\n",
+ " Image.fromarray(flow_to_image(flow21)).save(out_flow21_fn, quality=90)\n",
+ " except:\n",
+ " print('Error saving flow preview for frame ', out_flow21_fn)\n",
+ "\n",
+ "#copyright Alex Spirin @ 2022\n",
+ "def blended_roll(img_copy, shift, axis):\n",
+ " if int(shift) == shift:\n",
+ " return np.roll(img_copy, int(shift), axis=axis)\n",
+ "\n",
+ " max = math.ceil(shift)\n",
+ " min = math.floor(shift)\n",
+ " if min != 0 :\n",
+ " img_min = np.roll(img_copy, min, axis=axis)\n",
+ " else:\n",
+ " img_min = img_copy\n",
+ " img_max = np.roll(img_copy, max, axis=axis)\n",
+ " blend = max-shift\n",
+ " img_blend = img_min*blend + img_max*(1-blend)\n",
+ " return img_blend\n",
+ "\n",
+ "#copyright Alex Spirin @ 2022\n",
+ "def move_cluster(img,i,res2, center, mode='blended_roll'):\n",
+ " img_copy = img.copy()\n",
+ " motion = center[i]\n",
+ " mask = np.where(res2==motion, 1, 0)[...,0][...,None]\n",
+ " y, x = motion\n",
+ " if mode=='blended_roll':\n",
+ " img_copy = blended_roll(img_copy, x, 0)\n",
+ " img_copy = blended_roll(img_copy, y, 1)\n",
+ " if mode=='int_roll':\n",
+ " img_copy = np.roll(img_copy, int(x), axis=0)\n",
+ " img_copy = np.roll(img_copy, int(y), axis=1)\n",
+ " return img_copy, mask\n",
+ "\n",
+ "import cv2\n",
+ "\n",
+ "\n",
+ "def get_k(flow, K):\n",
+ " Z = flow.reshape((-1,2))\n",
+ " # convert to np.float32\n",
+ " Z = np.float32(Z)\n",
+ " # define criteria, number of clusters(K) and apply kmeans()\n",
+ " criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)\n",
+ " ret,label,center=cv2.kmeans(Z,K,None,criteria,10,cv2.KMEANS_RANDOM_CENTERS)\n",
+ " # Now convert back into uint8, and make original image\n",
+ " res = center[label.flatten()]\n",
+ " res2 = res.reshape((flow.shape))\n",
+ " return res2, center\n",
+ "\n",
+ "def k_means_warp(flo, img, num_k):\n",
+ " # flo = np.load(flo)\n",
+ " img = np.array((img).convert('RGB'))\n",
+ " num_k = 8\n",
+ "\n",
+ " # print(img.shape)\n",
+ " res2, center = get_k(flo, num_k)\n",
+ " center = sorted(list(center), key=lambda x: abs(x).mean())\n",
+ "\n",
+ " img = cv2.resize(img, (res2.shape[:-1][::-1]))\n",
+ " img_out = np.ones_like(img)*255.\n",
+ "\n",
+ " for i in range(num_k):\n",
+ " img_rolled, mask_i = move_cluster(img,i,res2,center)\n",
+ " img_out = img_out*(1-mask_i) + img_rolled*(mask_i)\n",
+ "\n",
+ " # cv2_imshow(img_out)\n",
+ " return Image.fromarray(img_out.astype('uint8'))\n",
+ "\n",
+ "def flow_batch(i, batch, pool):\n",
+ " with torch.cuda.amp.autocast():\n",
+ " batch = batch[0]\n",
+ " frame_1 = batch[0][None,...].cuda()\n",
+ " frame_2 = batch[1][None,...].cuda()\n",
+ " frame1 = ds.frames[i]\n",
+ " frame1 = frame1.replace('\\\\','/')\n",
+ " out_flow21_fn = f\"{flo_fwd_folder}/{frame1.split('/')[-1]}\"\n",
+ " if flow_lq: frame_1, frame_2 = frame_1, frame_2\n",
+ " if use_jit_raft:\n",
+ " _, flow21 = raft_model(frame_2, frame_1)\n",
+ " else:\n",
+ " flow21 = raft_model(frame_2, frame_1, num_flow_updates=num_flow_updates)[-1] #flow_bwd\n",
+ " mag = (flow21[:,0:1,...]**2 + flow21[:,1:,...]**2).sqrt()\n",
+ " mag_thresh = 0.5\n",
+ " #zero out flow values for non-moving frames below threshold to avoid noisy flow/cc maps\n",
+ " if mag.max()0) and not force_flow_generation: print(f'Skipping flow generation:\\nFound {len(flows)} existing flow files in current working folder: {flo_folder}.\\nIf you wish to generate new flow files, check force_flow_generation and run this cell again.')\n",
+ "\n",
+ " if (len(flows)==0) or force_flow_generation:\n",
+ " ds = flowDataset(in_path, normalize=not use_jit_raft)\n",
+ "\n",
+ " frames = sorted(glob(in_path+'/*.*'));\n",
+ " if len(frames)<2:\n",
+ " print(f'WARNING!\\nCannot create flow maps: Found {len(frames)} frames extracted from your video input.\\nPlease check your video path.')\n",
+ " if len(frames)>=2:\n",
+ " if __name__ == '__main__':\n",
+ "\n",
+ " dl = DataLoader(ds, num_workers=num_workers)\n",
+ " if use_jit_raft:\n",
+ " if flow_lq:\n",
+ " raft_model = torch.jit.load(f'{root_dir}/WarpFusion/raft/raft_half.jit').eval()\n",
+ " # raft_model = torch.nn.DataParallel(RAFT(args2))\n",
+ " else: raft_model = torch.jit.load(f'{root_dir}/WarpFusion/raft/raft_fp32.jit').eval()\n",
+ " # raft_model.load_state_dict(torch.load(f'{root_path}/RAFT/models/raft-things.pth'))\n",
+ " # raft_model = raft_model.module.cuda().eval()\n",
+ " else:\n",
+ " if raft_model is None or not compile_raft:\n",
+ " from torchvision.models.optical_flow import Raft_Large_Weights, Raft_Small_Weights\n",
+ " from torchvision.models.optical_flow import raft_large, raft_small\n",
+ " raft_weights = Raft_Large_Weights.C_T_SKHT_V1\n",
+ " raft_device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "\n",
+ " raft_model = raft_large(weights=raft_weights, progress=False).to(raft_device)\n",
+ " # raft_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(raft_device)\n",
+ " raft_model = raft_model.eval()\n",
+ " if gpu != 'T4' and compile_raft: raft_model = torch.compile(raft_model)\n",
+ " if flow_lq:\n",
+ " raft_model = raft_model.half()\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " temp_flo = in_path+'_temp_flo'\n",
+ " # flo_fwd_folder = in_path+'_out_flo_fwd'\n",
+ " flo_fwd_folder = in_path+f'_out_flo_fwd/{side_x}_{side_y}/'\n",
+ " for f in pathlib.Path(f'{flo_fwd_folder}').glob('*.*'):\n",
+ " f.unlink()\n",
+ "\n",
+ " os.makedirs(flo_fwd_folder, exist_ok=True)\n",
+ " os.makedirs(temp_flo, exist_ok=True)\n",
+ " cc_path = f'{root_dir}/flow_tools/check_consistency.py'\n",
+ " with torch.no_grad():\n",
+ " p = Pool(threads)\n",
+ " for i,batch in enumerate(tqdm(dl)):\n",
+ " flow_batch(i, batch, p)\n",
+ " p.close()\n",
+ " p.join()\n",
+ "\n",
+ " del raft_model, p, dl, ds\n",
+ " gc.collect()\n",
+ " if is_colab: locale.getpreferredencoding = getpreferredencoding\n",
+ " if check_consistency and use_legacy_cc:\n",
+ " fwd = f\"{flo_fwd_folder}/*jpg.npy\"\n",
+ " bwd = f\"{flo_fwd_folder}/*jpg_12.npy\"\n",
+ "\n",
+ " if reverse_cc_order:\n",
+ " #old version, may be incorrect\n",
+ " print('Doing bwd->fwd cc check')\n",
+ " !python \"{cc_path}\" --flow_fwd \"{fwd}\" --flow_bwd \"{bwd}\" --output \"{flo_fwd_folder}/\" --image_output --output_postfix=\"-21_cc\" --blur=0. --save_separate_channels --skip_numpy_output\n",
+ " else:\n",
+ " print('Doing fwd->bwd cc check')\n",
+ " !python \"{cc_path}\" --flow_fwd \"{bwd}\" --flow_bwd \"{fwd}\" --output \"{flo_fwd_folder}/\" --image_output --output_postfix=\"-21_cc\" --blur=0. --save_separate_channels --skip_numpy_output\n",
+ " # delete forward flow\n",
+ " # for f in pathlib.Path(flo_fwd_folder).glob('*jpg_12.npy'):\n",
+ " # f.unlink()\n",
+ "\n",
+ "flo_imgs = glob(flo_fwd_folder+'/*.jpg.jpg')[:5]\n",
+ "vframes = []\n",
+ "for flo_img in flo_imgs:\n",
+ " hframes = []\n",
+ " flo_img = flo_img.replace('\\\\','/')\n",
+ " frame = Image.open(videoFramesFolder + '/' + flo_img.split('/')[-1][:-4])\n",
+ " hframes.append(frame)\n",
+ " try:\n",
+ " alpha = Image.open(videoFramesAlpha + '/' + flo_img.split('/')[-1][:-4]).resize(frame.size)\n",
+ " hframes.append(alpha)\n",
+ " except:\n",
+ " pass\n",
+ " try:\n",
+ " cc_img = Image.open(flo_img[:-4]+'-21_cc.jpg').convert('L').resize(frame.size)\n",
+ " hframes.append(cc_img)\n",
+ " except:\n",
+ " pass\n",
+ " try:\n",
+ " flo_img = Image.open(flo_img).resize(frame.size)\n",
+ " hframes.append(flo_img)\n",
+ " except:\n",
+ " pass\n",
+ " v_imgs = vstack(hframes)\n",
+ " vframes.append(v_imgs)\n",
+ "preview = hstack(vframes)\n",
+ "del vframes, hframes\n",
+ "\n",
+ "executed_cells[cell_name] = True\n",
+ "fit(preview, 1024)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UZAstRLK6bDz"
+ },
+ "source": [
+ "# Load up a stable.\n",
+ "\n",
+ "Don't forget to place your checkpoint at /content/ and change the path accordingly.\n",
+ "\n",
+ "\n",
+ "You need to log on to https://huggingface.co and\n",
+ "\n",
+ "get checkpoints here -\n",
+ "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original\n",
+ "\n",
+ "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt\n",
+ "or\n",
+ "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt\n",
+ "\n",
+ "You can pick 1.2 or 1.3 as well, just be sure to grab the \"original\" flavor.\n",
+ "\n",
+ "For v2 go here:\n",
+ "https://huggingface.co/stabilityai/stable-diffusion-2-depth\n",
+ "https://huggingface.co/stabilityai/stable-diffusion-2-base\n",
+ "\n",
+ "Inpainting model: https://huggingface.co/runwayml/stable-diffusion-v1-5\n",
+ "\n",
+ "If you're having black frames with sdxl, turn off tiled vae, enable no_half_vae or use this vae - https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl_vae.safetensors"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "T_ikoekNjpnS"
+ },
+ "outputs": [],
+ "source": [
+ "#@markdown specify path to your Stable Diffusion checkpoint (the \"original\" flavor)\n",
+ "#@title define SD + K functions, load model\n",
+ "cell_name = 'load_model'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "from safetensors import safe_open\n",
+ "import argparse\n",
+ "import math,os,time\n",
+ "try:\n",
+ " os.chdir( f'{root_dir}/src/taming-transformers')\n",
+ " import taming\n",
+ " os.chdir( f'{root_dir}')\n",
+ " os.chdir( f'{root_dir}/k-diffusion')\n",
+ " import k_diffusion as K\n",
+ " os.chdir( f'{root_dir}')\n",
+ "except:\n",
+ " import taming\n",
+ " import k_diffusion as K\n",
+ "import wget\n",
+ "import accelerate\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from tqdm.notebook import trange, tqdm\n",
+ "sys.path.append('./k-diffusion')\n",
+ "\n",
+ "\n",
+ "\n",
+ "from pytorch_lightning import seed_everything\n",
+ "from k_diffusion.sampling import sample_euler, sample_euler_ancestral, sample_heun, sample_dpm_2, sample_dpm_2_ancestral, sample_lms, sample_dpm_fast, sample_dpm_adaptive, sample_dpmpp_2s_ancestral, sample_dpmpp_sde, sample_dpmpp_2m\n",
+ "\n",
+ "from omegaconf import OmegaConf\n",
+ "from ldm.util import instantiate_from_config\n",
+ "\n",
+ "from torch import autocast\n",
+ "import numpy as np\n",
+ "\n",
+ "from einops import rearrange\n",
+ "from torchvision.utils import make_grid\n",
+ "from torchvision import transforms\n",
+ "\n",
+ "try:\n",
+ " del sd_model\n",
+ "except: pass\n",
+ "try:\n",
+ " del model_wrap_cfg\n",
+ " del model_wrap\n",
+ "except: pass\n",
+ "\n",
+ "\n",
+ "torch.cuda.empty_cache()\n",
+ "gc.collect()\n",
+ "\n",
+ "\n",
+ "model_urls = {\n",
+ " \"sd_v1_5\":\"https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors\",\n",
+ " \"dpt_hybrid-midas-501f0c75\":\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt\"\n",
+ "}\n",
+ "\n",
+ "# https://huggingface.co/lllyasviel/ControlNet-v1-1/tree/main\n",
+ "control_model_urls = {\n",
+ " \"control_sd15_canny\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth\",\n",
+ " \"control_sd15_depth\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1p_sd15_depth.pth\",\n",
+ " \"control_sd15_softedge\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth\", # replaces hed, v11 uses sofftedge model here\n",
+ " \"control_sd15_mlsd\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_mlsd.pth\",\n",
+ " \"control_sd15_normalbae\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_normalbae.pth\",\n",
+ " \"control_sd15_openpose\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_openpose.pth\",\n",
+ " \"control_sd15_scribble\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_scribble.pth\",\n",
+ " \"control_sd15_seg\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_seg.pth\",\n",
+ " \"control_sd15_temporalnet\":\"https://huggingface.co/CiaraRowles/TemporalNet/resolve/main/diff_control_sd15_temporalnet_fp16.safetensors\",\n",
+ " \"control_sd15_face\":\"https://huggingface.co/CrucibleAI/ControlNetMediaPipeFace/resolve/main/control_v2p_sd15_mediapipe_face.safetensors\",\n",
+ " \"control_sd15_ip2p\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_ip2p.pth\",\n",
+ " \"control_sd15_inpaint\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_inpaint.pth\",\n",
+ " \"control_sd15_lineart\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth\",\n",
+ " \"control_sd15_lineart_anime\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15s2_lineart_anime.pth\",\n",
+ " \"control_sd15_shuffle\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_shuffle.pth\",\n",
+ " \"control_sdxl_canny\":\"https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.fp16.safetensors\",\n",
+ " \"control_sdxl_depth\":\"https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0/resolve/main/diffusion_pytorch_model.fp16.safetensors\",\n",
+ " \"control_sdxl_softedge\":\"https://huggingface.co/SargeZT/controlnet-sd-xl-1.0-softedge-dexined/resolve/main/controlnet-sd-xl-1.0-softedge-dexined.safetensors\",\n",
+ " \"control_sdxl_seg\":\"https://huggingface.co/SargeZT/sdxl-controlnet-seg/resolve/main/diffusion_pytorch_model.bin\",\n",
+ " \"control_sdxl_openpose\":\"https://huggingface.co/thibaud/controlnet-openpose-sdxl-1.0/resolve/main/OpenPoseXL2.safetensors\",\n",
+ " \"control_sdxl_lora_128_depth\":\"https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank128/control-lora-depth-rank128.safetensors\",\n",
+ " \"control_sdxl_lora_256_depth\":\"https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors\",\n",
+ " \"control_sdxl_lora_128_canny\":\"https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank128/control-lora-canny-rank128.safetensors\",\n",
+ " \"control_sdxl_lora_256_canny\":\"https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors\",\n",
+ " \"control_sdxl_lora_128_softedge\":\"https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank128/control-lora-sketch-rank128-metadata.safetensors\",\n",
+ " \"control_sdxl_lora_256_softedge\":\"https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors\",\n",
+ " \"control_sd15_tile\":\"https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth\",\n",
+ " \"control_sd15_qr\":\"https://huggingface.co/DionTimmer/controlnet_qrcode/resolve/main/control_v1p_sd15_qrcode.safetensors\",\n",
+ " \"control_sd21_qr\":\"https://huggingface.co/DionTimmer/controlnet_qrcode/resolve/main/control_v11p_sd21_qrcode.safetensors\",\n",
+ " \"control_sd21_depth\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_zoedepth.safetensors\",\n",
+ " \"control_sd21_scribble\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_scribble.safetensors\",\n",
+ " \"control_sd21_openpose\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_openposev2.safetensors\",\n",
+ " \"control_sd21_normalbae\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_normalbae.safetensors\",\n",
+ " \"control_sd21_lineart\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_lineart.safetensors\",\n",
+ " \"control_sd21_softedge\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_hed.safetensors\",\n",
+ " \"control_sd21_canny\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_canny.safetensors\",\n",
+ " \"control_sd21_seg\":\"https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_ade20k.safetensors\",\n",
+ " \"control_sdxl_temporalnet_v1\": \"https://huggingface.co/CiaraRowles/controlnet-temporalnet-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors\",\n",
+ " \"control_sd15_inpaint_softedge\":\"https://huggingface.co/sxela/WarpControlnets/resolve/main/control_v01e_sd15_inpaint_softedge.pth\",\n",
+ " \"control_sd15_temporal_depth\":\"https://huggingface.co/sxela/WarpControlnets/resolve/main/control_v01e_sd15_temporal_depth.pth\"\n",
+ "}\n",
+ "control_model_filenames = {\n",
+ " \"control_sdxl_canny\":\"diffusers-controlnet-canny-sdxl-1.0.fp16.safetensors\",\n",
+ " \"control_sdxl_depth\":\"diffusers-controlnet-depth-sdxl-1.0.fp16.safetensors\",\n",
+ " \"control_sdxl_softedge\":\"SargeZT-controlnet-sd-xl-1.0-softedge-dexined.safetensors\",\n",
+ " \"control_sdxl_seg\":\"SargeZT-sdxl-controlnet-seg.bin\",\n",
+ " \"control_sdxl_openpose\":\"thibaud-OpenPoseXL2.safetensors\",\n",
+ " \"control_sdxl_lora_128_depth\":\"stability-control-lora-depth-rank128.safetensors\",\n",
+ " \"control_sdxl_lora_256_depth\":\"stability-control-lora-depth-rank256.safetensors\",\n",
+ " \"control_sdxl_lora_128_canny\":\"stability-control-lora-canny-rank128.safetensors\",\n",
+ " \"control_sdxl_lora_256_canny\":\"stability-control-lora-canny-rank256.safetensors\",\n",
+ " \"control_sdxl_lora_128_softedge\":\"stability-control-lora-sketch-rank128.safetensors\",\n",
+ " \"control_sdxl_lora_256_softedge\":\"stability-control-lora-sketch-rank256.safetensors\",\n",
+ " \"control_sdxl_temporalnet_v1\":\"CiaraRowles-temporalnet-sdxl-v1.safetensors\" #old-style cn with 3 input channels\n",
+ "}\n",
+ "\n",
+ "def model_to(model, device):\n",
+ " for param in model.state.values():\n",
+ " # Not sure there are any global tensors in the state dict\n",
+ " if isinstance(param, torch.Tensor):\n",
+ " param.data = param.data.to(device)\n",
+ " if param._grad is not None:\n",
+ " param._grad.data = param._grad.data.to(device)\n",
+ " elif isinstance(param, dict):\n",
+ " for subparam in param.values():\n",
+ " if isinstance(subparam, torch.Tensor):\n",
+ " subparam.data = subparam.data.to(device)\n",
+ " if subparam._grad is not None:\n",
+ " subparam._grad.data = subparam._grad.data.to(device)\n",
+ "\n",
+ "\n",
+ "# import wget\n",
+ "model_version = 'control_multi'#@param ['control_multi_v2','control_multi_v2_768','control_multi_sdxl','control_multi','sdxl_base','sdxl_refiner','v1','v1_inpainting','v1_instructpix2pix','v2_512', 'v2_768_v']\n",
+ "if model_version == 'v1' :\n",
+ " config_path = f\"{root_dir}/stablediffusion/configs/stable-diffusion/v1-inference.yaml\"\n",
+ "if model_version == 'v1_inpainting':\n",
+ " config_path = f\"{root_dir}/stablediffusion/configs/stable-diffusion/v1-inpainting-inference.yaml\"\n",
+ "if model_version == 'v2_512':\n",
+ " config_path = f\"{root_dir}/stablediffusion/configs/stable-diffusion/v2-inference.yaml\"\n",
+ "if model_version == 'v2_768_v':\n",
+ " config_path = f\"{root_dir}/stablediffusion/configs/stable-diffusion/v2-inference-v.yaml\"\n",
+ "if model_version == 'v2_depth':\n",
+ " config_path = f\"{root_dir}/stablediffusion/configs/stable-diffusion/v2-midas-inference.yaml\"\n",
+ " os.makedirs(f'{root_dir}/midas_models', exist_ok=True)\n",
+ " if not os.path.exists(f\"{root_dir}/midas_models/dpt_hybrid-midas-501f0c75.pt\"):\n",
+ " midas_url = model_urls['dpt_hybrid-midas-501f0c75']\n",
+ " os.makedirs(f'{root_dir}/midas_models', exist_ok=True)\n",
+ " wget.download(midas_url, f\"{root_dir}/midas_models/dpt_hybrid-midas-501f0c75.pt\")\n",
+ " # !wget -O \"{root_dir}/midas_models/dpt_hybrid-midas-501f0c75.pt\" https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt\n",
+ "if 'sdxl' in model_version:\n",
+ " os.chdir(f'{root_dir}/generative-models')\n",
+ " import sgm\n",
+ " from sgm.modules.diffusionmodules.discretizer import LegacyDDPMDiscretization\n",
+ " os.chdir(f'{root_dir}')\n",
+ "else:\n",
+ " try: del comfy\n",
+ " except: pass\n",
+ "if model_version in ['sdxl_base', 'control_multi_sdxl']:\n",
+ " config_path = f\"{root_dir}/generative-models/configs/inference/sd_xl_base.yaml\"\n",
+ "if model_version == 'sdxl_refiner':\n",
+ " config_path = f\"{root_dir}/generative-models/configs/inference/sd_xl_refiner.yaml\"\n",
+ "\n",
+ "control_helpers = {\n",
+ " \"control_sd15_canny\":None,\n",
+ " \"control_sd15_depth\":\"dpt_hybrid-midas-501f0c75.pt\",\n",
+ " \"control_sd15_softedge\":\"network-bsds500.pth\",\n",
+ " \"control_sd15_mlsd\":\"mlsd_large_512_fp32.pth\",\n",
+ " \"control_sd15_normalbae\":\"dpt_hybrid-midas-501f0c75.pt\",\n",
+ " \"control_sd15_openpose\":[\"body_pose_model.pth\", \"hand_pose_model.pth\"],\n",
+ " \"control_sd15_scribble\":None,\n",
+ " \"control_sd15_seg\":\"upernet_global_small.pth\",\n",
+ " \"control_sd15_temporalnet\":None,\n",
+ " \"control_sd15_face\":None,\n",
+ " \"control_sdxl_temporalnet_v1\":None\n",
+ "}\n",
+ "\n",
+ "\n",
+ "\n",
+ "if model_version == 'v1_instructpix2pix':\n",
+ " config_path = f\"{root_dir}/stablediffusion/configs/stable-diffusion/v1_instruct_pix2pix.yaml\"\n",
+ "vae_ckpt = '' #@param {'type':'string'}\n",
+ "if vae_ckpt == '': vae_ckpt = None\n",
+ "load_to = 'cpu' #@param ['cpu','gpu']\n",
+ "if load_to == 'gpu': load_to = 'cuda'\n",
+ "quantize = True\n",
+ "#@markdown Enable no_half_vae if you are getting black frames.\n",
+ "no_half_vae = False #@param {'type':'boolean'}\n",
+ "import gc\n",
+ "init_dummy = True #@param {'type':'boolean'}\n",
+ "if 'sdxl' not in model_version:\n",
+ " init_dummy = False\n",
+ " print('disabling init dummy for non-sdxl models')\n",
+ "\n",
+ "from accelerate.utils import named_module_tensors, set_module_tensor_to_device\n",
+ "\n",
+ "def handle_size_mismatch(sd):\n",
+ " context_dim = sd.get('model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight', None)\n",
+ " # print(context_dim.shape[-1])\n",
+ " suggested_model_version = ''\n",
+ " base = ''\n",
+ " if context_dim is None:\n",
+ " return('Unknown model base type. Make sure you are not using LORA as your base checkpoint. Please check your checkpoint base model is the same that you have selected from the model_version dropdown.')\n",
+ "\n",
+ " else:\n",
+ " context_dim = context_dim.shape[-1]\n",
+ " if context_dim == 768:\n",
+ " suggested_model_version = 'control_multi' #if 'control' in model_version else 'v1'\n",
+ " base = 'v1.x'\n",
+ " elif context_dim == 1024:\n",
+ " suggested_model_version = 'control_multi_v2_512' #if 'control' in model_version else 'v2_512'\n",
+ " base = 'v2.x'\n",
+ " elif context_dim == 1280:\n",
+ " suggested_model_version = 'sdxl_refiner'\n",
+ " base = 'sdxl_refiner'\n",
+ " elif context_dim == 2048:\n",
+ " suggested_model_version = 'control_multi_sdxl' #if 'control' in model_version else 'sdxl_base'\n",
+ " base = 'sdxl_base'\n",
+ " else:\n",
+ " return('Unknown model base type. Please check your checkpoint base model is the same that you have selected from the model_version dropdown.')\n",
+ "\n",
+ " return(f\"\"\"\n",
+ "Model version / checkpoint base type mismatch.\n",
+ "You have selected {model_version} model_version and provided a checkpoint with {base} base model version.\n",
+ "Double check your model checkpoint base model or try switching model_version to {suggested_model_version} and running this cell again.\"\"\")\n",
+ "\n",
+ "def move_tensors(module, device='cuda'):\n",
+ " for name, _ in named_module_tensors(module):\n",
+ " old_value = getattr(module, name)\n",
+ " if device == torch.device('meta'):\n",
+ " new_value = None\n",
+ " else:\n",
+ " new_value = torch.zeros_like(old_value, device=device)\n",
+ " set_module_tensor_to_device(module, name, device, value=new_value)\n",
+ "\n",
+ "def maybe_instantiate(ckpt, config):\n",
+ " if ckpt.endswith('.pkl'):\n",
+ " with open(ckpt, 'rb') as f:\n",
+ " model = pickle.load(f).eval()\n",
+ " return model #return loaded pickle\n",
+ "\n",
+ " dymmy_path = os.path.join(root_dir,f'{model_version}.dummypkl')\n",
+ " if not os.path.exists(dymmy_path) and init_dummy:\n",
+ " if model_version in ['sdxl_base', 'control_multi_sdxl']:\n",
+ " #download dummmypkl\n",
+ " dummypkl_out_path = dymmy_path\n",
+ " dummypkl_url = 'https://github.com/Sxela/WarpFusion/releases/download/v0.1.0/control_multi_sdxl.dummypkl'\n",
+ " print('Downloading dummypkl file.')\n",
+ " wget.download(dummypkl_url, dummypkl_out_path)\n",
+ " #load dummy\n",
+ " if (os.path.exists(dymmy_path) and init_dummy) or ckpt.endswith('.pkl'):\n",
+ " try:\n",
+ " print('Loading dummy pkl')\n",
+ " #try load dummy pkl instead of initializing model\n",
+ " with open(dymmy_path, 'rb') as f:\n",
+ " model = pickle.load(f).eval()\n",
+ " if model_version in ['sdxl_base', 'control_multi_sdxl']:\n",
+ "\n",
+ " model.conditioner.embedders[0].transformer.text_model.embeddings = model.conditioner.embedders[0].transformer.text_model.embeddings.to_empty(device='cuda').cuda()\n",
+ " model.conditioner.embedders[0].transformer.text_model.encoder = model.conditioner.embedders[0].transformer.text_model.encoder.to_empty(device='cuda').cuda()\n",
+ " model.conditioner.embedders[1].model.transformer = model.conditioner.embedders[1].model.transformer.to_empty(device='cuda').cuda()\n",
+ "\n",
+ " model.first_stage_model.encoder = model.first_stage_model.encoder.to_empty(device='cuda').cuda()\n",
+ " model.first_stage_model.decoder = model.first_stage_model.decoder.to_empty(device='cuda').cuda()\n",
+ " model.model.diffusion_model = model.model.diffusion_model.to_empty(device='cuda').cuda()\n",
+ "\n",
+ " # for key, value in model.named_parameters():\n",
+ " # if value.device == torch.device('meta'):\n",
+ " # print(key, 'meta')\n",
+ " # print(next(model.parameters()))\n",
+ " return model\n",
+ " except:\n",
+ " print(traceback.format_exc())\n",
+ " model = None\n",
+ " print('Found pkl file but failed loading. Probably codebase mismatch, try resaving.')\n",
+ " else: model = None\n",
+ "\n",
+ " # instantiate and save dummy\n",
+ " if model is None:\n",
+ " from IPython.utils import io\n",
+ " with io.capture_output(stderr=False) as captured:\n",
+ " model = instantiate_from_config(config.model)\n",
+ "\n",
+ " if not os.path.exists(dymmy_path) and init_dummy:\n",
+ " if use_torch_v2:\n",
+ " model.half()\n",
+ " if model_version in ['sdxl_base', 'control_multi_sdxl']:\n",
+ " model.conditioner.embedders[0].transformer.text_model.encoder = model.conditioner.embedders[0].transformer.text_model.encoder.to(torch.device('meta')).eval()\n",
+ " model.conditioner.embedders[1].model.transformer = model.conditioner.embedders[1].model.transformer.to(torch.device('meta')).eval()\n",
+ " model.conditioner.embedders[0].transformer.text_model.embeddings = model.conditioner.embedders[0].transformer.text_model.embeddings.to(torch.device('meta')).eval()\n",
+ " model.first_stage_model.encoder = model.first_stage_model.encoder.to(torch.device('meta')).eval()\n",
+ " model.first_stage_model.decoder = model.first_stage_model.decoder.to(torch.device('meta')).eval()\n",
+ " model.model.diffusion_model = model.model.diffusion_model.to(torch.device('meta')).eval()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " print('Saving dummy pkl')\n",
+ " with open(dymmy_path, 'wb') as f:\n",
+ " pickle.dump(model, f)\n",
+ " model.half()\n",
+ " #res\n",
+ " if model_version in ['sdxl_base', 'control_multi_sdxl']:\n",
+ " model.conditioner.embedders[0].transformer.text_model.embeddings = model.conditioner.embedders[0].transformer.text_model.embeddings.to_empty(device='cuda').cuda()\n",
+ " model.conditioner.embedders[0].transformer.text_model.encoder = model.conditioner.embedders[0].transformer.text_model.encoder.to_empty(device='cuda').cuda()\n",
+ " model.conditioner.embedders[1].model.transformer = model.conditioner.embedders[1].model.transformer.to_empty(device='cuda').cuda()\n",
+ " model.first_stage_model.encoder = model.first_stage_model.encoder.to_empty(device='cuda').cuda()\n",
+ " model.first_stage_model.decoder = model.first_stage_model.decoder.to_empty(device='cuda').cuda()\n",
+ " model.model.diffusion_model = model.model.diffusion_model.to_empty(device='cuda').cuda()\n",
+ " #save dummy model\n",
+ " return model\n",
+ "\n",
+ "def load_model_from_config(config, ckpt, vae_ckpt=None, controlnet=None, verbose=False):\n",
+ " with torch.no_grad():\n",
+ " from IPython.utils import io\n",
+ "\n",
+ " model = maybe_instantiate(ckpt, config)\n",
+ "\n",
+ " if gpu != 'A100':\n",
+ " model.half()\n",
+ "\n",
+ "\n",
+ " print(f\"Loading model from {ckpt}\")\n",
+ " if ckpt.endswith('.safetensors'):\n",
+ " pl_sd = {}\n",
+ " with safe_open(ckpt, framework=\"pt\", device=load_to) as f:\n",
+ " for key in f.keys():\n",
+ " pl_sd[key] = f.get_tensor(key)\n",
+ " else: pl_sd = torch.load(ckpt, map_location=load_to)\n",
+ "\n",
+ " if \"global_step\" in pl_sd:\n",
+ " print(f\"Global Step: {pl_sd['global_step']}\")\n",
+ " if \"state_dict\" in pl_sd:\n",
+ " sd = pl_sd[\"state_dict\"]\n",
+ " else: sd = pl_sd\n",
+ " del pl_sd\n",
+ " gc.collect()\n",
+ "\n",
+ " if vae_ckpt is not None:\n",
+ " print(f\"Loading VAE from {vae_ckpt}\")\n",
+ " if vae_ckpt.endswith('.safetensors'):\n",
+ " vae_sd = {}\n",
+ " with safe_open(vae_ckpt, framework=\"pt\", device=load_to) as f:\n",
+ " for key in f.keys():\n",
+ " vae_sd[key] = f.get_tensor(key)\n",
+ " else: vae_sd = torch.load(vae_ckpt, map_location=load_to)\n",
+ " if \"state_dict\" in vae_sd:\n",
+ " vae_sd = vae_sd[\"state_dict\"]\n",
+ " sd = {\n",
+ " k: vae_sd[k[len(\"first_stage_model.\") :]] if k.startswith(\"first_stage_model.\") else v\n",
+ " for k, v in sd.items()\n",
+ " }\n",
+ " if 'sdxl' in model_version:\n",
+ " sd['denoiser.sigmas'] = torch.zeros(1000).to(load_to)\n",
+ " try:\n",
+ " m, u = model.load_state_dict(sd, strict=False)\n",
+ " except Exception as e:\n",
+ " if type(e) == RuntimeError:\n",
+ " # print(e.args, e.with_traceback)\n",
+ " if 'Error(s) in loading state_dict' in e.args[0]:\n",
+ " print('Checkpoint and model_version size mismatch.')\n",
+ " msg = handle_size_mismatch(sd)\n",
+ " raise RuntimeError(msg)\n",
+ " if gpu != 'A100':\n",
+ " model.half()\n",
+ " model.cuda()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " if len(m) > 0 and verbose:\n",
+ " print(\"missing keys:\")\n",
+ " print(m, len(m))\n",
+ " if len(u) > 0 and verbose:\n",
+ " print(\"unexpected keys:\")\n",
+ " print(u, len(u))\n",
+ "\n",
+ " if controlnet is not None:\n",
+ " ckpt = controlnet\n",
+ " print(f\"Loading model from {ckpt}\")\n",
+ " if ckpt.endswith('.safetensors'):\n",
+ " pl_sd = {}\n",
+ " with safe_open(ckpt, framework=\"pt\", device=load_to) as f:\n",
+ " for key in f.keys():\n",
+ " pl_sd[key] = f.get_tensor(key)\n",
+ " else: pl_sd = torch.load(ckpt, map_location=load_to)\n",
+ "\n",
+ " if \"global_step\" in pl_sd:\n",
+ " print(f\"Global Step: {pl_sd['global_step']}\")\n",
+ " if \"state_dict\" in pl_sd:\n",
+ " sd = pl_sd[\"state_dict\"]\n",
+ " else: sd = pl_sd\n",
+ " del pl_sd\n",
+ " gc.collect()\n",
+ " m, u = model.control_model.load_state_dict(sd, strict=False)\n",
+ " if len(m) > 0 and verbose:\n",
+ " print(\"missing keys:\")\n",
+ " print(m, len(m))\n",
+ " if len(u) > 0 and verbose:\n",
+ " print(\"unexpected keys:\")\n",
+ " print(u, len(u))\n",
+ "\n",
+ " return model\n",
+ "\n",
+ "import clip\n",
+ "from kornia import augmentation as KA\n",
+ "from torch.nn import functional as F\n",
+ "from resize_right import resize\n",
+ "\n",
+ "def spherical_dist_loss(x, y):\n",
+ " x = F.normalize(x, dim=-1)\n",
+ " y = F.normalize(y, dim=-1)\n",
+ " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n",
+ "\n",
+ "from einops import rearrange, repeat\n",
+ "\n",
+ "def make_cond_model_fn(model, cond_fn):\n",
+ " def model_fn(x, sigma, **kwargs):\n",
+ " with torch.enable_grad():\n",
+ " # with torch.no_grad():\n",
+ " x = x.detach().requires_grad_()\n",
+ " # print('x.shape, sigma', x.shape, sigma)\n",
+ " denoised = model(x, sigma, **kwargs);# print(denoised.requires_grad)\n",
+ " # with torch.enable_grad():\n",
+ " # denoised = denoised.detach().requires_grad_()\n",
+ " cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach();# print(cond_grad.requires_grad)\n",
+ " cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)\n",
+ " return cond_denoised\n",
+ " return model_fn\n",
+ "\n",
+ "def make_cond_model_fn(model, cond_fn):\n",
+ " def model_fn(x, sigma, **kwargs):\n",
+ " with torch.enable_grad():\n",
+ " # with torch.no_grad():\n",
+ " # x = x.detach().requires_grad_()\n",
+ " # print('x.shape, sigma', x.shape, sigma)\n",
+ " denoised = model(x, sigma, **kwargs);# print(denoised.requires_grad)\n",
+ " # with torch.enable_grad():\n",
+ " # print(sigma**0.5, sigma, sigma**2)\n",
+ " denoised = denoised.detach().requires_grad_()\n",
+ " cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach();# print(cond_grad.requires_grad)\n",
+ " cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)\n",
+ " return cond_denoised\n",
+ " return model_fn\n",
+ "\n",
+ "\n",
+ "def make_static_thresh_model_fn(model, value=1.):\n",
+ " def model_fn(x, sigma, **kwargs):\n",
+ " # print('x.shape, sigma', x.shape, sigma)\n",
+ " return model(x, sigma, **kwargs).clamp(-value, value)\n",
+ " return model_fn\n",
+ "\n",
+ "def get_image_embed(x):\n",
+ " if x.shape[2:4] != clip_size:\n",
+ " x = resize(x, out_shape=clip_size, pad_mode='reflect')\n",
+ " # print('clip', x.shape)\n",
+ " # x = clip_normalize(x).cuda()\n",
+ " x = clip_model.encode_image(x).float()\n",
+ " return F.normalize(x)\n",
+ "\n",
+ "def load_img_sd(path, size):\n",
+ " # print(type(path))\n",
+ " # print('load_sd',path)\n",
+ "\n",
+ " image = Image.open(path).convert(\"RGB\")\n",
+ " # print(f'loaded img with size {image.size}')\n",
+ " image = image.resize(size, resample=Image.LANCZOS)\n",
+ " # w, h = image.size\n",
+ " # print(f\"loaded input image of size ({w}, {h}) from {path}\")\n",
+ " # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32\n",
+ "\n",
+ " # image = image.resize((w, h), resample=Image.LANCZOS)\n",
+ " if VERBOSE: print(f'resized to {image.size}')\n",
+ " image = np.array(image).astype(np.float32) / 255.0\n",
+ " image = image[None].transpose(0, 3, 1, 2)\n",
+ " image = torch.from_numpy(image)\n",
+ " return 2.*image - 1.\n",
+ "\n",
+ "# import lpips\n",
+ "# lpips_model = lpips.LPIPS(net='vgg').to(device)\n",
+ "batch_size = 1 #max batch size\n",
+ "class CFGDenoiser(nn.Module):\n",
+ " def __init__(self, model):\n",
+ " super().__init__()\n",
+ " self.inner_model = model\n",
+ "\n",
+ " def forward(self, x, sigma, uncond, cond, cond_scale, image_cond=None,\n",
+ " prompt_weights=None, prompt_masks=None):\n",
+ " if 'sdxl' in model_version:\n",
+ " vector = cond['vector']\n",
+ " uc_vector = uncond['vector']\n",
+ " vector_in = torch.cat([uc_vector, vector])\n",
+ " cond = cond['crossattn']\n",
+ " uncond = uncond['crossattn']\n",
+ " else:\n",
+ " cond = prompt_parser.reconstruct_cond_batch(cond, 0)\n",
+ " uncond = prompt_parser.reconstruct_cond_batch(uncond, 0)\n",
+ " if cond.shape[1]>77:\n",
+ " cond = cond[:,:77,:]\n",
+ " print('Prompt length > 77 detected. Shorten your prompt or split into multiple prompts.')\n",
+ " uncond = uncond[:,:77,:]\n",
+ "\n",
+ " batch_size = sd_batch_size\n",
+ " # print('batch size in cfgd ', batch_size, sd_batch_size)\n",
+ " cond_uncond_size = cond.shape[0]+uncond.shape[0]\n",
+ " # print('cond_uncond_size',cond_uncond_size)\n",
+ " x_in = torch.cat([x] * batch_size)\n",
+ " sigma_in = torch.cat([sigma] * batch_size)\n",
+ " # print('cond.shape, uncond.shape', cond.shape, uncond.shape)\n",
+ " cond_in = torch.cat([uncond, cond])\n",
+ " res = None\n",
+ " uc_mask_shape = torch.ones(cond_in.shape[0], device=cond_in.device)\n",
+ " uc_mask_shape[0] = 0\n",
+ " # sd_model.model.diffusion_model.uc_mask_shape = uc_mask_shape[]\n",
+ " if prompt_weights is None:\n",
+ " prompt_weights = [1.]*cond.shape[0]\n",
+ " if prompt_weights is not None:\n",
+ " assert len(prompt_weights) >= cond.shape[0], 'The number of prompts is more than prompt weigths.'\n",
+ " prompt_weights = prompt_weights[:cond.shape[0]]\n",
+ " prompt_weights = torch.tensor(prompt_weights).to(cond.device)\n",
+ " prompt_weights = prompt_weights/prompt_weights.sum()\n",
+ "\n",
+ " if prompt_masks is not None:\n",
+ " print('Using masked prompts')\n",
+ " assert len(prompt_masks) == cond.shape[0], 'The number of masks doesn`t match the number of prompts-1.'\n",
+ " prompt_masks = torch.tensor(prompt_masks).to(cond.device)\n",
+ " # print('prompt_masks', prompt_masks.shape)\n",
+ " # we use masks so that the 1st mask is full white, and others are applied on top of it\n",
+ "\n",
+ " n_batches = cond_uncond_size//batch_size if cond_uncond_size % batch_size == 0 else (cond_uncond_size//batch_size)+1\n",
+ " # print('n_batches',n_batches)\n",
+ " if image_cond is None:\n",
+ " for i in range(n_batches):\n",
+ " if model_version in ['sdxl_base', 'sdxl_refiner']:\n",
+ " sd_model.conditioner.vector_in = vector_in[i*batch_size:(i+1)*batch_size]\n",
+ " sd_model.model.diffusion_model.uc_mask_shape = uc_mask_shape[i*batch_size:(i+1)*batch_size]\n",
+ " pred = self.inner_model(x_in[i*batch_size:(i+1)*batch_size], sigma_in[i*batch_size:(i+1)*batch_size],\n",
+ " cond=cond_in[i*batch_size:(i+1)*batch_size])\n",
+ " res = pred if res is None else torch.cat([res, pred])\n",
+ " uncond, cond = res[0:1], res[1:]\n",
+ "\n",
+ " #we can use either weights or masks\n",
+ " if prompt_masks is None:\n",
+ " cond = (cond * prompt_weights[:, None, None, None]).sum(dim=0, keepdim=True)\n",
+ " else:\n",
+ " cond_out = cond[0]\n",
+ " for i in range(len(cond)):\n",
+ " if i == 0: continue\n",
+ " cond_out = (cond[i]*prompt_masks[i] + cond_out*(1-prompt_masks[i]))\n",
+ " cond = cond_out[None,...]\n",
+ " del cond_out\n",
+ "\n",
+ " return uncond + (cond - uncond) * cond_scale\n",
+ " else:\n",
+ " if 'control_multi' not in model_version:\n",
+ " if img_zero_uncond:\n",
+ " img_in = torch.cat([torch.zeros_like(image_cond),\n",
+ " image_cond.repeat(cond.shape[0],1,1,1)])\n",
+ " else:\n",
+ " img_in = torch.cat([image_cond]*(1+cond.shape[0]))\n",
+ "\n",
+ " for i in range(n_batches):\n",
+ " sd_model.model.diffusion_model.uc_mask_shape = uc_mask_shape[i*batch_size:(i+1)*batch_size]\n",
+ " cond_dict = {\"c_crossattn\": [cond_in[i*batch_size:(i+1)*batch_size]], 'c_concat':[img_in[i*batch_size:(i+1)*batch_size]]}\n",
+ " pred = self.inner_model(x_in[i*batch_size:(i+1)*batch_size], sigma_in[i*batch_size:(i+1)*batch_size], cond=cond_dict)\n",
+ " res = pred if res is None else torch.cat([res, pred])\n",
+ " uncond, cond = res[0:1], res[1:]\n",
+ "\n",
+ " if prompt_masks is None:\n",
+ " cond = (cond * prompt_weights[:, None, None, None]).sum(dim=0, keepdim=True)\n",
+ " else:\n",
+ " cond_out = cond[0]\n",
+ " for i in range(len(cond)):\n",
+ " if i == 0: continue\n",
+ " cond_out = (cond[i]*prompt_masks[i] + cond_out*(1-prompt_masks[i]))\n",
+ " cond = cond_out[None,...]\n",
+ " del cond_out\n",
+ "\n",
+ " return uncond + (cond - uncond) * cond_scale\n",
+ "\n",
+ "\n",
+ " if 'control_multi' in model_version and controlnet_multimodel_mode != 'external':\n",
+ " img_in = {}\n",
+ " for key in image_cond.keys():\n",
+ " if img_zero_uncond or key == 'control_sd15_shuffle':\n",
+ " img_in[key] = torch.cat([torch.zeros_like(image_cond[key]),\n",
+ " image_cond[key].repeat(cond.shape[0],1,1,1)])\n",
+ " else:\n",
+ " img_in[key] = torch.cat([image_cond[key]]*(1+cond.shape[0]))\n",
+ "\n",
+ " for i in range(n_batches):\n",
+ "\n",
+ " sd_model.model.diffusion_model.uc_mask_shape = uc_mask_shape[i*batch_size:(i+1)*batch_size]\n",
+ " batch_img_in = {}\n",
+ " for key in img_in.keys():\n",
+ "\n",
+ " batch_img_in[key] = img_in[key][i*batch_size:(i+1)*batch_size]\n",
+ " # print('img_in[key].shape, batch_img_in[key]',img_in[key].shape, batch_img_in[key].shape)\n",
+ " cond_dict = {\n",
+ " \"c_crossattn\": [cond_in[i*batch_size:(i+1)*batch_size]],\n",
+ " 'c_concat': batch_img_in,\n",
+ " 'controlnet_multimodel':controlnet_multimodel_inferred,\n",
+ " 'loaded_controlnets':loaded_controlnets\n",
+ " }\n",
+ " if 'sdxl' in model_version:\n",
+ " y = vector_in[i*batch_size:(i+1)*batch_size]\n",
+ " cond_dict['y'] = y\n",
+ " x_in = torch.cat([x]*cond_dict[\"c_crossattn\"][0].shape[0])\n",
+ " sigma_in = torch.cat([sigma]*cond_dict[\"c_crossattn\"][0].shape[0])\n",
+ " # print(x_in.shape, cond_dict[\"c_crossattn\"][0].shape)\n",
+ " pred = self.inner_model(x_in,\n",
+ " sigma_in, cond=cond_dict)\n",
+ " # print(pred.shape)\n",
+ " res = pred if res is None else torch.cat([res, pred])\n",
+ " # print(res.shape)\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " # print('res shape', res.shape)\n",
+ " uncond, cond = res[0:1], res[1:]\n",
+ " if prompt_masks is None:\n",
+ " # print('cond shape', cond.shape, prompt_weights[:, None, None, None].shape)\n",
+ " cond = (cond * prompt_weights[:, None, None, None]).sum(dim=0, keepdim=True)\n",
+ " else:\n",
+ " cond_out = cond[0]\n",
+ " for i in range(len(cond)):\n",
+ " if i == 0: continue\n",
+ " cond_out = (cond[i]*prompt_masks[i] + cond_out*(1-prompt_masks[i]))\n",
+ " cond = cond_out[None,...]\n",
+ " del cond_out\n",
+ " # print('cond.shape', cond.shape)\n",
+ "\n",
+ " return uncond + (cond - uncond) * cond_scale\n",
+ " if 'control_multi' in model_version and controlnet_multimodel_mode == 'external':\n",
+ "\n",
+ " #wormalize weights\n",
+ " weights = np.array([controlnet_multimodel[m][\"weight\"] for m in controlnet_multimodel.keys()])\n",
+ " weights = weights/weights.sum()\n",
+ " result = None\n",
+ " # print(weights)\n",
+ " for i,controlnet in enumerate(controlnet_multimodel.keys()):\n",
+ " try:\n",
+ " if img_zero_uncond or controlnet == 'control_sd15_shuffle':\n",
+ " img_in = torch.cat([torch.zeros_like(image_cond[controlnet]),\n",
+ " image_cond[controlnet].repeat(cond.shape[0],1,1,1)])\n",
+ " else:\n",
+ " img_in = torch.cat([image_cond[controlnet]]*(1+cond.shape[0]))\n",
+ " except:\n",
+ " pass\n",
+ "\n",
+ " if weights[i]!=0:\n",
+ " controlnet_settings = controlnet_multimodel[controlnet]\n",
+ "\n",
+ " self.inner_model.inner_model.control_model = loaded_controlnets[controlnet]\n",
+ " for i in range(n_batches):\n",
+ " sd_model.model.diffusion_model.uc_mask_shape = uc_mask_shape[i*batch_size:(i+1)*batch_size]\n",
+ " cond_dict = {\"c_crossattn\": [cond_in[i*batch_size:(i+1)*batch_size]],\n",
+ " 'c_concat':[img_in[i*batch_size:(i+1)*batch_size]]}\n",
+ " if 'sdxl' in model_version:\n",
+ " y = vector_in[i*batch_size:(i+1)*batch_size]\n",
+ " cond_dict['y'] = y\n",
+ " pred = self.inner_model(x_in[i*batch_size:(i+1)*batch_size], sigma_in[i*batch_size:(i+1)*batch_size], cond=cond_dict)\n",
+ " gc.collect()\n",
+ " res = pred if res is None else torch.cat([res, pred])\n",
+ "\n",
+ " uncond, cond = res[0:1], res[1:]\n",
+ " # uncond, cond = self.inner_model(x_in, sigma_in, cond={\"c_crossattn\": [cond_in],\n",
+ " # 'c_concat': [img_in]}).chunk(2)\n",
+ " if prompt_masks is None:\n",
+ " cond = (cond * prompt_weights[:, None, None, None]).sum(dim=0, keepdim=True)\n",
+ " else:\n",
+ " cond_out = cond[0]\n",
+ " for i in range(len(cond)):\n",
+ " if i == 0: continue\n",
+ " cond_out = (cond[i]*prompt_masks[i] + cond_out*(1-prompt_masks[i]))\n",
+ " cond = cond_out[None,...]\n",
+ " del cond_out\n",
+ "\n",
+ " if result is None:\n",
+ " result = (uncond + (cond - uncond) * cond_scale)*weights[i]\n",
+ " else: result = result + (uncond + (cond - uncond) * cond_scale)*weights[i]\n",
+ " return result\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "import einops\n",
+ "class InstructPix2PixCFGDenoiser(nn.Module):\n",
+ " def __init__(self, model):\n",
+ " super().__init__()\n",
+ " self.inner_model = model\n",
+ "\n",
+ " def forward(self, z, sigma, cond, uncond, cond_scale, image_scale, image_cond, **kwargs):\n",
+ " # c = cond\n",
+ " # uc = uncond\n",
+ " c = prompt_parser.reconstruct_cond_batch(cond, 0)\n",
+ " uc = prompt_parser.reconstruct_cond_batch(uncond, 0)\n",
+ " text_cfg_scale = cond_scale\n",
+ " image_cfg_scale = image_scale\n",
+ " # print(image_cond)\n",
+ " cond = {}\n",
+ " cond[\"c_crossattn\"] = [c]\n",
+ " cond[\"c_concat\"] = [image_cond]\n",
+ "\n",
+ " uncond = {}\n",
+ " uncond[\"c_crossattn\"] = [uc]\n",
+ " uncond[\"c_concat\"] = [torch.zeros_like(cond[\"c_concat\"][0])]\n",
+ "\n",
+ " cfg_z = einops.repeat(z, \"1 ... -> n ...\", n=3)\n",
+ " cfg_sigma = einops.repeat(sigma, \"1 ... -> n ...\", n=3)\n",
+ "\n",
+ " cfg_cond = {\n",
+ " \"c_crossattn\": [torch.cat([cond[\"c_crossattn\"][0], uncond[\"c_crossattn\"][0], uncond[\"c_crossattn\"][0]])],\n",
+ " \"c_concat\": [torch.cat([cond[\"c_concat\"][0], cond[\"c_concat\"][0], uncond[\"c_concat\"][0]])],\n",
+ " }\n",
+ " out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)\n",
+ " return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)\n",
+ "\n",
+ "dynamic_thresh = 2.\n",
+ "device = 'cuda'\n",
+ "# config_path = f\"{root_dir}/stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n",
+ "model_path = \"/content/drive/MyDrive/models/revAnimated_v122.safetensors\" #@param {'type':'string'}\n",
+ "import pickle\n",
+ "#@markdown ---\n",
+ "#@markdown ControlNet download settings\n",
+ "#@markdown ControlNet downloads are managed by controlnet_multi settings in Main settings tab.\n",
+ "use_small_controlnet = True\n",
+ "# #@param {'type':'boolean'}\n",
+ "small_controlnet_model_path = ''\n",
+ "# #@param {'type':'string'}\n",
+ "download_control_model = True\n",
+ "# #@param {'type':'boolean'}\n",
+ "force_download = False #@param {'type':'boolean'}\n",
+ "controlnet_models_dir = \"/content/drive/MyDrive/models/ControlNet\" #@param {'type':'string'}\n",
+ "if not is_colab and (controlnet_models_dir.startswith('/content') or controlnet_models_dir=='' or controlnet_models_dir is None):\n",
+ " controlnet_models_dir = f\"{root_dir}/ControlNet/models\"\n",
+ " print('You have a controlnet path set up for google drive, but we are not on Colab. Defaulting controlnet model path to ', controlnet_models_dir)\n",
+ "os.makedirs(controlnet_models_dir, exist_ok=True)\n",
+ "#@markdown ---\n",
+ "\n",
+ "control_sd15_canny = False\n",
+ "control_sd15_depth = False\n",
+ "control_sd15_softedge = True\n",
+ "control_sd15_mlsd = False\n",
+ "control_sd15_normalbae = False\n",
+ "control_sd15_openpose = False\n",
+ "control_sd15_scribble = False\n",
+ "control_sd15_seg = False\n",
+ "control_sd15_temporalnet = False\n",
+ "control_sd15_face = False\n",
+ "\n",
+ "if model_version == 'control_multi':\n",
+ " control_versions = []\n",
+ " if control_sd15_canny: control_versions+=['control_sd15_canny']\n",
+ " if control_sd15_depth: control_versions+=['control_sd15_depth']\n",
+ " if control_sd15_softedge: control_versions+=['control_sd15_softedge']\n",
+ " if control_sd15_mlsd: control_versions+=['control_sd15_mlsd']\n",
+ " if control_sd15_normalbae: control_versions+=['control_sd15_normalbae']\n",
+ " if control_sd15_openpose: control_versions+=['control_sd15_openpose']\n",
+ " if control_sd15_scribble: control_versions+=['control_sd15_scribble']\n",
+ " if control_sd15_seg: control_versions+=['control_sd15_seg']\n",
+ " if control_sd15_temporalnet: control_versions+=['control_sd15_temporalnet']\n",
+ " if control_sd15_face: control_versions+=['control_sd15_face']\n",
+ "else: control_versions = [model_version]\n",
+ "\n",
+ "if 'control_multi' in model_version:\n",
+ " os.chdir(f\"{root_dir}/ControlNet/\")\n",
+ " from annotator.util import resize_image, HWC3\n",
+ " from cldm.model import create_model, load_state_dict\n",
+ " os.chdir('../')\n",
+ "\n",
+ "if model_version in ['control_multi', 'control_multi_v2', 'control_multi_v2_768']:\n",
+ "\n",
+ " if model_version == 'control_multi':\n",
+ " config = OmegaConf.load(f\"{root_dir}/ControlNet/models/cldm_v15.yaml\")\n",
+ " elif model_version in ['control_multi_v2', 'control_multi_v2_768']:\n",
+ " config = OmegaConf.load(f\"{root_dir}/ControlNet/models/cldm_v21.yaml\")\n",
+ " sd_model = load_model_from_config(config=config,\n",
+ " ckpt=model_path, vae_ckpt=vae_ckpt, verbose=True)\n",
+ "\n",
+ " #legacy\n",
+ " sd_model.cond_stage_model.half()\n",
+ " sd_model.model.half()\n",
+ " sd_model.control_model.half()\n",
+ " sd_model.cuda()\n",
+ "\n",
+ " gc.collect()\n",
+ "else:\n",
+ " assert os.path.exists(model_path), f'Model not found at path: {model_path}. Please enter a valid path to the checkpoint file.'\n",
+ " if model_path.endswith('.pkl'):\n",
+ " with open(model_path, 'rb') as f:\n",
+ " sd_model = pickle.load(f).cuda().eval()\n",
+ " if gpu == 'A100':\n",
+ " sd_model = sd_model.float()\n",
+ " else:\n",
+ " config = OmegaConf.load(config_path)\n",
+ " from IPython.utils import io\n",
+ "\n",
+ " sd_model = load_model_from_config(config, model_path, vae_ckpt=vae_ckpt, verbose=True).cuda()\n",
+ "\n",
+ "sys.path.append('./stablediffusion/')\n",
+ "from modules import prompt_parser, sd_hijack\n",
+ "if 'sdxl' in model_version:\n",
+ " discretizer = LegacyDDPMDiscretization()\n",
+ " # sd_model.alphas_cumprod = torch.from_numpy(discretizer.alphas_cumprod)\n",
+ " sd_model.register_buffer('alphas_cumprod', torch.from_numpy(discretizer.alphas_cumprod))\n",
+ " sd_model.model.conditioning_key = 'c_crossattn'\n",
+ " sd_model.cond_stage_model = sd_model.conditioner\n",
+ " sd_model.parameterization = 'eps'\n",
+ "\n",
+ " def apply_model_vector(x_noisy, t, cond, self=sd_model, **kwargs):\n",
+ " cond = {\n",
+ " 'crossattn': cond,\n",
+ " 'vector':self.conditioner.vector_in\n",
+ " }\n",
+ " return self.model.forward(x_noisy, t, cond, **kwargs)\n",
+ "\n",
+ " def get_first_stage_encoding(z, self=sd_model):\n",
+ " return z\n",
+ "\n",
+ " def get_unconditional_conditioning_sdxl(batch_c, self=sd_model.conditioner):\n",
+ "\n",
+ " def get_unique_embedder_keys_from_conditioner(conditioner):\n",
+ " return list(set([x.input_key for x in conditioner.embedders]))\n",
+ "\n",
+ " W, H = width_height\n",
+ "\n",
+ " init_dict = {\n",
+ " \"orig_width\": W,\n",
+ " \"orig_height\": H,\n",
+ " \"target_width\": W,\n",
+ " \"target_height\": H,\n",
+ " }\n",
+ "\n",
+ " prompt = batch_c\n",
+ " negative_prompt = prompt\n",
+ " num_samples = len(prompt)\n",
+ " force_uc_zero_embeddings = []\n",
+ "\n",
+ " value_dict = init_dict\n",
+ " value_dict[\"prompt\"] = prompt\n",
+ " value_dict[\"negative_prompt\"] = ['']\n",
+ "\n",
+ " value_dict[\"crop_coords_top\"] = 0\n",
+ " value_dict[\"crop_coords_left\"] = 0\n",
+ "\n",
+ " value_dict[\"aesthetic_score\"] = 6.0\n",
+ " value_dict[\"negative_aesthetic_score\"] = 2.5\n",
+ "\n",
+ " batch_c, _ = get_batch(\n",
+ " get_unique_embedder_keys_from_conditioner(sd_model.conditioner),\n",
+ " value_dict,\n",
+ " [num_samples],\n",
+ " )\n",
+ "\n",
+ " ucg_rates = list()\n",
+ " for embedder in self.embedders:\n",
+ " ucg_rates.append(embedder.ucg_rate)\n",
+ " embedder.ucg_rate = 0.0\n",
+ " c = self(batch_c)\n",
+ "\n",
+ " [print(c[key].shape) for key in c.keys()]\n",
+ " for embedder, rate in zip(self.embedders, ucg_rates):\n",
+ " embedder.ucg_rate = rate\n",
+ "\n",
+ " return c\n",
+ "\n",
+ " sd_model.get_learned_conditioning = get_unconditional_conditioning_sdxl\n",
+ " sd_model.apply_model = apply_model_vector\n",
+ " sd_model.disable_first_stage_autocast = no_half_vae\n",
+ " sd_model.get_first_stage_encoding = get_first_stage_encoding\n",
+ "\n",
+ "if sd_model.parameterization == \"v\" or model_version == 'control_multi_v2_768':\n",
+ " model_wrap = K.external.CompVisVDenoiser(sd_model, quantize=quantize )\n",
+ "else:\n",
+ " model_wrap = K.external.CompVisDenoiser(sd_model, quantize=quantize)\n",
+ "sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n",
+ "model_wrap_cfg = CFGDenoiser(model_wrap)\n",
+ "if model_version == 'v1_instructpix2pix':\n",
+ " model_wrap_cfg = InstructPix2PixCFGDenoiser(model_wrap)\n",
+ "\n",
+ "#@markdown If you're having crashes (CPU out of memory errors) while running this cell on standard colab env, consider saving the model as pickle.\\\n",
+ "#@markdown You can save the pickled model on your google drive and use it instead of the usual stable diffusion model.\\\n",
+ "#@markdown To do that, run the notebook with a high-ram env, run all cells before and including this cell as well, and save pickle in the next cell. Then you can switch to a low-ram env and load the pickled model.\n",
+ "\n",
+ "def make_batch_sd(\n",
+ " image,\n",
+ " mask,\n",
+ " txt,\n",
+ " device,\n",
+ " num_samples=1, inpainting_mask_weight=1):\n",
+ " image = np.array(image.convert(\"RGB\"))\n",
+ " image = image[None].transpose(0,3,1,2)\n",
+ " image = torch.from_numpy(image).to(dtype=torch.float32)/127.5-1.0\n",
+ "\n",
+ " if mask is not None:\n",
+ " mask = np.array(mask.convert(\"L\"))\n",
+ " mask = mask.astype(np.float32)/255.0\n",
+ " mask = mask[None,None]\n",
+ " mask[mask < 0.5] = 0\n",
+ " mask[mask >= 0.5] = 1\n",
+ " mask = torch.from_numpy(mask)\n",
+ " else:\n",
+ " mask = image.new_ones(1, 1, *image.shape[-2:])\n",
+ "\n",
+ " # masked_image = image * (mask < 0.5)\n",
+ "\n",
+ " masked_image = torch.lerp(\n",
+ " image,\n",
+ " image * (mask < 0.5),\n",
+ " inpainting_mask_weight\n",
+ " )\n",
+ "\n",
+ " batch = {\n",
+ " \"image\": repeat(image.to(device=device), \"1 ... -> n ...\", n=num_samples),\n",
+ " \"txt\": num_samples * [txt],\n",
+ " \"mask\": repeat(mask.to(device=device), \"1 ... -> n ...\", n=num_samples),\n",
+ " \"masked_image\": repeat(masked_image.to(device=device), \"1 ... -> n ...\", n=num_samples),\n",
+ " }\n",
+ " return batch\n",
+ "\n",
+ "def inpainting_conditioning(source_image, image_mask = None, inpainting_mask_weight = 1, sd_model=sd_model):\n",
+ " #based on https://github.com/AUTOMATIC1111/stable-diffusion-webui\n",
+ "\n",
+ " # Handle the different mask inputs\n",
+ " if image_mask is not None:\n",
+ "\n",
+ " if torch.is_tensor(image_mask):\n",
+ "\n",
+ " conditioning_mask = image_mask[:,:1,...]\n",
+ " # print('mask conditioning_mask', conditioning_mask.shape)\n",
+ " else:\n",
+ " print(image_mask.shape, source_image.shape)\n",
+ " # conditioning_mask = np.array(image_mask.convert(\"L\"))\n",
+ " conditioning_mask = image_mask[...,0].astype(np.float32) / 255.0\n",
+ " conditioning_mask = torch.from_numpy(conditioning_mask[None, None]).float()\n",
+ "\n",
+ " # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0\n",
+ " conditioning_mask = torch.round(conditioning_mask)\n",
+ " else:\n",
+ " conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])\n",
+ " print(conditioning_mask.shape, source_image.shape)\n",
+ " # Create another latent image, this time with a masked version of the original input.\n",
+ " # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.\n",
+ " conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)\n",
+ " conditioning_image = torch.lerp(\n",
+ " source_image,\n",
+ " source_image * (1.0 - conditioning_mask),\n",
+ " inpainting_mask_weight\n",
+ " )\n",
+ "\n",
+ " # Encode the new masked image using first stage of network.\n",
+ " conditioning_image = sd_model.get_first_stage_encoding( sd_model.encode_first_stage(conditioning_image))\n",
+ "\n",
+ " # Create the concatenated conditioning tensor to be fed to `c_concat`\n",
+ " conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=conditioning_image.shape[-2:])\n",
+ " conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)\n",
+ " image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)\n",
+ " image_conditioning = image_conditioning.to('cuda').type( sd_model.dtype)\n",
+ "\n",
+ " return image_conditioning\n",
+ "\n",
+ "import torch\n",
+ "# divisible by 8 fix from AUTOMATIC1111\n",
+ "def cat8(tensors, *args, **kwargs):\n",
+ " if len(tensors) == 2:\n",
+ " a, b = tensors\n",
+ " if a.shape[-2:] != b.shape[-2:]:\n",
+ " a = torch.nn.functional.interpolate(a, b.shape[-2:], mode=\"nearest\")\n",
+ "\n",
+ " tensors = (a, b)\n",
+ "\n",
+ " return torch.cat(tensors, *args, **kwargs)\n",
+ "\n",
+ "from torch import fft\n",
+ "def Fourier_filter(x, threshold, scale):\n",
+ " # FFT\n",
+ " x_freq = fft.fftn(x, dim=(-2, -1))\n",
+ " x_freq = fft.fftshift(x_freq, dim=(-2, -1))\n",
+ "\n",
+ " B, C, H, W = x_freq.shape\n",
+ " mask = torch.ones((B, C, H, W)).cuda()\n",
+ "\n",
+ " crow, ccol = H // 2, W //2\n",
+ " mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale\n",
+ " x_freq = x_freq * mask\n",
+ "\n",
+ " # IFFT\n",
+ " x_freq = fft.ifftshift(x_freq, dim=(-2, -1))\n",
+ " x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real\n",
+ "\n",
+ " return x_filtered\n",
+ "\n",
+ "b1= 1.2\n",
+ "b2= 1.4\n",
+ "s1= 0.9\n",
+ "s2= 0.2\n",
+ "\n",
+ "def apply_freeu(h, _hs):\n",
+ " if h.shape[1] == 1280:\n",
+ " h[:,:640] = h[:,:640] * b1\n",
+ " _hs = Fourier_filter(_hs.float(), threshold=1, scale=s1)\n",
+ " if h.shape[1] == 640:\n",
+ " h[:,:320] = h[:,:320] * b2\n",
+ " _hs = Fourier_filter(_hs.float(), threshold=1, scale=s2)\n",
+ " return h, _hs\n",
+ "\n",
+ "\n",
+ "def cldm_forward(x, timesteps=None, context=None, control=None, only_mid_control=False, self = sd_model.model.diffusion_model,**kwargs):\n",
+ " hs = []\n",
+ "\n",
+ " t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n",
+ " emb = self.time_embed(t_emb)\n",
+ " h = x.type(self.dtype)\n",
+ " for module in self.input_blocks:\n",
+ " h = module(h, emb, context)\n",
+ " hs.append(h)\n",
+ " h = self.middle_block(h, emb, context)\n",
+ "\n",
+ " if control is not None: h += control.pop()\n",
+ "\n",
+ " for i, module in enumerate(self.output_blocks):\n",
+ "\n",
+ " _hs = hs.pop()\n",
+ " if do_freeunet and not apply_freeu_after_control:\n",
+ " h, _hs = apply_freeu(h, _hs)\n",
+ "\n",
+ " if not only_mid_control and control is not None:\n",
+ " _control = control.pop()\n",
+ " _hs += _control\n",
+ "\n",
+ " if do_freeunet and apply_freeu_after_control:\n",
+ " h, _hs = apply_freeu(h, _hs)\n",
+ "\n",
+ " h = cat8([h, _hs], dim=1)\n",
+ " h = module(h, emb, context)\n",
+ "\n",
+ " h = h.type(x.dtype)\n",
+ " return self.out(h)\n",
+ "\n",
+ "def sdxl_cn_forward(x, timesteps=None, context=None, y=None, control=None, self=sd_model.model.diffusion_model,\n",
+ " only_mid_control=False, **kwargs):\n",
+ " \"\"\"\n",
+ " Apply the model to an input batch.\n",
+ " :param x: an [N x C x ...] Tensor of inputs.\n",
+ " :param timesteps: a 1-D batch of timesteps.\n",
+ " :param context: conditioning plugged in via crossattn\n",
+ " :param y: an [N] Tensor of labels, if class-conditional.\n",
+ " :return: an [N x C x ...] Tensor of outputs.\n",
+ " \"\"\"\n",
+ " assert (y is not None) == (\n",
+ " self.num_classes is not None\n",
+ " ), \"must specify y if and only if the model is class-conditional\"\n",
+ " hs = []\n",
+ " t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n",
+ " emb = self.time_embed(t_emb)\n",
+ "\n",
+ " if self.num_classes is not None:\n",
+ " assert y.shape[0] == x.shape[0]\n",
+ " emb = emb + self.label_emb(y)\n",
+ "\n",
+ " h = x\n",
+ " for module in self.input_blocks:\n",
+ " h = module(h, emb, context)\n",
+ " hs.append(h)\n",
+ " h = self.middle_block(h, emb, context)\n",
+ "\n",
+ " if control is not None: h += control.pop()\n",
+ "\n",
+ " for module in self.output_blocks:\n",
+ " _hs = hs.pop()\n",
+ " if do_freeunet and not apply_freeu_after_control:\n",
+ " h, _hs = apply_freeu(h, _hs)\n",
+ "\n",
+ " if not only_mid_control and control is not None:\n",
+ " _control = control.pop()\n",
+ " _hs += _control\n",
+ "\n",
+ " if do_freeunet and apply_freeu_after_control:\n",
+ " h, _hs = apply_freeu(h, _hs)\n",
+ "\n",
+ " h = cat8([h, _hs], dim=1)\n",
+ " h = module(h, emb, context)\n",
+ " h = h.type(x.dtype)\n",
+ " return self.out(h)\n",
+ "\n",
+ "if model_version == 'control_multi_sdxl':\n",
+ " sd_model.model.diffusion_model.forward = sdxl_cn_forward\n",
+ "\n",
+ "try:\n",
+ " if 'sdxl' not in model_version:\n",
+ " sd_model.model.diffusion_model.forward = cldm_forward\n",
+ "\n",
+ "except Exception as e:\n",
+ " print(e)\n",
+ " # pass\n",
+ "\n",
+ "if 'sdxl' in model_version:\n",
+ " @torch.enable_grad()\n",
+ " def differentiable_decode_first_stage(z, self=sd_model):\n",
+ " z = 1.0 / self.scale_factor * z\n",
+ " with torch.autocast(\"cuda\", enabled=not self.disable_first_stage_autocast):\n",
+ " out = self.first_stage_model.decode(z)\n",
+ " return out\n",
+ " sd_model.differentiable_decode_first_stage = differentiable_decode_first_stage\n",
+ "\n",
+ "#from colab\n",
+ "def apply_model_sdxl_cn(x_noisy, t, cond, self=sd_model, *args, **kwargs):\n",
+ " if 'sdxl' in model_version:\n",
+ " y = cond['y']\n",
+ " else: y = None\n",
+ "\n",
+ " t_ratio = 1-t[0]/self.num_timesteps;\n",
+ " assert isinstance(cond, dict)\n",
+ " diffusion_model = self.model.diffusion_model\n",
+ " cond_txt = torch.cat(cond['c_crossattn'], 1)\n",
+ " # y = torch.cat(cond['y'], 1)\n",
+ "\n",
+ " #if dict - we've got a multicontrolnet\n",
+ " if cond['c_concat'] is None:\n",
+ " control = None\n",
+ " eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control, y=y)\n",
+ " return eps\n",
+ "\n",
+ " if isinstance(cond['c_concat'], dict):\n",
+ " try:\n",
+ " uc_mask_shape = diffusion_model.uc_mask_shape[:, None, None, None]\n",
+ "\n",
+ " except:\n",
+ " uc_mask_shape = torch.ones(x_noisy.shape[0], device=x_noisy.device)\n",
+ " controlnet_multimodel = cond['controlnet_multimodel']\n",
+ " loaded_controlnets = cond['loaded_controlnets']\n",
+ " control_wsum = None\n",
+ " #loop throught all controlnets to get controls\n",
+ " active_models = {}\n",
+ " for key in controlnet_multimodel.keys():\n",
+ " settings = controlnet_multimodel[key]\n",
+ " if settings['weight']!=0 and t_ratio>=settings['start'] and t_ratio<=settings['end']:\n",
+ " active_models[key] = controlnet_multimodel[key]\n",
+ " weights = np.array([active_models[m][\"weight\"] for m in active_models.keys()])\n",
+ " if self.normalize_weights: weights = weights/weights.sum()\n",
+ "\n",
+ " if self.low_vram:\n",
+ " diffusion_model.cpu()\n",
+ " for key in active_models.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ " torch.cuda.empty_cache(); gc.collect()\n",
+ " for i,key in enumerate(active_models.keys()):\n",
+ " if self.debug:\n",
+ " print('controlnet_multimodel keys ', controlnet_multimodel[key].keys())\n",
+ " print('Using layer weights ', controlnet_multimodel[key]['layer_weights'], key)\n",
+ " try:\n",
+ " cond_hint = torch.cat([cond['c_concat'][key]], 1)\n",
+ " if 'zero_uncond' in controlnet_multimodel[key].keys():\n",
+ " if controlnet_multimodel[key]['zero_uncond']:\n",
+ " if self.debug: print(f'Using zero uncond {list(uc_mask_shape.detach().cpu().numpy())} for {key}')\n",
+ " cond_hint*=uc_mask_shape # try zeroing the prediction, should mimic zero uncond, need to research more\n",
+ " if self.low_vram:\n",
+ " loaded_controlnets[key].half().to(device=cond_hint.device)\n",
+ " with torch.autocast('cuda'), torch.no_grad(), torch.inference_mode():\n",
+ " control = loaded_controlnets[key].cuda()(x=x_noisy, hint=cond_hint,\n",
+ " timesteps=t, context=cond_txt, y=y)\n",
+ "\n",
+ " if 'layer_weights' in controlnet_multimodel[key].keys():\n",
+ " control_scales = controlnet_multimodel[key]['layer_weights']\n",
+ " if self.debug: print('Using layer weights ', control_scales, key)\n",
+ " control_scales = control_scales[:len(control)]\n",
+ " control = [c * scale for c, scale in zip(control, control_scales)]\n",
+ " if key == 'control_sd15_shuffle':\n",
+ " #apply avg pooling for shuffle control\n",
+ " control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]\n",
+ " if control_wsum is None: control_wsum = [weights[i]*o for o in control]\n",
+ " else: control_wsum = [weights[i]*c+cs for c,cs in zip(control,control_wsum)]\n",
+ " if self.low_vram:\n",
+ " loaded_controlnets[key].cpu()\n",
+ " torch.cuda.empty_cache(); gc.collect()\n",
+ " except Exception as e:\n",
+ " assert type(e) != torch.cuda.OutOfMemoryError, 'Got CUDA out of memory during ControlNet proccessing.'\n",
+ " print(e)\n",
+ " control = control_wsum\n",
+ " else:\n",
+ " cond_hint = torch.cat(cond['c_concat'], 1)\n",
+ " control = self.control_model(x=x_noisy, hint=cond_hint,\n",
+ " timesteps=t, context=cond_txt, y=y)\n",
+ "\n",
+ " if control is not None:\n",
+ " control = [c * scale for c, scale in zip(control, self.control_scales)]\n",
+ " if self.global_average_pooling:\n",
+ " control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]\n",
+ " if self.low_vram:\n",
+ " for key in active_models.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ " torch.cuda.empty_cache(); gc.collect()\n",
+ " diffusion_model.half().cuda()\n",
+ "\n",
+ " eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, y=y, only_mid_control=self.only_mid_control)\n",
+ "\n",
+ " return eps\n",
+ "if model_version == 'control_multi':\n",
+ " #try using it with v1.5 cn as well\n",
+ " sd_model.apply_model = apply_model_sdxl_cn\n",
+ "if model_version == 'control_multi_sdxl':\n",
+ " sd_model.apply_model = apply_model_sdxl_cn\n",
+ " os.chdir(f'{root_dir}/ComfyUI')\n",
+ " import sys, os\n",
+ " sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath('./')), \"comfy\"))\n",
+ " sys.argv=['']\n",
+ " import os, sys\n",
+ " sys.path.append('./comfy')\n",
+ " from comfy.sd import load_controlnet\n",
+ " os.chdir(f'{root_dir}')\n",
+ "\n",
+ " sd_model.num_timesteps = 1000\n",
+ " sd_model.debug = False\n",
+ " sd_model.global_average_pooling = False\n",
+ " sd_model.only_mid_control = False\n",
+ "\n",
+ " import comfy\n",
+ " sd_model.model.model_config = nn.Module() #dummy module to assign config to it\n",
+ " sd_model.model.model_config.unet_config = OmegaConf.load(config_path).model.params.network_config.params\n",
+ " sd_model.model.model_config.unet_config = dict(sd_model.model.model_config.unet_config)\n",
+ " sd_model.model.model_config.unet_config['out_channels'] = ''\n",
+ " sd_model.model.model_config.unet_config.pop('spatial_transformer_attn_type')\n",
+ " sd_model.model.model_config.unet_config['image_size'] = 32\n",
+ " def get_dtype(self=sd_model.model):\n",
+ " return next(self.parameters()).dtype\n",
+ " sd_model.model.get_dtype = get_dtype\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z36v90fNgLMF"
+ },
+ "source": [
+ "# Extra features"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "37oGv9dhVRDE"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Tiled VAE\n",
+ "#@markdown Enable if you're getting CUDA Out of memory errors during encode_first_stage or decode_first_stage.\n",
+ "#@markdown Is slower.\n",
+ "#tiled vae from thttps://github.com/CompVis/latent-diffusion\n",
+ "cell_name = 'tiled_vae'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "use_tiled_vae = True #@param {'type':'boolean'}\n",
+ "tile_size = 128 #\\@param {'type':'number'}\n",
+ "stride = 96 #\\@param {'type':'number'}\n",
+ "#@markdown how many tiles per side [H,W]\n",
+ "num_tiles = [2,2] #@param {'type':'raw'}\n",
+ "padding = [0.5,0.5] #\\@param {'type':'raw'}\n",
+ "print(f'Splitting WxH {width_height} into {num_tiles[0]*num_tiles[1]} {width_height[0]//num_tiles[1]}x{width_height[1]//num_tiles[0]} tiles' )\n",
+ "\n",
+ "if num_tiles in [0, '', None]:\n",
+ " num_tiles = None\n",
+ "\n",
+ "if padding in [0, '', None]:\n",
+ " padding = None\n",
+ "def get_fold_unfold( x, kernel_size, stride, uf=1, df=1, self=sd_model): # todo load once not every time, shorten code\n",
+ " \"\"\"\n",
+ " :param x: img of size (bs, c, h, w)\n",
+ " :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])\n",
+ " \"\"\"\n",
+ " bs, nc, h, w = x.shape\n",
+ "\n",
+ " # number of crops in image\n",
+ " Ly = (h - kernel_size[0]) // stride[0] + 1\n",
+ " Lx = (w - kernel_size[1]) // stride[1] + 1\n",
+ "\n",
+ " if uf == 1 and df == 1:\n",
+ " fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n",
+ " unfold = torch.nn.Unfold(**fold_params)\n",
+ "\n",
+ " fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)\n",
+ "\n",
+ " weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)\n",
+ " normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap\n",
+ " weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))\n",
+ "\n",
+ " elif uf > 1 and df == 1:\n",
+ " fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n",
+ " unfold = torch.nn.Unfold(**fold_params)\n",
+ "\n",
+ " fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[1] * uf),\n",
+ " dilation=1, padding=0,\n",
+ " stride=(stride[0] * uf, stride[1] * uf))\n",
+ " fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)\n",
+ "\n",
+ " weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)\n",
+ " normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap\n",
+ " weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))\n",
+ "\n",
+ " elif df > 1 and uf == 1:\n",
+ " fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n",
+ " unfold = torch.nn.Unfold(**fold_params)\n",
+ "\n",
+ " fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[1] // df),\n",
+ " dilation=1, padding=0,\n",
+ " stride=(stride[0] // df, stride[1] // df))\n",
+ " fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)\n",
+ "\n",
+ " weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)\n",
+ " normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap\n",
+ " weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))\n",
+ "\n",
+ " else:\n",
+ " raise NotImplementedError\n",
+ "\n",
+ " normalization = torch.where(normalization==0.,1e-6, normalization)\n",
+ " # if 'sdxl' in model_version:\n",
+ " # normalization = torch.where(normalization!=1.,1, 1)\n",
+ " # weighting = torch.where(weighting!=1.,1, 1)\n",
+ " return fold, unfold, normalization, weighting\n",
+ "\n",
+ "#non divisible by 8 fails here\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def encode_first_stage(x, self=sd_model):\n",
+ " ts = time.time()\n",
+ " if hasattr(self, \"split_input_params\"):\n",
+ " with torch.autocast('cuda'):\n",
+ " # if no_half_vae:\n",
+ " # self.first_stage_model = self.first_stage_model.float()\n",
+ " # x = x.float()\n",
+ " if self.split_input_params[\"patch_distributed_vq\"]:\n",
+ " print('------using tiled vae------')\n",
+ " bs, nc, h, w = x.shape\n",
+ " df = self.split_input_params[\"vqf\"]\n",
+ " if self.split_input_params[\"num_tiles\"] is not None:\n",
+ " num_tiles = self.split_input_params[\"num_tiles\"]\n",
+ " ks = [h//num_tiles[0], w//num_tiles[1]]\n",
+ " else:\n",
+ " ks = self.split_input_params[\"ks\"] # eg. (128, 128)\n",
+ " ks = [o*(df) for o in ks]\n",
+ " # ks = self.split_input_params[\"ks\"] # eg. (128, 128)\n",
+ " # ks = [o*df for o in ks]\n",
+ "\n",
+ "\n",
+ " if self.split_input_params[\"padding\"] is not None:\n",
+ " padding = self.split_input_params[\"padding\"]\n",
+ " stride = [int(ks[0]*padding[0]), int(ks[1]*padding[1])]\n",
+ " else:\n",
+ " stride = self.split_input_params[\"stride\"] # eg. (64, 64)\n",
+ " stride = [o*(df) for o in stride]\n",
+ " # stride = self.split_input_params[\"stride\"] # eg. (64, 64)\n",
+ " # stride = [o*df for o in stride]\n",
+ " # ks = [512,512]\n",
+ " # stride = [512,512]\n",
+ "\n",
+ "\n",
+ " # print('kernel, stride', ks, stride)\n",
+ "\n",
+ " self.split_input_params['original_image_size'] = x.shape[-2:]\n",
+ " bs, nc, h, w = x.shape\n",
+ "\n",
+ " target_h = math.ceil(h/ks[0])*ks[0]\n",
+ " target_w = math.ceil(w/ks[1])*ks[1]\n",
+ " padh = target_h - h\n",
+ " padw = target_w - w\n",
+ " pad = (0, padw, 0, padh)\n",
+ " if target_h != h or target_w != w:\n",
+ " print('Padding.')\n",
+ " # print('padding from ', h, w, 'to ', target_h, target_w)\n",
+ " x = torch.nn.functional.pad(x, pad, mode='reflect')\n",
+ " # print('padded from ', h, w, 'to ', z.shape[2:])\n",
+ "\n",
+ " if ks[0] > h or ks[1] > w:\n",
+ " ks = (min(ks[0], h), min(ks[1], w))\n",
+ " print(\"reducing Kernel\")\n",
+ "\n",
+ " if stride[0] > h or stride[1] > w:\n",
+ " stride = (min(stride[0], h), min(stride[1], w))\n",
+ " print(\"reducing stride\")\n",
+ "\n",
+ " fold, unfold, normalization, weighting = get_fold_unfold(x, ks, stride, df=df)\n",
+ " z = unfold(x) # (bn, nc * prod(**ks), L)\n",
+ " # Reshape to img shape\n",
+ " z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )\n",
+ " # print('z', z.shape)\n",
+ " # print('z stats', z.min(), z.max(), z.std(), z.mean())\n",
+ "\n",
+ " if no_half_vae:\n",
+ " self.disable_first_stage_autocast = True\n",
+ " self.first_stage_model.float()\n",
+ " z = z.float()\n",
+ " with torch.autocast('cuda', enabled=False):\n",
+ " output_list = [self.get_first_stage_encoding(self.first_stage_model.encode(z[:, :, :, :, i].float()), tiled_vae_call=True)\n",
+ " for i in range(z.shape[-1])]\n",
+ " # o = self.first_stage_model.encode(z[:, :, :, :, 0].float())\n",
+ " # o = self.scale_factor * o\n",
+ " # print('o stats', o.min(), o.max(), o.std(), o.mean())\n",
+ " # print(z.shape)\n",
+ " else:\n",
+ " output_list = [self.get_first_stage_encoding(self.first_stage_model.encode(z[:, :, :, :, i]), tiled_vae_call=True)\n",
+ " for i in range(z.shape[-1])]\n",
+ " # print('output_list stats', output_list[0].min(), output_list[0].max(), output_list[0].std(), output_list[0].mean())\n",
+ " o = torch.stack(output_list, axis=-1)\n",
+ " if 'sdxl' in model_version:\n",
+ " o = self.scale_factor * o\n",
+ " # print('o stats', o.min(), o.max(), o.std(), o.mean())\n",
+ " o = o * weighting\n",
+ " # print('o stats', o.min(), o.max(), o.std(), o.mean())\n",
+ " # print('o', o.shape)\n",
+ " # Reverse reshape to img shape\n",
+ " o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)\n",
+ " # print('o stats', o.min(), o.max(), o.std(), o.mean())\n",
+ " # stitch crops together\n",
+ " decoded = fold(o)\n",
+ " # print('decoded stats', decoded.min(), decoded.max(), decoded.std(), decoded.mean())\n",
+ " decoded = decoded / normalization\n",
+ " print('Tiled vae encoder took ', f'{time.time()-ts:.2}')\n",
+ " # print('decoded stats', decoded.min(), decoded.max(), decoded.std(), decoded.mean())\n",
+ " return decoded[...,:h//df, :w//df]\n",
+ "\n",
+ " else:\n",
+ " print('Vae encoder took ', f'{time.time()-ts:.2}')\n",
+ " # print('x stats', x.min(), x.max(), x.std(), x.mean())\n",
+ " return self.first_stage_model.encode(x)\n",
+ " else:\n",
+ " print('Vae encoder took ', f'{time.time()-ts:.2}')\n",
+ " # print('x stats', x.min(), x.max(), x.std(), x.mean())\n",
+ " return self.first_stage_model.encode(x)\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def decode_first_stage(z, predict_cids=False, force_not_quantize=False, self=sd_model):\n",
+ " ts = time.time()\n",
+ " if predict_cids:\n",
+ " if z.dim() == 4:\n",
+ " z = torch.argmax(z.exp(), dim=1).long()\n",
+ " z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)\n",
+ " z = rearrange(z, 'b h w c -> b c h w').contiguous()\n",
+ "\n",
+ " z = 1. / self.scale_factor * z\n",
+ "\n",
+ " if hasattr(self, \"split_input_params\"):\n",
+ " with torch.autocast('cuda'):\n",
+ "\n",
+ " print('------using tiled vae------')\n",
+ " # print('latent shape: ', z.shape)\n",
+ " # print(self.split_input_params)\n",
+ " if self.split_input_params[\"patch_distributed_vq\"]:\n",
+ " bs, nc, h, w = z.shape\n",
+ " if self.split_input_params[\"num_tiles\"] is not None:\n",
+ " num_tiles = self.split_input_params[\"num_tiles\"]\n",
+ " ks = [h//num_tiles[0], w//num_tiles[1]]\n",
+ " else:\n",
+ " ks = self.split_input_params[\"ks\"] # eg. (128, 128)\n",
+ "\n",
+ " if self.split_input_params[\"padding\"] is not None:\n",
+ " padding = self.split_input_params[\"padding\"]\n",
+ " stride = [int(ks[0]*padding[0]), int(ks[1]*padding[1])]\n",
+ " else:\n",
+ " stride = self.split_input_params[\"stride\"] # eg. (64, 64)\n",
+ "\n",
+ " uf = self.split_input_params[\"vqf\"]\n",
+ "\n",
+ "\n",
+ "\n",
+ " target_h = math.ceil(h/ks[0])*ks[0]\n",
+ " target_w = math.ceil(w/ks[1])*ks[1]\n",
+ " padh = target_h - h\n",
+ " padw = target_w - w\n",
+ " pad = (0, padw, 0, padh)\n",
+ " if target_h != h or target_w != w:\n",
+ " print('Padding.')\n",
+ " # print('padding from ', h, w, 'to ', target_h, target_w)\n",
+ " z = torch.nn.functional.pad(z, pad, mode='reflect')\n",
+ " # print('padded from ', h, w, 'to ', z.shape[2:])\n",
+ "\n",
+ "\n",
+ " if ks[0] > h or ks[1] > w:\n",
+ " ks = (min(ks[0], h), min(ks[1], w))\n",
+ " print(\"reducing Kernel\")\n",
+ "\n",
+ " if stride[0] > h or stride[1] > w:\n",
+ " stride = (min(stride[0], h), min(stride[1], w))\n",
+ " print(\"reducing stride\")\n",
+ "\n",
+ "\n",
+ "\n",
+ " # print(ks, stride)\n",
+ " fold, unfold, normalization, weighting = get_fold_unfold(z, ks, stride, uf=uf)\n",
+ "\n",
+ "\n",
+ " z = unfold(z) # (bn, nc * prod(**ks), L)\n",
+ " # print('z unfold, normalization, weighting',z.shape, normalization.shape, weighting.shape)\n",
+ " # 1. Reshape to img shape\n",
+ " z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )\n",
+ " # print('z unfold view , normalization, weighting',z.shape)\n",
+ " # 2. apply model loop over last dim\n",
+ "\n",
+ " if no_half_vae:\n",
+ " with torch.autocast('cuda', enabled=False):\n",
+ " self.disable_first_stage_autocast = True\n",
+ " self.first_stage_model.float()\n",
+ " z = z.float()\n",
+ " output_list = [self.first_stage_model.decode(z[:, :, :, :, i].float())\n",
+ " for i in range(z.shape[-1])]\n",
+ " else:\n",
+ " output_list = [self.first_stage_model.decode(z[:, :, :, :, i])\n",
+ " for i in range(z.shape[-1])]\n",
+ "\n",
+ " o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)\n",
+ " # print('out stack', o.shape)\n",
+ "\n",
+ " o = o * weighting\n",
+ " # Reverse 1. reshape to img shape\n",
+ " o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)\n",
+ " # stitch crops together\n",
+ " decoded = fold(o)\n",
+ " decoded = decoded / normalization # norm is shape (1, 1, h, w)\n",
+ " print('Tiled vae decoder took ', f'{time.time()-ts:.2}')\n",
+ " # print('decoded stats', decoded.min(), decoded.max(), decoded.std(), decoded.mean())\n",
+ " # assert False\n",
+ " return decoded[...,:h*uf, :w*uf]\n",
+ " else:\n",
+ " print('Vae decoder took ', f'{time.time()-ts:.2}')\n",
+ " # print('z stats', z.min(), z.max(), z.std(), z.mean())\n",
+ " return self.first_stage_model.decode(z)\n",
+ "\n",
+ " else:\n",
+ " # print('z stats', z.min(), z.max(), z.std(), z.mean())\n",
+ " print('Vae decoder took ', f'{time.time()-ts:.2}')\n",
+ " return self.first_stage_model.decode(z)\n",
+ "\n",
+ "\n",
+ "\n",
+ "def get_first_stage_encoding(encoder_posterior, self=sd_model, tiled_vae_call=False):\n",
+ " if hasattr(self, \"split_input_params\") and not tiled_vae_call:\n",
+ " #pass for tiled vae\n",
+ " return encoder_posterior\n",
+ " if sd_model.is_sdxl:\n",
+ " # print('skipping for sdxl')\n",
+ " return encoder_posterior\n",
+ " if isinstance(encoder_posterior, DiagonalGaussianDistribution):\n",
+ " z = encoder_posterior.sample()\n",
+ " elif isinstance(encoder_posterior, torch.Tensor):\n",
+ " z = encoder_posterior\n",
+ " else:\n",
+ " raise NotImplementedError(f\"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented\")\n",
+ " return self.scale_factor * z\n",
+ "\n",
+ "if use_tiled_vae:\n",
+ "\n",
+ " ks = tile_size\n",
+ " stride = stride\n",
+ " vqf = 8 #\n",
+ " split_input_params = {\"ks\": (ks,ks), \"stride\": (stride, stride),\n",
+ " \"num_tiles\": num_tiles, \"padding\": padding,\n",
+ " \"vqf\": vqf,\n",
+ " \"patch_distributed_vq\": True,\n",
+ " \"tie_braker\": False,\n",
+ " \"clip_max_weight\": 0.5,\n",
+ " \"clip_min_weight\": 0.01,\n",
+ " \"clip_max_tie_weight\": 0.5,\n",
+ " \"clip_min_tie_weight\": 0.01}\n",
+ "\n",
+ " # split_input_params = {\"ks\": (ks,ks), \"stride\": (stride, stride),\n",
+ " # \"num_tiles\": num_tiles, \"padding\": padding,\n",
+ " # \"vqf\": vqf,\n",
+ " # \"patch_distributed_vq\": True,\n",
+ " # \"tie_braker\": False,\n",
+ " # \"clip_max_weight\": 0.9,\n",
+ " # \"clip_min_weight\": 0.001,\n",
+ " # \"clip_max_tie_weight\": 0.9,\n",
+ " # \"clip_min_tie_weight\": 0.001}\n",
+ "\n",
+ " bkup_decode_first_stage = sd_model.decode_first_stage\n",
+ " bkup_encode_first_stage = sd_model.encode_first_stage\n",
+ " bkup_get_first_stage_encoding = sd_model.get_first_stage_encoding\n",
+ " try:\n",
+ " bkup_get_fold_unfold = sd_model.get_fold_unfold\n",
+ " except:\n",
+ " pass\n",
+ "\n",
+ " sd_model.split_input_params = split_input_params\n",
+ " sd_model.decode_first_stage = decode_first_stage\n",
+ " sd_model.encode_first_stage = encode_first_stage\n",
+ " sd_model.get_first_stage_encoding = get_first_stage_encoding\n",
+ " sd_model.get_fold_unfold = get_fold_unfold\n",
+ "\n",
+ "else:\n",
+ " if hasattr(sd_model, \"split_input_params\"):\n",
+ " delattr(sd_model, \"split_input_params\")\n",
+ " try:\n",
+ " sd_model.decode_first_stage = bkup_decode_first_stage\n",
+ " sd_model.encode_first_stage = bkup_encode_first_stage\n",
+ " sd_model.get_first_stage_encoding = bkup_get_first_stage_encoding\n",
+ " sd_model.get_fold_unfold = bkup_get_fold_unfold\n",
+ " except: pass\n",
+ "\n",
+ "def get_weighting(h, w, Ly, Lx, device, self=sd_model):\n",
+ " weighting = delta_border(h, w)\n",
+ " weighting = torch.clip(weighting, self.split_input_params[\"clip_min_weight\"],\n",
+ " self.split_input_params[\"clip_max_weight\"], )\n",
+ " weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)\n",
+ "\n",
+ " if self.split_input_params[\"tie_braker\"]:\n",
+ " L_weighting = delta_border(Ly, Lx)\n",
+ " L_weighting = torch.clip(L_weighting,\n",
+ " self.split_input_params[\"clip_min_tie_weight\"],\n",
+ " self.split_input_params[\"clip_max_tie_weight\"])\n",
+ "\n",
+ " L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)\n",
+ " weighting = weighting * L_weighting\n",
+ " return weighting\n",
+ "\n",
+ "def meshgrid(h, w):\n",
+ " y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)\n",
+ " x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)\n",
+ "\n",
+ " arr = torch.cat([y, x], dim=-1)\n",
+ " return arr\n",
+ "def delta_border(h, w):\n",
+ " \"\"\"\n",
+ " :param h: height\n",
+ " :param w: width\n",
+ " :return: normalized distance to image border,\n",
+ " wtith min distance = 0 at border and max dist = 0.5 at image center\n",
+ " \"\"\"\n",
+ " lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)\n",
+ " arr = meshgrid(h, w) / lower_right_corner\n",
+ " dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]\n",
+ " dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]\n",
+ " edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]\n",
+ " return edge_dist\n",
+ "\n",
+ "if 'sdxl' in model_version:\n",
+ " sd_model.get_weighting = get_weighting\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "wc8CzsYuMLb2"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Save loaded model\n",
+ "#@markdown For this cell to work you need to load model in the previous cell.\\\n",
+ "#@markdown Saves an already loaded model as an object file, that weights less, loads faster, and requires less CPU RAM.\\\n",
+ "#@markdown After saving model as pickle, you can then load it as your usual stable diffusion model in thecell above.\\\n",
+ "#@markdown The model will be saved under the same name with .pkl extenstion.\n",
+ "cell_name = 'save_loaded_model'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "save_model_pickle = False #@param {'type':'boolean'}\n",
+ "save_folder = \"/content/drive/MyDrive/models\" #@param {'type':'string'}\n",
+ "if save_folder != '' and save_model_pickle:\n",
+ " os.makedirs(save_folder, exist_ok=True)\n",
+ " out_path = save_folder+model_path.replace('\\\\', '/').split('/')[-1].split('.')[0]+'.pkl'\n",
+ " with open(out_path, 'wb') as f:\n",
+ " pickle.dump(sd_model, f)\n",
+ " print('Model successfully saved as: ',out_path)\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mcI6h0A7NcZ-"
+ },
+ "source": [
+ "## CLIP guidance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Rz341_ND0U90"
+ },
+ "outputs": [],
+ "source": [
+ "#@title CLIP guidance settings\n",
+ "#@markdown You can use clip guidance to further push style towards your text input.\\\n",
+ "#@markdown Please note that enabling it (by using clip_guidance_scale>0) will greatly increase render times and VRAM usage.\\\n",
+ "#@markdown For now it does 1 sample of the whole image per step (similar to 1 outer_cut in discodiffusion).\n",
+ "cell_name = 'clip_guidance'\n",
+ "check_execution(cell_name)\n",
+ "# clip_type, clip_pretrain = 'ViT-B-32-quickgelu', 'laion400m_e32'\n",
+ "# clip_type, clip_pretrain ='ViT-L-14', 'laion2b_s32b_b82k'\n",
+ "clip_type = 'ViT-H-14' #@param ['ViT-L-14','ViT-B-32-quickgelu', 'ViT-H-14']\n",
+ "if clip_type == 'ViT-H-14' : clip_pretrain = 'laion2b_s32b_b79k'\n",
+ "if clip_type == 'ViT-L-14' : clip_pretrain = 'laion2b_s32b_b82k'\n",
+ "if clip_type == 'ViT-B-32-quickgelu' : clip_pretrain = 'laion400m_e32'\n",
+ "\n",
+ "clip_guidance_scale = 0 #@param {'type':\"number\"}\n",
+ "if clip_guidance_scale > 0:\n",
+ " clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(clip_type, pretrained=clip_pretrain)\n",
+ " _=clip_model.half().cuda().eval()\n",
+ " clip_size = clip_model.visual.image_size\n",
+ " for param in clip_model.parameters():\n",
+ " param.requires_grad = False\n",
+ "else:\n",
+ " try:\n",
+ " del clip_model\n",
+ " gc.collect()\n",
+ " except: pass\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yyC0Qb0qOcsJ"
+ },
+ "source": [
+ "## Automatic Brightness Adjustment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "UlJJ5qNSKo3K"
+ },
+ "outputs": [],
+ "source": [
+ "#@markdown ###Automatic Brightness Adjustment\n",
+ "#@markdown Automatically adjust image brightness when its mean value reaches a certain threshold\\\n",
+ "#@markdown Ratio means the vaue by which pixel values are multiplied when the thresjold is reached\\\n",
+ "#@markdown Fix amount is being directly added to\\subtracted from pixel values to prevent oversaturation due to multiplications\\\n",
+ "#@markdown Fix amount is also being applied to border values defined by min\\max threshold, like 1 and 254 to keep the image from having burnt out\\pitch black areas while still being within set high\\low thresholds\n",
+ "cell_name = 'brightness_adjustment'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "#@markdown The idea comes from https://github.com/lowfuel/progrockdiffusion\n",
+ "\n",
+ "enable_adjust_brightness = False #@param {'type':'boolean'}\n",
+ "high_brightness_threshold = 180 #@param {'type':'number'}\n",
+ "high_brightness_adjust_ratio = 0.97 #@param {'type':'number'}\n",
+ "high_brightness_adjust_fix_amount = 2 #@param {'type':'number'}\n",
+ "max_brightness_threshold = 254 #@param {'type':'number'}\n",
+ "low_brightness_threshold = 40 #@param {'type':'number'}\n",
+ "low_brightness_adjust_ratio = 1.03 #@param {'type':'number'}\n",
+ "low_brightness_adjust_fix_amount = 2 #@param {'type':'number'}\n",
+ "min_brightness_threshold = 1 #@param {'type':'number'}\n",
+ "\n",
+ "executed_cells[cell_name] = True\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fKzFgXM6cHYE"
+ },
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "T8xpuFgUEeLz"
+ },
+ "source": [
+ "## Content-aware scheduling"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "u2Wh6TVcTn5o"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Content-aware scheduing\n",
+ "#@markdown Allows automated settings scheduling based on video frames difference. If a scene changes, it will be detected and reflected in the schedule.\\\n",
+ "#@markdown rmse function is faster than lpips, but less precise.\\\n",
+ "#@markdown After the analysis is done, check the graph and pick a threshold that works best for your video. 0.5 is a good one for lpips, 1.2 is a good one for rmse. Don't forget to adjust the templates with new threshold in the cell below.\n",
+ "cell_name = 'content_aware_scheduling'\n",
+ "check_execution(cell_name)\n",
+ "def load_img_lpips(path, size=(512,512)):\n",
+ " image = Image.open(path).convert(\"RGB\")\n",
+ " image = image.resize(size, resample=Image.LANCZOS)\n",
+ " # print(f'resized to {image.size}')\n",
+ " image = np.array(image).astype(np.float32) / 127\n",
+ " image = image[None].transpose(0, 3, 1, 2)\n",
+ " image = torch.from_numpy(image)\n",
+ " image = normalize(image)\n",
+ " return image.cuda()\n",
+ "\n",
+ "diff = None\n",
+ "analyze_video = False #@param {'type':'boolean'}\n",
+ "\n",
+ "diff_function = 'lpips' #@param ['rmse','lpips','rmse+lpips']\n",
+ "\n",
+ "def l1_loss(x,y):\n",
+ " return torch.sqrt(torch.mean((x-y)**2))\n",
+ "\n",
+ "\n",
+ "def rmse(x,y):\n",
+ " return torch.abs(torch.mean(x-y))\n",
+ "\n",
+ "def joint_loss(x,y):\n",
+ " return rmse(x,y)*lpips_model(x,y)\n",
+ "\n",
+ "diff_func = rmse\n",
+ "if diff_function == 'lpips':\n",
+ " diff_func = lpips_model\n",
+ "if diff_function == 'rmse+lpips':\n",
+ " diff_func = joint_loss\n",
+ "\n",
+ "if analyze_video:\n",
+ " diff = [0]\n",
+ " frames = sorted(glob(f'{videoFramesFolder}/*.jpg'))\n",
+ " from tqdm.notebook import trange\n",
+ " for i in trange(1,len(frames)):\n",
+ " with torch.no_grad():\n",
+ " diff.append(diff_func(load_img_lpips(frames[i-1]), load_img_lpips(frames[i])).sum().mean().detach().cpu().numpy())\n",
+ "\n",
+ " import numpy as np\n",
+ " import matplotlib.pyplot as plt\n",
+ "\n",
+ " plt.rcParams[\"figure.figsize\"] = [12.50, 3.50]\n",
+ " plt.rcParams[\"figure.autolayout\"] = True\n",
+ "\n",
+ " y = diff\n",
+ " plt.title(f\"{diff_function} frame difference\")\n",
+ " plt.plot(y, color=\"red\")\n",
+ " calc_thresh = np.percentile(np.array(diff), 97)\n",
+ " plt.axhline(y=calc_thresh, color='b', linestyle='dashed')\n",
+ "\n",
+ " plt.show()\n",
+ " print(f'suggested threshold: {calc_thresh.round(2)}')\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "GjCvKjYX29Gr"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Plot threshold vs frame difference\n",
+ "#@markdown The suggested threshold may be incorrect, so you can plot your value and see if it covers the peaks.\n",
+ "cell_name = 'plot_threshold_vs_frame_difference'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "if diff is not None:\n",
+ " import numpy as np\n",
+ " import matplotlib.pyplot as plt\n",
+ "\n",
+ " plt.rcParams[\"figure.figsize\"] = [12.50, 3.50]\n",
+ " plt.rcParams[\"figure.autolayout\"] = True\n",
+ "\n",
+ " y = diff\n",
+ " plt.title(f\"{diff_function} frame difference\")\n",
+ " plt.plot(y, color=\"red\")\n",
+ " calc_thresh = np.percentile(np.array(diff), 97)\n",
+ " plt.axhline(y=calc_thresh, color='b', linestyle='dashed')\n",
+ " user_threshold = 0.13 #@param {'type':'raw'}\n",
+ " plt.axhline(y=user_threshold, color='r')\n",
+ "\n",
+ " plt.show()\n",
+ " peaks = []\n",
+ " for i,d in enumerate(diff):\n",
+ " if d>user_threshold:\n",
+ " peaks.append(i)\n",
+ " print(f'Peaks at frames: {peaks} for user_threshold of {user_threshold}')\n",
+ "else: print('Please analyze frames in the previous cell to plot graph')\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "rtwnApyva73r"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "#@title Create schedules from frame difference\n",
+ "cell_name = 'create_schedules'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "def adjust_schedule(diff, normal_val, new_scene_val, thresh, falloff_frames, sched=None):\n",
+ " diff_array = np.array(diff)\n",
+ "\n",
+ " diff_new = np.zeros_like(diff_array)\n",
+ " diff_new = diff_new+normal_val\n",
+ "\n",
+ " for i in range(len(diff_new)):\n",
+ " el = diff_array[i]\n",
+ " if sched is not None:\n",
+ " diff_new[i] = get_scheduled_arg(i, sched)\n",
+ " if el>thresh or i==0:\n",
+ " diff_new[i] = new_scene_val\n",
+ " if falloff_frames>0:\n",
+ " for j in range(falloff_frames):\n",
+ " if i+j>len(diff_new)-1: break\n",
+ " # print(j,(falloff_frames-j)/falloff_frames, j/falloff_frames )\n",
+ " falloff_val = normal_val\n",
+ " if sched is not None:\n",
+ " falloff_val = get_scheduled_arg(i+falloff_frames, sched)\n",
+ " diff_new[i+j] = new_scene_val*(falloff_frames-j)/falloff_frames+falloff_val*j/falloff_frames\n",
+ " return diff_new\n",
+ "\n",
+ "def check_and_adjust_sched(sched, template, diff, respect_sched=True):\n",
+ " if template is None or template == '' or template == []:\n",
+ " return sched\n",
+ " normal_val, new_scene_val, thresh, falloff_frames = template\n",
+ " sched_source = None\n",
+ " if respect_sched:\n",
+ " sched_source = sched\n",
+ " return list(adjust_schedule(diff, normal_val, new_scene_val, thresh, falloff_frames, sched_source).astype('float').round(3))\n",
+ "\n",
+ "#@markdown fill in templates for schedules you'd like to create from frames' difference\\\n",
+ "#@markdown leave blank to use schedules from previous cells\\\n",
+ "#@markdown format: **[normal value, high difference value, difference threshold, falloff from high to normal (number of frames)]**\\\n",
+ "#@markdown For example, setting flow blend template to [0.999, 0.3, 0.5, 5] will use 0.999 everywhere unless a scene has changed (frame difference >0.5) and then set flow_blend for this frame to 0.3 and gradually fade to 0.999 in 5 frames\n",
+ "\n",
+ "latent_scale_template = '' #@param {'type':'raw'}\n",
+ "init_scale_template = '' #@param {'type':'raw'}\n",
+ "steps_template = '' #@param {'type':'raw'}\n",
+ "style_strength_template = '' #@param {'type':'raw'}\n",
+ "flow_blend_template = [0.8, 0., 0.51, 2] #@param {'type':'raw'}\n",
+ "cc_masked_template = [0.7, 0, 0.51, 2] #@param {'type':'raw'}\n",
+ "cfg_scale_template = None #@param {'type':'raw'}\n",
+ "image_scale_template = None #@param {'type':'raw'}\n",
+ "\n",
+ "#@markdown Turning this off will disable templates and will use schedules set in previous cell\n",
+ "make_schedules = False #@param {'type':'boolean'}\n",
+ "#@markdown Turning this on will respect previously set schedules and only alter the frames with peak difference\n",
+ "respect_sched = True #@param {'type':'boolean'}\n",
+ "diff_override = [] #@param {'type':'raw'}\n",
+ "\n",
+ "#shift+1 required\n",
+ "executed_cells[cell_name] = True\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "U5rrnKtV7FoY"
+ },
+ "source": [
+ "## Frame Captioning\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "T1RMzlod7KFX"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Generate captions for keyframes\n",
+ "#@markdown Automatically generate captions for every n-th frame, \\\n",
+ "#@markdown or keyframe list: at keyframe, at offset from keyframe, between keyframes.\\\n",
+ "#@markdown keyframe source: Every n-th frame, user-input, Content-aware scheduling keyframes\n",
+ "\n",
+ "cell_name = 'frame_captioning'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "inputFrames = sorted(glob(f'{videoFramesFolder}/*.jpg'))\n",
+ "make_captions = False #@param {'type':'boolean'}\n",
+ "keyframe_source = 'Every n-th frame' #@param ['Content-aware scheduling keyframes', 'User-defined keyframe list', 'Every n-th frame']\n",
+ "#@markdown This option only works with keyframe source == User-defined keyframe list\n",
+ "user_defined_keyframes = [3,4,5] #@param\n",
+ "#@markdown This option only works with keyframe source == Content-aware scheduling keyframes\n",
+ "diff_thresh = 0.33 #@param {'type':'number'}\n",
+ "#@markdown This option only works with keyframe source == Every n-th frame\n",
+ "nth_frame = 60 #@param {'type':'number'}\n",
+ "if keyframe_source == 'Content-aware scheduling keyframes':\n",
+ " if diff in [None, '', []]:\n",
+ " print('ERROR: Keyframes were not generated. Please go back to Content-aware scheduling cell, enable analyze_video nad run it or choose a different caption keyframe source.')\n",
+ " caption_keyframes = None\n",
+ " else:\n",
+ " caption_keyframes = [1]+[i+1 for i,o in enumerate(diff) if o>=diff_thresh]\n",
+ "if keyframe_source == 'User-defined keyframe list':\n",
+ " caption_keyframes = user_defined_keyframes\n",
+ "if keyframe_source == 'Every n-th frame':\n",
+ " caption_keyframes = list(range(1, len(inputFrames), nth_frame))\n",
+ "#@markdown Remaps keyframes based on selected offset mode\n",
+ "offset_mode = 'Fixed' #@param ['Fixed', 'Between Keyframes', 'None']\n",
+ "#@markdown Only works with offset_mode == Fixed\n",
+ "fixed_offset = 0 #@param {'type':'number'}\n",
+ "\n",
+ "videoFramesCaptions = videoFramesFolder+'Captions'\n",
+ "if make_captions and caption_keyframes is not None:\n",
+ " try:\n",
+ " blip_model\n",
+ " except:\n",
+ "\n",
+ " os.chdir('./BLIP')\n",
+ " from models.blip import blip_decoder\n",
+ " os.chdir('../')\n",
+ " from PIL import Image\n",
+ " import torch\n",
+ " from torchvision import transforms\n",
+ " from torchvision.transforms.functional import InterpolationMode\n",
+ "\n",
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ " image_size = 384\n",
+ " transform = transforms.Compose([\n",
+ " transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
+ " ])\n",
+ "\n",
+ " model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'# -O /content/model_base_caption_capfilt_large.pth'\n",
+ "\n",
+ " blip_model = blip_decoder(pretrained=model_url, image_size=384, vit='base',med_config='./BLIP/configs/med_config.json')\n",
+ " blip_model.eval()\n",
+ " blip_model = blip_model.to(device)\n",
+ " finally:\n",
+ " print('Using keyframes: ', caption_keyframes[:20], ' (first 20 keyframes displyed')\n",
+ " if offset_mode == 'None':\n",
+ " keyframes = caption_keyframes\n",
+ " if offset_mode == 'Fixed':\n",
+ " keyframes = caption_keyframes\n",
+ " for i in range(len(caption_keyframes)):\n",
+ " if keyframes[i] >= max(caption_keyframes):\n",
+ " keyframes[i] = caption_keyframes[i]\n",
+ " else: keyframes[i] = min(caption_keyframes[i]+fixed_offset, caption_keyframes[i+1])\n",
+ " print('Remapped keyframes to ', keyframes[:20])\n",
+ " if offset_mode == 'Between Keyframes':\n",
+ " keyframes = caption_keyframes\n",
+ " for i in range(len(caption_keyframes)):\n",
+ " if keyframes[i] >= max(caption_keyframes):\n",
+ " keyframes[i] = caption_keyframes[i]\n",
+ " else:\n",
+ " keyframes[i] = caption_keyframes[i] + int((caption_keyframes[i+1]-caption_keyframes[i])/2)\n",
+ " print('Remapped keyframes to ', keyframes[:20])\n",
+ "\n",
+ " videoFramesCaptions = videoFramesFolder+'Captions'\n",
+ " createPath(videoFramesCaptions)\n",
+ "\n",
+ "\n",
+ " from tqdm.notebook import trange\n",
+ "\n",
+ " for f in pathlib.Path(videoFramesCaptions).glob('*.txt'):\n",
+ " f.unlink()\n",
+ " for i in tqdm(keyframes):\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " keyFrameFilename = inputFrames[i-1]\n",
+ " raw_image = Image.open(keyFrameFilename)\n",
+ " image = transform(raw_image).unsqueeze(0).to(device)\n",
+ " caption = blip_model.generate(image, sample=True, top_p=0.9, max_length=30, min_length=5)\n",
+ " captionFilename = os.path.join(videoFramesCaptions, keyFrameFilename.replace('\\\\','/').split('/')[-1][:-4]+'.txt')\n",
+ " with open(captionFilename, 'w') as f:\n",
+ " f.write(caption[0])\n",
+ "\n",
+ "def load_caption(caption_file):\n",
+ " caption = ''\n",
+ " with open(caption_file, 'r') as f:\n",
+ " caption = f.read()\n",
+ " return caption\n",
+ "\n",
+ "def get_caption(frame_num):\n",
+ " caption_files = sorted(glob(os.path.join(videoFramesCaptions,'*.txt')))\n",
+ " frame_num1 = frame_num+1\n",
+ " if len(caption_files) == 0:\n",
+ " return None\n",
+ " frame_numbers = [int(o.replace('\\\\','/').split('/')[-1][:-4]) for o in caption_files]\n",
+ " # print(frame_numbers, frame_num)\n",
+ " if frame_num1 < frame_numbers[0]:\n",
+ " return load_caption(caption_files[0])\n",
+ " if frame_num1 >= frame_numbers[-1]:\n",
+ " return load_caption(caption_files[-1])\n",
+ " for i in range(len(frame_numbers)):\n",
+ " if frame_num1 >= frame_numbers[i] and frame_num1 < frame_numbers[i+1]:\n",
+ " return load_caption(caption_files[i])\n",
+ " return None\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_MleAG1V0ss6"
+ },
+ "source": [
+ "# Render settings\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7vXkwEkB9KTG"
+ },
+ "source": [
+ "## Non-gui\n",
+ "These settings are used as initial settings for the GUI unless you specify default_settings_path. Then the GUI settings will be loaded from the specified file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "ZsuiToUttxZ-"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Flow and turbo settings\n",
+ "#@markdown #####**Video Optical Flow Settings:**\n",
+ "cell_name = 'flow_and_turbo_settings'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "flow_warp = True #@param {type: 'boolean'}\n",
+ "#cal optical flow from video frames and warp prev frame with flow\n",
+ "flow_blend = 0.999\n",
+ "##@param {type: 'number'} #0 - take next frame, 1 - take prev warped frame\n",
+ "check_consistency = True #@param {type: 'boolean'}\n",
+ " #cal optical flow from video frames and warp prev frame with flow\n",
+ "\n",
+ "#======= TURBO MODE\n",
+ "#@markdown ---\n",
+ "#@markdown ####**Turbo Mode:**\n",
+ "#@markdown (Starts after frame 1,) skips diffusion steps and just uses flow map to warp images for skipped frames.\n",
+ "#@markdown Speeds up rendering by 2x-4x, and may improve image coherence between frames. frame_blend_mode smooths abrupt texture changes across 2 frames.\n",
+ "#@markdown For different settings tuned for Turbo Mode, refer to the original Disco-Turbo Github: https://github.com/zippy731/disco-diffusion-turbo\n",
+ "\n",
+ "turbo_mode = False #@param {type:\"boolean\"}\n",
+ "turbo_steps = \"3\" #@param [\"2\",\"3\",\"4\",\"5\",\"6\"] {type:\"string\"}\n",
+ "turbo_preroll = 1 # frames\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "mxNoyb1tzbPO"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Consistency map mixing\n",
+ "#@markdown You can mix consistency map layers separately\\\n",
+ "#@markdown missed_consistency_weight - masks pixels that have missed their expected position in the next frame \\\n",
+ "#@markdown overshoot_consistency_weight - masks pixels warped from outside the frame\\\n",
+ "#@markdown edges_consistency_weight - masks moving objects' edges\\\n",
+ "#@markdown The default values to simulate previous versions' behavior are 1,1,1\n",
+ "cell_name = 'consistency_maps_mixing'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "missed_consistency_weight = 1 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.05'}\n",
+ "overshoot_consistency_weight = 1 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.05'}\n",
+ "edges_consistency_weight = 1 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.05'}\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "OUmYjPGSzcwG"
+ },
+ "outputs": [],
+ "source": [
+ "#@title ####**Seed and grad Settings:**\n",
+ "cell_name = 'seed_and_grad_settings'\n",
+ "check_execution(cell_name)\n",
+ "set_seed = '4275770367' #@param{type: 'string'}\n",
+ "\n",
+ "\n",
+ "#@markdown *Clamp grad is used with any of the init_scales or sat_scale above 0*\\\n",
+ "#@markdown Clamp grad limits the amount various criterions, controlled by *_scale parameters, are pushing the image towards the desired result.\\\n",
+ "#@markdown For example, high scale values may cause artifacts, and clamp_grad removes this effect.\n",
+ "#@markdown 0.7 is a good clamp_max value.\n",
+ "eta = 0.55\n",
+ "clamp_grad = True #@param{type: 'boolean'}\n",
+ "clamp_max = 2 #@param{type: 'number'}\n",
+ "\n",
+ "executed_cells[cell_name] = True\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PgnJ26Bh3Ru8"
+ },
+ "source": [
+ "### Prompts\n",
+ "`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uGhc6Atr3TF-"
+ },
+ "outputs": [],
+ "source": [
+ "cell_name = 'prompts'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "text_prompts = {0: ['a beautiful highly detailed cyberpunk mechanical \\\n",
+ "augmented most beautiful (woman) ever, cyberpunk 2077, neon, dystopian, \\\n",
+ "hightech, trending on artstation']}\n",
+ "\n",
+ "negative_prompts = {\n",
+ " 0: [\"text, naked, nude, logo, cropped, two heads, four arms, lazy eye, blurry, unfocused\"]\n",
+ "}\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GWWNdYvj3Xst"
+ },
+ "source": [
+ "### Warp Turbo Smooth Settings"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2P5GfX3G3VKC"
+ },
+ "source": [
+ "turbo_frame_skips_steps - allows to set different frames_skip_steps for turbo frames. None means turbo frames are warped only without diffusion\n",
+ "\n",
+ "soften_consistency_mask - clip the lower values of consistency mask to this value. Raw video frames will leak stronger with lower values.\n",
+ "\n",
+ "soften_consistency_mask_for_turbo_frames - same, but for turbo frames\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "TAHtTRh_3ga5"
+ },
+ "outputs": [],
+ "source": [
+ "#@title ##Warp Turbo Smooth Settings\n",
+ "#@markdown Skip steps for turbo frames. Select 100% to skip diffusion rendering for turbo frames completely.\n",
+ "cell_name = 'warp_turbo_smooth_settings'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "turbo_frame_skips_steps = '100% (don`t diffuse turbo frames, fastest)' #@param ['70%','75%','80%','85%', '90%', '95%', '100% (don`t diffuse turbo frames, fastest)']\n",
+ "\n",
+ "if turbo_frame_skips_steps == '100% (don`t diffuse turbo frames, fastest)':\n",
+ " turbo_frame_skips_steps = None\n",
+ "else:\n",
+ " turbo_frame_skips_steps = int(turbo_frame_skips_steps.split('%')[0])/100\n",
+ "#None - disable and use default skip steps\n",
+ "\n",
+ "#@markdown ###Consistency mask postprocessing\n",
+ "#@markdown ####Soften consistency mask\n",
+ "#@markdown Lower values mean less stylized frames and more raw video input in areas with fast movement, but fewer trails add ghosting.\\\n",
+ "#@markdown Gives glitchy datamoshing look.\\\n",
+ "#@markdown Higher values keep stylized frames, but add trails and ghosting.\n",
+ "\n",
+ "soften_consistency_mask = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
+ "forward_weights_clip = soften_consistency_mask\n",
+ "#0 behaves like consistency on, 1 - off, in between - blends\n",
+ "soften_consistency_mask_for_turbo_frames = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
+ "forward_weights_clip_turbo_step = soften_consistency_mask_for_turbo_frames\n",
+ "#None - disable and use forward_weights_clip for turbo frames, 0 behaves like consistency on, 1 - off, in between - blends\n",
+ "#@markdown ####Blur consistency mask.\n",
+ "#@markdown Softens transition between raw video init and stylized frames in occluded areas.\n",
+ "consistency_blur = 1 #@param\n",
+ "#@markdown ####Dilate consistency mask.\n",
+ "#@markdown Expands consistency mask without blurring the edges.\n",
+ "consistency_dilate = 3 #@param\n",
+ "\n",
+ "\n",
+ "# disable_cc_for_turbo_frames = False #@param {\"type\":\"boolean\"}\n",
+ "#disable consistency for turbo frames, the same as forward_weights_clip_turbo_step = 1, but a bit faster\n",
+ "\n",
+ "#@markdown ###Frame padding\n",
+ "#@markdown Increase padding if you have a shaky\\moving camera footage and are getting black borders.\n",
+ "\n",
+ "padding_ratio = 0.2 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
+ "#relative to image size, in range 0-1\n",
+ "padding_mode = 'reflect' #@param ['reflect','edge','wrap']\n",
+ "\n",
+ "\n",
+ "#safeguard the params\n",
+ "if turbo_frame_skips_steps is not None:\n",
+ " turbo_frame_skips_steps = min(max(0,turbo_frame_skips_steps),1)\n",
+ "forward_weights_clip = min(max(0,forward_weights_clip),1)\n",
+ "if forward_weights_clip_turbo_step is not None:\n",
+ " forward_weights_clip_turbo_step = min(max(0,forward_weights_clip_turbo_step),1)\n",
+ "padding_ratio = min(max(0,padding_ratio),1)\n",
+ "##@markdown ###Inpainting\n",
+ "##@markdown Inpaint occluded areas on top of raw frames. 0 - 0% inpainting opacity (no inpainting), 1 - 100% inpainting opacity. Other values blend between raw and inpainted frames.\n",
+ "\n",
+ "inpaint_blend = 0\n",
+ "##@param {type:\"slider\", min:0,max:1,value:1,step:0.1}\n",
+ "\n",
+ "#@markdown ###Color matching\n",
+ "#@markdown Match color of inconsistent areas to unoccluded ones, after inconsistent areas were replaced with raw init video or inpainted\\\n",
+ "#@markdown 0 - off, other values control effect opacity\n",
+ "\n",
+ "match_color_strength = 0 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.1'}\n",
+ "\n",
+ "disable_cc_for_turbo_frames = False\n",
+ "\n",
+ "#@markdown ###Warp settings\n",
+ "\n",
+ "warp_mode = 'use_image' #@param ['use_latent', 'use_image']\n",
+ "warp_towards_init = 'off' #@param ['stylized', 'off']\n",
+ "\n",
+ "if warp_towards_init != 'off':\n",
+ " if flow_lq:\n",
+ " raft_model = torch.jit.load(f'{root_dir}/WarpFusion/raft/raft_half.jit').eval()\n",
+ " # raft_model = torch.nn.DataParallel(RAFT(args2))\n",
+ " else: raft_model = torch.jit.load(f'{root_dir}/WarpFusion/raft/raft_fp32.jit').eval()\n",
+ "\n",
+ "cond_image_src = 'init' #@param ['init', 'stylized']\n",
+ "\n",
+ "executed_cells[cell_name] = True\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4bCGxkUZ3r68"
+ },
+ "source": [
+ "### Video masking (render-time)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "_R1MvKb53sL7"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Video mask settings\n",
+ "#@markdown Check to enable background masking during render. Not recommended, better use masking when creating the output video for more control and faster testing.\n",
+ "cell_name = 'video_mask_settings'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "use_background_mask = False #@param {'type':'boolean'}\n",
+ "#@markdown Check to invert the mask.\n",
+ "invert_mask = False #@param {'type':'boolean'}\n",
+ "#@markdown Apply mask right before feeding init image to the model. Unchecking will only mask current raw init frame.\n",
+ "apply_mask_after_warp = True #@param {'type':'boolean'}\n",
+ "#@markdown Choose background source to paste masked stylized image onto: image, color, init video.\n",
+ "background = \"init_video\" #@param ['image', 'color', 'init_video']\n",
+ "#@markdown Specify the init image path or color depending on your background source choice.\n",
+ "background_source = 'red' #@param {'type':'string'}\n",
+ "\n",
+ "executed_cells[cell_name] = True\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nm_EeEeu391T"
+ },
+ "source": [
+ "### Frame correction (latent & color matching)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "0PAmcATq3-el"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Frame correction\n",
+ "#@markdown Match frame pixels or latent to other frames to preven oversaturation and feedback loop artifacts\n",
+ "#@markdown ###Latent matching\n",
+ "#@markdown Match the range of latent vector towards the 1st frame or a user defined range. Doesn't restrict colors, but may limit contrast.\n",
+ "cell_name = 'frame_correction'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "\n",
+ "normalize_latent = 'off' #@param ['off', 'color_video', 'color_video_offset', 'user_defined', 'stylized_frame', 'init_frame', 'stylized_frame_offset', 'init_frame_offset']\n",
+ "#@markdown in offset mode, specifies the offset back from current frame, and 0 means current frame. In non-offset mode specifies the fixed frame number. 0 means the 1st frame.\n",
+ "\n",
+ "normalize_latent_offset = 0 #@param {'type':'number'}\n",
+ "#@markdown User defined stats to normalize the latent towards\n",
+ "latent_fixed_mean = 0. #@param {'type':'raw'}\n",
+ "latent_fixed_std = 0.9 #@param {'type':'raw'}\n",
+ "#@markdown Match latent on per-channel basis\n",
+ "latent_norm_4d = True #@param {'type':'boolean'}\n",
+ "#@markdown ###Color matching\n",
+ "#@markdown Color match frame towards stylized or raw init frame. Helps prevent images going deep purple. As a drawback, may lock colors to the selected fixed frame. Select stylized_frame with colormatch_offset = 0 to reproduce previous notebooks.\n",
+ "colormatch_frame = 'stylized_frame' #@param ['off', 'color_video', 'color_video_offset','stylized_frame', 'init_frame', 'stylized_frame_offset', 'init_frame_offset']\n",
+ "#@markdown Color match strength. 1 mimics legacy behavior\n",
+ "color_match_frame_str = 0.5 #@param {'type':'number'}\n",
+ "#@markdown in offset mode, specifies the offset back from current frame, and 0 means current frame. In non-offset mode specifies the fixed frame number. 0 means the 1st frame.\n",
+ "colormatch_offset = 0 #@param {'type':'number'}\n",
+ "colormatch_method = 'PDF'#@param ['LAB', 'PDF', 'mean']\n",
+ "colormatch_method_fn = PT.lab_transfer\n",
+ "if colormatch_method == 'LAB':\n",
+ " colormatch_method_fn = PT.pdf_transfer\n",
+ "if colormatch_method == 'mean':\n",
+ " colormatch_method_fn = PT.mean_std_transfer\n",
+ "#@markdown Match source frame's texture\n",
+ "colormatch_regrain = False #@param {'type':'boolean'}\n",
+ "\n",
+ "executed_cells[cell_name] = True\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xTVNjezk3aTa"
+ },
+ "source": [
+ "### Main settings.\n",
+ "\n",
+ "Duplicated in the GUI and can be loaded there."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "yAD7sBet32in"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Basic\n",
+ "\n",
+ "cell_name = 'main_settings'\n",
+ "check_execution(cell_name)\n",
+ "# DD-style losses, renders 2 times slower (!) and more memory intensive :D\n",
+ "\n",
+ "latent_scale_schedule = [0,0] #controls coherency with previous frame in latent space. 0 is a good starting value. 1+ render slower, but may improve image coherency. 100 is a good value if you decide to turn it on.\n",
+ "init_scale_schedule = [0,0] #controls coherency with previous frame in pixel space. 0 - off, 1000 - a good starting value if you decide to turn it on.\n",
+ "sat_scale = 0\n",
+ "\n",
+ "init_grad = False #True - compare result to real frame, False - to stylized frame\n",
+ "grad_denoised = True #fastest, on by default, calc grad towards denoised x instead of input x\n",
+ "\n",
+ "steps_schedule = {\n",
+ " 0: 25\n",
+ "} #schedules total steps. useful with low strength, when you end up with only 10 steps at 0.2 strength x50 steps. Increasing max steps for low strength gives model more time to get to your text prompt\n",
+ "style_strength_schedule = [0.7]#[0.5]+[0.2]*149+[0.3]*3+[0.2] #use this instead of skip steps. It means how many steps we should do. 0.8 = we diffuse for 80% steps, so we skip 20%. So for skip steps 70% use 0.3\n",
+ "flow_blend_schedule = [0.8] #for example [0.1]*3+[0.999]*18+[0.3] will fade-in for 3 frames, keep style for 18 frames, and fade-out for the rest\n",
+ "cfg_scale_schedule = [15] #text2image strength, 7.5 is a good default\n",
+ "blend_json_schedules = True #True - interpolate values between keyframes. False - use latest keyframe\n",
+ "\n",
+ "dynamic_thresh = 30\n",
+ "\n",
+ "fixed_code = False #Aka fixed seed. you can use this with fast moving videos, but be careful with still images\n",
+ "code_randomness = 0.1 # Only affects fixed code. high values make the output collapse\n",
+ "# normalize_code = True #Only affects fixed code.\n",
+ "\n",
+ "warp_strength = 1 #leave 1 for no change. 1.01 is already a strong value.\n",
+ "flow_override_map = []#[*range(1,15)]+[16]*10+[*range(17+10,17+10+20)]+[18+10+20]*15+[*range(19+10+20+15,9999)] #map flow to frames. set to [] to disable. [1]*10+[*range(10,9999)] repeats 1st frame flow 10 times, then continues as usual\n",
+ "\n",
+ "blend_latent_to_init = 0\n",
+ "\n",
+ "colormatch_after = False #colormatch after stylizing. On in previous notebooks.\n",
+ "colormatch_turbo = False #apply colormatching for turbo frames. On in previous notebooks\n",
+ "\n",
+ "user_comment = 'testing cc layers'\n",
+ "\n",
+ "mask_result = False #imitates inpainting by leaving only inconsistent areas to be diffused\n",
+ "\n",
+ "use_karras_noise = False #Should work better with current sample, needs more testing.\n",
+ "end_karras_ramp_early = False\n",
+ "\n",
+ "warp_interp = Image.LANCZOS\n",
+ "VERBOSE = True\n",
+ "\n",
+ "use_patchmatch_inpaiting = 0\n",
+ "\n",
+ "warp_num_k = 128 # number of patches per frame\n",
+ "warp_forward = False #use k-means patched warping (moves large areas instead of single pixels)\n",
+ "\n",
+ "inverse_inpainting_mask = False\n",
+ "inpainting_mask_weight = 1.\n",
+ "mask_source = 'none'\n",
+ "mask_clip_low = 0\n",
+ "mask_clip_high = 255\n",
+ "sampler = sample_euler\n",
+ "image_scale = 2\n",
+ "image_scale_schedule = {0:1.5, 1:2}\n",
+ "\n",
+ "inpainting_mask_source = 'none'\n",
+ "\n",
+ "fixed_seed = False #fixes seed\n",
+ "offload_model = True #offloads model to cpu defore running decoder. May save a bit of VRAM\n",
+ "\n",
+ "use_predicted_noise = False\n",
+ "rec_randomness = 0.\n",
+ "rec_cfg = 1.\n",
+ "rec_prompts = {0: ['woman walking on a treadmill']}\n",
+ "rec_source = 'init'\n",
+ "rec_steps_pct = 1\n",
+ "\n",
+ "#controlnet settings\n",
+ "controlnet_preprocess = True #preprocess input conditioning image for controlnet. If false, use raw conditioning as input to the model without detection/preprocessing\n",
+ "detect_resolution = 768 #control net conditioning image resolution\n",
+ "bg_threshold = 0.4 #controlnet depth/normal bg cutoff threshold\n",
+ "low_threshold = 100 #canny filter parameters\n",
+ "high_threshold = 200 #canny filter parameters\n",
+ "value_threshold = 0.1 #mlsd model settings\n",
+ "distance_threshold = 0.1 #mlsd model settings\n",
+ "\n",
+ "temporalnet_source = 'stylized'\n",
+ "temporalnet_skip_1st_frame = True\n",
+ "controlnet_multimodel_mode = 'internal' #external or internal. internal - sums controlnet values before feeding those into diffusion model, external - sum outputs of differnet contolnets after passing through diffusion model. external seems slower but smoother.)\n",
+ "\n",
+ "do_softcap = False #softly clamp latent excessive values. reduces feedback loop effect a bit\n",
+ "softcap_thresh = 0.9 # scale down absolute values above that threshold (latents are being clamped at [-1:1] range, so 0.9 will downscale values above 0.9 to fit into that range, [-1.5:1.5] will be scaled to [-1:1], but only absolute values over 0.9 will be affected)\n",
+ "softcap_q = 1. # percentile to downscale. 1-downscle full range with outliers, 0.9 - downscale only 90% values above thresh, clamp 10%)\n",
+ "\n",
+ "max_faces = 10\n",
+ "masked_guidance = False #use mask for init/latent guidance to ignore inconsistencies and only guide based on the consistent areas\n",
+ "cc_masked_diffusion_schedule = [0.7] # 0 - off. 0.5-0.7 are good values. make inconsistent area passes only before this % of actual steps, then diffuse whole image\n",
+ "alpha_masked_diffusion = 0. # 0 - off. 0.5-0.7 are good values. make alpha masked area passes only before this % of actual steps, then diffuse whole image\n",
+ "invert_alpha_masked_diffusion = False\n",
+ "\n",
+ "save_controlnet_annotations = True\n",
+ "pose_detector = 'dw_pose'\n",
+ "control_sd15_openpose_hands_face = True\n",
+ "control_sd15_depth_detector = 'Zoe' # Zoe or Midas\n",
+ "control_sd15_softedge_detector = 'PIDI' # HED or PIDI\n",
+ "control_sd15_seg_detector = 'Seg_UFADE20K' # Seg_OFCOCO Seg_OFADE20K Seg_UFADE20K\n",
+ "control_sd15_scribble_detector = 'PIDI' # HED or PIDI\n",
+ "control_sd15_lineart_coarse = False\n",
+ "control_sd15_inpaint_mask_source = 'consistency_mask' # consistency_mask, None, cond_video\n",
+ "control_sd15_shuffle_source = 'color_video' # color_video, init, prev_frame, first_frame\n",
+ "control_sd15_shuffle_1st_source = 'color_video' # color_video, init, None,\n",
+ "overwrite_rec_noise = False\n",
+ "\n",
+ "controlnet_multimodel = {\n",
+ " \"control_sd15_depth\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": ''\n",
+ " },\n",
+ " \"control_sd15_canny\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_softedge\": {\n",
+ " \"weight\": 1,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_mlsd\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_normalbae\": {\n",
+ " \"weight\": 1,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_openpose\": {\n",
+ " \"weight\": 1,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_scribble\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_seg\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_temporalnet\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_face\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_ip2p\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_inpaint\": {\n",
+ " \"weight\": 1,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"source\": \"stylized\"\n",
+ " },\n",
+ " \"control_sd15_lineart\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_lineart_anime\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": \"\"\n",
+ " },\n",
+ " \"control_sd15_shuffle\":{\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"source\": \"\"\n",
+ " }\n",
+ "}\n",
+ "if model_version == 'control_multi_sdxl':\n",
+ " controlnet_multimodel = {\n",
+ " \"control_sdxl_canny\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " \"preprocess\": '',\n",
+ " \"mode\": '',\n",
+ " \"detect_resolution\": '',\n",
+ " \"source\": ''\n",
+ " },\n",
+ " \"control_sdxl_depth\": {\n",
+ " \"weight\": 1,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " },\n",
+ " \"control_sdxl_seg\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " },\n",
+ " \"control_sdxl_openpose\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " },\n",
+ " \"control_sdxl_softedge\": {\n",
+ " \"weight\": 1,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1,\n",
+ " }\n",
+ " }\n",
+ "if model_version in ['control_multi_v2','control_multi_v2_768']:\n",
+ " controlnet_multimodel = {\n",
+ " \"control_sd21_canny\": {\n",
+ " \"weight\": 0,\n",
+ " \"start\": 0,\n",
+ " \"end\": 1\n",
+ " }\n",
+ "}\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OeF4nJaf3eiD"
+ },
+ "source": [
+ "### Advanced.\n",
+ "\n",
+ "Barely used. Not duplicated in the gui. You will need to run this cell to apply settings."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DY8NX-kP35h3"
+ },
+ "outputs": [],
+ "source": [
+ "#these variables are not in the GUI and are not being loaded.\n",
+ "cell_name = 'advanced'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "# torch.backends.cudnn.enabled = True # disabling this may increase performance on Ampere and Ada GPUs\n",
+ "\n",
+ "diffuse_inpaint_mask_blur = 25 #used in mask result to extent the mask\n",
+ "diffuse_inpaint_mask_thresh = 0.8 #used in mask result to extent the mask\n",
+ "\n",
+ "add_noise_to_latent = True #add noise to latent vector during latent guidance\n",
+ "noise_upscale_ratio = 1 #noise upscale ratio for latent noise during latent guidance\n",
+ "guidance_use_start_code = True #fix latent noise across steps during latent guidance\n",
+ "init_latent_fn = spherical_dist_loss #function to compute latent distance, l1_loss, rmse, spherical_dist_loss\n",
+ "use_scale = False #use gradient scaling (for mixed precision)\n",
+ "g_invert_mask = False #invert guidance mask\n",
+ "\n",
+ "cb_noise_upscale_ratio = 1 #noise in masked diffusion callback\n",
+ "cb_add_noise_to_latent = True #noise in masked diffusion callback\n",
+ "cb_use_start_code = True #fix noise per frame in masked diffusion callback\n",
+ "cb_fixed_code = False #fix noise across all animation in masked diffusion callback (overcooks fast af)\n",
+ "cb_norm_latent = False #norm cb latent to normal ditribution stats in masked diffusion callback\n",
+ "\n",
+ "img_zero_uncond = False #by default image conditioned models use same image for negative conditioning (i.e. both positive and negative image conditings are the same. you can use empty negative condition by enabling this)\n",
+ "\n",
+ "use_legacy_fixed_code = False\n",
+ "\n",
+ "deflicker_scale = 0.\n",
+ "deflicker_latent_scale = 0\n",
+ "\n",
+ "prompt_patterns_sched = {}\n",
+ "\n",
+ "normalize_prompt_weights = True\n",
+ "controlnet_low_vram = False\n",
+ "\n",
+ "sd_batch_size = 2\n",
+ "\n",
+ "mask_paths = []\n",
+ "\n",
+ "deflicker_scale = 0.\n",
+ "deflicker_latent_scale = 0\n",
+ "\n",
+ "controlnet_mode = 'balanced'\n",
+ "normalize_cn_weights = True\n",
+ "sd_model.normalize_weights = normalize_cn_weights\n",
+ "sd_model.debug = False\n",
+ "\n",
+ "apply_freeu_after_control = False\n",
+ "do_freeunet = False\n",
+ "\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SZ6qrVEJeG1u"
+ },
+ "source": [
+ "# Lora & Embedding paths"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "FXZFBaFv79sK"
+ },
+ "outputs": [],
+ "source": [
+ "#@title LORA & embedding paths\n",
+ "cell_name = 'lora'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "weight_load_location = 'cpu'\n",
+ "from modules import devices, shared\n",
+ "#@markdown Specify folders containing your Loras and Textual Inversion Embeddings. Detected loras will be listed after you run the cell.\n",
+ "lora_dir = 'c:\\\\code\\\\warp\\\\models/loras' #@param {'type':'string'}\n",
+ "if not is_colab and lora_dir.startswith('/content'):\n",
+ " lora_dir = './loras'\n",
+ " print('Overriding lora dir to ./loras for non-colab env because you path begins with /content. Change path to desired folder')\n",
+ "\n",
+ "custom_embed_dir = 'c:\\\\code\\\\warp\\\\models/embeddings' #@param {'type':'string'}\n",
+ "if not is_colab and custom_embed_dir.startswith('/content'):\n",
+ " custom_embed_dir = './embeddings'\n",
+ " os.makedirs(custom_embed_dir, exist_ok=True)\n",
+ " print('Overriding embeddings dir to ./embeddings for non-colab env because you path begins with /content. Change path to desired folder')\n",
+ "\n",
+ "# %cd C:\\code\\warp\\18_venv\\stablediffusion\\modules\\Lora\n",
+ "\n",
+ "os.chdir(f'{root_dir}/stablediffusion/modules/Lora')\n",
+ "from networks import list_available_networks, available_networks, load_networks, assign_network_names_to_compvis_modules, loaded_networks\n",
+ "import networks\n",
+ "os.chdir(root_dir)\n",
+ "list_available_networks(lora_dir)\n",
+ "import re\n",
+ "\n",
+ "print('Found loras: ',[*available_networks.keys()])\n",
+ "if 'sdxl' in model_version: sd_model.is_sdxl = True\n",
+ "else: sd_model.is_sdxl = False\n",
+ "\n",
+ "if not hasattr(torch.nn, 'Linear_forward_before_network'):\n",
+ " torch.nn.Linear_forward_before_network = torch.nn.Linear.forward\n",
+ "\n",
+ "if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):\n",
+ " torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict\n",
+ "\n",
+ "if not hasattr(torch.nn, 'Conv2d_forward_before_network'):\n",
+ " torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward\n",
+ "\n",
+ "if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):\n",
+ " torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict\n",
+ "\n",
+ "if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):\n",
+ " torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward\n",
+ "\n",
+ "if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):\n",
+ " torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict\n",
+ "\n",
+ "\n",
+ "\n",
+ "def inject_network(sd_model):\n",
+ " print('injecting loras')\n",
+ " torch.nn.Linear.forward = networks.network_Linear_forward\n",
+ " torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict\n",
+ " torch.nn.Conv2d.forward = networks.network_Conv2d_forward\n",
+ " torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict\n",
+ " torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward\n",
+ " torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict\n",
+ "\n",
+ " sd_model = assign_network_names_to_compvis_modules(sd_model)\n",
+ "\n",
+ "def unload_network():\n",
+ " torch.nn.Linear.forward = torch.nn.Linear_forward_before_network\n",
+ " torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network\n",
+ " torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network\n",
+ " torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network\n",
+ " torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network\n",
+ " torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# (c) Alex Spirin 2023\n",
+ "\n",
+ "\n",
+ "def split_lora_from_prompts(prompts):\n",
+ " re1 = '\\<(.*?)\\>'\n",
+ " new_prompt_loras = {}\n",
+ " new_prompts = {}\n",
+ "\n",
+ " #iterate through prompts keyframes and fill in lora schedules\n",
+ " for key in prompts.keys():\n",
+ " prompt_list = prompts[key]\n",
+ " prompt_loras = []\n",
+ " new_prompts[key] = []\n",
+ " for i in range(len(prompt_list)):\n",
+ " subp = prompt_list[i]\n",
+ "\n",
+ " #get a dict of loras:weights from a prompt\n",
+ " prompt_loras+=re.findall(re1, subp)\n",
+ " new_prompts[key].append(re.sub(re1, '', subp).strip(' '))\n",
+ "\n",
+ " prompt_loras_dict = dict([(o.split(':')[1], o.split(':')[-1]) for o in prompt_loras])\n",
+ "\n",
+ " #fill lora dict based on keyframe, lora:weight\n",
+ " for lora_key in prompt_loras_dict.keys():\n",
+ " try: new_prompt_loras[lora_key]\n",
+ " except: new_prompt_loras[lora_key] = {}\n",
+ " new_prompt_loras[lora_key][key] = float(prompt_loras_dict[lora_key])\n",
+ "\n",
+ " # remove lora keywords from prompts\n",
+ "\n",
+ "\n",
+ " return new_prompts, new_prompt_loras\n",
+ "\n",
+ "def get_prompt_weights(prompts):\n",
+ " weight_re = r\":\\s*([\\d.]+)\\s*$\"\n",
+ " new_prompts = {}\n",
+ " prompt_weights_dict = {}\n",
+ " max_len = 0\n",
+ " for key in prompts.keys():\n",
+ " prompt_list = prompts[key]\n",
+ " if len(prompt_list) == 1:\n",
+ " prompt_weights_dict[key] = [1] #if 1 prompt set weight to 1\n",
+ " new_prompts[key] = prompt_list\n",
+ " else:\n",
+ " weights = []\n",
+ " new_prompt = []\n",
+ " for i in range(len(prompt_list)):\n",
+ " subprompt = prompt_list[i]\n",
+ " m = re.findall(weight_re, subprompt) #find :number at the end of the string\n",
+ " new_prompt.append(re.sub(weight_re, '', subprompt).strip(' '))\n",
+ " m = m[0] if len(m)>0 else 1\n",
+ " weights.append(m)\n",
+ "\n",
+ " prompt_weights_dict[key] = weights\n",
+ " new_prompts[key] = new_prompt\n",
+ " max_len = max(max_len,len(prompt_weights_dict[key]))\n",
+ "\n",
+ " for key in prompt_weights_dict.keys():\n",
+ " weights = prompt_weights_dict[key]\n",
+ " if len(weights) self.attn_weight:\n",
+ " self_attention_context = torch.cat([self_attention_context] + self.bank, dim=1)\n",
+ " self.bank.clear()\n",
+ " self_attn1 = self.attn1(x_norm1, context=self_attention_context)\n",
+ "\n",
+ " x = self_attn1 + x\n",
+ " x = self.attn2(self.norm2(x), context=context) + x\n",
+ " x = self.ff(self.norm3(x)) + x\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Attention Injection by Lvmin Zhang\n",
+ "# https://github.com/lllyasviel\n",
+ "# https://github.com/Mikubill/sd-webui-controlnet\n",
+ "outer = sd_model.model.diffusion_model\n",
+ "def control_forward(x, timesteps=None, context=None, control=None, only_mid_control=False, self=outer, **kwargs):\n",
+ " if reference_latent is not None:\n",
+ " # print('Using reference')\n",
+ " query_size = int(x.shape[0])\n",
+ " used_hint_cond_latent = reference_latent\n",
+ " try:\n",
+ " uc_mask_shape = outer.uc_mask_shape\n",
+ " except:\n",
+ " uc_mask_shape = [0,1]\n",
+ " uc_mask = torch.tensor(uc_mask_shape, dtype=x.dtype, device=x.device)[:, None, None, None]\n",
+ " ref_cond_xt = sd_model.q_sample(used_hint_cond_latent, torch.round(timesteps.float()).long())\n",
+ "\n",
+ " if reference_mode=='Controlnet':\n",
+ " ref_uncond_xt = x.clone()\n",
+ " # print('ControlNet More Important - Using standard cfg for reference.')\n",
+ " elif reference_mode=='Prompt':\n",
+ " ref_uncond_xt = ref_cond_xt.clone()\n",
+ " # print('Prompt More Important - Using no cfg for reference.')\n",
+ " else:\n",
+ " ldm_time_max = getattr(sd_model, 'num_timesteps', 1000)\n",
+ " time_weight = (timesteps.float() / float(ldm_time_max)).clip(0, 1)[:, None, None, None]\n",
+ " time_weight *= torch.pi * 0.5\n",
+ " # We should use sin/cos to make sure that the std of weighted matrix follows original ddpm schedule\n",
+ " ref_uncond_xt = x * torch.sin(time_weight) + ref_cond_xt.clone() * torch.cos(time_weight)\n",
+ " # print('Balanced - Using time-balanced cfg for reference.')\n",
+ " for module in outer.attn_module_list:\n",
+ " module.bank = []\n",
+ " ref_xt = ref_cond_xt * uc_mask + ref_uncond_xt * (1 - uc_mask)\n",
+ " outer.attention_auto_machine = AttentionAutoMachine.Write\n",
+ " # print('ok')\n",
+ " outer.original_forward(x=ref_xt, timesteps=timesteps, context=context)\n",
+ " outer.attention_auto_machine = AttentionAutoMachine.Read\n",
+ " outer.attention_auto_machine_weight = reference_weight\n",
+ "\n",
+ " hs = []\n",
+ " with torch.no_grad():\n",
+ " t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n",
+ " emb = self.time_embed(t_emb)\n",
+ " h = x.type(self.dtype)\n",
+ " for module in self.input_blocks:\n",
+ " h = module(h, emb, context)\n",
+ " hs.append(h)\n",
+ " h = self.middle_block(h, emb, context)\n",
+ "\n",
+ " if control is not None: h += control.pop()\n",
+ "\n",
+ " for i, module in enumerate(self.output_blocks):\n",
+ " if only_mid_control or control is None:\n",
+ " h = torch.cat([h, hs.pop()], dim=1)\n",
+ " else:\n",
+ " h = torch.cat([h, hs.pop() + control.pop()], dim=1)\n",
+ " h = module(h, emb, context)\n",
+ "\n",
+ " h = h.type(x.dtype)\n",
+ " return self.out(h)\n",
+ "\n",
+ "import inspect, re\n",
+ "\n",
+ "def varname(p):\n",
+ " for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:\n",
+ " m = re.search(r'\\bvarname\\s*\\(\\s*([A-Za-z_][A-Za-z0-9_]*)\\s*\\)', line)\n",
+ " if m:\n",
+ " return m.group(1)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "use_reference = False #@param {'type':'boolean'}\n",
+ "reference_weight = 0.5 #@param\n",
+ "reference_source = 'init' #@param ['stylized', 'init', 'prev_frame','color_video']\n",
+ "reference_mode = 'Balanced' #@param ['Balanced', 'Controlnet', 'Prompt']\n",
+ "\n",
+ "reference_active = reference_weight>0 and use_reference and reference_source != 'None'\n",
+ "if 'sdxl' in model_version:\n",
+ " reference_active = False\n",
+ " print('Temporarily disabling reference controlnet for SDXL')\n",
+ "if reference_active:\n",
+ " # outer = sd_model.model.diffusion_model\n",
+ " try:\n",
+ " outer.forward = outer.original_forward\n",
+ " except: pass\n",
+ " outer.original_forward = outer.forward\n",
+ " outer.attention_auto_machine_weight = reference_weight\n",
+ " outer.forward = control_forward\n",
+ " outer.attention_auto_machine = AttentionAutoMachine.Read\n",
+ " print('Using reference control.')\n",
+ "\n",
+ " attn_modules = [module for module in torch_dfs(outer) if isinstance(module, BasicTransformerBlock)]\n",
+ " attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])\n",
+ "\n",
+ " for i, module in enumerate(attn_modules):\n",
+ " module._original_inner_forward = module._forward\n",
+ " module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n",
+ " module.bank = []\n",
+ " module.attn_weight = float(i) / float(len(attn_modules))\n",
+ "\n",
+ " outer.attn_module_list = attn_modules\n",
+ " for module in outer.attn_module_list:\n",
+ " module.bank = []\n",
+ "\n",
+ "\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-I3pjiyu9X9c"
+ },
+ "source": [
+ "# GUI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "41JKuLDjL5Td"
+ },
+ "outputs": [],
+ "source": [
+ "#@title gui\n",
+ "cell_name = 'GUI'\n",
+ "check_execution(cell_name)\n",
+ "global_keys = ['global', '', -1, '-1','global_settings']\n",
+ "\n",
+ "#@markdown Load settings from txt file or output frame image\n",
+ "gui_difficulty_dict = {\n",
+ " \"I'm too young to die.\":[\"flow_warp\", \"warp_strength\",\"warp_mode\",\"padding_mode\",\"padding_ratio\",\n",
+ " \"warp_towards_init\", \"flow_override_map\", \"mask_clip\", \"warp_num_k\",\"warp_forward\",\n",
+ " \"blend_json_schedules\", \"VERBOSE\",\"offload_model\", \"do_softcap\", \"softcap_thresh\",\n",
+ " \"softcap_q\", \"user_comment\",\"turbo_mode\",\"turbo_steps\", \"colormatch_turbo\",\n",
+ " \"turbo_frame_skips_steps\",\"soften_consistency_mask_for_turbo_frames\", \"check_consistency\",\n",
+ " \"missed_consistency_weight\",\"overshoot_consistency_weight\", \"edges_consistency_weight\",\n",
+ " \"soften_consistency_mask\",\"consistency_blur\",\"match_color_strength\",\"mask_result\",\n",
+ " \"use_patchmatch_inpaiting\",\"normalize_latent\",\"normalize_latent_offset\",\"latent_fixed_mean\",\n",
+ " \"latent_fixed_std\",\"latent_norm_4d\",\"use_karras_noise\", \"cond_image_src\", \"inpainting_mask_source\",\n",
+ " \"inverse_inpainting_mask\", \"inpainting_mask_weight\", \"init_grad\", \"grad_denoised\",\n",
+ " \"image_scale_schedule\",\"blend_latent_to_init\",\"dynamic_thresh\",\"rec_cfg\", \"rec_source\",\n",
+ " \"rec_steps_pct\", \"controlnet_multimodel_mode\",\n",
+ " \"overwrite_rec_noise\",\n",
+ " \"colormatch_after\",\"sat_scale\", \"clamp_grad\", \"apply_mask_after_warp\"],\n",
+ " \"Hey, not too rough.\":[\"flow_warp\", \"warp_strength\",\"warp_mode\",\n",
+ " \"warp_towards_init\", \"flow_override_map\", \"mask_clip\", \"warp_num_k\",\"warp_forward\",\n",
+ "\n",
+ " \"check_consistency\",\n",
+ "\n",
+ " \"use_patchmatch_inpaiting\",\"init_grad\", \"grad_denoised\",\n",
+ " \"image_scale_schedule\",\"blend_latent_to_init\",\"rec_cfg\",\n",
+ "\n",
+ " \"colormatch_after\",\"sat_scale\", \"clamp_grad\", \"apply_mask_after_warp\"],\n",
+ " \"Hurt me plenty.\":\"\",\n",
+ " \"Ultra-Violence.\":[\"warp_mode\",\"use_patchmatch_inpaiting\",\"warp_num_k\",\"warp_forward\",\"sat_scale\",]\n",
+ "}\n",
+ "import traceback\n",
+ "gui_difficulty = \"Hey, not too rough.\" #@param [\"I'm too young to die.\", \"Hey, not too rough.\", \"Ultra-Violence.\"]\n",
+ "print(f'Using \"{gui_difficulty}\" gui difficulty. Please switch to another difficulty\\nto unlock up to {len(gui_difficulty_dict[gui_difficulty])} more settings when you`re ready :D')\n",
+ "settings_path = '-1' #@param {'type':'string'}\n",
+ "load_settings_from_file = True #@param {'type':'boolean'}\n",
+ "#@markdown Disable to load settings into GUI from colab cells. You will need to re-run colab cells you've edited to apply changes, then re-run the gui cell.\\\n",
+ "#@markdown Enable to keep GUI state.\n",
+ "keep_gui_state_on_cell_rerun = True #@param {'type':'boolean'}\n",
+ "settings_out = batchFolder+f\"/settings\"\n",
+ "from ipywidgets import HTML, IntRangeSlider, FloatRangeSlider, jslink, Layout, VBox, HBox, Tab, Label, IntText, Dropdown, Text, Accordion, Button, Output, Textarea, FloatSlider, FloatText, Checkbox, SelectionSlider, Valid\n",
+ "\n",
+ "def desc_widget(widget, desc, width=80, h=True):\n",
+ " if isinstance(widget, Checkbox): return widget\n",
+ " if isinstance(width, str):\n",
+ " if width.endswith('%') or width.endswith('px'):\n",
+ " layout = Layout(width=width)\n",
+ " else: layout = Layout(width=f'{width}')\n",
+ "\n",
+ " text = Label(desc, layout = layout, tooltip = widget.tooltip, description_tooltip = widget.description_tooltip)\n",
+ " return HBox([text, widget]) if h else VBox([text, widget])\n",
+ "\n",
+ "no_preprocess_cn = ['control_sd21_qr','control_sd15_qr','control_sd15_temporalnet','control_sdxl_temporalnet_v1',\n",
+ " 'control_sd15_ip2p','control_sd15_shuffle','control_sd15_inpaint','control_sd15_tile']\n",
+ "\n",
+ "no_resolution_cn = ['control_sd21_qr','control_sd15_qr','control_sd15_temporalnet','control_sdxl_temporalnet_v1',\n",
+ " 'control_sd15_ip2p','control_sd15_shuffle','control_sd15_inpaint','control_sd15_tile']\n",
+ "\n",
+ "\n",
+ "class ControlNetControls(HBox):\n",
+ " def __init__(self, name, values, **kwargs):\n",
+ " self.label = HTML(\n",
+ " description=name,\n",
+ " description_tooltip=name, style={'description_width': 'initial' },\n",
+ " layout = Layout(position='relative', left='-25px', width='200px'))\n",
+ " self.name = name\n",
+ " self.enable = Checkbox(value=values['weight']>0,description='',indent=True, description_tooltip='Enable model.',\n",
+ " style={'description_width': '25px' },layout=Layout(width='70px', left='-25px'))\n",
+ " self.weight = FloatText(value = values['weight'], description=' ', step=0.05,\n",
+ " description_tooltip = 'Controlnet model weights. ',\n",
+ " layout=Layout(width='100px', visibility= 'visible' if values['weight']>0 else 'hidden'),\n",
+ " style={'description_width': '25px' })\n",
+ " self.start_end = FloatRangeSlider(\n",
+ " value=[values['start'],values['end']],\n",
+ " min=0,\n",
+ " max=1,\n",
+ " step=0.01,\n",
+ " description=' ',\n",
+ " description_tooltip='Controlnet active step range settings. For example, [||||||||||] 50 steps, [-------|||] 0.3 style strength (effective steps - 0.3x50 = 15), [--||||||--] - controlnet working range with start = 0.2 and end = 0.8, effective steps from 0.2x50 = 10 to 0.8x50 = 40',\n",
+ " disabled=False,\n",
+ " continuous_update=False,\n",
+ " orientation='horizontal',\n",
+ " readout=True,\n",
+ " layout = Layout(width='300px', visibility= 'visible' if values['weight']>0 else 'hidden'),\n",
+ " style={'description_width': '50px' }\n",
+ " )\n",
+ "\n",
+ "\n",
+ " if (not \"preprocess\" in values.keys()) or values[\"preprocess\"] in global_keys:\n",
+ " values[\"preprocess\"] = 'global'\n",
+ "\n",
+ " if (not \"mode\" in values.keys()) or values[\"mode\"] in global_keys:\n",
+ " values[\"mode\"] = 'global'\n",
+ "\n",
+ " if (not \"detect_resolution\" in values.keys()) or values[\"detect_resolution\"] in global_keys:\n",
+ " values[\"detect_resolution\"] = -1\n",
+ "\n",
+ "\n",
+ " if (not \"source\" in values.keys()) or values[\"source\"] in global_keys:\n",
+ " if name == 'control_sd15_inpaint': values[\"source\"] = 'stylized'\n",
+ " else: values[\"source\"] = 'global'\n",
+ " if values[\"source\"] == 'init': values[\"source\"] = 'raw_frame'\n",
+ "\n",
+ "\n",
+ " self.preprocess = Dropdown(description='',\n",
+ " options = ['True', 'False', 'global'], value = values['preprocess'],\n",
+ " description_tooltip='Preprocess input for this controlnet', layout=Layout(width='80px'))\n",
+ "\n",
+ " self.mode = Dropdown(description='',\n",
+ " options = ['balanced', 'controlnet', 'prompt', 'global'], value = values['mode'],\n",
+ " description_tooltip='Controlnet mode. Pay more attention to controlnet prediction, to prompt or somewhere in-between.',\n",
+ " layout=Layout(width='100px'))\n",
+ "\n",
+ " self.detect_resolution = IntText(value = values['detect_resolution'], description='',\n",
+ " description_tooltip = 'Controlnet detect_resolution.',layout=Layout(width='80px'), style={'description_width': 'initial' })\n",
+ "\n",
+ " self.source = Text(value=values['source'], description = '', layout=Layout(width='200px'),\n",
+ " description_tooltip='controlnet input source, either a file or video, raw_frame, cond_video, color_video, or stylized - to use previously stylized frame ad input. leave empty for global source')\n",
+ "\n",
+ " self.enable.observe(self.on_change)\n",
+ " self.weight.observe(self.on_change)\n",
+ " settings = [self.enable, self.label, self.weight, self.start_end, self.mode, self.source, self.detect_resolution, self.preprocess]\n",
+ " # no_preprocess_cn = ['control_sd21_qr','control_sd15_qr','control_sd15_temporalnet','control_sdxl_temporalnet_v1',\n",
+ " # 'control_sd15_ip2p','control_sd15_shuffle','control_sd15_inpaint','control_sd15_tile']\n",
+ " if name in no_preprocess_cn: self.preprocess.layout.visibility = 'hidden'\n",
+ " # no_resolution_cn = ['control_sd21_qr','control_sd15_qr','control_sd15_temporalnet','control_sdxl_temporalnet_v1',\n",
+ " # 'control_sd15_ip2p','control_sd15_shuffle','control_sd15_inpaint','control_sd15_tile']\n",
+ " if name in no_resolution_cn: self.detect_resolution.layout.visibility = 'hidden'\n",
+ "\n",
+ " if values['weight']==0:\n",
+ " self.preprocess.layout.visibility = 'hidden'\n",
+ " self.mode.layout.visibility = 'hidden'\n",
+ " self.detect_resolution.layout.visibility = 'hidden'\n",
+ " self.source.layout.visibility = 'hidden'\n",
+ " super().__init__(settings, layout = Layout(valign='center'))\n",
+ "\n",
+ " def on_change(self, change):\n",
+ " if change['name'] == 'value':\n",
+ " if self.enable.value:\n",
+ " self.weight.layout.visibility = 'visible'\n",
+ " if change['old'] == False and self.weight.value==0:\n",
+ " self.weight.value = 1\n",
+ " self.start_end.layout.visibility = 'visible'\n",
+ " self.preprocess.layout.visibility = 'visible'\n",
+ " self.mode.layout.visibility = 'visible'\n",
+ " self.detect_resolution.layout.visibility = 'visible'\n",
+ " self.source.layout.visibility = 'visible'\n",
+ " else:\n",
+ " self.weight.layout.visibility = 'hidden'\n",
+ " self.start_end.layout.visibility = 'hidden'\n",
+ " self.preprocess.layout.visibility = 'hidden'\n",
+ " self.mode.layout.visibility = 'hidden'\n",
+ " self.detect_resolution.layout.visibility = 'hidden'\n",
+ " self.source.layout.visibility = 'hidden'\n",
+ "\n",
+ " def __setattr__(self, attr, values):\n",
+ " if attr == 'value':\n",
+ " self.enable.value = values['weight']>0\n",
+ " self.weight.value = values['weight']\n",
+ " self.start_end.value=[values['start'],values['end']]\n",
+ " if (not \"preprocess\" in values.keys()) or values[\"preprocess\"] in global_keys:\n",
+ " values[\"preprocess\"] = 'global'\n",
+ "\n",
+ " if (not \"mode\" in values.keys()) or values[\"mode\"] in global_keys:\n",
+ " values[\"mode\"] = 'global'\n",
+ "\n",
+ " if (not \"detect_resolution\" in values.keys()) or values[\"detect_resolution\"] in global_keys:\n",
+ " values[\"detect_resolution\"] = -1\n",
+ "\n",
+ " if (not \"source\" in values.keys()) or values[\"source\"] in global_keys:\n",
+ " if self.name == 'control_sd15_inpaint': values[\"source\"] = 'stylized'\n",
+ " else: values[\"source\"] = 'global'\n",
+ " if values[\"source\"] == 'init': values[\"source\"] = 'raw_frame'\n",
+ " self.preprocess.value = values['preprocess']\n",
+ " self.mode.value = values['mode']\n",
+ " self.detect_resolution.value = values['detect_resolution']\n",
+ " self.source.value=values['source']\n",
+ "\n",
+ " else: super().__setattr__(attr, values)\n",
+ "\n",
+ " def __getattr__(self, attr):\n",
+ " if attr == 'value':\n",
+ " weight = 0\n",
+ " if self.weight.value>0 and self.enable.value: weight = self.weight.value\n",
+ " (start,end) = self.start_end.value\n",
+ " values = {\n",
+ " \"weight\": weight,\n",
+ " \"start\":start,\n",
+ " \"end\":end,\n",
+ "\n",
+ " }\n",
+ " if True:\n",
+ " # if self.preprocess.value not in global_keys:\n",
+ " values['preprocess'] = self.preprocess.value\n",
+ " # if self.mode.value not in global_keys:\n",
+ " values['mode'] = self.mode.value\n",
+ " # if self.detect_resolution.value not in global_keys:\n",
+ " values['detect_resolution'] = self.detect_resolution.value\n",
+ " # if self.source.value not in global_keys:\n",
+ " values['source'] = self.source.value\n",
+ " # print('returned values', values)\n",
+ " return values\n",
+ " if attr == 'name':\n",
+ " return self.name\n",
+ " else:\n",
+ " return super.__getattr__(attr)\n",
+ "\n",
+ "class ControlGUI(VBox):\n",
+ " def __init__(self, args):\n",
+ " enable_label = HTML(\n",
+ " description='Enable',\n",
+ " description_tooltip='Enable', style={'description_width': '50px' },\n",
+ " layout = Layout(width='40px', left='-50px', ))\n",
+ " model_label = HTML(\n",
+ " description='Model name',\n",
+ " description_tooltip='Model name', style={'description_width': '100px' },\n",
+ " layout = Layout(width='265px'))\n",
+ " weight_label = HTML(\n",
+ " description='weight',\n",
+ " description_tooltip='Model weight. 0 weight effectively disables the model. The total sum of all the weights will be normalized to 1.', style={'description_width': 'initial' },\n",
+ " layout = Layout(position='relative', left='-25px', width='125px'))#65\n",
+ " range_label = HTML(\n",
+ " description='active range (% or total steps)',\n",
+ " description_tooltip='Model`s active range. % of total steps when the model is active.\\n Controlnet active step range settings. For example, [||||||||||] 50 steps, [-------|||] 0.3 style strength (effective steps - 0.3x50 = 15), [--||||||--] - controlnet working range with start = 0.2 and end = 0.8, effective steps from 0.2x50 = 10 to 0.8x50 = 40', style={'description_width': 'initial' },\n",
+ " layout = Layout(position='relative', left='-25px', width='200px'))\n",
+ " mode_label = HTML(\n",
+ " description='mode',\n",
+ " description_tooltip='Controlnet mode. Pay more attention to controlnet prediction, to prompt or somewhere in-between.', layout = Layout(width='110px', left='0px', ))\n",
+ " source_label = HTML(\n",
+ " description='source',\n",
+ " description_tooltip='controlnet input source, either a file or video, raw_frame, cond_video, color_video, or stylized - to use previously stylized frame ad input. leave empty for global source',\n",
+ " layout = Layout(width='210px', left='0px', ))\n",
+ " resolution_label = HTML(\n",
+ " description='resolution',\n",
+ " description_tooltip='Controlnet detect_resolution. The size of the image fed into annotator model if current controlnet has one.',\n",
+ " layout = Layout(width='90px', left='0px', ))\n",
+ " preprocess_label = HTML(\n",
+ " description='preprocess',\n",
+ " description_tooltip='Preprocess (put through annotator model) input for this controlnet. When disabled, puts raw image from selected source into the controlnet. For example, if you have sequence of pdeth maps from your 3d software, you need to put path to those maps into source field and disable preprocessing.',\n",
+ " layout = Layout(width='80px', left='0px', ))\n",
+ " controls_list = [HBox([enable_label,model_label, weight_label, range_label, mode_label, source_label, resolution_label, preprocess_label ])]\n",
+ " controls_dict = {}\n",
+ " possible_controlnets = ['control_sd15_depth',\n",
+ " 'control_sd15_canny',\n",
+ " 'control_sd15_softedge',\n",
+ " 'control_sd15_mlsd',\n",
+ " 'control_sd15_normalbae',\n",
+ " 'control_sd15_openpose',\n",
+ " 'control_sd15_scribble',\n",
+ " 'control_sd15_seg',\n",
+ " 'control_sd15_temporalnet',\n",
+ " 'control_sd15_face',\n",
+ " 'control_sd15_ip2p',\n",
+ " 'control_sd15_inpaint',\n",
+ " 'control_sd15_lineart',\n",
+ " 'control_sd15_lineart_anime',\n",
+ " 'control_sd15_shuffle',\n",
+ " 'control_sd15_tile',\n",
+ " 'control_sd15_qr',\n",
+ " 'control_sd15_inpaint_softedge',\n",
+ " 'control_sd15_temporal_depth',\n",
+ " ]\n",
+ " possible_controlnets_sdxl = [\n",
+ " 'control_sdxl_canny',\n",
+ " 'control_sdxl_depth',\n",
+ " 'control_sdxl_softedge',\n",
+ " 'control_sdxl_seg',\n",
+ " 'control_sdxl_openpose',\n",
+ " 'control_sdxl_lora_128_depth',\n",
+ " \"control_sdxl_lora_256_depth\",\n",
+ " \"control_sdxl_lora_128_canny\",\n",
+ " \"control_sdxl_lora_256_canny\",\n",
+ " \"control_sdxl_lora_128_softedge\",\n",
+ " \"control_sdxl_lora_256_softedge\",\n",
+ " \"control_sdxl_temporalnet_v1\"\n",
+ " ]\n",
+ " possible_controlnets_v2 = [\n",
+ " 'control_sd21_qr',\n",
+ " \"control_sd21_depth\",\n",
+ " \"control_sd21_scribble\",\n",
+ " \"control_sd21_openpose\",\n",
+ " \"control_sd21_normalbae\",\n",
+ " \"control_sd21_lineart\",\n",
+ " \"control_sd21_softedge\",\n",
+ " \"control_sd21_canny\",\n",
+ " \"control_sd21_seg\"\n",
+ " ]\n",
+ " self.possible_controlnets = possible_controlnets\n",
+ " if model_version == 'control_multi':\n",
+ " self.possible_controlnets = possible_controlnets\n",
+ " elif model_version == 'control_multi_sdxl':\n",
+ " self.possible_controlnets = possible_controlnets_sdxl\n",
+ " elif model_version in ['control_multi_v2','control_multi_v2_768']:\n",
+ " self.possible_controlnets = possible_controlnets_v2\n",
+ "\n",
+ " for key in self.possible_controlnets:\n",
+ " if key in args.keys():\n",
+ " w = ControlNetControls(key, args[key])\n",
+ " else:\n",
+ " w = ControlNetControls(key, {\n",
+ " \"weight\":0,\n",
+ " \"start\":0,\n",
+ " \"end\":1\n",
+ " })\n",
+ " w.name = key\n",
+ " controls_list.append(w)\n",
+ " controls_dict[key] = w\n",
+ "\n",
+ " self.args = args\n",
+ " self.ws = controls_dict\n",
+ " super(ControlGUI, self).__init__(controls_list)\n",
+ "\n",
+ " def __setattr__(self, attr, values):\n",
+ " if attr == 'value':\n",
+ " keys = values.keys()\n",
+ " for i in range(len(self.children)):\n",
+ " w = self.children[i]\n",
+ " if isinstance(w, ControlNetControls) :\n",
+ " w.enable.value = False\n",
+ " for key in values.keys():\n",
+ " if w.name == key:\n",
+ " self.children[i].value = values[key]\n",
+ " else:\n",
+ " super().__setattr__(attr, values)\n",
+ "\n",
+ " def __getattr__(self, attr):\n",
+ " if attr == 'value':\n",
+ " res = {}\n",
+ " for key in self.possible_controlnets:\n",
+ " if self.ws[key].value['weight'] > 0:\n",
+ " res[key] = self.ws[key].value\n",
+ " return res\n",
+ " else:\n",
+ " return super.__getattr__(attr)\n",
+ "\n",
+ "def set_visibility(key, value, obj):\n",
+ " if isinstance(obj, dict):\n",
+ " if key in obj.keys():\n",
+ " obj[key].layout.visibility = value\n",
+ "\n",
+ "def get_settings_from_gui(user_settings_keys, guis):\n",
+ " for key in user_settings_keys:\n",
+ " if key in ['mask_clip_low', 'mask_clip_high']:\n",
+ " value = get_value('mask_clip', guis)\n",
+ " else:\n",
+ " value = get_value(key, guis)\n",
+ "\n",
+ " if key in ['latent_fixed_mean', 'latent_fixed_std']:\n",
+ " value = str(value)\n",
+ "\n",
+ " #apply eval for string schedules\n",
+ " if key in user_settings_eval_keys:\n",
+ " try:\n",
+ " value = eval(value)\n",
+ " except Exception as e:\n",
+ " print(e, key, value)\n",
+ "\n",
+ " #load mask clip\n",
+ " if key == 'mask_clip_low':\n",
+ " value = value[0]\n",
+ " if key == 'mask_clip_high':\n",
+ " value = value[1]\n",
+ "\n",
+ " user_settings[key] = value\n",
+ " return user_settings\n",
+ "\n",
+ "\n",
+ "def set_globals_from_gui(user_settings_keys, guis):\n",
+ "\n",
+ " for key in user_settings_keys:\n",
+ " if key not in globals().keys():\n",
+ " print(f'Variable {key} is not defined or present in globals()')\n",
+ " continue\n",
+ " #load mask clip\n",
+ "\n",
+ " if key in ['mask_clip_low', 'mask_clip_high']:\n",
+ " value = get_value('mask_clip', guis)\n",
+ " else:\n",
+ " value = get_value(key, guis)\n",
+ "\n",
+ " if key in ['latent_fixed_mean', 'latent_fixed_std']:\n",
+ " value = str(value)\n",
+ "\n",
+ " #apply eval for string schedules\n",
+ " if key in user_settings_eval_keys:\n",
+ " value = eval(value)\n",
+ "\n",
+ " if key == 'mask_clip_low':\n",
+ " value = value[0]\n",
+ " if key == 'mask_clip_high':\n",
+ " value = value[1]\n",
+ "\n",
+ " globals()[key] = value\n",
+ "\n",
+ "#try keep settings on occasional run cell\n",
+ "if keep_gui_state_on_cell_rerun:\n",
+ " try:\n",
+ " # user_settings = get_settings_from_gui(user_settings, guis)\n",
+ " set_globals_from_gui(user_settings_keys, guis)\n",
+ " except:\n",
+ " if not \"NameError: name 'get_value' is not defined\" in traceback.format_exc() and not \"NameError: name 'guis' is not defined\" in traceback.format_exc():\n",
+ " print('Error keeping state')\n",
+ " print(traceback.format_exc())\n",
+ " else:\n",
+ " pass\n",
+ "\n",
+ "gui_misc = {\n",
+ " \"user_comment\": Textarea(value=user_comment,layout=Layout(width=f'80%'), description = 'user_comment:', description_tooltip = 'Enter a comment to differentiate between save files.'),\n",
+ " \"blend_json_schedules\": Checkbox(value=blend_json_schedules, description='blend_json_schedules',indent=True, description_tooltip = 'Smooth values between keyframes.', tooltip = 'Smooth values between keyframes.'),\n",
+ " \"VERBOSE\": Checkbox(value=VERBOSE,description='VERBOSE',indent=True, description_tooltip = 'Print all logs'),\n",
+ " \"offload_model\": Checkbox(value=offload_model,description='offload_model',indent=True, description_tooltip = 'Offload unused models to CPU and back to GPU to save VRAM. May reduce speed.'),\n",
+ " \"do_softcap\": Checkbox(value=do_softcap,description='do_softcap',indent=True, description_tooltip = 'Softly clamp latent excessive values. Reduces feedback loop effect a bit.'),\n",
+ " \"softcap_thresh\":FloatSlider(value=softcap_thresh, min=0, max=1, step=0.05, description='softcap_thresh:', readout=True, readout_format='.1f', description_tooltip='Scale down absolute values above that threshold (latents are being clamped at [-1:1] range, so 0.9 will downscale values above 0.9 to fit into that range, [-1.5:1.5] will be scaled to [-1:1], but only absolute values over 0.9 will be affected'),\n",
+ " \"softcap_q\":FloatSlider(value=softcap_q, min=0, max=1, step=0.05, description='softcap_q:', readout=True, readout_format='.1f', description_tooltip='Percentile to downscale. 1-downscle full range with outliers, 0.9 - downscale only 90% values above thresh, clamp 10%'),\n",
+ " \"sd_batch_size\":IntText(value = sd_batch_size, description='sd_batch_size:', description_tooltip='Diffusion batch size. Default=2 for 1 positive + 1 negative prompt. '),\n",
+ " \"do_freeunet\": Checkbox(value=do_freeunet,description='do_freeunet',indent=True,\n",
+ " description_tooltip= 'Apply freeunet fix'),\n",
+ " \"apply_freeu_after_control\": Checkbox(value=apply_freeu_after_control,description='apply_freeu_after_control',indent=True,\n",
+ " description_tooltip= 'Apply freeunet fix after adding controlnet outputs'),\n",
+ "\n",
+ "\n",
+ "\n",
+ "}\n",
+ "\n",
+ "gui_mask = {\n",
+ " \"use_background_mask\":Checkbox(value=use_background_mask,description='use_background_mask',indent=True, description_tooltip='Enable masking. In order to use it, you have to either extract or provide an existing mask in Video Masking cell.\\n'),\n",
+ " \"invert_mask\":Checkbox(value=invert_mask,description='invert_mask',indent=True, description_tooltip='Inverts the mask, allowing to process either backgroung or characters, depending on your mask.'),\n",
+ " \"background\": Dropdown(description='background',\n",
+ " options = ['image', 'color', 'init_video'], value = background,\n",
+ " description_tooltip='Background type. Image - uses static image specified in background_source, color - uses fixed color specified in background_source, init_video - uses raw init video for masked areas.'),\n",
+ " \"background_source\": Text(value=background_source, description = 'background_source', description_tooltip='Specify image path or color name of hash.'),\n",
+ " \"apply_mask_after_warp\": Checkbox(value=apply_mask_after_warp,description='apply_mask_after_warp',indent=True, description_tooltip='On to reduce ghosting. Apply mask after warping and blending warped image with current raw frame. If off, only current frame will be masked, previous frame will be warped and blended wuth masked current frame.'),\n",
+ " \"mask_clip\" : IntRangeSlider(\n",
+ " value=(mask_clip_low, mask_clip_high),\n",
+ " min=0,\n",
+ " max=255,\n",
+ " step=1,\n",
+ " description='Mask clipping:',\n",
+ " description_tooltip='Values below the selected range will be treated as black mask, values above - as white.',\n",
+ " disabled=False,\n",
+ " continuous_update=False,\n",
+ " orientation='horizontal',\n",
+ " readout=True),\n",
+ " \"mask_paths\":Textarea(value=str(mask_paths),layout=Layout(width=f'80%'), description = 'mask_paths:',\n",
+ " description_tooltip='A list of paths to prompt mask files/folders/glob patterns. Format: [\"/somepath/somefile.mp4\", \"./otherpath/dirwithfiles/*.jpg]'),\n",
+ "\n",
+ "}\n",
+ "\n",
+ "gui_turbo = {\n",
+ " \"turbo_mode\":Checkbox(value=turbo_mode,description='turbo_mode',indent=True, description_tooltip='Turbo mode skips diffusion process on turbo_steps number of frames. Frames are still being warped and blended. Speeds up the render at the cost of possible trails an ghosting.' ),\n",
+ " \"turbo_steps\": IntText(value = turbo_steps, description='turbo_steps:', description_tooltip='Number of turbo frames'),\n",
+ " \"colormatch_turbo\":Checkbox(value=colormatch_turbo,description='colormatch_turbo',indent=True, description_tooltip='Apply frame color matching during turbo frames. May increease rendering speed, but may add minor flickering.'),\n",
+ " \"turbo_frame_skips_steps\" : SelectionSlider(description='turbo_frame_skips_steps',\n",
+ " options = ['70%','75%','80%','85%', '80%', '95%', '100% (don`t diffuse turbo frames, fastest)'], value = '100% (don`t diffuse turbo frames, fastest)', description_tooltip='Skip steps for turbo frames. Select 100% to skip diffusion rendering for turbo frames completely.'),\n",
+ " \"soften_consistency_mask_for_turbo_frames\": FloatSlider(value=soften_consistency_mask_for_turbo_frames, min=0, max=1, step=0.05, description='soften_consistency_mask_for_turbo_frames:', readout=True, readout_format='.1f', description_tooltip='Clips the consistency mask, reducing it`s effect'),\n",
+ "\n",
+ "}\n",
+ "\n",
+ "gui_warp = {\n",
+ " \"flow_warp\":Checkbox(value=flow_warp,description='flow_warp',indent=True, description_tooltip='Blend current raw init video frame with previously stylised frame with respect to consistency mask. 0 - raw frame, 1 - stylized frame'),\n",
+ "\n",
+ " \"flow_blend_schedule\" : Textarea(value=str(flow_blend_schedule),layout=Layout(width=f'80%'), description = 'flow_blend_schedule:', description_tooltip='Blend current raw init video frame with previously stylised frame with respect to consistency mask. 0 - raw frame, 1 - stylized frame'),\n",
+ " \"warp_num_k\": IntText(value = warp_num_k, description='warp_num_k:', description_tooltip='Nubmer of clusters in forward-warp mode. The more - the smoother is the motion. Lower values move larger chunks of image at a time.'),\n",
+ " \"warp_forward\": Checkbox(value=warp_forward,description='warp_forward',indent=True, description_tooltip='Experimental. Enable patch-based flow warping. Groups pixels by motion direction and moves them together, instead of moving individual pixels.'),\n",
+ " # \"warp_interp\": Textarea(value='Image.LANCZOS',layout=Layout(width=f'80%'), description = 'warp_interp:'),\n",
+ " \"warp_strength\": FloatText(value = warp_strength, description='warp_strength:', description_tooltip='Experimental. Motion vector multiplier. Provides a glitchy effect.'),\n",
+ " \"flow_override_map\": Textarea(value=str(flow_override_map),layout=Layout(width=f'80%'), description = 'flow_override_map:', description_tooltip='Experimental. Motion vector maps mixer. Allows changing frame-motion vetor indexes or repeating motion, provides a glitchy effect.'),\n",
+ " \"warp_mode\": Dropdown(description='warp_mode', options = ['use_latent', 'use_image'],\n",
+ " value = warp_mode, description_tooltip='Experimental. Apply warp to latent vector. May get really blurry, but reduces feedback loop effect for slow movement'),\n",
+ " \"warp_towards_init\": Dropdown(description='warp_towards_init',\n",
+ " options = ['stylized', 'off'] , value = warp_towards_init, description_tooltip='Experimental. After a frame is stylized, computes the difference between output and input for that frame, and warps the output back to input, preserving its shape.'),\n",
+ " \"padding_ratio\": FloatSlider(value=padding_ratio, min=0, max=1, step=0.05, description='padding_ratio:', readout=True, readout_format='.1f', description_tooltip='Amount of padding. Padding is used to avoid black edges when the camera is moving out of the frame.'),\n",
+ " \"padding_mode\": Dropdown(description='padding_mode', options = ['reflect','edge','wrap'],\n",
+ " value = padding_mode),\n",
+ "}\n",
+ "\n",
+ "# warp_interp = Image.LANCZOS\n",
+ "\n",
+ "gui_consistency = {\n",
+ " \"check_consistency\":Checkbox(value=check_consistency,description='check_consistency',indent=True, description_tooltip='Enables consistency checking (CC). CC is used to avoid ghosting and trails, that appear due to lack of information while warping frames. It allows replacing motion edges, frame borders, incorrectly moved areas with raw init frame data.'),\n",
+ " \"missed_consistency_weight\":FloatSlider(value=missed_consistency_weight, min=0, max=1, step=0.05, description='missed_consistency_weight:', readout=True, readout_format='.1f', description_tooltip='Multiplier for incorrectly predicted\\moved areas. For example, if an object moves and background appears behind it. We can predict what to put in that spot, so we can either duplicate the object, resulting in trail, or use init video data for that region.'),\n",
+ " \"overshoot_consistency_weight\":FloatSlider(value=overshoot_consistency_weight, min=0, max=1, step=0.05, description='overshoot_consistency_weight:', readout=True, readout_format='.1f', description_tooltip='Multiplier for areas that appeared out of the frame. We can either leave them black or use raw init video.'),\n",
+ " \"edges_consistency_weight\":FloatSlider(value=edges_consistency_weight, min=0, max=1, step=0.05, description='edges_consistency_weight:', readout=True, readout_format='.1f', description_tooltip='Multiplier for motion edges. Moving objects are most likely to leave trails, this option together with missed consistency weight helps prevent that, but in a more subtle manner.'),\n",
+ " \"soften_consistency_mask\" : FloatSlider(value=soften_consistency_mask, min=0, max=1, step=0.05, description='soften_consistency_mask:', readout=True, readout_format='.1f'),\n",
+ " \"consistency_blur\": FloatText(value = consistency_blur, description='consistency_blur:'),\n",
+ " \"consistency_dilate\": FloatText(value = consistency_dilate, description='consistency_dilate:', description_tooltip='expand consistency mask without blurring the edges'),\n",
+ " \"barely used\": Label(' '),\n",
+ " \"match_color_strength\" : FloatSlider(value=match_color_strength, min=0, max=1, step=0.05, description='match_color_strength:', readout=True, readout_format='.1f', description_tooltip='Enables colormathing raw init video pixls in inconsistent areas only to the stylized frame. May reduce flickering for inconsistent areas.'),\n",
+ " \"mask_result\": Checkbox(value=mask_result,description='mask_result',indent=True, description_tooltip='Stylizes only inconsistent areas. Takes consistent areas from the previous frame.'),\n",
+ " \"use_patchmatch_inpaiting\": FloatSlider(value=use_patchmatch_inpaiting, min=0, max=1, step=0.05, description='use_patchmatch_inpaiting:', readout=True, readout_format='.1f', description_tooltip='Uses patchmatch inapinting for inconsistent areas. Is slow.'),\n",
+ "}\n",
+ "\n",
+ "gui_diffusion = {\n",
+ " \"use_karras_noise\":Checkbox(value=use_karras_noise,description='use_karras_noise',indent=True, description_tooltip='Enable for samplers that have K at their name`s end.'),\n",
+ " \"sampler\": Dropdown(description='sampler',options= [('sample_euler', sample_euler),\n",
+ " ('sample_euler_ancestral',sample_euler_ancestral),\n",
+ " ('sample_heun',sample_heun),\n",
+ " ('sample_dpm_2', sample_dpm_2),\n",
+ " ('sample_dpm_2_ancestral',sample_dpm_2_ancestral),\n",
+ " ('sample_lms', sample_lms),\n",
+ " ('sample_dpm_fast', sample_dpm_fast),\n",
+ " ('sample_dpm_adaptive',sample_dpm_adaptive),\n",
+ " ('sample_dpmpp_2s_ancestral', sample_dpmpp_2s_ancestral),\n",
+ " ('sample_dpmpp_sde', sample_dpmpp_sde),\n",
+ " ('sample_dpmpp_2m', sample_dpmpp_2m)], value = sampler),\n",
+ " \"prompt_patterns_sched\": Textarea(value=str(prompt_patterns_sched),layout=Layout(width=f'80%'), description = 'Replace patterns:'),\n",
+ " \"text_prompts\" : Textarea(value=str(text_prompts),layout=Layout(width=f'80%'), description = 'Prompt:'),\n",
+ " \"negative_prompts\" : Textarea(value=str(negative_prompts), layout=Layout(width=f'80%'), description = 'Negative Prompt:'),\n",
+ " \"cond_image_src\":Dropdown(description='cond_image_src', options = ['init', 'stylized','cond_video'] ,\n",
+ " value = cond_image_src, description_tooltip='Depth map source for depth model. It can either take raw init video frame or previously stylized frame.'),\n",
+ " \"inpainting_mask_source\":Dropdown(description='inpainting_mask_source', options = ['none', 'consistency_mask', 'cond_video'] ,\n",
+ " value = inpainting_mask_source, description_tooltip='Inpainting model mask source. none - full white mask (inpaint whole image), consistency_mask - inpaint inconsistent areas only'),\n",
+ " \"inverse_inpainting_mask\":Checkbox(value=inverse_inpainting_mask,description='inverse_inpainting_mask',indent=True, description_tooltip='Inverse inpainting mask'),\n",
+ " \"inpainting_mask_weight\":FloatSlider(value=inpainting_mask_weight, min=0, max=1, step=0.05, description='inpainting_mask_weight:', readout=True, readout_format='.1f',\n",
+ " description_tooltip= 'Inpainting mask weight. 0 - Disables inpainting mask.'),\n",
+ " \"set_seed\": IntText(value = set_seed, description='set_seed:', description_tooltip='Seed. Use -1 for random.'),\n",
+ " \"clamp_grad\":Checkbox(value=clamp_grad,description='clamp_grad',indent=True, description_tooltip='Enable limiting the effect of external conditioning per diffusion step'),\n",
+ " \"clamp_max\": FloatText(value = clamp_max, description='clamp_max:',description_tooltip='limit the effect of external conditioning per diffusion step'),\n",
+ " \"latent_scale_schedule\":Textarea(value=str(latent_scale_schedule),layout=Layout(width=f'80%'), description = 'latent_scale_schedule:', description_tooltip='Latents scale defines how much minimize difference between output and input stylized image in latent space.'),\n",
+ " \"init_scale_schedule\": Textarea(value=str(init_scale_schedule),layout=Layout(width=f'80%'), description = 'init_scale_schedule:', description_tooltip='Init scale defines how much minimize difference between output and input stylized image in RGB space.'),\n",
+ " \"sat_scale\": FloatText(value = sat_scale, description='sat_scale:', description_tooltip='Saturation scale limits oversaturation.'),\n",
+ " \"init_grad\": Checkbox(value=init_grad,description='init_grad',indent=True, description_tooltip='On - compare output to real frame, Off - to stylized frame'),\n",
+ " \"grad_denoised\" : Checkbox(value=grad_denoised,description='grad_denoised',indent=True, description_tooltip='Fastest, On by default, calculate gradients with respect to denoised image instead of input image per diffusion step.' ),\n",
+ " \"steps_schedule\" : Textarea(value=str(steps_schedule),layout=Layout(width=f'80%'), description = 'steps_schedule:',\n",
+ " description_tooltip= 'Total diffusion steps schedule. Use list format like [50,70], where each element corresponds to a frame, last element being repeated forever, or dictionary like {0:50, 20:70} format to specify keyframes only.'),\n",
+ " \"style_strength_schedule\" : Textarea(value=str(style_strength_schedule),layout=Layout(width=f'80%'), description = 'style_strength_schedule:',\n",
+ " description_tooltip= 'Diffusion (style) strength. Actual number of diffusion steps taken (at 50 steps with 0.3 or 30% style strength you get 15 steps, which also means 35 0r 70% skipped steps). Inverse of skep steps. Use list format like [0.5,0.35], where each element corresponds to a frame, last element being repeated forever, or dictionary like {0:0.5, 20:0.35} format to specify keyframes only.'),\n",
+ " \"cfg_scale_schedule\": Textarea(value=str(cfg_scale_schedule),layout=Layout(width=f'80%'), description = 'cfg_scale_schedule:', description_tooltip= 'Guidance towards text prompt. 7 is a good starting value, 1 is off (text prompt has no effect).'),\n",
+ " \"image_scale_schedule\": Textarea(value=str(image_scale_schedule),layout=Layout(width=f'80%'), description = 'image_scale_schedule:', description_tooltip= 'Only used with InstructPix2Pix Model. Guidance towards text prompt. 1.5 is a good starting value'),\n",
+ " \"blend_latent_to_init\": FloatSlider(value=blend_latent_to_init, min=0, max=1, step=0.05, description='blend_latent_to_init:', readout=True, readout_format='.1f', description_tooltip = 'Blend latent vector with raw init'),\n",
+ " # \"use_karras_noise\": Checkbox(value=False,description='use_karras_noise',indent=True),\n",
+ " # \"end_karras_ramp_early\": Checkbox(value=False,description='end_karras_ramp_early',indent=True),\n",
+ " \"fixed_seed\": Checkbox(value=fixed_seed,description='fixed_seed',indent=True, description_tooltip= 'Fixed seed.'),\n",
+ " \"fixed_code\": Checkbox(value=fixed_code,description='fixed_code',indent=True, description_tooltip= 'Fixed seed analog. Fixes diffusion noise.'),\n",
+ " \"code_randomness\": FloatSlider(value=code_randomness, min=0, max=1, step=0.05, description='code_randomness:', readout=True, readout_format='.1f', description_tooltip= 'Fixed seed amount/effect strength.'),\n",
+ " # \"normalize_code\":Checkbox(value=normalize_code,description='normalize_code',indent=True, description_tooltip= 'Whether to normalize the noise after adding fixed seed.'),\n",
+ " \"dynamic_thresh\": FloatText(value = dynamic_thresh, description='dynamic_thresh:', description_tooltip= 'Limit diffusion model prediction output. Lower values may introduce clamping/feedback effect'),\n",
+ " \"use_predicted_noise\":Checkbox(value=use_predicted_noise,description='use_predicted_noise',indent=True, description_tooltip='Reconstruct initial noise from init / stylized image.'),\n",
+ " \"rec_prompts\" : Textarea(value=str(rec_prompts),layout=Layout(width=f'80%'), description = 'Rec Prompt:'),\n",
+ " \"rec_randomness\": FloatSlider(value=rec_randomness, min=0, max=1, step=0.05, description='rec_randomness:', readout=True, readout_format='.1f', description_tooltip= 'Reconstructed noise randomness. 0 - reconstructed noise only. 1 - random noise.'),\n",
+ " \"rec_cfg\": FloatText(value = rec_cfg, description='rec_cfg:', description_tooltip= 'CFG scale for noise reconstruction. 1-1.9 are the best values.'),\n",
+ " \"rec_source\": Dropdown(description='rec_source', options = ['init', 'stylized'] ,\n",
+ " value = rec_source, description_tooltip='Source for noise reconstruction. Either raw init frame or stylized frame.'),\n",
+ " \"rec_steps_pct\":FloatSlider(value=rec_steps_pct, min=0, max=1, step=0.05, description='rec_steps_pct:', readout=True, readout_format='.2f', description_tooltip= 'Reconstructed noise steps in relation to total steps. 1 = 100% steps.'),\n",
+ " \"overwrite_rec_noise\":Checkbox(value=overwrite_rec_noise,description='overwrite_rec_noise',indent=True,\n",
+ " description_tooltip= 'Overwrite reconstructed noise cache. By default reconstructed noise is not calculated if the settings haven`t changed too much. You can eit prompt, neg prompt, cfg scale, style strength, steps withot reconstructing the noise every time.'),\n",
+ "\n",
+ " \"masked_guidance\":Checkbox(value=masked_guidance,description='masked_guidance',indent=True,\n",
+ " description_tooltip= 'Use mask for init/latent guidance to ignore inconsistencies and only guide based on the consistent areas.'),\n",
+ " \"cc_masked_diffusion_schedule\": Textarea(value=str(cc_masked_diffusion_schedule),layout=Layout(width=f'80%'),\n",
+ " description = 'cc_masked_diffusion', description_tooltip= '0 - off. 0.5-0.7 are good values. Make inconsistent area passes only before this % of actual steps, then diffuse whole image.'),\n",
+ " \"alpha_masked_diffusion\": FloatSlider(value=alpha_masked_diffusion, min=0, max=1, step=0.05,\n",
+ " description='alpha_masked_diffusion:', readout=True, readout_format='.2f', description_tooltip= '0 - off. 0.5-0.7 are good values. Make alpha masked area passes only before this % of actual steps, then diffuse whole image.'),\n",
+ " \"invert_alpha_masked_diffusion\":Checkbox(value=invert_alpha_masked_diffusion,description='invert_alpha_masked_diffusion',indent=True,\n",
+ " description_tooltip= 'invert alpha ask for masked diffusion'),\n",
+ " \"normalize_prompt_weights\":Checkbox(value=normalize_prompt_weights,description='normalize_prompt_weights',indent=True,\n",
+ " description_tooltip='Scale prompt weights to sum up to 1.'),\n",
+ " \"deflicker_scale\": FloatText(value = deflicker_scale, description='deflicker_scale:',\n",
+ " description_tooltip= 'Deflicker loss scale in image pixel space'),\n",
+ " \"deflicker_latent_scale\": FloatText(value = deflicker_latent_scale,\n",
+ " description='deflicker_latent_scale:', description_tooltip= 'Deflicker loss scale in image latent space'),\n",
+ "\n",
+ "}\n",
+ "gui_colormatch = {\n",
+ " \"normalize_latent\": Dropdown(description='normalize_latent',\n",
+ " options = ['off', 'user_defined', 'color_video', 'color_video_offset',\n",
+ " 'stylized_frame', 'init_frame', 'stylized_frame_offset', 'init_frame_offset'], value =normalize_latent ,description_tooltip= 'Normalize latent to prevent it from overflowing. User defined: use fixed input values (latent_fixed_*) Stylized/init frame - match towards stylized/init frame with a fixed number (specified in the offset field below). Stylized\\init frame offset - match to a frame with a number = current frame - offset (specified in the offset filed below).'),\n",
+ " \"normalize_latent_offset\":IntText(value = normalize_latent_offset, description='normalize_latent_offset:', description_tooltip= 'Offset from current frame number for *_frame_offset mode, or fixed frame number for *frame mode.'),\n",
+ " \"latent_fixed_mean\": FloatText(value = latent_fixed_mean, description='latent_fixed_mean:', description_tooltip= 'User defined mean value for normalize_latent=user_Defined mode'),\n",
+ " \"latent_fixed_std\": FloatText(value = latent_fixed_std, description='latent_fixed_std:', description_tooltip= 'User defined standard deviation value for normalize_latent=user_Defined mode'),\n",
+ " \"latent_norm_4d\": Checkbox(value=latent_norm_4d,description='latent_norm_4d',indent=True, description_tooltip= 'Normalize on a per-channel basis (on by default)'),\n",
+ " \"colormatch_frame\": Dropdown(description='colormatch_frame', options = ['off', 'stylized_frame', 'color_video', 'color_video_offset', 'init_frame', 'stylized_frame_offset', 'init_frame_offset'],\n",
+ " value = colormatch_frame,\n",
+ " description_tooltip= 'Match frame colors to prevent it from overflowing. Stylized/init frame - match towards stylized/init frame with a fixed number (specified in the offset filed below). Stylized\\init frame offset - match to a frame with a number = current frame - offset (specified in the offset field below).'),\n",
+ " \"color_match_frame_str\": FloatText(value = color_match_frame_str, description='color_match_frame_str:', description_tooltip= 'Colormatching strength. 0 - no colormatching effect.'),\n",
+ " \"colormatch_offset\":IntText(value =colormatch_offset, description='colormatch_offset:', description_tooltip= 'Offset from current frame number for *_frame_offset mode, or fixed frame number for *frame mode.'),\n",
+ " \"colormatch_method\": Dropdown(description='colormatch_method', options = ['LAB', 'PDF', 'mean'], value =colormatch_method ),\n",
+ " # \"colormatch_regrain\": Checkbox(value=False,description='colormatch_regrain',indent=True),\n",
+ " \"colormatch_after\":Checkbox(value=colormatch_after,description='colormatch_after',indent=True, description_tooltip= 'On - Colormatch output frames when saving to disk, may differ from the preview. Off - colormatch before stylizing.'),\n",
+ "\n",
+ "}\n",
+ "\n",
+ "gui_controlnet = {\n",
+ " \"controlnet_preprocess\": Checkbox(value=controlnet_preprocess,description='controlnet_preprocess',indent=True,\n",
+ " description_tooltip= 'preprocess input conditioning image for controlnet. If false, use raw conditioning as input to the model without detection/preprocessing.'),\n",
+ " \"detect_resolution\":IntText(value = detect_resolution, description='detect_resolution:', description_tooltip= 'Control net conditioning image resolution. The size of the image passed into controlnet preprocessors. Suggest keeping this as high as you can fit into your VRAM for more details.'),\n",
+ " \"bg_threshold\":FloatText(value = bg_threshold, description='bg_threshold:', description_tooltip='Control net depth/normal bg cutoff threshold'),\n",
+ " \"low_threshold\":IntText(value = low_threshold, description='low_threshold:', description_tooltip= 'Control net canny filter parameters'),\n",
+ " \"high_threshold\":IntText(value = high_threshold, description='high_threshold:', description_tooltip= 'Control net canny filter parameters'),\n",
+ " \"value_threshold\":FloatText(value = value_threshold, description='value_threshold:', description_tooltip='Control net mlsd filter parameters'),\n",
+ " \"distance_threshold\":FloatText(value = distance_threshold, description='distance_threshold:', description_tooltip='Control net mlsd filter parameters'),\n",
+ " \"temporalnet_source\":Dropdown(description ='temporalnet_source', options = ['init', 'stylized'] ,\n",
+ " value = temporalnet_source, description_tooltip='Temporalnet guidance source. Previous init or previous stylized frame'),\n",
+ " \"temporalnet_skip_1st_frame\": Checkbox(value = temporalnet_skip_1st_frame,description='temporalnet_skip_1st_frame',indent=True,\n",
+ " description_tooltip='Skip temporalnet for 1st frame (if not skipped, will use raw init for guidance'),\n",
+ " \"controlnet_multimodel_mode\":Dropdown(description='controlnet_multimodel_mode', options = ['internal','external'], value =controlnet_multimodel_mode, description_tooltip='internal - sums controlnet values before feeding those into diffusion model, external - sum outputs of differnet contolnets after passing through diffusion model. external seems slower but smoother.' ),\n",
+ " \"max_faces\":IntText(value = max_faces, description='max_faces:', description_tooltip= 'Max faces to detect. Control net face parameters'),\n",
+ " \"controlnet_low_vram\":Checkbox(value = controlnet_low_vram,description='controlnet_low_vram',indent=True,\n",
+ " description_tooltip='Only load currently used controlnet to gpu. Slow, saves VRAM.'),\n",
+ " \"save_controlnet_annotations\": Checkbox(value = save_controlnet_annotations,description='save_controlnet_annotations',indent=True,\n",
+ " description_tooltip='Save controlnet annotator predictions. They will be saved to your project dir /controlnetDebug folder.'),\n",
+ " \"control_sd15_openpose_hands_face\":Checkbox(value = control_sd15_openpose_hands_face,description='control_sd15_openpose_hands_face',indent=True,\n",
+ " description_tooltip='Enable full openpose mode with hands and facial features.'),\n",
+ " \"control_sd15_depth_detector\" :Dropdown(description='control_sd15_depth_detector', options = ['Zoe','Midas'], value =control_sd15_depth_detector,\n",
+ " description_tooltip='Depth annotator model.' ),\n",
+ " \"pose_detector\" :Dropdown(description='pose_detector', options = ['openpose','dw_pose'], value =pose_detector,\n",
+ " description_tooltip='Pose detector model.' ),\n",
+ " \"control_sd15_softedge_detector\":Dropdown(description='control_sd15_softedge_detector', options = ['HED','PIDI'], value =control_sd15_softedge_detector,\n",
+ " description_tooltip='Softedge annotator model.' ),\n",
+ " \"control_sd15_seg_detector\":Dropdown(description='control_sd15_seg_detector', options = ['Seg_OFCOCO', 'Seg_OFADE20K', 'Seg_UFADE20K'], value =control_sd15_seg_detector,\n",
+ " description_tooltip='Segmentation annotator model.' ),\n",
+ " \"control_sd15_scribble_detector\":Dropdown(description='control_sd15_scribble_detector', options = ['HED','PIDI'], value =control_sd15_scribble_detector,\n",
+ " description_tooltip='Sccribble annotator model.' ),\n",
+ " \"control_sd15_lineart_coarse\":Checkbox(value = control_sd15_lineart_coarse,description='control_sd15_lineart_coarse',indent=True,\n",
+ " description_tooltip='Coarse strokes mode.'),\n",
+ " \"control_sd15_inpaint_mask_source\":Dropdown(description='control_sd15_inpaint_mask_source', options = ['consistency_mask', 'None', 'cond_video'], value =control_sd15_inpaint_mask_source,\n",
+ " description_tooltip='Inpainting controlnet mask source. consistency_mask - inpaints inconsistent areas, None - whole image, cond_video - loads external mask' ),\n",
+ " \"control_sd15_shuffle_source\":Dropdown(description='control_sd15_shuffle_source', options = ['color_video', 'init', 'prev_frame', 'first_frame'], value =control_sd15_shuffle_source,\n",
+ " description_tooltip='Shuffle controlnet source. color_video: uses color video frames (or single image) as source, init - uses current frame`s init as source (stylized+warped with consistency mask and flow_blend opacity), prev_frame - uses previously stylized frame (stylized, not warped), first_frame - first stylized frame' ),\n",
+ " \"control_sd15_shuffle_1st_source\":Dropdown(description='control_sd15_shuffle_1st_source', options = ['color_video', 'init', 'None'], value =control_sd15_shuffle_1st_source,\n",
+ " description_tooltip='Set 1st frame source for shuffle model. If you need to geet the 1st frame style from your image, and for the consecutive frames you want to use the resulting stylized images. color_video: uses color video frames (or single image) as source, init - uses current frame`s init as source (raw video frame), None - skips this controlnet for the 1st frame. For example, if you like the 1st frame you`re getting and want to keep its style, but don`t want to use an external image as a source.'),\n",
+ " \"controlnet_multimodel\":ControlGUI(controlnet_multimodel),\n",
+ " \"controlnet_mode\":Dropdown(description='controlnet_mode',\n",
+ " options = ['balanced', 'controlnet', 'prompt'], value = controlnet_mode,\n",
+ " description_tooltip='Controlnet mode. Pay more attention to controlnet prediction, to prompt or somewhere in-between.'),\n",
+ " \"normalize_cn_weights\":Checkbox(value = normalize_cn_weights,description='normalize_cn_weights',indent=True,\n",
+ " description_tooltip='Normalize controlnet weights to add up to 1. Off = keep raw controlnet weight values.'),\n",
+ "\n",
+ "}\n",
+ "\n",
+ "colormatch_regrain = False\n",
+ "\n",
+ "guis = [gui_diffusion, gui_controlnet, gui_warp, gui_consistency, gui_turbo, gui_mask, gui_colormatch, gui_misc]\n",
+ "\n",
+ "for key in gui_difficulty_dict[gui_difficulty]:\n",
+ " for gui in guis:\n",
+ " set_visibility(key, 'hidden', gui)\n",
+ "\n",
+ "class FilePath(HBox):\n",
+ " def __init__(self, **kwargs):\n",
+ " self.model_path = Text(value='', continuous_update = True,**kwargs)\n",
+ " self.path_checker = Valid(\n",
+ " value=False, layout=Layout(width='2000px')\n",
+ " )\n",
+ "\n",
+ " self.model_path.observe(self.on_change)\n",
+ " super().__init__([self.model_path, self.path_checker])\n",
+ "\n",
+ " def __getattr__(self, attr):\n",
+ " if attr == 'value':\n",
+ " return self.model_path.value\n",
+ " else:\n",
+ " return super.__getattr__(attr)\n",
+ "\n",
+ " def on_change(self, change):\n",
+ " if change['name'] == 'value':\n",
+ " if os.path.exists(change['new']):\n",
+ " self.path_checker.value = True\n",
+ " self.path_checker.description = ''\n",
+ " else:\n",
+ " self.path_checker.value = False\n",
+ " self.path_checker.description = 'The file does not exist. Please specify the correct path.'\n",
+ "\n",
+ "def add_labels_dict(gui):\n",
+ " style = {'description_width': '250px' }\n",
+ " layout = Layout(width='500px')\n",
+ " gui_labels = {}\n",
+ " for key in gui.keys():\n",
+ " gui[key].style = style\n",
+ " # temp = gui[key]\n",
+ " # temp.observe(dump_gui())\n",
+ " # gui[key] = temp\n",
+ " if isinstance(gui[key], ControlGUI):\n",
+ " continue\n",
+ " if not isinstance(gui[key], Textarea) and not isinstance( gui[key],Checkbox ):\n",
+ " # vis = gui[key].layout.visibility\n",
+ " # gui[key].layout = layout\n",
+ " gui[key].layout.width = '500px'\n",
+ " if isinstance( gui[key],Checkbox ):\n",
+ " html_label = HTML(\n",
+ " description=gui[key].description,\n",
+ " description_tooltip=gui[key].description_tooltip, style={'description_width': 'initial' },\n",
+ " layout = Layout(position='relative', left='-25px'))\n",
+ " gui_labels[key] = HBox([gui[key],html_label])\n",
+ " gui_labels[key].layout.visibility = gui[key].layout.visibility\n",
+ " gui[key].description = ''\n",
+ " # gui_labels[key] = gui[key]\n",
+ "\n",
+ " else:\n",
+ "\n",
+ " gui_labels[key] = gui[key]\n",
+ " # gui_labels[key].layout.visibility = gui[key].layout.visibility\n",
+ " # gui_labels[key].observe(print('smth changed', time.time()))\n",
+ "\n",
+ " return gui_labels\n",
+ "\n",
+ "\n",
+ "gui_diffusion_label, gui_controlnet_label, gui_warp_label, gui_consistency_label, gui_turbo_label, gui_mask_label, gui_colormatch_label, gui_misc_label = [add_labels_dict(o) for o in guis]\n",
+ "\n",
+ "cond_keys = ['latent_scale_schedule','init_scale_schedule','clamp_grad',\n",
+ " 'clamp_max','init_grad','grad_denoised','masked_guidance','deflicker_scale','deflicker_latent_scale' ]\n",
+ "conditioning_w = Accordion([VBox([gui_diffusion_label[o] for o in cond_keys])])\n",
+ "conditioning_w.set_title(0, 'External Conditioning...')\n",
+ "\n",
+ "seed_keys = ['set_seed', 'fixed_seed', 'fixed_code', 'code_randomness']\n",
+ "seed_w = Accordion([VBox([gui_diffusion_label[o] for o in seed_keys])])\n",
+ "seed_w.set_title(0, 'Seed...')\n",
+ "\n",
+ "rec_keys = ['use_predicted_noise','rec_prompts','rec_cfg','rec_randomness', 'rec_source', 'rec_steps_pct', 'overwrite_rec_noise']\n",
+ "rec_w = Accordion([VBox([gui_diffusion_label[o] for o in rec_keys])])\n",
+ "rec_w.set_title(0, 'Reconstructed noise...')\n",
+ "\n",
+ "prompt_keys = ['text_prompts', 'negative_prompts', 'prompt_patterns_sched',\n",
+ "'steps_schedule', 'style_strength_schedule',\n",
+ "'cfg_scale_schedule', 'blend_latent_to_init', 'dynamic_thresh',\n",
+ "'cond_image_src', 'cc_masked_diffusion_schedule', 'alpha_masked_diffusion', 'invert_alpha_masked_diffusion', 'normalize_prompt_weights']\n",
+ "if model_version == 'v1_instructpix2pix':\n",
+ " prompt_keys.append('image_scale_schedule')\n",
+ "if model_version == 'v1_inpainting':\n",
+ " prompt_keys+=['inpainting_mask_source', 'inverse_inpainting_mask', 'inpainting_mask_weight']\n",
+ "prompt_keys = [o for o in prompt_keys if o not in seed_keys+cond_keys]\n",
+ "prompt_w = [gui_diffusion_label[o] for o in prompt_keys]\n",
+ "\n",
+ "gui_diffusion_list = [*prompt_w, gui_diffusion_label['sampler'],\n",
+ "gui_diffusion_label['use_karras_noise'], conditioning_w, seed_w, rec_w]\n",
+ "\n",
+ "control_annotator_keys = ['normalize_cn_weights', 'save_controlnet_annotations','bg_threshold','low_threshold','high_threshold','value_threshold',\n",
+ " 'distance_threshold', 'max_faces', 'control_sd15_openpose_hands_face','control_sd15_depth_detector' ,'pose_detector','control_sd15_softedge_detector',\n",
+ "'control_sd15_seg_detector','control_sd15_scribble_detector','control_sd15_lineart_coarse','control_sd15_inpaint_mask_source',\n",
+ "'control_sd15_shuffle_source','control_sd15_shuffle_1st_source', 'temporalnet_source', 'temporalnet_skip_1st_frame']\n",
+ "control_global_keys = ['controlnet_preprocess', 'detect_resolution', 'controlnet_mode']\n",
+ "control_global_w_list = [gui_controlnet_label[o] for o in control_global_keys]\n",
+ "control_global_w_list.append(gui_diffusion_label[\"cond_image_src\"])\n",
+ "control_global_w = Accordion([VBox(control_global_w_list)])\n",
+ "control_global_w.set_title(0, 'Controlnet global settings...')\n",
+ "\n",
+ "control_annotator_w = Accordion([VBox([gui_controlnet_label[o] for o in control_annotator_keys])])\n",
+ "control_annotator_w.set_title(0, 'Controlnet annotator settings...')\n",
+ "controlnet_model_w = Accordion([gui_controlnet['controlnet_multimodel']])\n",
+ "controlnet_model_w.set_title(0, 'Controlnet models settings...')\n",
+ "control_keys = [ 'controlnet_multimodel_mode', 'controlnet_low_vram']\n",
+ "control_w = [gui_controlnet_label[o] for o in control_keys]\n",
+ "gui_control_list = [controlnet_model_w, control_global_w, control_annotator_w, *control_w]\n",
+ "\n",
+ "#misc\n",
+ "misc_keys = [\"user_comment\",\"blend_json_schedules\",\"VERBOSE\",\"offload_model\",'sd_batch_size','do_freeunet','apply_freeu_after_control']\n",
+ "misc_w = [gui_misc_label[o] for o in misc_keys]\n",
+ "\n",
+ "softcap_keys = ['do_softcap','softcap_thresh','softcap_q']\n",
+ "softcap_w = Accordion([VBox([gui_misc_label[o] for o in softcap_keys])])\n",
+ "softcap_w.set_title(0, 'Softcap settings...')\n",
+ "\n",
+ "load_settings_btn = Button(description='Load settings')\n",
+ "def btn_eventhandler(obj):\n",
+ " global guis\n",
+ " guis = load_settings(load_settings_path.value, guis)\n",
+ "load_settings_btn.on_click(btn_eventhandler)\n",
+ "load_settings_path = FilePath(placeholder='Please specify the path to the settings file to load.', description_tooltip='Please specify the path to the settings file to load.')\n",
+ "settings_w = Accordion([VBox([load_settings_path, load_settings_btn])])\n",
+ "settings_w.set_title(0, 'Load settings...')\n",
+ "gui_misc_list = [*misc_w, softcap_w, settings_w]\n",
+ "\n",
+ "guis_labels_source = [gui_diffusion_list]\n",
+ "guis_titles_source = ['diffusion']\n",
+ "if 'control' in model_version:\n",
+ " guis_labels_source += [gui_control_list]\n",
+ " guis_titles_source += ['controlnet']\n",
+ "\n",
+ "guis_labels_source += [gui_warp_label, gui_consistency_label,\n",
+ "gui_turbo_label, gui_mask_label, gui_colormatch_label, gui_misc_list]\n",
+ "guis_titles_source += ['warp', 'consistency', 'turbo', 'mask', 'colormatch', 'misc']\n",
+ "\n",
+ "guis_labels = [VBox([*o.values()]) if isinstance(o, dict) else VBox(o) for o in guis_labels_source]\n",
+ "\n",
+ "app = Tab(guis_labels)\n",
+ "for i,title in enumerate(guis_titles_source):\n",
+ " app.set_title(i, title)\n",
+ "\n",
+ "def get_value(key, obj):\n",
+ " if isinstance(obj, dict):\n",
+ " if key in obj.keys():\n",
+ " return obj[key].value\n",
+ " else:\n",
+ " for o in obj.keys():\n",
+ " res = get_value(key, obj[o])\n",
+ " if res is not None: return res\n",
+ " if isinstance(obj, list):\n",
+ " for o in obj:\n",
+ " res = get_value(key, o)\n",
+ " if res is not None: return res\n",
+ " return None\n",
+ "\n",
+ "def set_value(key, value, obj):\n",
+ " if isinstance(obj, dict):\n",
+ " if key in obj.keys():\n",
+ " obj[key].value = value\n",
+ " else:\n",
+ " for o in obj.keys():\n",
+ " set_value(key, value, obj[o])\n",
+ "\n",
+ " if isinstance(obj, list):\n",
+ " for o in obj:\n",
+ " set_value(key, value, o)\n",
+ "\n",
+ "\n",
+ "\n",
+ "import json\n",
+ "def infer_settings_path(path):\n",
+ " default_settings_path = path\n",
+ " if default_settings_path == '-1':\n",
+ " settings_files = sorted(glob(os.path.join(settings_out, '*.txt')),\n",
+ " key=os.path.getctime)\n",
+ " if len(settings_files)>0:\n",
+ " default_settings_path = settings_files[-1]\n",
+ " else:\n",
+ " print('Skipping load latest run settings: no settings files found.')\n",
+ " return ''\n",
+ " else:\n",
+ " try:\n",
+ " if type(eval(default_settings_path)) == int:\n",
+ " files = sorted(glob(os.path.join(settings_out, '*.txt')))\n",
+ " for f in files:\n",
+ " if f'({default_settings_path})' in f:\n",
+ " default_settings_path = f\n",
+ " except: pass\n",
+ "\n",
+ " path = default_settings_path\n",
+ " return path\n",
+ "\n",
+ "def load_settings(path, guis):\n",
+ " path = infer_settings_path(path)\n",
+ "\n",
+ " # global guis, load_settings_path, output\n",
+ " global output\n",
+ " if not os.path.exists(path):\n",
+ " output.clear_output()\n",
+ " print('Please specify a valid path to a settings file.')\n",
+ " return guis\n",
+ " if path.endswith('png'):\n",
+ " img = PIL.Image.open(path)\n",
+ " exif_data = img._getexif()\n",
+ " settings = json.loads(exif_data[37510])\n",
+ "\n",
+ " else:\n",
+ " print('Loading settings from: ', path)\n",
+ " with open(path, 'rb') as f:\n",
+ " settings = json.load(f)\n",
+ "\n",
+ " for key in settings:\n",
+ " try:\n",
+ " val = settings[key]\n",
+ " if key == 'normalize_latent' and val == 'first_latent':\n",
+ " val = 'init_frame'\n",
+ " settings['normalize_latent_offset'] = 0\n",
+ " if key == 'turbo_frame_skips_steps' and val == None:\n",
+ " val = '100% (don`t diffuse turbo frames, fastest)'\n",
+ " if key == 'seed':\n",
+ " key = 'set_seed'\n",
+ " if key == 'grad_denoised ':\n",
+ " key = 'grad_denoised'\n",
+ " if type(val) in [dict,list]:\n",
+ " if type(val) in [dict]:\n",
+ " temp = {}\n",
+ " for k in val.keys():\n",
+ " temp[int(k)] = val[k]\n",
+ " val = temp\n",
+ " val = json.dumps(val)\n",
+ " if key == 'cc_masked_diffusion':\n",
+ " key = 'cc_masked_diffusion_schedule'\n",
+ " val = f'[{val}]'\n",
+ " if key == 'mask_clip':\n",
+ " val = eval(val)\n",
+ " if key == 'sampler':\n",
+ " val = getattr(K.sampling, val)\n",
+ " if key == 'controlnet_multimodel':\n",
+ " val = val.replace('control_sd15_hed', 'control_sd15_softedge')\n",
+ " val = json.loads(val)\n",
+ " set_value(key, val, guis)\n",
+ " set_value(key, val, guis)\n",
+ " # print(key, val)\n",
+ " set_value(key, val, guis)\n",
+ " # print(get_value(key, guis))\n",
+ " except Exception as e:\n",
+ " print(key), print(settings[key] )\n",
+ " print(e)\n",
+ " # output.clear_output()\n",
+ " print('Successfully loaded settings from ', path )\n",
+ " return guis\n",
+ "\n",
+ "def dump_gui():\n",
+ " print('smth changed', time.time())\n",
+ "\n",
+ "output = Output()\n",
+ "\n",
+ "display.display(app)\n",
+ "if settings_path != '' and load_settings_from_file:\n",
+ " guis = load_settings(settings_path, guis)\n",
+ "\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DiffuseTop"
+ },
+ "source": [
+ "# 4. Diffuse!\n",
+ "if you are having OOM or PIL error here click \"restart and run all\" once."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "DoTheRun"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Do the Run!\n",
+ "#@markdown Preview max size\n",
+ "\n",
+ "cell_name = 'do_the_run'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "only_preview_controlnet = False #@param {'type':'boolean'}\n",
+ "\n",
+ "deflicker_scale = 0. #makes glitches :D\n",
+ "deflicker_latent_scale = 0.\n",
+ "fft_scale = 0.\n",
+ "fft_latent_scale = 0.\n",
+ "\n",
+ "if 'sdxl' in model_version: sd_model.is_sdxl = True\n",
+ "else: sd_model.is_sdxl = False\n",
+ "\n",
+ "try:\n",
+ " sd_model.cpu()\n",
+ " sd_model.model.cpu()\n",
+ " sd_model.cond_stage_model.cpu()\n",
+ " sd_model.first_stage_model.cpu()\n",
+ " if 'control' in model_version:\n",
+ " for key in loaded_controlnets.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ "except: pass\n",
+ "try:\n",
+ " apply_openpose.body_estimation.model.cpu()\n",
+ " apply_openpose.hand_estimation.model.cpu()\n",
+ " apply_openpose.face_estimation.model.cpu()\n",
+ "except: pass\n",
+ "try:\n",
+ " sd_model.model.diffusion_model.cpu()\n",
+ "except: pass\n",
+ "try:\n",
+ " apply_softedge.netNetwork.cpu()\n",
+ "except: pass\n",
+ "try:\n",
+ " apply_normal.netNetwork.cpu()\n",
+ "except: pass\n",
+ "try:\n",
+ " apply_depth.model.cpu()\n",
+ "except: pass\n",
+ "torch.cuda.empty_cache()\n",
+ "gc.collect()\n",
+ "\n",
+ "user_settings = get_settings_from_gui(user_settings_keys, guis)\n",
+ "#assign user_settings back to globals()\n",
+ "for key in user_settings.keys():\n",
+ " globals()[key] = user_settings[key]\n",
+ "\n",
+ "sd_model.low_vram = True if controlnet_low_vram else False\n",
+ "\n",
+ "mask_frames_many = None\n",
+ "if mask_paths != []:\n",
+ " mask_frames_many = []\n",
+ " for i in range(len(mask_paths)) :\n",
+ " mask_path = mask_paths[i]\n",
+ " prefix = f'mask_{i}'\n",
+ " mask_frames_many.append(FrameDataset(mask_path, outdir_prefix=prefix,\n",
+ " videoframes_root=f'{batchFolder}/videoFrames'))\n",
+ "\n",
+ "from glob import glob\n",
+ "controlnet_multimodel_inferred = copy.deepcopy(controlnet_multimodel)\n",
+ "\n",
+ "#set global settings by default\n",
+ "\n",
+ "global_keys = ['global', '', -1, '-1','global_settings']\n",
+ "fileDatasetsByPath = {}\n",
+ "\n",
+ "for key in controlnet_multimodel.keys():\n",
+ " if (not \"preprocess\" in controlnet_multimodel[key].keys()) or controlnet_multimodel[key][\"preprocess\"] in global_keys:\n",
+ " controlnet_multimodel_inferred[key][\"preprocess\"] = controlnet_preprocess\n",
+ "\n",
+ " if (not \"mode\" in controlnet_multimodel[key].keys()) or controlnet_multimodel[key][\"mode\"] in global_keys:\n",
+ " controlnet_multimodel_inferred[key][\"mode\"] = controlnet_mode\n",
+ "\n",
+ " if (not \"detect_resolution\" in controlnet_multimodel[key].keys()) or controlnet_multimodel[key][\"detect_resolution\"] in global_keys:\n",
+ " controlnet_multimodel_inferred[key][\"detect_resolution\"] = detect_resolution\n",
+ "\n",
+ " if (not \"source\" in controlnet_multimodel[key].keys()) or controlnet_multimodel[key][\"source\"] in global_keys:\n",
+ " controlnet_multimodel_inferred[key][\"source\"] = cond_image_src\n",
+ " if controlnet_multimodel_inferred[key][\"source\"] == 'init': controlnet_multimodel_inferred[key][\"source\"] = 'raw_frame'\n",
+ "\n",
+ " if controlnet_multimodel_inferred[key][\"source\"] == 'raw_frame':\n",
+ " #cache file daatsets with same sources\n",
+ " if videoFramesFolder not in fileDatasetsByPath.keys():\n",
+ " fileDatasetsByPath[videoFramesFolder] = FrameDataset(videoFramesFolder, f'{key}_source', '' )\n",
+ " controlnet_multimodel_inferred[key][\"source\"] = fileDatasetsByPath[videoFramesFolder]\n",
+ "\n",
+ " elif controlnet_multimodel_inferred[key][\"source\"] == 'cond_video':\n",
+ " if condVideoFramesFolder not in fileDatasetsByPath.keys():\n",
+ " fileDatasetsByPath[condVideoFramesFolder] = FrameDataset(condVideoFramesFolder, f'{key}_source', '' )\n",
+ " controlnet_multimodel_inferred[key][\"source\"] = fileDatasetsByPath[condVideoFramesFolder]\n",
+ "\n",
+ " elif controlnet_multimodel_inferred[key][\"source\"] == 'color_video':\n",
+ " if colorVideoFramesFolder not in fileDatasetsByPath.keys():\n",
+ " fileDatasetsByPath[colorVideoFramesFolder] = FrameDataset(colorVideoFramesFolder, f'{key}_source', '' )\n",
+ " controlnet_multimodel_inferred[key][\"source\"] = fileDatasetsByPath[colorVideoFramesFolder]\n",
+ "\n",
+ " elif controlnet_multimodel_inferred[key][\"source\"] not in ['raw_frame', 'stylized']:\n",
+ " if controlnet_multimodel_inferred[key][\"source\"] not in fileDatasetsByPath.keys():\n",
+ " fileDatasetsByPath[controlnet_multimodel_inferred[key][\"source\"]] = FrameDataset(controlnet_multimodel_inferred[key][\"source\"], f'{key}_source', '')\n",
+ " controlnet_multimodel_inferred[key][\"source\"] = fileDatasetsByPath[controlnet_multimodel_inferred[key][\"source\"]]\n",
+ "\n",
+ " if controlnet_multimodel_inferred[key][\"mode\"] == 'balanced':\n",
+ " controlnet_multimodel_inferred[key][\"layer_weights\"] = [1]*13\n",
+ " controlnet_multimodel_inferred[key][\"zero_uncond\"] = False\n",
+ " elif controlnet_multimodel_inferred[key][\"mode\"] == 'controlnet':\n",
+ " controlnet_multimodel_inferred[key][\"layer_weights\"] = [(0.825 ** float(12 - i)) for i in range(13)]\n",
+ " controlnet_multimodel_inferred[key][\"zero_uncond\"] = True\n",
+ " elif controlnet_multimodel_inferred[key][\"mode\"] == 'prompt':\n",
+ " controlnet_multimodel_inferred[key][\"layer_weights\"] = [(0.825 ** float(12 - i)) for i in range(13)]\n",
+ " controlnet_multimodel_inferred[key][\"zero_uncond\"] = False\n",
+ "\n",
+ "def get_control_source_images(frame_num, controlnet_multimodel_inferred, stylized_image):\n",
+ " controlnet_sources = {}\n",
+ " for key in controlnet_multimodel_inferred.keys():\n",
+ " control_source = controlnet_multimodel_inferred[key]['source']\n",
+ " if control_source == 'stylized':\n",
+ " controlnet_sources[key] = stylized_image\n",
+ " elif isinstance(control_source, FrameDataset):\n",
+ " controlnet_sources[key] = control_source[frame_num] #for raw, cond, color videos\n",
+ " return controlnet_sources\n",
+ "\n",
+ "image_prompts = {}\n",
+ "controlnet_multimodel_temp = {}\n",
+ "for key in controlnet_multimodel.keys():\n",
+ "\n",
+ " weight = controlnet_multimodel[key][\"weight\"]\n",
+ " if weight !=0 :\n",
+ " controlnet_multimodel_temp[key] = controlnet_multimodel[key]\n",
+ "controlnet_multimodel = controlnet_multimodel_temp\n",
+ "\n",
+ "inverse_mask_order = False\n",
+ "try:\n",
+ " import xformers.ops\n",
+ " xformers_available = True\n",
+ "except:\n",
+ " xformers_available = False\n",
+ "can_use_sdp = hasattr(torch.nn.functional, \"scaled_dot_product_attention\") and callable(getattr(torch.nn.functional, \"scaled_dot_product_attention\")) # not everyone has torch 2.x to use sdp\n",
+ "if can_use_sdp and not xformers_available:\n",
+ " shared.opts.xformers = False\n",
+ " shared.cmd_opts.xformers = False\n",
+ "else:\n",
+ " shared.opts.xformers = True\n",
+ " shared.cmd_opts.xformers = True\n",
+ "\n",
+ "import copy\n",
+ "apply_depth = None;\n",
+ "apply_canny = None; apply_mlsd = None;\n",
+ "apply_hed = None; apply_openpose = None;\n",
+ "apply_seg = None;\n",
+ "#loaded_controlnets = {}\n",
+ "torch.cuda.empty_cache(); gc.collect();\n",
+ "sd_model.control_scales = ([1]*13)\n",
+ "\n",
+ "skip_diffuse_cell = False #@param {'type':'boolean'}\n",
+ "if 'control_multi' in model_version:\n",
+ " try:\n",
+ " sd_model.control_model.cpu()\n",
+ " except: pass\n",
+ " print('Checking downloaded Annotator and ControlNet Models')\n",
+ " for controlnet in controlnet_multimodel.keys():\n",
+ " controlnet_settings = controlnet_multimodel[controlnet]\n",
+ " weight = controlnet_settings[\"weight\"]\n",
+ " if weight!=0 and not skip_diffuse_cell:\n",
+ " small_url = control_model_urls[controlnet]\n",
+ " if controlnet in control_model_filenames.keys():\n",
+ " local_filename = control_model_filenames[controlnet]\n",
+ " else: local_filename = small_url.split('/')[-1]\n",
+ " print(f\"Loading {controlnet} from checkpoint: {local_filename}\")\n",
+ " small_controlnet_model_path = f\"{controlnet_models_dir}/{local_filename}\"\n",
+ " if use_small_controlnet and os.path.exists(model_path) and not os.path.exists(small_controlnet_model_path):\n",
+ " print(f'Model found at {model_path}. Small model not found at {small_controlnet_model_path}.')\n",
+ " if not os.path.exists(small_controlnet_model_path) or force_download:\n",
+ " try:\n",
+ " pathlib.Path(small_controlnet_model_path).unlink()\n",
+ " except: pass\n",
+ " print(f'Downloading small {controlnet} model... ')\n",
+ " wget.download(small_url, small_controlnet_model_path)\n",
+ " print(f'Downloaded small {controlnet} model.')\n",
+ "\n",
+ " print('Loading ControlNet Models')\n",
+ " try:\n",
+ " to_pop = set(loaded_controlnets.keys()).symmetric_difference(set( controlnet_multimodel.keys()))\n",
+ " for key in to_pop:\n",
+ " if key in loaded_controlnets.keys():\n",
+ " loaded_controlnets.pop(key)\n",
+ " except NameError:\n",
+ " loaded_controlnets = {}\n",
+ "\n",
+ " for controlnet in controlnet_multimodel.keys():\n",
+ " controlnet_settings = controlnet_multimodel[controlnet]\n",
+ " weight = controlnet_settings[\"weight\"]\n",
+ " if weight!=0 and not skip_diffuse_cell:\n",
+ " if controlnet in loaded_controlnets.keys():\n",
+ " continue\n",
+ " small_url = control_model_urls[controlnet]\n",
+ " if controlnet in control_model_filenames.keys():\n",
+ " local_filename = control_model_filenames[controlnet]\n",
+ " else: local_filename = small_url.split('/')[-1]\n",
+ " small_controlnet_model_path = f\"{controlnet_models_dir}/{local_filename}\"\n",
+ " if model_version == 'control_multi_sdxl':\n",
+ " from IPython.utils import io\n",
+ " with io.capture_output(stderr=False) as captured:\n",
+ " cn = load_controlnet(small_controlnet_model_path)\n",
+ " if type(cn) == comfy.sd.ControlLora:\n",
+ " cn.pre_run(sd_model.model, lambda a: model_wrap.sigma_to_t(model_wrap.t_to_sigma(torch.tensor(a) * 999.0)))\n",
+ " loaded_controlnets[controlnet] = cn.control_model.cpu().half()\n",
+ " if model_version in ['control_multi', 'control_multi_v2','control_multi_v2_768']:\n",
+ " loaded_controlnets[controlnet] = copy.deepcopy(sd_model.control_model)\n",
+ " if os.path.exists(small_controlnet_model_path):\n",
+ " ckpt = small_controlnet_model_path\n",
+ " print(f\"Loading model from {ckpt}\")\n",
+ " if ckpt.endswith('.safetensors'):\n",
+ " pl_sd = {}\n",
+ " with safe_open(ckpt, framework=\"pt\", device=load_to) as f:\n",
+ " for key in f.keys():\n",
+ " pl_sd[key] = f.get_tensor(key)\n",
+ " else: pl_sd = torch.load(ckpt, map_location=load_to)\n",
+ "\n",
+ " if \"global_step\" in pl_sd:\n",
+ " print(f\"Global Step: {pl_sd['global_step']}\")\n",
+ " if \"state_dict\" in pl_sd:\n",
+ " sd = pl_sd[\"state_dict\"]\n",
+ " else: sd = pl_sd\n",
+ " if \"control_model.input_blocks.0.0.bias\" in sd:\n",
+ " sd = dict([(o.split('control_model.')[-1],sd[o]) for o in sd.keys() if o != 'difference'])\n",
+ " del pl_sd\n",
+ "\n",
+ " gc.collect()\n",
+ " m, u = loaded_controlnets[controlnet].load_state_dict(sd, strict=True)\n",
+ " loaded_controlnets[controlnet].half()\n",
+ " if len(m) > 0 and verbose:\n",
+ " print(\"missing keys:\")\n",
+ " print(m, len(m))\n",
+ " if len(u) > 0 and verbose:\n",
+ " print(\"unexpected keys:\")\n",
+ " print(u, len(u))\n",
+ " else:\n",
+ " print('Small controlnet model not found in path but specified in settings. Please adjust settings or check controlnet path.')\n",
+ " sys.exit(0)\n",
+ "\n",
+ "if not skip_diffuse_cell:\n",
+ "# print('Loading annotators.')\n",
+ " controlnet_keys = controlnet_multimodel.keys() if 'control_multi' in model_version else model_version\n",
+ " depth_cns = set([\"control_sd21_depth\", 'control_sd15_depth','control_sd15_normal',\n",
+ " 'control_sdxl_depth', 'control_sdxl_lora_128_depth', 'control_sdxl_lora_256_depth', \"control_sd15_temporal_depth\"])\n",
+ " if len(depth_cns.intersection(set(controlnet_keys)))>0:\n",
+ " if control_sd15_depth_detector == 'Midas' or \"control_sd15_normal\" in controlnet_keys:\n",
+ " from annotator.midas import MidasDetector\n",
+ " apply_depth = MidasDetector()\n",
+ " print('Loaded MidasDetector')\n",
+ " if control_sd15_depth_detector == 'Zoe':\n",
+ " from annotator.zoe import ZoeDetector\n",
+ " apply_depth = ZoeDetector()\n",
+ " print('Loaded ZoeDetector')\n",
+ "\n",
+ " normalbae_cns = set([\"control_sd15_normalbae\", \"control_sd21_normalbae\"])\n",
+ " if len(normalbae_cns.intersection(set(controlnet_keys)))>0:\n",
+ " from annotator.normalbae import NormalBaeDetector\n",
+ " apply_normal = NormalBaeDetector()\n",
+ " print('Loaded NormalBaeDetector')\n",
+ "\n",
+ " canny_cns = set(['control_sd15_canny','control_sdxl_canny',\n",
+ " 'control_sdxl_lora_128_canny', 'control_sdxl_lora_256_canny'])\n",
+ " if len(canny_cns.intersection(set(controlnet_keys)))>0:\n",
+ " from annotator.canny import CannyDetector\n",
+ " apply_canny = CannyDetector()\n",
+ " print('Loaded CannyDetector')\n",
+ "\n",
+ " softedge_cns = set([\"control_sd21_softedge\", 'control_sd15_softedge', 'control_sdxl_softedge',\n",
+ " 'control_sdxl_lora_128_softedge', 'control_sdxl_lora_256_softedge',\"control_sd15_inpaint_softedge\"])\n",
+ " if len(softedge_cns.intersection(set(controlnet_keys)))>0:\n",
+ " if control_sd15_softedge_detector == 'HED':\n",
+ " from annotator.hed import HEDdetector\n",
+ " apply_softedge = HEDdetector()\n",
+ " print('Loaded HEDdetector')\n",
+ " if control_sd15_softedge_detector == 'PIDI':\n",
+ " from annotator.pidinet import PidiNetDetector\n",
+ " apply_softedge = PidiNetDetector()\n",
+ " print('Loaded PidiNetDetector')\n",
+ " scribble_cns = set(['control_sd15_scribble', \"control_sd21_scribble\"])\n",
+ " if len(scribble_cns.intersection(set(controlnet_keys)))>0:\n",
+ " from annotator.util import nms\n",
+ " if control_sd15_scribble_detector == 'HED':\n",
+ " from annotator.hed import HEDdetector\n",
+ " apply_scribble = HEDdetector()\n",
+ " print('Loaded HEDdetector')\n",
+ " if control_sd15_scribble_detector == 'PIDI':\n",
+ " from annotator.pidinet import PidiNetDetector\n",
+ " apply_scribble = PidiNetDetector()\n",
+ " print('Loaded PidiNetDetector')\n",
+ "\n",
+ " if \"control_sd15_mlsd\" in controlnet_keys:\n",
+ " from annotator.mlsd import MLSDdetector\n",
+ " apply_mlsd = MLSDdetector()\n",
+ " print('Loaded MLSDdetector')\n",
+ "\n",
+ " openpose_cns = set([\"control_sd15_openpose\", \"control_sdxl_openpose\", \"control_sd21_openpose\"])\n",
+ " if len(openpose_cns.intersection(set(controlnet_keys)))>0:\n",
+ " if pose_detector == 'openpose':\n",
+ " from annotator.openpose import OpenposeDetector\n",
+ " apply_openpose = OpenposeDetector()\n",
+ " print('Loaded OpenposeDetector')\n",
+ " elif pose_detector == 'dw_pose':\n",
+ " import gdown\n",
+ " if not os.path.exists(f\"{root_dir}/ControlNet/annotator/ckpts/dw-ll_ucoco_384.onnx\"):\n",
+ " gdown.download(id='12L8E2oAgZy4VACGSK9RaZBZrfgx7VTA2', output=f\"{root_dir}/ControlNet/annotator/ckpts/dw-ll_ucoco_384.onnx\")\n",
+ " if not os.path.exists(f\"{root_dir}/ControlNet/annotator/ckpts/yolox_l.onnx\"):\n",
+ " gdown.download(id='1w9pXC8tT0p9ndMN-CArp1__b2GbzewWI', output=f\"{root_dir}/ControlNet/annotator/ckpts/yolox_l.onnx\")\n",
+ " os.chdir(f\"{root_dir}/ControlNet\")\n",
+ " from annotator.dwpose import DWposeDetector\n",
+ " apply_openpose = DWposeDetector()\n",
+ " print('Loaded DWposeDetector')\n",
+ " os.chdir(root_dir)\n",
+ "\n",
+ " seg_cns = set([\"control_sd15_seg\", \"control_sdxl_seg\", \"control_sd21_seg\"])\n",
+ " if len(seg_cns.intersection(set(controlnet_keys)))>0:\n",
+ " if control_sd15_seg_detector == 'Seg_OFCOCO':\n",
+ " from annotator.oneformer import OneformerCOCODetector\n",
+ " apply_seg = OneformerCOCODetector()\n",
+ " print('Loaded OneformerCOCODetector')\n",
+ " elif control_sd15_seg_detector == 'Seg_OFADE20K':\n",
+ " from annotator.oneformer import OneformerADE20kDetector\n",
+ " apply_seg = OneformerADE20kDetector()\n",
+ " print('Loaded OneformerADE20kDetector')\n",
+ " elif control_sd15_seg_detector == 'Seg_UFADE20K':\n",
+ " from annotator.uniformer import UniformerDetector\n",
+ " apply_seg = UniformerDetector()\n",
+ " print('Loaded UniformerDetector')\n",
+ " if \"control_sd15_shuffle\" in controlnet_keys:\n",
+ " from annotator.shuffle import ContentShuffleDetector\n",
+ " apply_shuffle = ContentShuffleDetector()\n",
+ " print('Loaded ContentShuffleDetector')\n",
+ "\n",
+ " lineart_cns = set([\"control_sd15_lineart\", \"control_sd21_lineart\"])\n",
+ " if len(lineart_cns.intersection(set(controlnet_keys)))>0:\n",
+ " from annotator.lineart import LineartDetector\n",
+ " apply_lineart = LineartDetector()\n",
+ " print('Loaded LineartDetector')\n",
+ " if \"control_sd15_lineart_anime\" in controlnet_keys:\n",
+ " from annotator.lineart_anime import LineartAnimeDetector\n",
+ " apply_lineart_anime = LineartAnimeDetector()\n",
+ " print('Loaded LineartAnimeDetector')\n",
+ "\n",
+ "def deflicker_loss(processed2, processed1, raw1, raw2, criterion1, criterion2):\n",
+ " raw_diff = criterion1(raw2, raw1)\n",
+ " proc_diff = criterion1(processed1, processed2)\n",
+ " return criterion2(raw_diff, proc_diff)\n",
+ "\n",
+ "# unload_network()\n",
+ "sd_model.cuda()\n",
+ "sd_hijack.model_hijack.hijack(sd_model)\n",
+ "sd_hijack.model_hijack.embedding_db.add_embedding_dir(custom_embed_dir)\n",
+ "if 'sdxl' not in model_version:\n",
+ " sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(sd_model, force_reload=True)\n",
+ "\n",
+ "latent_scale_schedule_bkup = copy.copy(latent_scale_schedule)\n",
+ "init_scale_schedule_bkup = copy.copy(init_scale_schedule)\n",
+ "steps_schedule_bkup = copy.copy(steps_schedule)\n",
+ "style_strength_schedule_bkup = copy.copy(style_strength_schedule)\n",
+ "flow_blend_schedule_bkup = copy.copy(flow_blend_schedule)\n",
+ "cfg_scale_schedule_bkup = copy.copy(cfg_scale_schedule)\n",
+ "image_scale_schedule_bkup = copy.copy(image_scale_schedule)\n",
+ "cc_masked_diffusion_schedule_bkup = copy.copy(cc_masked_diffusion_schedule)\n",
+ "\n",
+ "\n",
+ "if make_schedules:\n",
+ " if diff is None and diff_override == []: sys.exit(f'\\nERROR!\\n\\nframes were not anayzed. Please enable analyze_video in the previous cell, run it, and then run this cell again\\n')\n",
+ " if diff_override != []: diff = diff_override\n",
+ "\n",
+ " print('Applied schedules:')\n",
+ " latent_scale_schedule = check_and_adjust_sched(latent_scale_schedule, latent_scale_template, diff, respect_sched)\n",
+ " init_scale_schedule = check_and_adjust_sched(init_scale_schedule, init_scale_template, diff, respect_sched)\n",
+ " steps_schedule = check_and_adjust_sched(steps_schedule, steps_template, diff, respect_sched)\n",
+ " style_strength_schedule = check_and_adjust_sched(style_strength_schedule, style_strength_template, diff, respect_sched)\n",
+ " flow_blend_schedule = check_and_adjust_sched(flow_blend_schedule, flow_blend_template, diff, respect_sched)\n",
+ " cc_masked_diffusion_schedule = check_and_adjust_sched(flow_blend_schedule, cc_masked_template, diff, respect_sched)\n",
+ "\n",
+ " cfg_scale_schedule = check_and_adjust_sched(cfg_scale_schedule, cfg_scale_template, diff, respect_sched)\n",
+ " image_scale_schedule = check_and_adjust_sched(image_scale_schedule, cfg_scale_template, diff, respect_sched)\n",
+ " for sched, name in zip([cc_masked_diffusion_schedule, latent_scale_schedule, init_scale_schedule, steps_schedule, style_strength_schedule, flow_blend_schedule,\n",
+ " cfg_scale_schedule, image_scale_schedule], ['cc_masked_diffusion_schedule','latent_scale_schedule', 'init_scale_schedule', 'steps_schedule', 'style_strength_schedule', 'flow_blend_schedule',\n",
+ " 'cfg_scale_schedule', 'image_scale_schedule']):\n",
+ " if type(sched) == list:\n",
+ " if len(sched)>2:\n",
+ " print(name, ': ', sched[:100])\n",
+ "\n",
+ "use_karras_noise = False\n",
+ "end_karras_ramp_early = False\n",
+ "# use_predicted_noise = False\n",
+ "warp_interp = Image.LANCZOS\n",
+ "start_code_cb = None #variable for cb_code\n",
+ "guidance_start_code = None #variable for guidance code\n",
+ "\n",
+ "display_size = 720 #@param\n",
+ "\n",
+ "image_prompts = {}\n",
+ "sd_model.normalize_weights = normalize_cn_weights\n",
+ "sd_model.low_vram = True if controlnet_low_vram else False\n",
+ "\n",
+ "if turbo_frame_skips_steps == '100% (don`t diffuse turbo frames, fastest)':\n",
+ " turbo_frame_skips_steps = None\n",
+ "else:\n",
+ " turbo_frame_skips_steps = int(turbo_frame_skips_steps.split('%')[0])/100\n",
+ "\n",
+ "disable_cc_for_turbo_frames = False\n",
+ "\n",
+ "colormatch_method_fn = PT.lab_transfer\n",
+ "if colormatch_method == 'PDF':\n",
+ " colormatch_method_fn = PT.pdf_transfer\n",
+ "if colormatch_method == 'mean':\n",
+ " colormatch_method_fn = PT.mean_std_transfer\n",
+ "\n",
+ "turbo_preroll = 1\n",
+ "intermediate_saves = None\n",
+ "intermediates_in_subfolder = True\n",
+ "steps_per_checkpoint = None\n",
+ "\n",
+ "forward_weights_clip = soften_consistency_mask\n",
+ "forward_weights_clip_turbo_step = soften_consistency_mask_for_turbo_frames\n",
+ "inpaint_blend = 0\n",
+ "\n",
+ "if animation_mode == 'Video Input':\n",
+ " max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))\n",
+ "\n",
+ "def split_prompts(prompts):\n",
+ " prompt_series = pd.Series([np.nan for a in range(max_frames)])\n",
+ " for i, prompt in prompts.items():\n",
+ " prompt_series[i] = prompt\n",
+ " # prompt_series = prompt_series.astype(str)\n",
+ " prompt_series = prompt_series.ffill().bfill()\n",
+ " return prompt_series\n",
+ "\n",
+ "key_frames = True\n",
+ "interp_spline = 'Linear'\n",
+ "perlin_init = False\n",
+ "perlin_mode = 'mixed'\n",
+ "\n",
+ "if warp_towards_init != 'off':\n",
+ " if flow_lq:\n",
+ " raft_model = torch.jit.load(f'{root_dir}/WarpFusion/raft/raft_half.jit').eval()\n",
+ " # raft_model = torch.nn.DataParallel(RAFT(args2))\n",
+ " else: raft_model = torch.jit.load(f'{root_dir}/WarpFusion/raft/raft_fp32.jit').eval()\n",
+ "\n",
+ "\n",
+ "def printf(*msg, file=f'{root_dir}/log.txt'):\n",
+ " now = datetime.now()\n",
+ " dt_string = now.strftime(\"%d/%m/%Y %H:%M:%S\")\n",
+ " with open(file, 'a') as f:\n",
+ " msg = f'{dt_string}> {\" \".join([str(o) for o in (msg)])}'\n",
+ " print(msg, file=f)\n",
+ "printf('--------Beginning new run------')\n",
+ "##@markdown `n_batches` ignored with animation modes.\n",
+ "display_rate = 9999999\n",
+ "##@param{type: 'number'}\n",
+ "n_batches = 1\n",
+ "##@param{type: 'number'}\n",
+ "start_code = None\n",
+ "first_latent = None\n",
+ "first_latent_source = 'not set'\n",
+ "os.chdir(root_dir)\n",
+ "n_mean_avg = None\n",
+ "n_std_avg = None\n",
+ "n_smooth = 0.5\n",
+ "#Update Model Settings\n",
+ "timestep_respacing = f'ddim{steps}'\n",
+ "diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
+ "\n",
+ "batch_size = 1\n",
+ "\n",
+ "def move_files(start_num, end_num, old_folder, new_folder):\n",
+ " for i in range(start_num, end_num):\n",
+ " old_file = old_folder + f'/{batch_name}({batchNum})_{i:06}.png'\n",
+ " new_file = new_folder + f'/{batch_name}({batchNum})_{i:06}.png'\n",
+ " os.rename(old_file, new_file)\n",
+ "\n",
+ "noise_upscale_ratio = int(noise_upscale_ratio)\n",
+ "#@markdown ---\n",
+ "#@markdown Frames to run. Leave empty or [0,0] to run all frames.\n",
+ "frame_range = [0,0] #@param\n",
+ "resume_run = False #@param{type: 'boolean'}\n",
+ "run_to_resume = 'latest' #@param{type: 'string'}\n",
+ "resume_from_frame = 'latest' #@param{type: 'string'}\n",
+ "retain_overwritten_frames = False #@param{type: 'boolean'}\n",
+ "if retain_overwritten_frames is True:\n",
+ " retainFolder = f'{batchFolder}/retained'\n",
+ " createPath(retainFolder)\n",
+ "\n",
+ "if animation_mode == 'Video Input':\n",
+ " frames = sorted(glob(in_path+'/*.*'));\n",
+ " if len(frames)==0:\n",
+ " sys.exit(\"ERROR: 0 frames found.\\nPlease check your video input path and rerun the video settings cell.\")\n",
+ " flows = glob(flo_folder+'/*.*')\n",
+ " if (len(flows)==0) and flow_warp:\n",
+ " sys.exit(\"ERROR: 0 flow files found.\\nPlease rerun the flow generation cell.\")\n",
+ "settings_out = batchFolder+f\"/settings\"\n",
+ "if resume_run:\n",
+ " if run_to_resume == 'latest':\n",
+ " try:\n",
+ " batchNum\n",
+ " except:\n",
+ " batchNum = len(glob(f\"{settings_out}/{batch_name}(*)_settings.txt\"))-1\n",
+ " else:\n",
+ " batchNum = int(run_to_resume)\n",
+ " if resume_from_frame == 'latest':\n",
+ " start_frame = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
+ " if animation_mode != 'Video Input' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:\n",
+ " start_frame = start_frame - (start_frame % int(turbo_steps))\n",
+ " else:\n",
+ " start_frame = int(resume_from_frame)+1\n",
+ " if animation_mode != 'Video Input' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:\n",
+ " start_frame = start_frame - (start_frame % int(turbo_steps))\n",
+ " if retain_overwritten_frames is True:\n",
+ " existing_frames = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
+ " frames_to_save = existing_frames - start_frame\n",
+ " print(f'Moving {frames_to_save} frames to the Retained folder')\n",
+ " move_files(start_frame, existing_frames, batchFolder, retainFolder)\n",
+ "else:\n",
+ " start_frame = 0\n",
+ " batchNum = len(glob(settings_out+\"/*.txt\"))\n",
+ " while os.path.isfile(f\"{settings_out}/{batch_name}({batchNum})_settings.txt\") is True or os.path.isfile(f\"{batchFolder}/{batch_name}-{batchNum}_settings.txt\") is True:\n",
+ " batchNum += 1\n",
+ "\n",
+ "print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')\n",
+ "\n",
+ "if set_seed == 'random_seed' or set_seed == -1:\n",
+ " random.seed()\n",
+ " seed = random.randint(0, 2**32)\n",
+ " # print(f'Using seed: {seed}')\n",
+ "else:\n",
+ " seed = int(set_seed)\n",
+ "\n",
+ "new_prompt_loras = {}\n",
+ "prompt_weights = {}\n",
+ "if text_prompts:\n",
+ " _, new_prompt_loras = split_lora_from_prompts(text_prompts)\n",
+ "\n",
+ " print('Inferred loras schedule:\\n', new_prompt_loras)\n",
+ " _, prompt_weights = get_prompt_weights(text_prompts)\n",
+ "\n",
+ " print('---prompt_weights---', prompt_weights, text_prompts)\n",
+ "if new_prompt_loras not in [{}, [], '', None]:# and model_version not in ['sdxl_base', 'sdxl_refiner']:\n",
+ "#inject lora even with empty weights to unload?\n",
+ " inject_network(sd_model)\n",
+ "else:\n",
+ " loaded_networks.clear()\n",
+ "\n",
+ "args = {\n",
+ " 'batchNum': batchNum,\n",
+ " 'prompts_series':text_prompts if text_prompts else None,\n",
+ " 'rec_prompts_series':rec_prompts if rec_prompts else None,\n",
+ " 'neg_prompts_series':negative_prompts if negative_prompts else None,\n",
+ " 'image_prompts_series':image_prompts if image_prompts else None,\n",
+ " 'seed': seed,\n",
+ " 'display_rate':display_rate,\n",
+ " 'n_batches':n_batches if animation_mode == 'None' else 1,\n",
+ " 'batch_size':batch_size,\n",
+ " 'batch_name': batch_name,\n",
+ " 'steps': steps,\n",
+ " 'diffusion_sampling_mode': diffusion_sampling_mode,\n",
+ " 'width_height': width_height,\n",
+ " 'clip_guidance_scale': clip_guidance_scale,\n",
+ " 'tv_scale': tv_scale,\n",
+ " 'range_scale': range_scale,\n",
+ " 'sat_scale': sat_scale,\n",
+ " 'cutn_batches': cutn_batches,\n",
+ " 'init_image': init_image,\n",
+ " 'init_scale': init_scale,\n",
+ " 'skip_steps': skip_steps,\n",
+ " 'side_x': side_x,\n",
+ " 'side_y': side_y,\n",
+ " 'timestep_respacing': timestep_respacing,\n",
+ " 'diffusion_steps': diffusion_steps,\n",
+ " 'animation_mode': animation_mode,\n",
+ " 'video_init_path': video_init_path,\n",
+ " 'extract_nth_frame': extract_nth_frame,\n",
+ " 'video_init_seed_continuity': video_init_seed_continuity,\n",
+ " 'key_frames': key_frames,\n",
+ " 'max_frames': max_frames if animation_mode != \"None\" else 1,\n",
+ " 'interp_spline': interp_spline,\n",
+ " 'start_frame': start_frame,\n",
+ " 'padding_mode': padding_mode,\n",
+ " 'text_prompts': text_prompts,\n",
+ " 'image_prompts': image_prompts,\n",
+ " 'intermediate_saves': intermediate_saves,\n",
+ " 'intermediates_in_subfolder': intermediates_in_subfolder,\n",
+ " 'steps_per_checkpoint': steps_per_checkpoint,\n",
+ " 'perlin_init': perlin_init,\n",
+ " 'perlin_mode': perlin_mode,\n",
+ " 'set_seed': set_seed,\n",
+ " 'clamp_grad': clamp_grad,\n",
+ " 'clamp_max': clamp_max,\n",
+ " 'skip_augs': skip_augs,\n",
+ "}\n",
+ "if frame_range not in [None, [0,0], '', [0], 0]:\n",
+ " args['start_frame'] = frame_range[0]\n",
+ " args['max_frames'] = min(args['max_frames'],frame_range[1])\n",
+ "args = SimpleNamespace(**args)\n",
+ "\n",
+ "import traceback\n",
+ "\n",
+ "gc.collect()\n",
+ "torch.cuda.empty_cache()\n",
+ "try:\n",
+ " if only_preview_controlnet:\n",
+ " if 'control_multi' in model_version:\n",
+ " init_image = glob(videoFramesFolder+'/*.*')[0]\n",
+ " models = list(controlnet_multimodel.keys())\n",
+ " models = [o for o in models if o not in no_preprocess_cn]; print(models)\n",
+ " controlnet_sources = {}\n",
+ " if controlnet_multimodel != {}:\n",
+ " W, H = width_height\n",
+ " controlnet_sources = get_control_source_images(frame_range[0], controlnet_multimodel_inferred, stylized_image=init_image)\n",
+ " controlnet_sources['control_inpainting_mask'] = init_image\n",
+ " controlnet_sources['shuffle_source'] = init_image\n",
+ " controlnet_sources['init_image'] = init_image\n",
+ " controlnet_sources['prev_frame'] = init_image\n",
+ " controlnet_sources['next_frame'] = f'{videoFramesFolder}/{frame_range[0]+1:06}.jpg'\n",
+ " detected_maps, models = get_controlnet_annotations(model_version, W, H, models, controlnet_sources)\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " for m in models:\n",
+ " display.display(fit(PIL.Image.fromarray(detected_maps[m].astype('uint8')), maxsize=display_size))\n",
+ " elif not skip_diffuse_cell:\n",
+ " do_run()\n",
+ "except:\n",
+ " try:\n",
+ " sd_model.cpu()\n",
+ " if 'control' in model_version:\n",
+ " for key in loaded_controlnets.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " except: pass\n",
+ " traceback.print_exc()\n",
+ "\n",
+ "print('n_stats_avg (mean, std): ', n_mean_avg, n_std_avg)\n",
+ "\n",
+ "gc.collect()\n",
+ "torch.cuda.empty_cache()\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CreateVidTop"
+ },
+ "source": [
+ "# 5. Create the video"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "CreateVid"
+ },
+ "outputs": [],
+ "source": [
+ "import PIL\n",
+ "#@title ### **Create video**\n",
+ "#@markdown Video file will save in the same folder as your images.\n",
+ "cell_name = 'create_video'\n",
+ "check_execution(cell_name)\n",
+ "\n",
+ "from tqdm.notebook import trange\n",
+ "skip_video_for_run_all = False #@param {type: 'boolean'}\n",
+ "#@markdown ### **Video masking (post-processing)**\n",
+ "#@markdown Use previously generated background mask during video creation\n",
+ "use_background_mask_video = False #@param {type: 'boolean'}\n",
+ "invert_mask_video = False #@param {type: 'boolean'}\n",
+ "#@markdown Choose background source: image, color, init video.\n",
+ "background_video = \"init_video\" #@param ['image', 'color', 'init_video']\n",
+ "#@markdown Specify the init image path or color depending on your background video source choice.\n",
+ "background_source_video = 'red' #@param {type: 'string'}\n",
+ "blend_mode = \"optical flow\" #@param ['None', 'linear', 'optical flow']\n",
+ "# if (blend_mode == \"optical flow\") & (animation_mode != 'Video Input Legacy'):\n",
+ "#@markdown ### **Video blending (post-processing)**\n",
+ "# print('Please enable Video Input mode and generate optical flow maps to use optical flow blend mode')\n",
+ "blend = 0.5#@param {type: 'number'}\n",
+ "check_consistency = True #@param {type: 'boolean'}\n",
+ "postfix = ''\n",
+ "missed_consistency_weight = 1 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.05'}\n",
+ "overshoot_consistency_weight = 1 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.05'}\n",
+ "edges_consistency_weight = 1 #@param {'type':'slider', 'min':'0', 'max':'1', 'step':'0.05'}\n",
+ "# bitrate = 10 #@param {'type':'slider', 'min':'5', 'max':'28', 'step':'1'}\n",
+ "failed_frames = []\n",
+ "\n",
+ "def try_process_frame(i, func):\n",
+ " global failed_frames\n",
+ " try:\n",
+ " func(i)\n",
+ " except:\n",
+ " print('Error processing frame ', i)\n",
+ "\n",
+ " print('retrying 1 time')\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " try:\n",
+ " func(i)\n",
+ " except Exception as e:\n",
+ " print('Error processing frame ', i, '. Please lower thread number to 1-3.', e)\n",
+ " failed_frames.append(i)\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "if use_background_mask_video:\n",
+ " postfix+='_mask'\n",
+ " if invert_mask_video:\n",
+ " postfix+='_inv'\n",
+ "#@markdown #### Upscale settings\n",
+ "upscale_ratio = \"1\" #@param [1,2,3,4]\n",
+ "upscale_ratio = int(upscale_ratio)\n",
+ "upscale_model = 'realesr-animevideov3' #@param ['RealESRGAN_x4plus', 'RealESRNet_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus', 'realesr-animevideov3', 'realesr-general-x4v3']\n",
+ "\n",
+ "#@markdown #### Multithreading settings\n",
+ "#@markdown Suggested range - from 1 to number of cores on SSD and double number of cores - on HDD. Mostly limited by your drive bandwidth.\n",
+ "#@markdown Results for 500 frames @ 6 cores: 5 threads - 2:38, 10 threads - 0:55, 20 - 0:56, 1: 5:53\n",
+ "threads = 12#@param {type:\"number\"}\n",
+ "threads = max(min(threads, 64),1)\n",
+ "frames = []\n",
+ "if upscale_ratio>1:\n",
+ " try:\n",
+ " for key in loaded_controlnets.keys():\n",
+ " loaded_controlnets[key].cpu()\n",
+ " except: pass\n",
+ " try:\n",
+ " sd_model.model.cpu()\n",
+ " sd_model.cond_stage_model.cpu()\n",
+ " sd_model.cpu()\n",
+ " sd_model.first_stage_model.cpu()\n",
+ " model_wrap.inner_model.cpu()\n",
+ " model_wrap.cpu()\n",
+ " model_wrap_cfg.cpu()\n",
+ " model_wrap_cfg.inner_model.cpu()\n",
+ " except: pass\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " os.makedirs(f'{root_dir}/Real-ESRGAN', exist_ok=True)\n",
+ " os.chdir(f'{root_dir}/Real-ESRGAN')\n",
+ " print(f'Upscaling to x{upscale_ratio} using {upscale_model}')\n",
+ " from realesrgan.archs.srvgg_arch import SRVGGNetCompact\n",
+ " from basicsr.utils.download_util import load_file_from_url\n",
+ " from realesrgan import RealESRGANer\n",
+ " from basicsr.archs.rrdbnet_arch import RRDBNet\n",
+ " os.chdir(root_dir)\n",
+ " # model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')\n",
+ " # netscale = 4\n",
+ " # file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']\n",
+ "\n",
+ "\n",
+ " up_model_name = upscale_model\n",
+ " if up_model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model\n",
+ " up_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)\n",
+ " netscale = 4\n",
+ " file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']\n",
+ " elif up_model_name == 'RealESRNet_x4plus': # x4 RRDBNet model\n",
+ " up_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)\n",
+ " netscale = 4\n",
+ " file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']\n",
+ " elif up_model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks\n",
+ " up_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)\n",
+ " netscale = 4\n",
+ " file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']\n",
+ " elif up_model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model\n",
+ " up_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)\n",
+ " netscale = 2\n",
+ " file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']\n",
+ " elif up_model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)\n",
+ " up_model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')\n",
+ " netscale = 4\n",
+ " file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']\n",
+ " elif up_model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)\n",
+ " up_model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')\n",
+ " netscale = 4\n",
+ " file_url = [\n",
+ " 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',\n",
+ " 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'\n",
+ " ]\n",
+ " upscaler_model_path = os.path.join('weights', up_model_name + '.pth')\n",
+ " if not os.path.isfile(upscaler_model_path):\n",
+ " ROOT_DIR = root_dir\n",
+ " for url in file_url:\n",
+ " # model_path will be updated\n",
+ " upscaler_model_path = load_file_from_url(\n",
+ " url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)\n",
+ "\n",
+ " dni_weight = None\n",
+ "\n",
+ " upsampler = RealESRGANer(\n",
+ " scale=netscale,\n",
+ " model_path=upscaler_model_path,\n",
+ " dni_weight=dni_weight,\n",
+ " model=up_model,\n",
+ " tile=0,\n",
+ " tile_pad=10,\n",
+ " pre_pad=0,\n",
+ " half=True,\n",
+ " device='cuda',\n",
+ " )\n",
+ "\n",
+ "#@markdown ### **Video settings**\n",
+ "use_deflicker = True #@param {'type':'boolean'}\n",
+ "# if platform.system() != 'Linux' and use_deflicker:\n",
+ "# use_deflicker = False\n",
+ "# print('Disabling ffmpeg deflicker filter for windows install, as it is causing a crash.')\n",
+ "if skip_video_for_run_all == True:\n",
+ " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
+ "\n",
+ "else:\n",
+ " # import subprocess in case this cell is run without the above cells\n",
+ " import subprocess\n",
+ " from base64 import b64encode\n",
+ "\n",
+ " from multiprocessing.pool import ThreadPool as Pool\n",
+ "\n",
+ " pool = Pool(threads)\n",
+ "\n",
+ " latest_run = batchNum\n",
+ "\n",
+ " folder = batch_name #@param\n",
+ " run = latest_run#@param\n",
+ " final_frame = 'final_frame'\n",
+ "\n",
+ " #@markdown This is the frame where the video will start\n",
+ " init_frame = 1#@param {type:\"number\"}\n",
+ " #@markdown You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n",
+ "\n",
+ " last_frame = final_frame#@param {type:\"number\"}\n",
+ " #@markdown Export fps. Leave as -1 to get fps from your init video divided by nth frame, and to keep video duration the same.\n",
+ " fps = -1#@param {type:\"number\"}\n",
+ " if fps == -1:\n",
+ " if 'extract_nth_frame' in globals().keys():\n",
+ " if 'detected_fps' in globals().keys():\n",
+ " fps = detected_fps/extract_nth_frame\n",
+ " elif 'video_init_path' in globals().keys():\n",
+ " fps = get_fps(video_init_path)/extract_nth_frame\n",
+ "\n",
+ " assert fps != -1, 'please specify a valid FPS value > 0'\n",
+ " print(f'Using detected fps of {fps}')\n",
+ " output_format = 'h264_mp4' #@param ['h264_mp4','qtrle_mov','prores_mov']\n",
+ "\n",
+ " if last_frame == 'final_frame':\n",
+ " last_frame = len(glob(batchFolder+f\"/{folder}({run})_*.png\"))\n",
+ " print(f'Total frames: {last_frame}')\n",
+ "\n",
+ " video_out = batchFolder+f\"/video\"\n",
+ " os.makedirs(video_out, exist_ok=True)\n",
+ " image_path = f\"{outDirPath}/{folder}/{folder}({run})_%06d.png\"\n",
+ " filepath = f\"{video_out}/{folder}({run})_{'_noblend'}.{output_format.split('_')[-1]}\"\n",
+ "\n",
+ " if upscale_ratio>1:\n",
+ " postfix+=f'_x{upscale_ratio}_{upscale_model}'\n",
+ " if use_deflicker:\n",
+ " postfix+='_dfl'\n",
+ " if (blend_mode == 'optical flow') & (True) :\n",
+ " image_path = f\"{outDirPath}/{folder}/flow/{folder}({run})_%06d.png\"\n",
+ " postfix += '_flow'\n",
+ "\n",
+ " video_out = batchFolder+f\"/video\"\n",
+ " os.makedirs(video_out, exist_ok=True)\n",
+ " filepath = f\"{video_out}/{folder}({run})_{postfix}.{output_format.split('_')[-1]}\"\n",
+ " if last_frame == 'final_frame':\n",
+ " last_frame = len(glob(batchFolder+f\"/flow/{folder}({run})_*.png\"))\n",
+ " flo_out = batchFolder+f\"/flow\"\n",
+ "\n",
+ " os.makedirs(flo_out, exist_ok=True)\n",
+ "\n",
+ " frames_in = sorted(glob(batchFolder+f\"/{folder}({run})_*.png\"))\n",
+ " assert len(frames_in)>1, 'Less than 1 frame found in the specified run, make sure you have specified correct batch name and run number.'\n",
+ "\n",
+ " frame0 = Image.open(frames_in[0])\n",
+ " if use_background_mask_video:\n",
+ " frame0 = apply_mask(frame0, 0, background_video, background_source_video, invert_mask_video)\n",
+ " if upscale_ratio>1:\n",
+ " frame0 = np.array(frame0)[...,::-1]\n",
+ " output, _ = upsampler.enhance(frame0, outscale=upscale_ratio)\n",
+ " frame0 = PIL.Image.fromarray((output)[...,::-1].astype('uint8'))\n",
+ " frame0.save(flo_out+'/'+frames_in[0].replace('\\\\','/').split('/')[-1])\n",
+ "\n",
+ " def process_flow_frame(i):\n",
+ " frame1_path = frames_in[i-1]\n",
+ " frame2_path = frames_in[i]\n",
+ "\n",
+ " frame1 = Image.open(frame1_path)\n",
+ " frame2 = Image.open(frame2_path)\n",
+ " frame1_stem = f\"{(int(frame1_path.split('/')[-1].split('_')[-1][:-4])+1):06}.jpg\"\n",
+ " flo_path = f\"{flo_folder}/{frame1_stem}.npy\"\n",
+ " weights_path = None\n",
+ " if check_consistency:\n",
+ " if reverse_cc_order:\n",
+ " weights_path = f\"{flo_folder}/{frame1_stem}-21_cc.jpg\"\n",
+ " else:\n",
+ " weights_path = f\"{flo_folder}/{frame1_stem}_12-21_cc.jpg\"\n",
+ " tic = time.time()\n",
+ " printf('process_flow_frame warp')\n",
+ " frame = warp(frame1, frame2, flo_path, blend=blend, weights_path=weights_path,\n",
+ " pad_pct=padding_ratio, padding_mode=padding_mode, inpaint_blend=0, video_mode=True)\n",
+ " if use_background_mask_video:\n",
+ " frame = apply_mask(frame, i, background_video, background_source_video, invert_mask_video)\n",
+ " if upscale_ratio>1:\n",
+ " frame = np.array(frame)[...,::-1]\n",
+ " output, _ = upsampler.enhance(frame.clip(0,255), outscale=upscale_ratio)\n",
+ " frame = PIL.Image.fromarray((output)[...,::-1].clip(0,255).astype('uint8'))\n",
+ " frame.save(batchFolder+f\"/flow/{folder}({run})_{i:06}.png\")\n",
+ "\n",
+ " with Pool(threads) as p:\n",
+ " fn = partial(try_process_frame, func=process_flow_frame)\n",
+ " total_frames = range(init_frame, min(len(frames_in), last_frame))\n",
+ " result = list(tqdm(p.imap(fn, total_frames), total=len(total_frames)))\n",
+ "\n",
+ " if blend_mode == 'linear':\n",
+ " image_path = f\"{outDirPath}/{folder}/blend/{folder}({run})_%06d.png\"\n",
+ " postfix += '_blend'\n",
+ "\n",
+ " video_out = batchFolder+f\"/video\"\n",
+ " os.makedirs(video_out, exist_ok=True)\n",
+ " filepath = f\"{video_out}/{folder}({run})_{postfix}.{output_format.split('_')[-1]}\"\n",
+ " if last_frame == 'final_frame':\n",
+ " last_frame = len(glob(batchFolder+f\"/blend/{folder}({run})_*.png\"))\n",
+ " blend_out = batchFolder+f\"/blend\"\n",
+ " os.makedirs(blend_out, exist_ok = True)\n",
+ " frames_in = glob(batchFolder+f\"/{folder}({run})_*.png\")\n",
+ "\n",
+ " frame0 = Image.open(frames_in[0])\n",
+ " if use_background_mask_video:\n",
+ " frame0 = apply_mask(frame0, 0, background_video, background_source_video, invert_mask_video)\n",
+ " if upscale_ratio>1:\n",
+ " frame0 = np.array(frame0)[...,::-1]\n",
+ " output, _ = upsampler.enhance(frame0.clip(0,255), outscale=upscale_ratio)\n",
+ " frame0 = PIL.Image.fromarray((output)[...,::-1].clip(0,255).astype('uint8'))\n",
+ " frame0.save(flo_out+'/'+frames_in[0].replace('\\\\','/').split('/')[-1])\n",
+ "\n",
+ " def process_blend_frame(i):\n",
+ " frame1_path = frames_in[i-1]\n",
+ " frame2_path = frames_in[i]\n",
+ "\n",
+ " frame1 = Image.open(frame1_path)\n",
+ " frame2 = Image.open(frame2_path)\n",
+ " frame = Image.fromarray((np.array(frame1)*(1-blend) + np.array(frame2)*(blend)).round().astype('uint8'))\n",
+ " if use_background_mask_video:\n",
+ " frame = apply_mask(frame, i, background_video, background_source_video, invert_mask_video)\n",
+ " if upscale_ratio>1:\n",
+ " frame = np.array(frame)[...,::-1]\n",
+ " output, _ = upsampler.enhance(frame.clip(0,255), outscale=upscale_ratio)\n",
+ " frame = PIL.Image.fromarray((output)[...,::-1].clip(0,255).astype('uint8'))\n",
+ " frame.save(batchFolder+f\"/blend/{folder}({run})_{i:06}.png\")\n",
+ "\n",
+ " with Pool(threads) as p:\n",
+ " fn = partial(try_process_frame, func=process_blend_frame)\n",
+ " total_frames = range(init_frame, min(len(frames_in), last_frame))\n",
+ " result = list(tqdm(p.imap(fn, total_frames), total=len(total_frames)))\n",
+ " deflicker_str = ''\n",
+ " if output_format == 'h264_mp4':\n",
+ " cmd = [\n",
+ " 'ffmpeg',\n",
+ " '-y',\n",
+ " '-vcodec',\n",
+ " 'png',\n",
+ " '-framerate',\n",
+ " str(fps),\n",
+ " '-start_number',\n",
+ " str(init_frame),\n",
+ " '-i',\n",
+ " image_path,\n",
+ " '-frames:v',\n",
+ " str(last_frame+1),\n",
+ " '-c:v',\n",
+ " 'libx264',\n",
+ " '-pix_fmt',\n",
+ " 'yuv420p'\n",
+ " ]\n",
+ " if output_format == 'qtrle_mov':\n",
+ " cmd = [\n",
+ " 'ffmpeg',\n",
+ " '-y',\n",
+ " '-vcodec',\n",
+ " 'png',\n",
+ " '-r',\n",
+ " str(fps),\n",
+ " '-start_number',\n",
+ " str(init_frame),\n",
+ " '-i',\n",
+ " image_path,\n",
+ " '-frames:v',\n",
+ " str(last_frame+1),\n",
+ " '-c:v',\n",
+ " 'qtrle',\n",
+ " '-vf',\n",
+ " f'fps={fps}'\n",
+ " ]\n",
+ " if output_format == 'prores_mov':\n",
+ " cmd = [\n",
+ " 'ffmpeg',\n",
+ " '-y',\n",
+ " '-vcodec',\n",
+ " 'png',\n",
+ " '-r',\n",
+ " str(fps),\n",
+ " '-start_number',\n",
+ " str(init_frame),\n",
+ " '-i',\n",
+ " image_path,\n",
+ " '-frames:v',\n",
+ " str(last_frame+1),\n",
+ " '-c:v',\n",
+ " 'prores_aw',\n",
+ " '-profile:v',\n",
+ " '2',\n",
+ " '-pix_fmt',\n",
+ " 'yuv422p10',\n",
+ " '-vf',\n",
+ " f'fps={fps}'\n",
+ " ]\n",
+ " if use_deflicker:\n",
+ " cmd+=['-vf','deflicker=mode=pm:size=10']\n",
+ " cmd+=[filepath]\n",
+ " experimental_deflicker = False #@param {'type':'boolean'}\n",
+ "\n",
+ " if upscale_ratio>1:\n",
+ " del up_model, upsampler\n",
+ " gc.collect()\n",
+ " process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
+ " stdout, stderr = process.communicate()\n",
+ " if process.returncode != 0:\n",
+ " print(stderr)\n",
+ " raise RuntimeError(stderr)\n",
+ " else:\n",
+ " print(f\"The video is ready and saved to {filepath}\")\n",
+ " keep_audio = True #@param {'type':'boolean'}\n",
+ " if experimental_deflicker:\n",
+ " f_deflicker = filepath[:-4]+'_deflicker'+filepath[-4:]\n",
+ " cmd_d=['ffmpeg', '-y','-fflags', '+genpts', '-i', filepath, '-fflags', '+genpts', '-i', filepath,\n",
+ " '-filter_complex', \"[0:v]setpts=PTS-STARTPTS[top]; [1:v]setpts=PTS-STARTPTS+.033/TB, format=yuva420p, colorchannelmixer=aa=0.5[bottom]; [top][bottom]overlay=shortest=1\",\n",
+ " f_deflicker]\n",
+ "\n",
+ " if os.path.exists(filepath):\n",
+ " process = subprocess.Popen(cmd_d, cwd=f'{root_dir}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
+ " stdout, stderr = process.communicate()\n",
+ " if process.returncode != 0:\n",
+ " print(stderr)\n",
+ " raise RuntimeError(stderr)\n",
+ " else:\n",
+ " print(f\"The deflickered video is saved to {f_deflicker}\")\n",
+ " else: print('Error deflickering video: either init or output video don`t exist.')\n",
+ " filepath = f_deflicker\n",
+ "\n",
+ " if keep_audio:\n",
+ " f_audio = filepath[:-4]+'_audio'+filepath[-4:]\n",
+ " if os.path.exists(filepath) and os.path.exists(video_init_path):\n",
+ "\n",
+ " cmd_a = ['ffmpeg', '-y', '-i', filepath, '-i', video_init_path, '-map', '0:v', '-map', '1:a', '-c:v', 'copy', '-shortest', f_audio]\n",
+ " process = subprocess.Popen(cmd_a, cwd=f'{root_dir}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
+ " stdout, stderr = process.communicate()\n",
+ " if process.returncode != 0:\n",
+ " print(stderr)\n",
+ " raise RuntimeError(stderr)\n",
+ " else:\n",
+ " print(f\"The video with added audio is saved to {f_audio}\")\n",
+ " else: print('Error adding audio from init video to output video: either init or output video don`t exist.')\n",
+ "\n",
+ "executed_cells[cell_name] = True"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "zEUdU6k2JC84"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Shutdown runtime\n",
+ "#@markdown Useful with the new Colab policy.\\\n",
+ "#@markdown If on, shuts down the runtime after every cell has been run successfully.\n",
+ "\n",
+ "shut_down_after_run_all = False #@param {'type':'boolean'}\n",
+ "if shut_down_after_run_all and is_colab:\n",
+ " from google.colab import runtime\n",
+ " runtime.unassign()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "wqRuPO8zt6wU"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Beep\n",
+ "beep = True #@param {'type':'boolean'}\n",
+ "#@markdown warp and pleasant sountrack by #infinitevibes\n",
+ "sound = 'warp' #@param ['beep', 'warp', 'pleasant']\n",
+ "from IPython.display import Audio\n",
+ "if beep:\n",
+ " if sound == 'warp':\n",
+ " #fm sound params\n",
+ " duration = 2.5\n",
+ " carrier_freq = 110\n",
+ " modulator_freq = 50\n",
+ " modulation_index = 0.5\n",
+ "\n",
+ " t = np.arange(duration * 44100) / 44100\n",
+ "\n",
+ " carrier = np.sin(2 * np.pi * carrier_freq * t)\n",
+ " modulator = np.sin(2 * np.pi * modulator_freq * t)\n",
+ "\n",
+ " #perform fm synthesis\n",
+ " fm_sound = np.sin(2 * np.pi * (carrier_freq + modulation_index * modulator) * t)\n",
+ "\n",
+ " #ding sound params\n",
+ " ding_duration = 3\n",
+ " ding_freq = 1760/2\n",
+ " ding_volume = 0.27\n",
+ "\n",
+ " #generate ding + overtones\n",
+ " ding_sound = ding_volume * np.sin(2 * np.pi * ding_freq * np.arange(ding_duration * 44100) / 44100)\n",
+ " for overtone in range(2, 5):\n",
+ " ding_sound += ding_volume * 0.5/overtone * np.sin(2 * np.pi * overtone * ding_freq * np.arange(ding_duration * 44100) / 44100)\n",
+ "\n",
+ " #release\n",
+ " decay_rate = 2 # (higher values make the sound fade out faster)\n",
+ " ding_sound *= np.exp(-decay_rate * np.arange(ding_duration * 44100) / 44100)\n",
+ "\n",
+ " #reverb\n",
+ " reverb_delay = 0.01 # in seconds\n",
+ " reverb_decay = 0.2 # (lower values result in more reverb)\n",
+ " delay_samples = int(reverb_delay * 44100)\n",
+ " ding_sound[:-delay_samples] += reverb_decay * ding_sound[delay_samples:]\n",
+ "\n",
+ " #concatenate the fm and ding\n",
+ " final_sound = np.concatenate([fm_sound, ding_sound])\n",
+ "\n",
+ " #play\n",
+ " display.display(Audio(final_sound, rate=44100, autoplay=True))\n",
+ " if sound == 'pleasant':\n",
+ " note_duration = 1\n",
+ " overlap = 0.7\n",
+ " sampling_rate = 44100\n",
+ "\n",
+ " note_frequencies = [261.63, 293.66, 329.63, 392.00, 261.63*2, 130.81]\n",
+ "\n",
+ " t = np.arange(note_duration * sampling_rate) / sampling_rate\n",
+ "\n",
+ " arpeggio = np.zeros(int(note_duration * sampling_rate * len(note_frequencies)))\n",
+ " for i, freq in enumerate(note_frequencies):\n",
+ "\n",
+ " note = np.sin(2 * np.pi * freq * t)\n",
+ " crossfade_duration = 0.01\n",
+ " crossfade_samples = int(crossfade_duration * sampling_rate)\n",
+ " note[:crossfade_samples] *= np.linspace(0, 1, crossfade_samples)\n",
+ " note[-crossfade_samples:] *= np.linspace(1, 0, crossfade_samples)\n",
+ "\n",
+ " start = int(i * (1 - overlap) * note_duration * sampling_rate)\n",
+ " arpeggio[start:start+note.size] += note\n",
+ "\n",
+ " display.display(Audio(arpeggio, rate=sampling_rate, autoplay=True))\n",
+ " if sound == 'beep':\n",
+ " if not is_colab:\n",
+ " from IPython.display import Audio\n",
+ "\n",
+ " # Define the beep sound parameters\n",
+ " duration = 1 # Duration of the beep sound in seconds\n",
+ " freq = 440 # Frequency of the beep sound in Hz\n",
+ "\n",
+ " # Generate the beep sound\n",
+ " beep_sound = 0.1 * np.sin(2 * np.pi * freq * np.arange(duration * 44100) / 44100)\n",
+ "\n",
+ " # Play the beep sound\n",
+ " display.display(Audio(beep_sound, rate=44100, autoplay=True))\n",
+ "\n",
+ " if is_colab:\n",
+ " from google.colab import output\n",
+ " output.eval_js('new Audio(\"https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg\").play()')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FCLS9fPYVR6N"
+ },
+ "source": [
+ "# Extras"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lD-D3s3F8iu0"
+ },
+ "source": [
+ "## Compare settings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "OVw02jGI8YDB"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Insert paths to two settings.txt files to compare\n",
+ "\n",
+ "file1 = '0' #@param {'type':'string'}\n",
+ "file2 = '0' #@param {'type':'string'}\n",
+ "\n",
+ "import json\n",
+ "from glob import glob\n",
+ "import os\n",
+ "\n",
+ "changes = []\n",
+ "added = []\n",
+ "removed = []\n",
+ "\n",
+ "def infer_settings_path(path):\n",
+ " default_settings_path = path\n",
+ " if default_settings_path == '-1':\n",
+ " settings_files = sorted(glob(os.path.join(settings_out, '*.txt')),\n",
+ " key=os.path.getctime)\n",
+ " if len(settings_files)>0:\n",
+ " default_settings_path = settings_files[-1]\n",
+ " else:\n",
+ " print('Skipping load latest run settings: no settings files found.')\n",
+ " return ''\n",
+ " else:\n",
+ " try:\n",
+ " if type(eval(default_settings_path)) == int:\n",
+ " files = sorted(glob(os.path.join(settings_out, '*.txt')))\n",
+ " for f in files:\n",
+ " if f'({default_settings_path})' in f:\n",
+ " default_settings_path = f\n",
+ " except: pass\n",
+ "\n",
+ " path = default_settings_path\n",
+ " return path\n",
+ "\n",
+ "file1 = infer_settings_path(file1)\n",
+ "file2 = infer_settings_path(file2)\n",
+ "\n",
+ "if file1 != '' and file2 != '':\n",
+ " import json\n",
+ " with open(file1, 'rb') as f:\n",
+ " f1 = json.load(f)\n",
+ " with open(file2, 'rb') as f:\n",
+ " f2 = json.load(f)\n",
+ " joint_keys = set(list(f1.keys())+list(f2.keys()))\n",
+ " print(f'Comparing\\n{file1.split(\"/\")[-1]}\\n{file2.split(\"/\")[-1]}\\n')\n",
+ " for key in joint_keys:\n",
+ " if key in f1.keys() and key in f2.keys() and f1[key] != f2[key]:\n",
+ " changes.append(f'{key}: {f1[key]} -> {f2[key]}')\n",
+ " # print(f'{key}: {f1[key]} -> {f2[key]}')\n",
+ " if key in f1.keys() and key not in f2.keys():\n",
+ " removed.append(f'{key}: {f1[key]} -> ')\n",
+ " # print(f'{key}: {f1[key]} -> ')\n",
+ " if key not in f1.keys() and key in f2.keys():\n",
+ " added.append(f'{key}: -> {f2[key]}')\n",
+ " # print(f'{key}: -> {f2[key]}')\n",
+ "\n",
+ "print('Changed:\\n')\n",
+ "for o in changes:\n",
+ " print(o)\n",
+ "\n",
+ "print('\\n\\nAdded in file2:\\n')\n",
+ "for o in added:\n",
+ " print(o)\n",
+ "\n",
+ "print('\\n\\nRemoved in file2:\\n')\n",
+ "for o in removed:\n",
+ " print(o)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lcMr-zF_VUcv"
+ },
+ "source": [
+ "## Masking and tracking\n",
+ "\n",
+ "Can be run separately from the rest of the notebook"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "UymeCZ9XVXO2"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Install SAMTrack-CLI\n",
+ "#@markdown originally from https://github.com/z-x-yang/Segment-and-Track-Anything \\\n",
+ "#@markdown Restart the notebook after install.\n",
+ "#https://stackoverflow.com/questions/64261546/how-to-solve-error-microsoft-visual-c-14-0-or-greater-is-required-when-inst\n",
+ "import os, platform\n",
+ "try:\n",
+ " #cd to root if root dir defined\n",
+ " os.chdir(root_dir)\n",
+ "except:\n",
+ " root_dir = os.getcwd()\n",
+ "\n",
+ "!git clone https://github.com/Sxela/Segment-and-Track-Anything-CLI\n",
+ "os.chdir(os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))\n",
+ "\n",
+ "!python -m pip install -e ./sam\n",
+ "if platform.system() == 'Linux':\n",
+ " !python -m pip install -e git+https://github.com/IDEA-Research/GroundingDINO.git@main#egg=GroundingDINO\n",
+ "else:\n",
+ " os.makedirs('./src', exist_ok=True)\n",
+ " !git clone https://github.com/IDEA-Research/GroundingDINO \"{os.path.join(root_dir,'Segment-and-Track-Anything-CLI')}/src/GroundingDINO\"\n",
+ " !python -m pip install -r \"{os.path.join(root_dir,'Segment-and-Track-Anything-CLI')}/src/GroundingDINO/requirements.txt\"\n",
+ "!python -m pip install numpy opencv-python pycocotools matplotlib Pillow scikit-image\n",
+ "!python -m pip install gdown\n",
+ "\n",
+ "!git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git\n",
+ "if platform.system() == 'Linux':\n",
+ " !python -m pip install -e ./Pytorch-Correlation-extension\n",
+ "else:\n",
+ " !python -m pip install -r ./Pytorch-Correlation-extension/requirements.txt\n",
+ "\n",
+ "os.chdir(os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))\n",
+ "os.makedirs(os.path.join(root_dir,'Segment-and-Track-Anything-CLI', 'ckpt'), exist_ok=True)\n",
+ "\n",
+ "import gdown\n",
+ "# download aot-ckpt\n",
+ "if not os.path.exists('./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth'):\n",
+ " gdown.download(id='1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ', output='./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth')\n",
+ "\n",
+ "import wget\n",
+ "# download sam-ckpt\n",
+ "if not os.path.exists('./ckpt/sam_vit_b_01ec64.pth'):\n",
+ " wget.download(\"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\",\n",
+ " \"ckpt/\")\n",
+ "\n",
+ "if not os.path.exists('./ckpt/groundingdino_swint_ogc.pth'):\n",
+ "# download grounding-dino ckpt\n",
+ " wget.download(\"https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth\",\n",
+ " \"ckpt/\")\n",
+ "\n",
+ "import wget\n",
+ "import zipfile\n",
+ "\n",
+ "if platform.system() != 'Linux':\n",
+ " #download prebuilt binaries for cuda 11.8, torch 2, python 3.10, win11\n",
+ " if not os.path.exists('./site-packages.zip'):\n",
+ " wget.download(\"https://raw.githubusercontent.com/Sxela/Segment-and-Track-Anything-CLI/main/site-packages.zip\",\n",
+ " \"./site-packages.zip\")\n",
+ "\n",
+ " with zipfile.ZipFile(\"site-packages.zip\", 'r') as zip_ref:\n",
+ " zip_ref.extractall(f'{root_dir}/env/Lib/')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "S5zVfW5damBg"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Detection setup\n",
+ "#@markdown Use this cell to tweak detection settings, that will be later used on the whole video.\n",
+ "#@markdown Run this cell to get detection preview.\\\n",
+ "#@markdown Code mostly taken from https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/demo_instseg.ipynb\n",
+ "import os, pathlib, shutil, sys, subprocess\n",
+ "from glob import glob\n",
+ "try:\n",
+ " #cd to root if root dir defined\n",
+ " os.chdir(root_dir)\n",
+ "except:\n",
+ " root_dir = os.getcwd()\n",
+ "\n",
+ "os.chdir(os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))\n",
+ "\n",
+ "#(c) Alex Spirin 2023\n",
+ "\n",
+ "import hashlib\n",
+ "# We use input file hashes to automate video extraction\n",
+ "#\n",
+ "def generate_file_hash(input_file):\n",
+ " # Get file name and metadata\n",
+ " file_name = os.path.basename(input_file)\n",
+ " file_size = os.path.getsize(input_file)\n",
+ " creation_time = os.path.getctime(input_file)\n",
+ "\n",
+ " # Generate hash\n",
+ " hasher = hashlib.sha256()\n",
+ " hasher.update(file_name.encode('utf-8'))\n",
+ " hasher.update(str(file_size).encode('utf-8'))\n",
+ " hasher.update(str(creation_time).encode('utf-8'))\n",
+ " file_hash = hasher.hexdigest()\n",
+ "\n",
+ " return file_hash\n",
+ "\n",
+ "def createPath(filepath):\n",
+ " os.makedirs(filepath, exist_ok=True)\n",
+ "\n",
+ "\n",
+ "def extractFrames(video_path, output_path, nth_frame, start_frame, end_frame):\n",
+ " createPath(output_path)\n",
+ " print(f\"Exporting Video Frames (1 every {nth_frame})...\")\n",
+ " try:\n",
+ " for f in [o.replace('\\\\','/') for o in glob(output_path+'/*.jpg')]:\n",
+ " # for f in pathlib.Path(f'{output_path}').glob('*.jpg'):\n",
+ " pathlib.Path(f).unlink()\n",
+ " except:\n",
+ " print('error deleting frame ', f)\n",
+ " # vf = f'select=not(mod(n\\\\,{nth_frame}))'\n",
+ " vf = f'select=between(n\\\\,{start_frame}\\\\,{end_frame}) , select=not(mod(n\\\\,{nth_frame}))'\n",
+ " if os.path.exists(video_path):\n",
+ " try:\n",
+ " # subprocess.run(['ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ "\n",
+ " subprocess.run(['ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " except:\n",
+ " subprocess.run([f'{root_dir}/ffmpeg.exe', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ "\n",
+ " else:\n",
+ " sys.exit(f'\\nERROR!\\n\\nVideo not found: {video_path}.\\nPlease check your video path.\\n')\n",
+ "\n",
+ "\n",
+ "class FrameDataset():\n",
+ " def __init__(self, source_path, outdir_prefix, videoframes_root):\n",
+ " self.frame_paths = None\n",
+ " image_extenstions = ['jpeg', 'jpg', 'png', 'tiff', 'bmp', 'webp']\n",
+ "\n",
+ " if not os.path.exists(source_path):\n",
+ " if len(glob(source_path))>0:\n",
+ " self.frame_paths = sorted(glob(source_path))\n",
+ " else:\n",
+ " raise Exception(f'Frame source for {outdir_prefix} not found at {source_path}\\nPlease specify an existing source path.')\n",
+ " if os.path.exists(source_path):\n",
+ " if os.path.isfile(source_path):\n",
+ " if os.path.splitext(source_path)[1][1:].lower() in image_extenstions:\n",
+ " self.frame_paths = [source_path]\n",
+ " hash = generate_file_hash(source_path)[:10]\n",
+ " out_path = os.path.join(videoframes_root, outdir_prefix+'_'+hash)\n",
+ "\n",
+ " extractFrames(source_path, out_path,\n",
+ " nth_frame=1, start_frame=0, end_frame=999999999)\n",
+ " self.frame_paths = glob(os.path.join(out_path, '*.*'))\n",
+ " if len(self.frame_paths)<1:\n",
+ " raise Exception(f'Couldn`t extract frames from {source_path}\\nPlease specify an existing source path.')\n",
+ " elif os.path.isdir(source_path):\n",
+ " self.frame_paths = glob(os.path.join(source_path, '*.*'))\n",
+ " if len(self.frame_paths)<1:\n",
+ " raise Exception(f'Found 0 frames in {source_path}\\nPlease specify an existing source path.')\n",
+ " extensions = []\n",
+ " if self.frame_paths is not None:\n",
+ " for f in self.frame_paths:\n",
+ " ext = os.path.splitext(f)[1][1:]\n",
+ " if ext not in image_extenstions:\n",
+ " raise Exception(f'Found non-image file extension: {ext} in {source_path}. Please provide a folder with image files of the same extension, or specify a glob pattern.')\n",
+ " if not ext in extensions:\n",
+ " extensions+=[ext]\n",
+ " if len(extensions)>1:\n",
+ " raise Exception(f'Found multiple file extensions: {extensions} in {source_path}. Please provide a folder with image files of the same extension, or specify a glob pattern.')\n",
+ "\n",
+ " self.frame_paths = sorted(self.frame_paths)\n",
+ "\n",
+ " else: raise Exception(f'Frame source for {outdir_prefix} not found at {source_path}\\nPlease specify an existing source path.')\n",
+ " print(f'Found {len(self.frame_paths)} frames at {source_path}')\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " idx = min(idx, len(self.frame_paths)-1)\n",
+ " return self.frame_paths[idx]\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.frame_paths)\n",
+ "\n",
+ "# mostly taken from https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/demo_instseg.ipynb\n",
+ "\n",
+ "import os\n",
+ "import cv2\n",
+ "from SegTracker import SegTracker\n",
+ "from model_args import aot_args,sam_args,segtracker_args\n",
+ "from PIL import Image\n",
+ "from aot_tracker import _palette\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import imageio\n",
+ "import matplotlib.pyplot as plt\n",
+ "from scipy.ndimage import binary_dilation\n",
+ "\n",
+ "import gc\n",
+ "def save_prediction(pred_mask,output_dir,file_name):\n",
+ " save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n",
+ " save_mask = save_mask.convert(mode='P')\n",
+ " save_mask.putpalette(_palette)\n",
+ " save_mask.save(os.path.join(output_dir,file_name))\n",
+ "def colorize_mask(pred_mask):\n",
+ " save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n",
+ " save_mask = save_mask.convert(mode='P')\n",
+ " save_mask.putpalette(_palette)\n",
+ " save_mask = save_mask.convert(mode='RGB')\n",
+ " return np.array(save_mask)\n",
+ "def draw_mask(img, mask, alpha=0.7, id_countour=False):\n",
+ " img_mask = np.zeros_like(img)\n",
+ " img_mask = img\n",
+ " if id_countour:\n",
+ " # very slow ~ 1s per image\n",
+ " obj_ids = np.unique(mask)\n",
+ " obj_ids = obj_ids[obj_ids!=0]\n",
+ "\n",
+ " for id in obj_ids:\n",
+ " # Overlay color on binary mask\n",
+ " if id <= 255:\n",
+ " color = _palette[id*3:id*3+3]\n",
+ " else:\n",
+ " color = [0,0,0]\n",
+ " foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)\n",
+ " binary_mask = (mask == id)\n",
+ "\n",
+ " # Compose image\n",
+ " img_mask[binary_mask] = foreground[binary_mask]\n",
+ "\n",
+ " countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n",
+ " img_mask[countours, :] = 0\n",
+ " else:\n",
+ " binary_mask = (mask!=0)\n",
+ " countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n",
+ " foreground = img*(1-alpha)+colorize_mask(mask)*alpha\n",
+ " img_mask[binary_mask] = foreground[binary_mask]\n",
+ " img_mask[countours,:] = 0\n",
+ "\n",
+ " return img_mask.astype(img.dtype)\n",
+ "\n",
+ "video_path = 'C:\\\\code\\\\warp\\\\inits\\\\y2mate.com - Jennifer Connelly HOT 90s GIRLS_1080p.mp4' #@param {'type':'string'}\n",
+ "video_name = video_path.replace('\\\\','/').split('/')[-1]\n",
+ "io_args = {\n",
+ " 'input_video': video_path,\n",
+ " 'output_mask_dir': f'./assets/{video_name}_masks', # save pred masks\n",
+ " 'output_video': f'./assets/{video_name}_seg.mp4', # mask+frame vizualization, mp4 or avi, else the same as input video\n",
+ " 'output_gif': f'./assets/{video_name}_seg.gif', # mask visualization\n",
+ "}\n",
+ "prefix = ''\n",
+ "try:\n",
+ " videoframes_root = f'{batchFolder}/videoFrames'\n",
+ "except:\n",
+ " videoframes_root = f'{root_dir}/videoFrames'\n",
+ "\n",
+ "frames = FrameDataset(video_path, outdir_prefix=prefix, videoframes_root=videoframes_root)\n",
+ "\n",
+ "# choose good parameters in sam_args based on the first frame segmentation result\n",
+ "# other arguments can be modified in model_args.py\n",
+ "# note the object number limit is 255 by default, which requires < 10GB GPU memory with amp\n",
+ "sam_args['generator_args'] = {\n",
+ " 'points_per_side': 60,\n",
+ " 'pred_iou_thresh': 0.8,\n",
+ " 'stability_score_thresh': 0.9,\n",
+ " 'crop_n_layers': 1,\n",
+ " 'crop_n_points_downscale_factor': 2,\n",
+ " 'min_mask_region_area': 200,\n",
+ " }\n",
+ "\n",
+ "# Set Text args\n",
+ "'''\n",
+ "parameter:\n",
+ " grounding_caption: Text prompt to detect objects in key-frames\n",
+ " box_threshold: threshold for box\n",
+ " text_threshold: threshold for label(text)\n",
+ " box_size_threshold: If the size ratio between the box and the frame is larger than the box_size_threshold, the box will be ignored. This is used to filter out large boxes.\n",
+ " reset_image: reset the image embeddings for SAM\n",
+ "'''\n",
+ "frame_number = 0 #@param {'type':'number'}\n",
+ "frame_number = int(frame_number)\n",
+ "#@markdown Text prompt to detect objects in key-frames\n",
+ "grounding_caption = \"person\" #@param {'type':'string'}\n",
+ "#@markdown Box detection confidence threshold\n",
+ "box_threshold = 0.3 #@param {'type':'number'}\n",
+ "#@markdown Text confidence threshold\n",
+ "text_threshold = 0.3 #@param {'type':'number'}\n",
+ "#@markdown Box to Image ratio threshold (with box_size_threshold = 0.8 detections over 80% of the image will be ignored)\n",
+ "box_size_threshold = 1 #@param {'type':'number'}\n",
+ "\n",
+ "reset_image = True\n",
+ "\n",
+ "frame_idx = 0\n",
+ "segtracker = SegTracker(segtracker_args,sam_args,aot_args)\n",
+ "segtracker.restart_tracker()\n",
+ "\n",
+ "with torch.cuda.amp.autocast():\n",
+ " frame = cv2.imread(frames[frame_number])\n",
+ " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n",
+ " pred_mask, annotated_frame = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold,\n",
+ " box_size_threshold, reset_image=reset_image)\n",
+ " torch.cuda.empty_cache()\n",
+ " obj_ids = np.unique(pred_mask)\n",
+ " obj_ids = obj_ids[obj_ids!=0]\n",
+ " print(\"processed frame {}, obj_num {}\".format(frame_idx,len(obj_ids)),end='\\n')\n",
+ " init_res = draw_mask(annotated_frame, pred_mask,id_countour=False)\n",
+ " plt.figure(figsize=(10,10))\n",
+ " plt.axis('off')\n",
+ " plt.imshow(init_res)\n",
+ " plt.show()\n",
+ " plt.figure(figsize=(10,10))\n",
+ " plt.axis('off')\n",
+ " plt.imshow(colorize_mask(pred_mask))\n",
+ " plt.show()\n",
+ "\n",
+ " del segtracker\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "aeU9iVz3mIHo"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Mask whole video.\n",
+ "use_cli = False #@param {'type':'boolean'}\n",
+ "import subprocess\n",
+ "start_frame = 0 #@param {'type':'number'}\n",
+ "end_frame = 10 #@param {'type':'number'}\n",
+ "#@markdown The interval to run SAM to segment new objects\n",
+ "sam_gap = 50 #@param {'type':'number'}\n",
+ "#@markdown minimal mask area to add a new mask as a new object\n",
+ "min_area = 200 #@param {'type':'number'}\n",
+ "#@markdown maximal object number to track in a video\n",
+ "max_obj_num = 255 #@param {'type':'number'}\n",
+ "#@markdown the area of a new object in the background should > 80%\n",
+ "min_new_obj_iou = 0.8 #@param {'type':'number'}\n",
+ "save_separate_masks = True\n",
+ "save_joint_mask = False #@param {'type':'boolean'}\n",
+ "save_mask = save_joint_mask\n",
+ "save_video = False #@param {'type':'boolean'}\n",
+ "save_gif = False #@param {'type':'boolean'}\n",
+ "# grounding_caption\n",
+ "# box_threshold\n",
+ "# text_threshold\n",
+ "# box_size_threshold\n",
+ "# video_path\n",
+ "output_multimask_dir = os.path.join(videoframes_root, f'{generate_file_hash(video_path)[:10]}_masks')\n",
+ "if use_cli:\n",
+ " def run_command(cmd, cwd='./'):\n",
+ " with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True) as p:\n",
+ " while True:\n",
+ " line = p.stdout.readline()\n",
+ " if not line:\n",
+ " break\n",
+ " print(line)\n",
+ " exit_code = p.poll()\n",
+ " return exit_code\n",
+ "\n",
+ " # !python /content/Segment-and-Track-Anything/run.py\\\n",
+ " # --video_path /content/SaveInsta.App_-_3067564057762969265_1317509610.mp4\\\n",
+ " # --save_separate_masks --outdir /content/out/\n",
+ "\n",
+ "\n",
+ " cmd = ['python', 'run.py','--video_path', video_path, '--save_separate_masks', '--outdir', output_multimask_dir,\n",
+ " '--caption', grounding_caption, '--box_threshold', box_threshold, '--text_threshold', text_threshold, '--box_size_threshold', box_size_threshold,\n",
+ " '--sam_gap', sam_gap, '--min_area', min_area, '--max_obj_num', max_obj_num, '--min_new_obj_iou',min_new_obj_iou]\n",
+ " cmd = [str(o) for o in cmd]\n",
+ " returncode = run_command(cmd, cwd=os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))\n",
+ " if process.returncode != 0:\n",
+ " raise RuntimeError(returncode)\n",
+ " else:\n",
+ " print(f\"The video is ready and saved to {output_multimask_dir}\")\n",
+ "else:\n",
+ " os.makedirs('./debug/seg_result', exist_ok=True)\n",
+ " os.makedirs('./debug/aot_result', exist_ok=True)\n",
+ " segtracker_args = {\n",
+ " 'sam_gap': sam_gap,\n",
+ " 'min_area': min_area,\n",
+ " 'max_obj_num': max_obj_num,\n",
+ " 'min_new_obj_iou': min_new_obj_iou\n",
+ " }\n",
+ "\n",
+ " if save_mask:\n",
+ " output_dir = io_args['output_mask_dir']\n",
+ " os.makedirs(output_dir, exist_ok=True)\n",
+ " pred_list = []\n",
+ " masked_pred_list = []\n",
+ "\n",
+ " segtracker = SegTracker(segtracker_args, sam_args, aot_args)\n",
+ " segtracker.restart_tracker()\n",
+ " from tqdm.notebook import tqdm, trange\n",
+ " if start_frame == 0 and end_frame == 0:\n",
+ " frame_range = trange(len(frames))\n",
+ " else:\n",
+ " frame_range = trange(start_frame, end_frame+1)\n",
+ " for frame_idx in frame_range:\n",
+ " frame = cv2.imread(frames[frame_idx])\n",
+ " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n",
+ " if frame_idx == start_frame:\n",
+ " pred_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold, reset_image)\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " segtracker.add_reference(frame, pred_mask)\n",
+ " elif ((frame_idx-start_frame) % sam_gap) == 0:\n",
+ " seg_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold,\n",
+ " box_size_threshold, reset_image)\n",
+ " # save_prediction(seg_mask, './debug/seg_result', str(frame_idx)+'.png')\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ " track_mask = segtracker.track(frame)\n",
+ " # save_prediction(track_mask, './debug/aot_result', str(frame_idx)+'.png')\n",
+ "\n",
+ " # find new objects, and update tracker with new objects\n",
+ " new_obj_mask = segtracker.find_new_objs(track_mask, seg_mask)\n",
+ " if np.sum(new_obj_mask > 0) > frame.shape[0] * frame.shape[1] * 0.4:\n",
+ " new_obj_mask = np.zeros_like(new_obj_mask)\n",
+ " if save_mask: save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')\n",
+ " pred_mask = track_mask + new_obj_mask\n",
+ " segtracker.add_reference(frame, pred_mask)\n",
+ " else:\n",
+ " pred_mask = segtracker.track(frame,update_memory=True)\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ "\n",
+ " if save_mask: save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')\n",
+ "\n",
+ " pred_list.append(pred_mask)\n",
+ "\n",
+ " print(\"processed frame {}, obj_num {}\".format(frame_idx,segtracker.get_obj_num()),end='\\r')\n",
+ "\n",
+ "\n",
+ " if save_video:\n",
+ " # draw pred mask on frame and save as a video\n",
+ " cap = cv2.VideoCapture(io_args['input_video'])\n",
+ " fps = cap.get(cv2.CAP_PROP_FPS)\n",
+ " width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
+ " height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
+ " num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
+ "\n",
+ " if io_args['input_video'][-3:]=='mp4':\n",
+ " fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n",
+ " elif io_args['input_video'][-3:] == 'avi':\n",
+ " fourcc = cv2.VideoWriter_fourcc(*\"MJPG\")\n",
+ " # fourcc = cv2.VideoWriter_fourcc(*\"XVID\")\n",
+ " else:\n",
+ " fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))\n",
+ " out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))\n",
+ "\n",
+ " frame_idx = 0\n",
+ "\n",
+ " progress_bar = tqdm(total=num_frames)\n",
+ " progress_bar.set_description(\"Processing frames...\")\n",
+ "\n",
+ " while cap.isOpened():\n",
+ " ret, frame = cap.read()\n",
+ " if not ret:\n",
+ " break\n",
+ " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n",
+ " try:\n",
+ " pred_mask = pred_list[frame_idx]\n",
+ " except: break\n",
+ " masked_frame = draw_mask(frame,pred_mask)\n",
+ " # masked_frame = masked_pred_list[frame_idx]\n",
+ " masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)\n",
+ " out.write(masked_frame)\n",
+ " print('frame {} writed'.format(frame_idx),end='\\r')\n",
+ " frame_idx += 1\n",
+ " progress_bar.update(1)\n",
+ " out.release()\n",
+ " cap.release()\n",
+ " print(\"\\n{} saved\".format(io_args['output_video']))\n",
+ " print('\\nfinished')\n",
+ "\n",
+ " if save_gif:\n",
+ " # save colorized masks as a gif\n",
+ " imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)\n",
+ " print(\"{} saved\".format(io_args['output_gif']))\n",
+ "\n",
+ " from multiprocessing.pool import ThreadPool as Pool\n",
+ " from functools import partial\n",
+ " import PIL\n",
+ "\n",
+ " threads = 12\n",
+ "\n",
+ " def write_masks_frame(frame_num, predicted_masks, output_folder, max_ids=255):\n",
+ " predicted_masks_frame = predicted_masks[frame_num]\n",
+ " for i in range(max_ids+1):\n",
+ " img_out = PIL.Image.fromarray(((predicted_masks_frame==i)*255).astype('uint8'))\n",
+ " img_out.save(os.path.join(output_folder, f'mask{i:03}', f'alpha_{frame_num:06}.jpg'))\n",
+ "\n",
+ " def write_masks_frame_multi(predicted_masks, output_folder, max_ids):\n",
+ " for i in range(max_ids+1):\n",
+ " os.makedirs(os.path.join(output_folder, f'mask{i:03}'), exist_ok=True)\n",
+ "\n",
+ " with Pool(threads) as p:\n",
+ " fn = partial(write_masks_frame, predicted_masks=predicted_masks, output_folder=output_folder, max_ids=max_ids)\n",
+ " result = list(tqdm(p.imap(fn, range(len(predicted_masks))), total=len(predicted_masks)))\n",
+ "\n",
+ " if save_separate_masks:\n",
+ " print('Saving Separate masks')\n",
+ " write_masks_frame_multi(predicted_masks=pred_list, output_folder=output_multimask_dir, max_ids=segtracker.get_obj_num())\n",
+ " print(f'Saved masks to {output_multimask_dir}')\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fx1n8BRf7h67"
+ },
+ "source": [
+ "## RIFE\n",
+ "Frame interpolation\n",
+ "Can be run separately from the rest of the notebook"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "BS5_JY9i7iwI"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Install\n",
+ "\n",
+ "import os\n",
+ "import zipfile\n",
+ "try:\n",
+ " #cd to root if root dir defined\n",
+ " os.chdir(root_dir)\n",
+ "except:\n",
+ " root_dir = os.getcwd()\n",
+ "\n",
+ "!python -m pip install gdown\n",
+ "\n",
+ "!git clone https://github.com/Sxela/ECCV2022-RIFE\n",
+ "os.makedirs(f'{root_dir}/ECCV2022-RIFE/train_log', exist_ok=True)\n",
+ "os.chdir(os.path.join(root_dir,'ECCV2022-RIFE'))\n",
+ "\n",
+ "import gdown, zipfile\n",
+ "if not os.path.exists(f'{root_dir}/ECCV2022-RIFE/RIFE_trained_model_v3.6.zip'):\n",
+ " gdown.download(id='1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_', output=f'{root_dir}/ECCV2022-RIFE/RIFE_trained_model_v3.6.zip')\n",
+ "\n",
+ " with zipfile.ZipFile(f'{root_dir}/ECCV2022-RIFE/RIFE_trained_model_v3.6.zip', 'r') as zip_ref:\n",
+ " zip_ref.extractall(f'{root_dir}/ECCV2022-RIFE/')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "X5b3JdllJv-X"
+ },
+ "outputs": [],
+ "source": [
+ "# @title Interpolate frames\n",
+ "#@markdown Can be used to interpolate results.\\\n",
+ "#@markdown If you have a high-fps output video (like 60ps), you can also try skipping frames to reduce high-frequency flicker. If you have already used nth frame during your video render, skipping frames here may produce weird results.\\\n",
+ "#@markdown 2^exponenet frames to generate\n",
+ "#@markdown exponent=1 will generate 2 frames, 2 - 4 frames\n",
+ "import os, cv2, pathlib, subprocess\n",
+ "try:\n",
+ " #cd to root if root dir defined\n",
+ " os.chdir(root_dir)\n",
+ "except:\n",
+ " root_dir = os.getcwd()\n",
+ "\n",
+ "exponent = 2 #@param\n",
+ "print(f'Will generate x{2**exponent} frames.')\n",
+ "\n",
+ "#@markdown video or frames to interpolate\n",
+ "video_path = \"C:\\\\Users\\\\User\\\\Downloads\\\\stable_warpfusion_0.20.0(126)__flow_audio.mp4\" #@param {'type':'string'}\n",
+ "#@markdown use nth frame (drop non-nth frame). 1 uses every frame, 2 uses every 2nd frame, 3 - every third, etc.\n",
+ "nth_frame = 2 #@param\n",
+ "#@markdown input video fps. only used if video_path is a folder with frames. if video_path is a video, will use video fps divided by nth frame\n",
+ "fps = 30 #@param {'type':'number'}\n",
+ "assert os.path.exists(video_path), 'Please specify an existing video_path.'\n",
+ "if os.path.isfile(video_path):\n",
+ " videoCapture = cv2.VideoCapture(video_path)\n",
+ " det_fps = videoCapture.get(cv2.CAP_PROP_FPS)\n",
+ " fps = 2**exponent*det_fps/nth_frame\n",
+ " videoCapture.release()\n",
+ " print(f'Detected fps of {det_fps}. With nth_frame={nth_frame} the output fps will be {fps}')\n",
+ "\n",
+ "\n",
+ "def createPath(filepath):\n",
+ " os.makedirs(filepath, exist_ok=True)\n",
+ "\n",
+ "from glob import glob\n",
+ "\n",
+ "def extractFrames(video_path, output_path, nth_frame, start_frame, end_frame):\n",
+ " if os.path.exists(output_path):shutil.rmtree(output_path)\n",
+ " os.makedirs(output_path, exist_ok=True)\n",
+ "\n",
+ " print(f\"Exporting Video Frames (1 every {nth_frame})...\")\n",
+ " try:\n",
+ " for f in [o.replace('\\\\','/') for o in glob(output_path+'/*.jpg')]:\n",
+ " # for f in pathlib.Path(f'{output_path}').glob('*.jpg'):\n",
+ " pathlib.Path(f).unlink()\n",
+ " except:\n",
+ " print('error deleting frame ', f)\n",
+ " # vf = f'select=not(mod(n\\\\,{nth_frame}))'\n",
+ " vf = f'select=between(n\\\\,{start_frame}\\\\,{end_frame}) , select=not(mod(n\\\\,{nth_frame}))'\n",
+ " if os.path.exists(video_path):\n",
+ " try:\n",
+ " # subprocess.run(['ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ "\n",
+ " subprocess.run(['../ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ " except:\n",
+ " subprocess.run(['../ffmpeg.exe', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
+ "\n",
+ " else:\n",
+ " sys.exit(f'\\nERROR!\\n\\nVideo not found: {video_path}.\\nPlease check your video path.\\n')\n",
+ "\n",
+ "# cli usage example\n",
+ "\n",
+ "# import subprocess\n",
+ "# os.chdir(os.path.join(root_dir,'ECCV2022-RIFE'))\n",
+ "# exponent = 2\n",
+ "# video_path = \"C:\\\\code\\\\warp\\\\20-12\\\\images_out\\\\stable_warpfusion_0.20.0\\\\video\\\\stable_warpfusion_0.20.0(33)__flow_audio.mp4\"\n",
+ "\n",
+ "# cmd = ['python', 'inference_video.py', f'--exp={exponent}', f'--video={video_path}']\n",
+ "# process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)\n",
+ "\n",
+ "# while process.poll() is None:\n",
+ "# out = process.stdout.readline()\n",
+ "# if out != b'':\n",
+ "# print(out.decode('UTF-8'))\n",
+ "# print(process.stdout.read())\n",
+ "# process.stdout.close()\n",
+ "\n",
+ "\n",
+ "os.chdir(os.path.join(root_dir,'ECCV2022-RIFE'))\n",
+ "\n",
+ "\n",
+ "import os\n",
+ "import cv2\n",
+ "import torch\n",
+ "import argparse\n",
+ "import numpy as np\n",
+ "from tqdm import tqdm\n",
+ "from torch.nn import functional as F\n",
+ "import warnings\n",
+ "import _thread\n",
+ "from queue import Queue, Empty\n",
+ "from model.pytorch_msssim import ssim_matlab\n",
+ "import shutil\n",
+ "\n",
+ "frames_temp_dir = os.path.join(root_dir,'ECCV2022-RIFE','temp_frames')\n",
+ "# if os.path.exists(frames_temp_dir):shutil.rmtree(frames_temp_dir)\n",
+ "# else: os.makedirs(frames_temp_dir, exist_ok=True)\n",
+ "if os.path.isfile(video_path):\n",
+ " extractFrames(video_path, frames_temp_dir, nth_frame, start_frame=0, end_frame=999999999)\n",
+ "if os.path.isdir(video_path):\n",
+ " frames_temp_dir = video_path\n",
+ "#extract frames to temp dir\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "def transferAudio(sourceVideo, targetVideo):\n",
+ " import shutil\n",
+ " tempdir = f\"{root_dir}/ECCV2022-RIFE/temp\"\n",
+ " tempAudioFileName = f\"{tempdir}/audio.mkv\"\n",
+ "\n",
+ " # split audio from original video file and store in \"temp\" directory\n",
+ " if True:\n",
+ "\n",
+ " # clear old \"temp\" directory if it exits\n",
+ " if os.path.isdir(tempdir):\n",
+ " # remove temp directory\n",
+ " shutil.rmtree(tempdir)\n",
+ " # create new \"temp\" directory\n",
+ " os.makedirs(tempdir, exist_ok=True)\n",
+ " # extract audio from video\n",
+ " os.system('ffmpeg -y -i \"{}\" -c:a copy -vn \"{}\"'.format(sourceVideo, tempAudioFileName))\n",
+ "\n",
+ " targetNoAudio = os.path.splitext(targetVideo)[0] + \"_noaudio\" + os.path.splitext(targetVideo)[1]\n",
+ " os.rename(targetVideo, targetNoAudio)\n",
+ " # combine audio file and new video file\n",
+ " os.system('ffmpeg -y -i \"{}\" -i {} -c copy \"{}\"'.format(targetNoAudio, tempAudioFileName, targetVideo))\n",
+ "\n",
+ " if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac\n",
+ " tempAudioFileName = f\"{tempdir}/audio.m4a\"\n",
+ " os.system('ffmpeg -y -i \"{}\" -c:a aac -b:a 160k -vn \"{}\"'.format(sourceVideo, tempAudioFileName))\n",
+ " os.system('ffmpeg -y -i \"{}\" -i \"{}\" -c copy \"{}\"'.format(targetNoAudio, tempAudioFileName, targetVideo))\n",
+ " if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format\n",
+ " os.rename(targetNoAudio, targetVideo)\n",
+ " print(\"Audio transfer failed. Interpolated video will have no audio\")\n",
+ " else:\n",
+ " print(\"Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.\")\n",
+ "\n",
+ " # remove audio-less video\n",
+ " os.remove(targetNoAudio)\n",
+ " else:\n",
+ " os.remove(targetNoAudio)\n",
+ "\n",
+ " # remove temp directory\n",
+ " shutil.rmtree(tempdir)\n",
+ "\n",
+ "from types import SimpleNamespace\n",
+ "args = SimpleNamespace(**\n",
+ " {\n",
+ "'video' : None ,\n",
+ "'output' : None ,\n",
+ "'img' : None ,\n",
+ "'montage' : False ,\n",
+ "'modelDir' : 'train_log' ,\n",
+ "'fp16' : False ,\n",
+ "'UHD' : False ,\n",
+ "'scale' : 1.0 ,\n",
+ "'skip' : False ,\n",
+ "'fps' : None ,\n",
+ "'png' : False ,\n",
+ "'ext' : 'mp4' ,\n",
+ "'exp' : 1 ,\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "# args.video = video_path\n",
+ "args.fps = fps\n",
+ "args.img = frames_temp_dir\n",
+ "args.exp = exponent\n",
+ "args.output = '{}_{}X_{}fps_{}nth.{}'.format(\".\".join(video_path.split('.')[:-1]), (2 ** args.exp), int(np.round(args.fps)), nth_frame, args.ext)\n",
+ "\n",
+ "assert (not args.video is None or not args.img is None)\n",
+ "if args.skip:\n",
+ " print(\"skip flag is abandoned, please refer to issue #207.\")\n",
+ "if args.UHD and args.scale==1.0:\n",
+ " args.scale = 0.5\n",
+ "assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]\n",
+ "# if not args.img is None:\n",
+ "# args.png = True\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "torch.set_grad_enabled(False)\n",
+ "if torch.cuda.is_available():\n",
+ " torch.backends.cudnn.enabled = True\n",
+ " torch.backends.cudnn.benchmark = True\n",
+ " if(args.fp16):\n",
+ " torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
+ "os.chdir(os.path.join(root_dir,'ECCV2022-RIFE'))\n",
+ "try:\n",
+ " try:\n",
+ " try:\n",
+ " from model.RIFE_HDv2 import Model\n",
+ " model = Model()\n",
+ " model.load_model(args.modelDir, -1)\n",
+ " print(\"Loaded v2.x HD model.\")\n",
+ " except:\n",
+ " from train_log.RIFE_HDv3 import Model\n",
+ " model = Model()\n",
+ " model.load_model(args.modelDir, -1)\n",
+ " print(\"Loaded v3.x HD model.\")\n",
+ " except:\n",
+ " from model.RIFE_HD import Model\n",
+ " model = Model()\n",
+ " model.load_model(args.modelDir, -1)\n",
+ " print(\"Loaded v1.x HD model\")\n",
+ "except:\n",
+ " from model.RIFE import Model\n",
+ " model = Model()\n",
+ " model.load_model(args.modelDir, -1)\n",
+ " print(\"Loaded ArXiv-RIFE model\")\n",
+ "model.eval()\n",
+ "model.device()\n",
+ "\n",
+ "def videogen_fn(videoCapture):\n",
+ " ret = True\n",
+ " while ret:\n",
+ " ret, frame = videoCapture.read()\n",
+ " if not ret: yield None\n",
+ " else:\n",
+ " frame = frame[:, :, ::-1].copy()\n",
+ " yield frame\n",
+ "\n",
+ "\n",
+ "from functools import partial\n",
+ "fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')\n",
+ "fpsNotAssigned = True\n",
+ "if args.video is not None:\n",
+ " videoCapture = cv2.VideoCapture(args.video)\n",
+ " fps = videoCapture.get(cv2.CAP_PROP_FPS)\n",
+ " tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)\n",
+ " if args.fps is None:\n",
+ " fpsNotAssigned = True\n",
+ " args.fps = fps * (2 ** args.exp)\n",
+ " else:\n",
+ " fpsNotAssigned = False\n",
+ " videogen = videogen_fn(videoCapture)\n",
+ " lastframe = next(videogen)\n",
+ "\n",
+ " fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')\n",
+ " video_path_wo_ext, ext = os.path.splitext(args.video)\n",
+ " print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps))\n",
+ " if args.png == False and fpsNotAssigned == True:\n",
+ " print(\"The audio will be merged after interpolation process\")\n",
+ " else:\n",
+ " print(\"Will not merge audio because using png or fps flag!\")\n",
+ "else:\n",
+ " videogen = []\n",
+ " for f in os.listdir(args.img):\n",
+ " # if 'png' in f:\n",
+ " if f.endswith('.png') or f.endswith('.jpg'):\n",
+ " videogen.append(f)\n",
+ " tot_frame = len(videogen)\n",
+ " videogen.sort(key= lambda x:int(x[:-4]))\n",
+ " lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()\n",
+ " videogen = videogen[1:]\n",
+ "h, w, _ = lastframe.shape\n",
+ "vid_out_name = None\n",
+ "vid_out = None\n",
+ "if args.png:\n",
+ " if not os.path.exists('vid_out'):\n",
+ " os.mkdir('vid_out')\n",
+ "else:\n",
+ " if args.output is not None:\n",
+ " vid_out_name = args.output\n",
+ " else:\n",
+ " vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, (2 ** args.exp), int(np.round(args.fps)), args.ext)\n",
+ " print(f'Exporting video to {vid_out_name}')\n",
+ " vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h))\n",
+ "\n",
+ "def clear_write_buffer(user_args, write_buffer):\n",
+ " cnt = 0\n",
+ " while True:\n",
+ " item = write_buffer.get()\n",
+ " if item is None:\n",
+ " break\n",
+ " if user_args.png:\n",
+ " cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])\n",
+ " cnt += 1\n",
+ " else:\n",
+ " vid_out.write(item[:, :, ::-1])\n",
+ "\n",
+ "def build_read_buffer(user_args, read_buffer, videogen):\n",
+ " try:\n",
+ " for frame in videogen:\n",
+ " if not user_args.img is None:\n",
+ " frame = cv2.imread(os.path.join(user_args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()\n",
+ " if user_args.montage:\n",
+ " frame = frame[:, left: left + w]\n",
+ " read_buffer.put(frame)\n",
+ " except:\n",
+ " pass\n",
+ " read_buffer.put(None)\n",
+ "\n",
+ "def make_inference(I0, I1, n):\n",
+ " global model\n",
+ " middle = model.inference(I0, I1, args.scale)\n",
+ " if n == 1:\n",
+ " return [middle]\n",
+ " first_half = make_inference(I0, middle, n=n//2)\n",
+ " second_half = make_inference(middle, I1, n=n//2)\n",
+ " if n%2:\n",
+ " return [*first_half, middle, *second_half]\n",
+ " else:\n",
+ " return [*first_half, *second_half]\n",
+ "\n",
+ "def pad_image(img):\n",
+ " if(args.fp16):\n",
+ " return F.pad(img, padding).half()\n",
+ " else:\n",
+ " return F.pad(img, padding)\n",
+ "\n",
+ "if args.montage:\n",
+ " left = w // 4\n",
+ " w = w // 2\n",
+ "tmp = max(32, int(32 / args.scale))\n",
+ "ph = ((h - 1) // tmp + 1) * tmp\n",
+ "pw = ((w - 1) // tmp + 1) * tmp\n",
+ "padding = (0, pw - w, 0, ph - h)\n",
+ "pbar = tqdm(total=tot_frame)\n",
+ "if args.montage:\n",
+ " lastframe = lastframe[:, left: left + w]\n",
+ "write_buffer = Queue(maxsize=500)\n",
+ "read_buffer = Queue(maxsize=500)\n",
+ "_thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))\n",
+ "_thread.start_new_thread(clear_write_buffer, (args, write_buffer))\n",
+ "\n",
+ "I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.\n",
+ "I1 = pad_image(I1)\n",
+ "temp = None # save lastframe when processing static frame\n",
+ "\n",
+ "while True:\n",
+ " if temp is not None:\n",
+ " frame = temp\n",
+ " temp = None\n",
+ " else:\n",
+ " frame = read_buffer.get()\n",
+ " if frame is None:\n",
+ " break\n",
+ " I0 = I1\n",
+ " I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.\n",
+ " I1 = pad_image(I1)\n",
+ " I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)\n",
+ " I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)\n",
+ " ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])\n",
+ "\n",
+ " break_flag = False\n",
+ " if ssim > 0.996:\n",
+ " frame = read_buffer.get() # read a new frame\n",
+ " if frame is None:\n",
+ " break_flag = True\n",
+ " frame = lastframe\n",
+ " else:\n",
+ " temp = frame\n",
+ " I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.\n",
+ " I1 = pad_image(I1)\n",
+ " I1 = model.inference(I0, I1, args.scale)\n",
+ " I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)\n",
+ " ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])\n",
+ " frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]\n",
+ "\n",
+ " if ssim < 0.2:\n",
+ " output = []\n",
+ " for i in range((2 ** args.exp) - 1):\n",
+ " output.append(I0)\n",
+ " '''\n",
+ " output = []\n",
+ " step = 1 / (2 ** args.exp)\n",
+ " alpha = 0\n",
+ " for i in range((2 ** args.exp) - 1):\n",
+ " alpha += step\n",
+ " beta = 1-alpha\n",
+ " output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)\n",
+ " '''\n",
+ " else:\n",
+ " output = make_inference(I0, I1, 2**args.exp-1) if args.exp else []\n",
+ "\n",
+ " if args.montage:\n",
+ " write_buffer.put(np.concatenate((lastframe, lastframe), 1))\n",
+ " for mid in output:\n",
+ " mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))\n",
+ " write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))\n",
+ " else:\n",
+ " write_buffer.put(lastframe)\n",
+ " for mid in output:\n",
+ " mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))\n",
+ " write_buffer.put(mid[:h, :w])\n",
+ " pbar.update(1)\n",
+ " lastframe = frame\n",
+ " if break_flag:\n",
+ " break\n",
+ "\n",
+ "if args.montage:\n",
+ " write_buffer.put(np.concatenate((lastframe, lastframe), 1))\n",
+ "else:\n",
+ " write_buffer.put(lastframe)\n",
+ "import time\n",
+ "while(not write_buffer.empty()):\n",
+ " time.sleep(0.1)\n",
+ "pbar.close()\n",
+ "if not vid_out is None:\n",
+ " vid_out.release()\n",
+ "\n",
+ "# move audio to new video file if appropriate\n",
+ "if args.png == False and fpsNotAssigned == True and not args.video is None:\n",
+ " try:\n",
+ " transferAudio(args.video, vid_out_name)\n",
+ " except:\n",
+ " print(\"Audio transfer failed. Interpolated video will have no audio\")\n",
+ " targetNoAudio = os.path.splitext(vid_out_name)[0] + \"_noaudio\" + os.path.splitext(vid_out_name)[1]\n",
+ " os.rename(targetNoAudio, vid_out_name)\n",
+ "\n",
+ "filepath = args.output\n",
+ "video_init_path = video_path\n",
+ "if os.path.isfile(video_path):\n",
+ " print('\\nTransferring audio.')\n",
+ " f_audio = \".\".join(filepath.split('.')[:-1])+'_audio.'+filepath.split('.')[-1]\n",
+ " cmd_a = ['../ffmpeg', '-y', '-i', filepath, '-i', video_init_path, '-map', '0:v', '-map', '1:a', '-c:v', 'copy', '-shortest', f_audio]\n",
+ " process = subprocess.Popen(cmd_a, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
+ " stdout, stderr = process.communicate()\n",
+ " if process.returncode!=0: print('Error exporting audio. Your input file probably has no audio.')\n",
+ "\n",
+ "os.chdir(root_dir)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "anaconda-cloud": {},
+ "colab": {
+ "collapsed_sections": [
+ "WrxXo2FVivvi",
+ "CreditsChTop",
+ "LicenseTop",
+ "z36v90fNgLMF",
+ "mcI6h0A7NcZ-",
+ "yyC0Qb0qOcsJ",
+ "T8xpuFgUEeLz",
+ "U5rrnKtV7FoY",
+ "_MleAG1V0ss6",
+ "PgnJ26Bh3Ru8",
+ "GWWNdYvj3Xst",
+ "4bCGxkUZ3r68",
+ "nm_EeEeu391T",
+ "OeF4nJaf3eiD",
+ "SZ6qrVEJeG1u",
+ "CgG4Uq5vepSI",
+ "FCLS9fPYVR6N"
+ ],
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "private_outputs": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.0"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "81794d4967e6c3204c66dcd87b604927b115b27c00565d3d43f05ba2f3a2cb0d"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file