web-stable-diffusion / dist /stable_diffusion.js
silait's picture
Upload 40 files
7171c5f verified
raw
history blame
19.2 kB
/**
* Wrapper to handle PNDM scheduler
*/
class TVMPNDMScheduler {
constructor(schedulerConsts, latentShape, tvm, device, vm) {
this.timestep = [];
this.sampleCoeff = [];
this.alphaDiff = [];
this.modelOutputDenomCoeff = [];
this.ets = [];
this.schedulerFunc = [];
this.currSample = undefined;
this.tvm = tvm;
// prebuild constants
// principle: always detach for class members
// to avoid recycling output scope.
function loadConsts(output, dtype, input) {
for (let t = 0; t < input.length; ++t) {
output.push(
tvm.detachFromCurrentScope(
tvm.empty([], dtype, device).copyFrom([input[t]])
)
);
}
}
loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]);
loadConsts(this.sampleCoeff, "float32", schedulerConsts["sample_coeff"]);
loadConsts(this.alphaDiff, "float32", schedulerConsts["alpha_diff"]);
loadConsts(
this.modelOutputDenomCoeff, "float32",
schedulerConsts["model_output_denom_coeff"]);
for (let i = 0; i < 4; ++i) {
this.ets.push(
this.tvm.detachFromCurrentScope(
this.tvm.empty(latentShape, "float32", device)
)
);
}
for (let i = 0; i < 5; ++i) {
this.schedulerFunc.push(
tvm.detachFromCurrentScope(
vm.getFunction("pndm_scheduler_step_" + i.toString())
)
);
}
}
dispose() {
for (let t = 0; t < this.timestep.length; ++t) {
this.timestep[t].dispose();
this.sampleCoeff[t].dispose();
this.alphaDiff[t].dispose();
this.modelOutputDenomCoeff[t].dispose();
}
for (let i = 0; i < this.schedulerFunc.length; ++i) {
this.schedulerFunc[i].dispose();
}
if (this.currSample) {
this.currSample.dispose();
}
for (let i = 0; i < this.ets.length; ++i) {
this.ets[i].dispose();
}
}
step(modelOutput, sample, counter) {
// keep running history of last four inputs
if (counter != 1) {
this.ets.shift();
this.ets.push(this.tvm.detachFromCurrentScope(
modelOutput
));
}
if (counter == 0) {
this.currSample = this.tvm.detachFromCurrentScope(
sample
);
} else if (counter == 1) {
sample = this.tvm.attachToCurrentScope(this.currSample);
this.currSample = undefined;
}
const findex = counter < 4 ? counter : 4;
const prevLatents = this.schedulerFunc[findex](
sample,
modelOutput,
this.sampleCoeff[counter],
this.alphaDiff[counter],
this.modelOutputDenomCoeff[counter],
this.ets[0],
this.ets[1],
this.ets[2],
this.ets[3]
);
return prevLatents;
}
}
/**
* Wrapper to handle multistep DPM-solver scheduler
*/
class TVMDPMSolverMultistepScheduler {
constructor(schedulerConsts, latentShape, tvm, device, vm) {
this.timestep = [];
this.alpha = [];
this.sigma = [];
this.c0 = [];
this.c1 = [];
this.c2 = [];
this.lastModelOutput = undefined;
this.convertModelOutputFunc = undefined;
this.stepFunc = undefined;
this.tvm = tvm;
// prebuild constants
// principle: always detach for class members
// to avoid recycling output scope.
function loadConsts(output, dtype, input) {
for (let t = 0; t < input.length; ++t) {
output.push(
tvm.detachFromCurrentScope(
tvm.empty([], dtype, device).copyFrom([input[t]])
)
);
}
}
loadConsts(this.timestep, "int32", schedulerConsts["timesteps"]);
loadConsts(this.alpha, "float32", schedulerConsts["alpha"]);
loadConsts(this.sigma, "float32", schedulerConsts["sigma"]);
loadConsts(this.c0, "float32", schedulerConsts["c0"]);
loadConsts(this.c1, "float32", schedulerConsts["c1"]);
loadConsts(this.c2, "float32", schedulerConsts["c2"]);
this.lastModelOutput = this.tvm.detachFromCurrentScope(
this.tvm.empty(latentShape, "float32", device)
)
this.convertModelOutputFunc = tvm.detachFromCurrentScope(
vm.getFunction("dpm_solver_multistep_scheduler_convert_model_output")
)
this.stepFunc = tvm.detachFromCurrentScope(
vm.getFunction("dpm_solver_multistep_scheduler_step")
)
}
dispose() {
for (let t = 0; t < this.timestep.length; ++t) {
this.timestep[t].dispose();
this.alpha[t].dispose();
this.sigma[t].dispose();
this.c0[t].dispose();
this.c1[t].dispose();
this.c2[t].dispose();
}
this.lastModelOutput.dispose();
this.convertModelOutputFunc.dispose();
this.stepFunc.dispose();
}
step(modelOutput, sample, counter) {
modelOutput = this.convertModelOutputFunc(sample, modelOutput, this.alpha[counter], this.sigma[counter])
const prevLatents = this.stepFunc(
sample,
modelOutput,
this.lastModelOutput,
this.c0[counter],
this.c1[counter],
this.c2[counter],
);
this.lastModelOutput = this.tvm.detachFromCurrentScope(
modelOutput
);
return prevLatents;
}
}
class StableDiffusionPipeline {
constructor(tvm, tokenizer, schedulerConsts, cacheMetadata) {
if (cacheMetadata == undefined) {
throw Error("Expect cacheMetadata");
}
this.tvm = tvm;
this.tokenizer = tokenizer;
this.maxTokenLength = 77;
this.device = this.tvm.webgpu();
this.tvm.bindCanvas(document.getElementById("canvas"));
// VM functions
this.vm = this.tvm.detachFromCurrentScope(
this.tvm.createVirtualMachine(this.device)
);
this.schedulerConsts = schedulerConsts;
this.clipToTextEmbeddings = this.tvm.detachFromCurrentScope(
this.vm.getFunction("clip")
);
this.clipParams = this.tvm.detachFromCurrentScope(
this.tvm.getParamsFromCache("clip", cacheMetadata.clipParamSize)
);
this.unetLatentsToNoisePred = this.tvm.detachFromCurrentScope(
this.vm.getFunction("unet")
);
this.unetParams = this.tvm.detachFromCurrentScope(
this.tvm.getParamsFromCache("unet", cacheMetadata.unetParamSize)
);
this.vaeToImage = this.tvm.detachFromCurrentScope(
this.vm.getFunction("vae")
);
this.vaeParams = this.tvm.detachFromCurrentScope(
this.tvm.getParamsFromCache("vae", cacheMetadata.vaeParamSize)
);
this.imageToRGBA = this.tvm.detachFromCurrentScope(
this.vm.getFunction("image_to_rgba")
);
this.concatEmbeddings = this.tvm.detachFromCurrentScope(
this.vm.getFunction("concat_embeddings")
);
}
dispose() {
// note: tvm instance is not owned by this class
this.concatEmbeddings.dispose();
this.imageToRGBA.dispose()
this.vaeParams.dispose();
this.vaeToImage.dispose();
this.unetParams.dispose();
this.unetLatentsToNoisePred.dispose();
this.clipParams.dispose();
this.clipToTextEmbeddings.dispose();
this.vm.dispose();
}
/**
* Tokenize the prompt to TVMNDArray.
* @param prompt Input prompt
* @returns The text id NDArray.
*/
tokenize(prompt) {
const encoded = this.tokenizer.encode(prompt, true).input_ids;
const inputIDs = new Int32Array(this.maxTokenLength);
if (encoded.length < this.maxTokenLength) {
inputIDs.set(encoded);
const lastTok = encoded[encoded.length - 1];
inputIDs.fill(lastTok, encoded.length, inputIDs.length);
} else {
inputIDs.set(encoded.slice(0, this.maxTokenLength));
}
return this.tvm.empty([1, this.maxTokenLength], "int32", this.device).copyFrom(inputIDs);
}
/**
* async preload webgpu pipelines when possible.
*/
async asyncLoadWebGPUPiplines() {
await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule());
}
/**
* Run generation pipeline.
*
* @param prompt Input prompt.
* @param negPrompt Input negative prompt.
* @param progressCallback Callback to check progress.
* @param schedulerId The integer ID of the scheduler to use.
* - 0 for multi-step DPM solver,
* - 1 for PNDM solver.
* @param vaeCycle optionally draw VAE result every cycle iterations.
* @param beginRenderVae Begin rendering VAE after skipping these warmup runs.
*/
async generate(
prompt,
negPrompt = "",
progressCallback = undefined,
schedulerId = 0,
vaeCycle = -1,
beginRenderVae = 10
) {
// Principle: beginScope/endScope in synchronized blocks,
// this helps to recycle intermediate memories
// detach states that needs to go across async boundaries.
//--------------------------
// Stage 0: CLIP
//--------------------------
this.tvm.beginScope();
// get latents
const latentShape = [1, 4, 64, 64];
var unetNumSteps;
if (schedulerId == 0) {
scheduler = new TVMDPMSolverMultistepScheduler(
this.schedulerConsts[0], latentShape, this.tvm, this.device, this.vm);
unetNumSteps = this.schedulerConsts[0]["num_steps"];
} else {
scheduler = new TVMPNDMScheduler(
this.schedulerConsts[1], latentShape, this.tvm, this.device, this.vm);
unetNumSteps = this.schedulerConsts[1]["num_steps"];
}
const totalNumSteps = unetNumSteps + 2;
if (progressCallback !== undefined) {
progressCallback("clip", 0, 1, totalNumSteps);
}
const embeddings = this.tvm.withNewScope(() => {
let posInputIDs = this.tokenize(prompt);
let negInputIDs = this.tokenize(negPrompt);
const posEmbeddings = this.clipToTextEmbeddings(
posInputIDs, this.clipParams);
const negEmbeddings = this.clipToTextEmbeddings(
negInputIDs, this.clipParams);
// maintain new latents
return this.tvm.detachFromCurrentScope(
this.concatEmbeddings(negEmbeddings, posEmbeddings)
);
});
// use uniform distribution with same variance as normal(0, 1)
const scale = Math.sqrt(12) / 2;
let latents = this.tvm.detachFromCurrentScope(
this.tvm.uniform(latentShape, -scale, scale, this.tvm.webgpu())
);
this.tvm.endScope();
//---------------------------
// Stage 1: UNet + Scheduler
//---------------------------
if (vaeCycle != -1) {
// show first frame
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
this.tvm.showImage(this.imageToRGBA(image));
});
await this.device.sync();
}
vaeCycle = vaeCycle == -1 ? unetNumSteps : vaeCycle;
let lastSync = undefined;
for (let counter = 0; counter < unetNumSteps; ++counter) {
if (progressCallback !== undefined) {
progressCallback("unet", counter, unetNumSteps, totalNumSteps);
}
const timestep = scheduler.timestep[counter];
// recycle noisePred, track latents manually
const newLatents = this.tvm.withNewScope(() => {
this.tvm.attachToCurrentScope(latents);
const noisePred = this.unetLatentsToNoisePred(
latents, timestep, embeddings, this.unetParams);
// maintain new latents
return this.tvm.detachFromCurrentScope(
scheduler.step(noisePred, latents, counter)
);
});
latents = newLatents;
// use skip one sync, although likely not as useful.
if (lastSync !== undefined) {
await lastSync;
}
// async event checker
lastSync = this.device.sync();
// Optionally, we can draw intermediate result of VAE.
if ((counter + 1) % vaeCycle == 0 &&
(counter + 1) != unetNumSteps &&
counter >= beginRenderVae) {
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
this.tvm.showImage(this.imageToRGBA(image));
});
await this.device.sync();
}
}
scheduler.dispose();
embeddings.dispose();
//-----------------------------
// Stage 2: VAE and draw image
//-----------------------------
if (progressCallback !== undefined) {
progressCallback("vae", 0, 1, totalNumSteps);
}
this.tvm.withNewScope(() => {
const image = this.vaeToImage(latents, this.vaeParams);
this.tvm.showImage(this.imageToRGBA(image));
});
latents.dispose();
await this.device.sync();
if (progressCallback !== undefined) {
progressCallback("vae", 1, 1, totalNumSteps);
}
}
clearCanvas() {
this.tvm.clearCanvas();
}
};
/**
* A instance that can be used to facilitate deployment.
*/
class StableDiffusionInstance {
constructor() {
this.tvm = undefined;
this.pipeline = undefined;
this.config = undefined;
this.generateInProgress = false;
this.logger = console.log;
}
/**
* Initialize TVM
* @param wasmUrl URL to wasm source.
* @param cacheUrl URL to NDArray cache.
* @param logger Custom logger.
*/
async #asyncInitTVM(wasmUrl, cacheUrl) {
if (this.tvm !== undefined) {
return;
}
if (document.getElementById("log") !== undefined) {
this.logger = function (message) {
console.log(message);
const d = document.createElement("div");
d.innerHTML = message;
document.getElementById("log").appendChild(d);
};
}
const wasmSource = await (
await fetch(wasmUrl)
).arrayBuffer();
const tvm = await tvmjs.instantiate(
new Uint8Array(wasmSource),
new EmccWASI(),
this.logger
);
// initialize WebGPU
try {
const output = await tvmjs.detectGPUDevice();
if (output !== undefined) {
var label = "WebGPU";
if (output.adapterInfo.description.length != 0) {
label += " - " + output.adapterInfo.description;
} else {
label += " - " + output.adapterInfo.vendor;
}
document.getElementById(
"gpu-tracker-label").innerHTML = ("Initialize GPU device: " + label);
tvm.initWebGPU(output.device);
} else {
document.getElementById(
"gpu-tracker-label").innerHTML = "This browser env do not support WebGPU";
this.reset();
throw Error("This browser env do not support WebGPU");
}
} catch (err) {
document.getElementById("gpu-tracker-label").innerHTML = (
"Find an error initializing the WebGPU device " + err.toString()
);
console.log(err.stack);
this.reset();
throw Error("Find an error initializing WebGPU: " + err.toString());
}
this.tvm = tvm;
function initProgressCallback(report) {
document.getElementById("progress-tracker-label").innerHTML = report.text;
document.getElementById("progress-tracker-progress").value = report.progress * 100;
}
tvm.registerInitProgressCallback(initProgressCallback);
if (!cacheUrl.startsWith("http")) {
cacheUrl = new URL(cacheUrl, document.URL).href;
}
await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu());
}
/**
* Initialize the pipeline
*
* @param schedulerConstUrl The scheduler constant.
* @param tokenizerName The name of the tokenizer.
*/
async #asyncInitPipeline(schedulerConstUrl, tokenizerName) {
if (this.tvm == undefined) {
throw Error("asyncInitTVM is not called");
}
if (this.pipeline !== undefined) return;
var schedulerConst = []
for (let i = 0; i < schedulerConstUrl.length; ++i) {
schedulerConst.push(await (await fetch(schedulerConstUrl[i])).json())
}
const tokenizer = await tvmjsGlobalEnv.getTokenizer(tokenizerName);
this.pipeline = this.tvm.withNewScope(() => {
return new StableDiffusionPipeline(this.tvm, tokenizer, schedulerConst, this.tvm.cacheMetadata);
});
await this.pipeline.asyncLoadWebGPUPiplines();
}
/**
* Async initialize config
*/
async #asyncInitConfig() {
if (this.config !== undefined) return;
this.config = await (await fetch("stable-diffusion-config.json")).json();
}
/**
* Function to create progress callback tracker.
* @returns A progress callback tracker.
*/
#getProgressCallback() {
const tstart = performance.now();
function progressCallback(stage, counter, numSteps, totalNumSteps) {
const timeElapsed = (performance.now() - tstart) / 1000;
let text = "Generating ... at stage " + stage;
if (stage == "unet") {
counter += 1;
text += " step [" + counter + "/" + numSteps + "]"
}
if (stage == "vae") {
counter = totalNumSteps;
}
text += ", " + Math.ceil(timeElapsed) + " secs elapsed.";
document.getElementById("progress-tracker-label").innerHTML = text;
document.getElementById("progress-tracker-progress").value = (counter / totalNumSteps) * 100;
}
return progressCallback;
}
/**
* Async initialize instance.
*/
async asyncInit() {
if (this.pipeline !== undefined) return;
await this.#asyncInitConfig();
await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl);
await this.#asyncInitPipeline(this.config.schedulerConstUrl, this.config.tokenizer);
}
/**
* Async initialize
*
* @param tvm The tvm instance.
*/
async asyncInitOnRPCServerLoad(tvmInstance) {
if (this.tvm !== undefined) {
throw Error("Cannot reuse a loaded instance for rpc");
}
this.tvm = tvmInstance;
this.tvm.beginScope();
this.tvm.registerAsyncServerFunc("generate", async (prompt, schedulerId, vaeCycle) => {
document.getElementById("inputPrompt").value = prompt;
const negPrompt = "";
document.getElementById("negativePrompt").value = "";
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle);
});
this.tvm.registerAsyncServerFunc("clearCanvas", async () => {
this.tvm.clearCanvas();
});
this.tvm.registerAsyncServerFunc("showImage", async (data) => {
this.tvm.showImage(data);
});
this.tvm.endScope();
}
/**
* Run generate
*/
async generate() {
if (this.requestInProgress) {
this.logger("Request in progress, generate request ignored");
return;
}
this.requestInProgress = true;
try {
await this.asyncInit();
const prompt = document.getElementById("inputPrompt").value;
const negPrompt = document.getElementById("negativePrompt").value;
const schedulerId = document.getElementById("schedulerId").value;
const vaeCycle = document.getElementById("vaeCycle").value;
await this.pipeline.generate(prompt, negPrompt, this.#getProgressCallback(), schedulerId, vaeCycle);
} catch (err) {
this.logger("Generate error, " + err.toString());
console.log(err.stack);
this.reset();
}
this.requestInProgress = false;
}
/**
* Reset the instance;
*/
reset() {
this.tvm = undefined;
if (this.pipeline !== undefined) {
this.pipeline.dispose();
}
this.pipeline = undefined;
}
}
localStableDiffusionInst = new StableDiffusionInstance();
tvmjsGlobalEnv.asyncOnGenerate = async function () {
await localStableDiffusionInst.generate();
};
tvmjsGlobalEnv.asyncOnRPCServerLoad = async function (tvm) {
const inst = new StableDiffusionInstance();
await inst.asyncInitOnRPCServerLoad(tvm);
};