-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathworkflow_image_to_text.py
84 lines (68 loc) · 3.13 KB
/
workflow_image_to_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
# Streamlit secrets
USER_ID = st.secrets["USER_ID"]
PAT = st.secrets["PAT"]
APP_ID = st.secrets["APP_ID"]
WORKFLOW_ID_IMAGE = st.secrets["WORKFLOW_ID_IMAGE"]
WORKFLOW_ID_STORY_GPT3 = st.secrets["WORKFLOW_ID_STORY_GPT3"]
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2
@st.cache_data(persist=True)
def generate_image_caption(image):
channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)
metadata = (('authorization', 'Key ' + PAT),)
userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID) # The userDataObject is required when using a PAT
post_workflow_results_response = stub.PostWorkflowResults(
service_pb2.PostWorkflowResultsRequest(
user_app_id=userDataObject,
workflow_id=WORKFLOW_ID_IMAGE,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(
image=resources_pb2.Image(
# url=IMAGE_URL
base64=image
)
)
)
]
),
metadata=metadata
)
if post_workflow_results_response.status.code != status_code_pb2.SUCCESS:
print(post_workflow_results_response.status)
raise Exception("Post workflow results failed, status: " + post_workflow_results_response.status.description)
# We'll get one WorkflowResult for each input we used above. Because of one input, we have here one WorkflowResult
results = post_workflow_results_response.results[0]
return results.outputs[0].data.text.raw
@st.cache_data(persist=True)
def generate_story_from_image_caption(image_caption_with_user_input):
channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)
userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)
metadata = (('authorization', 'Key ' + PAT),)
post_workflow_results_response = stub.PostWorkflowResults(
service_pb2.PostWorkflowResultsRequest(
user_app_id=userDataObject,
workflow_id=WORKFLOW_ID_STORY_GPT3,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(
# url=TEXT_FILE_URL
raw=image_caption_with_user_input
)
)
)
]
),
metadata=metadata
)
if post_workflow_results_response.status.code != status_code_pb2.SUCCESS:
print(post_workflow_results_response.status)
raise Exception("Post workflow results failed, status: " + post_workflow_results_response.status.description)
# We'll get one WorkflowResult for each input we used above. Because of one input, we have here one WorkflowResult
results = post_workflow_results_response.results[0]
return results.outputs[0].data.text.raw