annotation / call_assistant_api.py
MudeHui's picture
Add application file
1fb65ae
# install the lib of : https://github.com/pengwangucla/cv_utils
from vis_common import *
import vis_utils as v_uts
import json
import os
import time
import base64
import requests
from openai import OpenAI
from tenacity import retry, wait_random_exponential, stop_after_attempt, wait_fixed
from GPT_prompts import REWRITE_PROMPT_0
API_KEY = os.environ.get("BYTE_API_KEY")
class EditActionClassifier():
def __init__(self):
self.client = OpenAI()
self.assistant_key = "asst_57vfLupV8VCsCZx0BJOppSnw"
self.thread = self.client.beta.threads.create()
@retry(wait=wait_fixed(10), stop=stop_after_attempt(3))
def infer(self, edit_action):
message = self.client.beta.threads.messages.create(
thread_id=self.thread.id,
role="user",
content=edit_action
)
run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=self.assistant_key,
)
pbar = tqdm(total=100)
while run.status != 'completed':
run = self.client.beta.threads.runs.retrieve(
thread_id=self.thread.id,
run_id=run.id
)
time.sleep(.5) # Sleep and check run status again
pbar.update(1)
pbar.set_description('Run Status: ' + run.status)
if run.status == 'failed':
break
if run.status == 'failed':
print("Run failed")
return ""
messages = self.client.beta.threads.messages.list(
thread_id=self.thread.id
)
result = messages.data[0].content[0].text.value
if "edit class" in results:
try:
class_name = json.loads(result)["edit class"]
except Exception as e:
print(f"{result}, can not be load by json")
class_name = result
return class_name
def test_personal_dalle3():
# Call the API
client = OpenAI()
response = client.images.generate(
model="dall-e-3",
prompt="a cute cat with a hat on",
size="1792x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
image_url = "https://oaidalleapiprodscus.blob.core.windows.net/private/org-S0JkO5ALwPh1E3YpnKFiS7Gh/user-gJLc6S6Gmp2NCFBcEyZNgRNz/img-RDqXwfARPT6LSovnZXbMyzSO.png?st=2024-01-12T18%3A54%3A32Z&se=2024-01-12T20%3A54%3A32Z&sp=r&sv=2021-08-06&sr=b&rscd=inline&rsct=image/png&skoid=6aaadede-4fb3-4698-a8f6-684d7786b067&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2024-01-11T22%3A56%3A51Z&ske=2024-01-12T22%3A56%3A51Z&sks=b&skv=2021-08-06&sig=BoeIYYvxu5Cnt4YfM53Az7EYlEYUTkWfXQCKrKaDWD0%3D"
# Download the image from the URL
image_response = requests.get(image_url)
# Check if the request was successful
if image_response.status_code == 200:
# Save the image to a file
with open('cute_cat_with_hat.jpg', 'wb') as file:
file.write(image_response.content)
else:
print("Failed to download the image.")
def test_call_gpt4_api():
from langchain_community.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage
BASE_URL = "https://search-us.byteintl.net/gpt/openapi/online/v2/crawl/"
DEPLOYMENT_NAME = "gpt-4-0613"
DEPLOYMENT_NAME = "gpt-4-1106-preview"
model = AzureChatOpenAI(
openai_api_base=BASE_URL,
openai_api_version="2023-03-15-preview",
deployment_name=DEPLOYMENT_NAME,
openai_api_key=API_KEY,
openai_api_type="azure",
temperature=0.5,
max_tokens=512,
)
content = REWRITE_PROMPT_0.format(prompt1="Create a diptych image that consists two images. \
The left image is front-view of lying real white 12 years old man. \
The right image keep everything the same but change the background of the subject to europe.")
generate_log = model([HumanMessage(content=content)]).content
print(generate_log)
def test_call_gpt4v_api():
from langchain_community.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage
BASE_URL = "https://search-us.byteintl.net/gpt/openapi/online/v2/crawl/"
DEPLOYMENT_NAME = "openai_gpt-4-vision" # gptv 或 openai_gpt-4-vision
model = AzureChatOpenAI(
openai_api_base=BASE_URL,
openai_api_version="2023-07-01-preview",
deployment_name=DEPLOYMENT_NAME,
openai_api_key=API_KEY,
openai_api_type="azure",
temperature=0.5,
max_tokens=512,
)
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
input_ip = {
"url": image_url
}
image_path = "./imgs/dataset.jpg"
base64_image = v_uts.encode_b64(image_path)
input_ip = {
"url": f"data:image/jpeg;base64,{base64_image}"
}
generate_log = model([HumanMessage(content=[
{
"type": "text",
"text": "What’s in this image?"
},
{
"type": "image_url",
"image_url": input_ip
}
])])
print(generate_log)
# curl --location --request POST 'https://search.bytedance.net/gpt/openapi/online/v2/crawl?ak=业务方AK' \
# --header 'Content-Type: application/json' \
# --header 'X-TT-LOGID: 请求方logID,方便定位问题' \
# --data-raw '{
# "prompt": "A poster of Microsoft", // 文字描述画图内容
# "size": "1024x1024", // 图片大小。只支持 1024x1024 / 1024x1792 / 1792x1024
# "quality": "standard", // 图片质量,默认standard
# "style": "vivid", // 图片风格,模型vivid
# "n": 1,
# "model": "dall-e-3" // 对应模型名称,必填
# }'
# // response
# {
# "created": 1702889995,
# "data": [
# {
# "url": "https://dalleprodsec.blob.core.windows.net/private/images/0811eacd-bf25-4961-814f-36d7f453907c/generated_00.png?se=2023-12-19T09%3A00%3A09Z&sig=cIRz7je1Qbjlt5GjeyLGKoxPRFggr7NAxLSeeCuGyYk%3D&ske=2023-12-22T11%3A18%3A13Z&skoid=e52d5ed7-0657-4f62-bc12-7e5dbb260a96&sks=b&skt=2023-12-15T11%3A18%3A13Z&sktid=33e01921-4d64-4f8c-a055-5bdaffd5e33d&skv=2020-10-02&sp=r&spr=https&sr=b&sv=2020-10-02",
# "revised_prompt": "A designed poster featuring the logo of a prominent technology company, accompanied by various emboldened text denoting the company's name and a motivational slogan. The distinct, four rectangular logo in bright colors is situated at the center of the poster, against a plain background. The composition strikes a balance between minimalism and impact, typifying the company's powerful image in the global technology industry."
# }
# ]
# }
def test_call_dalle3_api():
""" openai==1.2.0, httpx==0.23.0
"""
from openai import AzureOpenAI
BASE_URL = "https://search-va.byteintl.net/gpt/openapi/online/v2/crawl"
DEPLOYMENT_NAME = "dall-e-3"
API_KEY = "hpjWvnz7wM2mzDg4Ggnt96xcOjeYcktj"
client = AzureOpenAI(
api_version="2023-12-01-preview",
api_key=API_KEY,
azure_endpoint=BASE_URL)
result = client.images.generate(
model=DEPLOYMENT_NAME, # the name of your DALL-E 3 deployment
prompt="A soldier girl holding a USA flag",
n=1,
size="1024x1024",
quality="standard",
style="vivid"
)
image_url = result.data[0].url
image_response = requests.get(image_url)
# Check if the request was successful
if image_response.status_code == 200:
# Save the image to a file
with open('.jpg', 'wb') as file:
file.write(image_response.content)
else:
print("Failed to download the image.")
if __name__ == "__main__":
# classifier = EditActionClassifier()
# class_name = classifier.infer("Remove the background of the image")
# print(class_name)
# test_personal_dalle3()
# test_call_gpt4_api()
# test_call_gpt4v_api()
test_call_dalle3_api()