generated from databricks-industry-solutions/industry-solutions-blueprints
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path99_utils.py
51 lines (42 loc) · 1.31 KB
/
99_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Databricks notebook source
# Installing requirement libraries
%pip install -r ./requirements.txt --quiet
dbutils.library.restartPython()
# COMMAND ----------
# Common imports used throughout.
import matplotlib.pyplot as plt
import PIL
import torch
# COMMAND ----------
def show_image(image: PIL.Image.Image):
"""
Show one generated image.
"""
plt.imshow(image)
plt.axis("off")
plt.show()
def show_image_grid(imgs, rows, cols, resize=256):
"""
Show multiple generated images in grid.
"""
if resize is not None:
imgs = [img.resize((resize, resize)) for img in imgs]
w, h = imgs[0].size
grid = PIL.Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def caption_images(input_image, blip_processor, blip_model, device):
"""
Caption images with an annotation model.
"""
inputs = blip_processor(images=input_image, return_tensors="pt").to(
device, torch.float16
)
pixel_values = inputs.pixel_values
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = blip_processor.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
return generated_caption