Update the Imaginate image generation API (#836)
* Update the Imaginate feature server API * Change connection status strategy to prevent console errors * Possible CORS fix? Maybe revert. * Update to the final API and fix bugs
This commit is contained in:
parent
9d56e86203
commit
5be59f7fce
|
|
@ -615,7 +615,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi
|
|||
}))],
|
||||
ImaginateStatus::Generating => vec![WidgetHolder::new(Widget::TextButton(TextButton {
|
||||
label: "Terminate".into(),
|
||||
tooltip: "Cancel in-progress image generation and keep the latest progress".into(),
|
||||
tooltip: "Cancel the in-progress image generation and keep the latest progress".into(),
|
||||
on_update: WidgetCallback::new(|_| DocumentMessage::ImaginateTerminate.into()),
|
||||
..Default::default()
|
||||
}))],
|
||||
|
|
@ -745,7 +745,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi
|
|||
},
|
||||
LayoutGroup::Row {
|
||||
widgets: {
|
||||
let tooltip = "Number of iterations to improve the image generation quality, with diminishing returns around 40".to_string();
|
||||
let tooltip = "Number of iterations to improve the image generation quality, with diminishing returns around 40 when using the Euler A sampling method".to_string();
|
||||
vec![
|
||||
WidgetHolder::new(Widget::TextLabel(TextLabel {
|
||||
value: "Sampling Steps".into(),
|
||||
|
|
@ -774,13 +774,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi
|
|||
},
|
||||
LayoutGroup::Row {
|
||||
widgets: {
|
||||
let tooltip = "
|
||||
Algorithm used to generate the image during each sampling step.\n\
|
||||
\n\
|
||||
'DPM Fast' and 'DPM Adaptive' do not support live refreshing updates.
|
||||
"
|
||||
.trim()
|
||||
.to_string();
|
||||
let tooltip = "Algorithm used to generate the image during each sampling step".to_string();
|
||||
|
||||
let sampling_methods = ImaginateSamplingMethod::list();
|
||||
let mut entries = Vec::with_capacity(sampling_methods.len());
|
||||
|
|
|
|||
|
|
@ -584,12 +584,14 @@ export default defineComponent({
|
|||
},
|
||||
async activateEyedropperSample() {
|
||||
// TODO: Replace this temporary solution that only works in Chromium-based browsers with the custom color sampler used by the Eyedropper tool
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
if (!(window as any).EyeDropper) {
|
||||
this.editor.instance.eyedropperSampleForColorPicker();
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const result = await new (window as any).EyeDropper().open();
|
||||
this.setColorCode(result.sRGBHex);
|
||||
} catch {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import { escapeJSON } from "@/utility-functions/escape";
|
||||
/* eslint-disable camelcase */
|
||||
|
||||
// import { escapeJSON } from "@/utility-functions/escape";
|
||||
import { blobToBase64 } from "@/utility-functions/files";
|
||||
import { type RequestResult, requestWithUploadDownloadProgress } from "@/utility-functions/network";
|
||||
import { stripIndents } from "@/utility-functions/strip-indents";
|
||||
import { type Editor } from "@/wasm-communication/editor";
|
||||
import type { XY } from "@/wasm-communication/messages";
|
||||
import { type ImaginateGenerationParameters } from "@/wasm-communication/messages";
|
||||
|
||||
const MAX_POLLING_RETRIES = 4;
|
||||
const SERVER_STATUS_CHECK_TIMEOUT = 5000;
|
||||
const SAMPLING_MODES_POLLING_UNSUPPORTED = ["DPM fast", "DPM adaptive"];
|
||||
const PROGRESS_EVERY_N_STEPS = 5;
|
||||
|
||||
let timer: NodeJS.Timeout | undefined;
|
||||
let terminated = false;
|
||||
|
|
@ -50,7 +51,7 @@ export async function imaginateGenerate(
|
|||
|
||||
// Begin polling for updates to the in-progress image generation at the specified interval
|
||||
// Don't poll if the chosen interval is 0, or if the chosen sampling method does not support polling
|
||||
if (refreshFrequency > 0 && !SAMPLING_MODES_POLLING_UNSUPPORTED.includes(parameters.samplingMethod)) {
|
||||
if (refreshFrequency > 0) {
|
||||
const interval = Math.max(refreshFrequency * 1000, 500);
|
||||
scheduleNextPollingUpdate(interval, Date.now(), 0, editor, hostname, documentId, layerPath, parameters.resolution);
|
||||
}
|
||||
|
|
@ -62,9 +63,9 @@ export async function imaginateGenerate(
|
|||
}
|
||||
|
||||
// Extract the final image from the response and convert it to a data blob
|
||||
// Highly unstable API
|
||||
const base64 = JSON.parse(body)?.data[0]?.[0] as string | undefined;
|
||||
if (typeof base64 !== "string" || !base64.startsWith("data:image/png;base64,")) throw new Error("Could not read final image result from server response");
|
||||
const base64Data = JSON.parse(body)?.images?.[0] as string | undefined;
|
||||
const base64 = typeof base64Data === "string" && base64Data.length > 0 ? `data:image/png;base64,${base64Data}` : undefined;
|
||||
if (!base64) throw new Error("Could not read final image result from server response");
|
||||
const blob = await (await fetch(base64)).blob();
|
||||
|
||||
// Send the backend an updated status
|
||||
|
|
@ -156,9 +157,11 @@ function scheduleNextPollingUpdate(
|
|||
|
||||
try {
|
||||
const [blob, percentComplete] = await pollImage(hostname);
|
||||
|
||||
// After waiting for the polling result back from the server, if during that intervening time the user has terminated the generation, exit so we don't overwrite that terminated status
|
||||
if (terminated) return;
|
||||
|
||||
preloadAndSetImaginateBlobURL(editor, blob, documentId, layerPath, resolution.x, resolution.y);
|
||||
if (blob) preloadAndSetImaginateBlobURL(editor, blob, documentId, layerPath, resolution.x, resolution.y);
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, percentComplete, "Generating");
|
||||
|
||||
scheduleNextPollingUpdate(interval, nextTimeoutBegan, 0, editor, hostname, documentId, layerPath, resolution);
|
||||
|
|
@ -178,46 +181,30 @@ function scheduleNextPollingUpdate(
|
|||
}
|
||||
|
||||
// API COMMUNICATION FUNCTIONS
|
||||
// These are highly unstable APIs that will need to be updated very frequently, so we currently assume usage of this exact commit from the server:
|
||||
// https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/7d6042b908c064774ee10961309d396eabdc6c4a
|
||||
|
||||
function endpoint(hostname: string): string {
|
||||
// Highly unstable API
|
||||
return `${hostname}api/predict/`;
|
||||
async function pollImage(hostname: string): Promise<[Blob | undefined, number]> {
|
||||
// Fetch the percent progress and in-progress image from the API
|
||||
const result = await fetch(`${hostname}sdapi/v1/progress`, { signal: pollingAbortController.signal, method: "GET" });
|
||||
const { current_image, progress } = await result.json();
|
||||
|
||||
// Convert to a usable format
|
||||
const progressPercent = progress * 100;
|
||||
const base64 = typeof current_image === "string" && current_image.length > 0 ? `data:image/png;base64,${current_image}` : undefined;
|
||||
|
||||
// Deal with a missing image
|
||||
if (!base64) {
|
||||
// The image is not ready yet (because it's only had a few samples since generation began), but we do have a progress percentage
|
||||
if (!Number.isNaN(progressPercent) && progressPercent >= 0 && progressPercent <= 100) {
|
||||
return [undefined, progressPercent];
|
||||
}
|
||||
|
||||
async function pollImage(hostname: string): Promise<[Blob, number]> {
|
||||
// Highly unstable API
|
||||
const result = await fetch(endpoint(hostname), {
|
||||
signal: pollingAbortController.signal,
|
||||
headers: {
|
||||
accept: "*/*",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
referrer: hostname,
|
||||
referrerPolicy: "strict-origin-when-cross-origin",
|
||||
body: stripIndents`
|
||||
{
|
||||
"fn_index":3,
|
||||
"data":[],
|
||||
"session_hash":"0000000000"
|
||||
}`,
|
||||
method: "POST",
|
||||
mode: "cors",
|
||||
credentials: "omit",
|
||||
});
|
||||
const json = await result.json();
|
||||
// Highly unstable API
|
||||
const percentComplete = Math.abs(Number(json.data[0].match(/(?<="width:).*?(?=%")/)[0])); // The API sometimes returns negative values presumably due to a bug
|
||||
// Highly unstable API
|
||||
const base64 = json.data[2];
|
||||
|
||||
if (typeof base64 !== "string" || !base64.startsWith("data:image/png;base64,")) return Promise.reject();
|
||||
// Something else is wrong and the image wasn't provided as expected
|
||||
return Promise.reject();
|
||||
}
|
||||
|
||||
// The image was provided so we turn it into a data blob
|
||||
const blob = await (await fetch(base64)).blob();
|
||||
|
||||
return [blob, percentComplete];
|
||||
return [blob, progressPercent];
|
||||
}
|
||||
|
||||
async function generate(
|
||||
|
|
@ -231,138 +218,83 @@ async function generate(
|
|||
xhr?: XMLHttpRequest;
|
||||
}> {
|
||||
let body;
|
||||
let endpoint;
|
||||
if (image === undefined || parameters.denoisingStrength === undefined) {
|
||||
// Highly unstable API
|
||||
body = stripIndents`
|
||||
{
|
||||
"fn_index":13,
|
||||
"data":[
|
||||
"${escapeJSON(parameters.prompt)}",
|
||||
"${escapeJSON(parameters.negativePrompt)}",
|
||||
"None",
|
||||
"None",
|
||||
${parameters.samples},
|
||||
"${parameters.samplingMethod}",
|
||||
${parameters.restoreFaces},
|
||||
${parameters.tiling},
|
||||
1,
|
||||
1,
|
||||
${parameters.cfgScale},
|
||||
${parameters.seed},
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
${parameters.resolution.y},
|
||||
${parameters.resolution.x},
|
||||
false,
|
||||
0.7,
|
||||
0,
|
||||
0,
|
||||
"None",
|
||||
false,
|
||||
false,
|
||||
null,
|
||||
"",
|
||||
"Seed",
|
||||
"",
|
||||
"Nothing",
|
||||
"",
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
null,
|
||||
""
|
||||
],
|
||||
"session_hash":"0000000000"
|
||||
}`;
|
||||
endpoint = `${hostname}sdapi/v1/txt2img`;
|
||||
|
||||
body = {
|
||||
// enable_hr: false,
|
||||
// denoising_strength: 0,
|
||||
// firstphase_width: 0,
|
||||
// firstphase_height: 0,
|
||||
prompt: parameters.prompt,
|
||||
// styles: [],
|
||||
seed: Number(parameters.seed),
|
||||
// subseed: -1,
|
||||
// subseed_strength: 0,
|
||||
// seed_resize_from_h: -1,
|
||||
// seed_resize_from_w: -1,
|
||||
// batch_size: 1,
|
||||
// n_iter: 1,
|
||||
steps: parameters.samples,
|
||||
cfg_scale: parameters.cfgScale,
|
||||
width: parameters.resolution.x,
|
||||
height: parameters.resolution.y,
|
||||
restore_faces: parameters.restoreFaces,
|
||||
tiling: parameters.tiling,
|
||||
negative_prompt: parameters.negativePrompt,
|
||||
// eta: 0,
|
||||
// s_churn: 0,
|
||||
// s_tmax: 0,
|
||||
// s_tmin: 0,
|
||||
// s_noise: 1,
|
||||
override_settings: {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
},
|
||||
sampler_index: parameters.samplingMethod,
|
||||
};
|
||||
} else {
|
||||
const sourceImageBase64 = await blobToBase64(image);
|
||||
|
||||
// Highly unstable API
|
||||
body = stripIndents`
|
||||
{
|
||||
"fn_index":33,
|
||||
"data":[
|
||||
0,
|
||||
"${escapeJSON(parameters.prompt)}",
|
||||
"${escapeJSON(parameters.negativePrompt)}",
|
||||
"None",
|
||||
"None",
|
||||
"${sourceImageBase64}",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
"Draw mask",
|
||||
${parameters.samples},
|
||||
"${parameters.samplingMethod}",
|
||||
4,
|
||||
"fill",
|
||||
${parameters.restoreFaces},
|
||||
${parameters.tiling},
|
||||
1,
|
||||
1,
|
||||
${parameters.cfgScale},
|
||||
${parameters.denoisingStrength},
|
||||
${parameters.seed},
|
||||
-1,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
${parameters.resolution.y},
|
||||
${parameters.resolution.x},
|
||||
"Just resize",
|
||||
false,
|
||||
32,
|
||||
"Inpaint masked",
|
||||
"",
|
||||
"",
|
||||
"None",
|
||||
"",
|
||||
true,
|
||||
true,
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
50,
|
||||
true,
|
||||
1,
|
||||
0,
|
||||
false,
|
||||
4,
|
||||
1,
|
||||
"",
|
||||
128,
|
||||
8,
|
||||
["left","right","up","down"],
|
||||
1,
|
||||
0.05,
|
||||
128,
|
||||
4,
|
||||
"fill",
|
||||
["left","right","up","down"],
|
||||
false,
|
||||
false,
|
||||
null,
|
||||
"",
|
||||
"",
|
||||
64,
|
||||
"None",
|
||||
"Seed",
|
||||
"",
|
||||
"Nothing",
|
||||
"",
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
null,
|
||||
"",
|
||||
""
|
||||
],
|
||||
"session_hash":"0000000000"
|
||||
}`;
|
||||
endpoint = `${hostname}sdapi/v1/img2img`;
|
||||
|
||||
body = {
|
||||
init_images: [sourceImageBase64],
|
||||
// resize_mode: 0,
|
||||
denoising_strength: parameters.denoisingStrength,
|
||||
// mask: "",
|
||||
// mask_blur: 4,
|
||||
// inpainting_fill: 0,
|
||||
// inpaint_full_res: true,
|
||||
// inpaint_full_res_padding: 0,
|
||||
// inpainting_mask_invert: 0,
|
||||
prompt: parameters.prompt,
|
||||
// styles: [],
|
||||
seed: Number(parameters.seed),
|
||||
// subseed: -1,
|
||||
// subseed_strength: 0,
|
||||
// seed_resize_from_h: -1,
|
||||
// seed_resize_from_w: -1,
|
||||
// batch_size: 1,
|
||||
// n_iter: 1,
|
||||
steps: parameters.samples,
|
||||
cfg_scale: parameters.cfgScale,
|
||||
width: parameters.resolution.x,
|
||||
height: parameters.resolution.y,
|
||||
restore_faces: parameters.restoreFaces,
|
||||
tiling: parameters.tiling,
|
||||
negative_prompt: parameters.negativePrompt,
|
||||
// eta: 0,
|
||||
// s_churn: 0,
|
||||
// s_tmax: 0,
|
||||
// s_tmin: 0,
|
||||
// s_noise: 1,
|
||||
override_settings: {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
},
|
||||
sampler_index: parameters.samplingMethod,
|
||||
// include_init_images: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Prepare a promise that will resolve after the outbound request upload is complete
|
||||
|
|
@ -381,7 +313,7 @@ async function generate(
|
|||
uploadedResolve();
|
||||
}
|
||||
};
|
||||
const [result, xhr] = requestWithUploadDownloadProgress(endpoint(hostname), "POST", body, uploadProgress, abortAndResetPolling);
|
||||
const [result, xhr] = requestWithUploadDownloadProgress(endpoint, "POST", JSON.stringify(body), uploadProgress, abortAndResetPolling);
|
||||
result.catch(() => uploadedReject());
|
||||
|
||||
// Return the promise that resolves when the request upload is complete, the promise that resolves when the response download is complete, and the XHR so it can be aborted
|
||||
|
|
@ -389,26 +321,7 @@ async function generate(
|
|||
}
|
||||
|
||||
async function terminate(hostname: string): Promise<void> {
|
||||
const body = stripIndents`
|
||||
{
|
||||
"fn_index":2,
|
||||
"data":[],
|
||||
"session_hash":"0000000000"
|
||||
}`;
|
||||
|
||||
await fetch(endpoint(hostname), {
|
||||
headers: {
|
||||
accept: "*/*",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
referrer: hostname,
|
||||
referrerPolicy: "strict-origin-when-cross-origin",
|
||||
body,
|
||||
method: "POST",
|
||||
mode: "cors",
|
||||
credentials: "omit",
|
||||
});
|
||||
await fetch(`${hostname}sdapi/v1/interrupt`, { method: "POST" });
|
||||
}
|
||||
|
||||
async function checkConnection(hostname: string): Promise<boolean> {
|
||||
|
|
@ -417,33 +330,18 @@ async function checkConnection(hostname: string): Promise<boolean> {
|
|||
|
||||
const timeout = setTimeout(() => statusAbortController.abort(), SERVER_STATUS_CHECK_TIMEOUT);
|
||||
|
||||
const body = stripIndents`
|
||||
{
|
||||
"fn_index":100,
|
||||
"data":[],
|
||||
"session_hash":"0000000000"
|
||||
}`;
|
||||
|
||||
try {
|
||||
await fetch(endpoint(hostname), {
|
||||
signal: statusAbortController.signal,
|
||||
headers: {
|
||||
accept: "*/*",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
referrer: hostname,
|
||||
referrerPolicy: "strict-origin-when-cross-origin",
|
||||
body,
|
||||
method: "POST",
|
||||
mode: "cors",
|
||||
credentials: "omit",
|
||||
});
|
||||
// Intentionally misuse this API endpoint by using it just to check for a code 200 response, regardless of what the result is
|
||||
const { status } = await fetch(`${hostname}sdapi/v1/progress?skip_current_image=true`, { signal: statusAbortController.signal, method: "GET" });
|
||||
|
||||
// This code means the server has indeed responded and the endpoint exists (otherwise it would be 404)
|
||||
if (status === 200) {
|
||||
clearTimeout(timeout);
|
||||
|
||||
return true;
|
||||
} catch (_) {
|
||||
}
|
||||
} catch {
|
||||
// Do nothing here
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue