enzostvs's picture
enzostvs HF staff
set instance prompt only at the generation
fba0083
raw
history blame
2.78 kB
/** @type {import('./$types').RequestHandler} */
import { json, type RequestEvent } from '@sveltejs/kit';
import { env } from '$env/dynamic/private'
import { env as publicEnv } from '$env/dynamic/public';
import { promises } from 'fs';
import { randomUUID } from 'crypto';
import { tokenIsAvailable } from '$lib/utils';
import prisma from '$lib/prisma';
export async function POST({ request, cookies } : RequestEvent) {
const token = cookies.get('hf_access_token')
const generation = await request.json()
if (!generation?.model?.id) {
return json({
error: {
token: "A model id is required"
}
}, { status: 400 })
}
if (!generation?.inputs) {
return json({
error: {
token: "An inputs is required"
}
}, { status: 400 })
}
const model = await prisma.model.findFirst({
where: {
id: generation.model.id
},
select: {
instance_prompt: true,
}
})
const response = await fetch(env.SECRET_INFERENCE_API_URL + "/models/" + generation?.model?.id, {
method: "POST",
headers: {
Authorization: `Bearer ${env.SECRET_HF_TOKEN}`,
'Content-Type': 'application/json',
['x-use-cache']: "0"
},
body: JSON.stringify({
...generation,
inputs: `${(model?.instance_prompt || "")} ${generation.inputs}`,
}),
})
.then((response) => {
return response.arrayBuffer()
})
.then((response) => {
return Buffer.from(response)
})
.catch((error) => {
return {
error: error.message,
}
})
if ("error" in response) {
return json({
error: {
token: response.error
}
}, { status: 400 })
}
let gallery;
if (token) {
const user = await tokenIsAvailable(token)
if (user?.sub) {
const dir = await promises.opendir(publicEnv.PUBLIC_FILE_UPLOAD_DIR).catch(() => null)
if (!dir) {
await promises.mkdir(publicEnv.PUBLIC_FILE_UPLOAD_DIR)
}
const file_name_formatted = randomUUID() + "_" + generation?.inputs?.replaceAll(/[^a-zA-Z0-9]/g, "-") + ".png"
await promises.writeFile(`${publicEnv.PUBLIC_FILE_UPLOAD_DIR}/${file_name_formatted}`, response)
gallery = await prisma.gallery.create({
data: {
image: file_name_formatted,
prompt: generation.inputs,
isPublic: false,
user: {
connect: {
sub: user.sub
}
},
model: {
connect: {
id: generation.model.id
}
},
}
})
.catch((error) => {
console.log(error)
})
}
}
const image = Buffer.from(response).toString('base64')
return json({
image: "data:image/png;base64," + image,
gallery
})
}