diff --git a/editor/src/messages/portfolio/document/properties_panel/utility_functions.rs b/editor/src/messages/portfolio/document/properties_panel/utility_functions.rs index f0bbe900..1f8cca79 100644 --- a/editor/src/messages/portfolio/document/properties_panel/utility_functions.rs +++ b/editor/src/messages/portfolio/document/properties_panel/utility_functions.rs @@ -615,7 +615,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi }))], ImaginateStatus::Generating => vec![WidgetHolder::new(Widget::TextButton(TextButton { label: "Terminate".into(), - tooltip: "Cancel in-progress image generation and keep the latest progress".into(), + tooltip: "Cancel the in-progress image generation and keep the latest progress".into(), on_update: WidgetCallback::new(|_| DocumentMessage::ImaginateTerminate.into()), ..Default::default() }))], @@ -745,7 +745,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi }, LayoutGroup::Row { widgets: { - let tooltip = "Number of iterations to improve the image generation quality, with diminishing returns around 40".to_string(); + let tooltip = "Number of iterations to improve the image generation quality, with diminishing returns around 40 when using the Euler A sampling method".to_string(); vec![ WidgetHolder::new(Widget::TextLabel(TextLabel { value: "Sampling Steps".into(), @@ -774,13 +774,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi }, LayoutGroup::Row { widgets: { - let tooltip = " - Algorithm used to generate the image during each sampling step.\n\ - \n\ - 'DPM Fast' and 'DPM Adaptive' do not support live refreshing updates. - " - .trim() - .to_string(); + let tooltip = "Algorithm used to generate the image during each sampling step".to_string(); let sampling_methods = ImaginateSamplingMethod::list(); let mut entries = Vec::with_capacity(sampling_methods.len()); diff --git a/frontend/src/components/floating-menus/ColorPicker.vue b/frontend/src/components/floating-menus/ColorPicker.vue index f870e07c..ff187592 100644 --- a/frontend/src/components/floating-menus/ColorPicker.vue +++ b/frontend/src/components/floating-menus/ColorPicker.vue @@ -584,12 +584,14 @@ export default defineComponent({ }, async activateEyedropperSample() { // TODO: Replace this temporary solution that only works in Chromium-based browsers with the custom color sampler used by the Eyedropper tool + // eslint-disable-next-line @typescript-eslint/no-explicit-any if (!(window as any).EyeDropper) { this.editor.instance.eyedropperSampleForColorPicker(); return; } try { + // eslint-disable-next-line @typescript-eslint/no-explicit-any const result = await new (window as any).EyeDropper().open(); this.setColorCode(result.sRGBHex); } catch { diff --git a/frontend/src/utility-functions/imaginate.ts b/frontend/src/utility-functions/imaginate.ts index 6d5098d7..84bcc657 100644 --- a/frontend/src/utility-functions/imaginate.ts +++ b/frontend/src/utility-functions/imaginate.ts @@ -1,14 +1,15 @@ -import { escapeJSON } from "@/utility-functions/escape"; +/* eslint-disable camelcase */ + +// import { escapeJSON } from "@/utility-functions/escape"; import { blobToBase64 } from "@/utility-functions/files"; import { type RequestResult, requestWithUploadDownloadProgress } from "@/utility-functions/network"; -import { stripIndents } from "@/utility-functions/strip-indents"; import { type Editor } from "@/wasm-communication/editor"; import type { XY } from "@/wasm-communication/messages"; import { type ImaginateGenerationParameters } from "@/wasm-communication/messages"; const MAX_POLLING_RETRIES = 4; const SERVER_STATUS_CHECK_TIMEOUT = 5000; -const SAMPLING_MODES_POLLING_UNSUPPORTED = ["DPM fast", "DPM adaptive"]; +const PROGRESS_EVERY_N_STEPS = 5; let timer: NodeJS.Timeout | undefined; let terminated = false; @@ -50,7 +51,7 @@ export async function imaginateGenerate( // Begin polling for updates to the in-progress image generation at the specified interval // Don't poll if the chosen interval is 0, or if the chosen sampling method does not support polling - if (refreshFrequency > 0 && !SAMPLING_MODES_POLLING_UNSUPPORTED.includes(parameters.samplingMethod)) { + if (refreshFrequency > 0) { const interval = Math.max(refreshFrequency * 1000, 500); scheduleNextPollingUpdate(interval, Date.now(), 0, editor, hostname, documentId, layerPath, parameters.resolution); } @@ -62,9 +63,9 @@ export async function imaginateGenerate( } // Extract the final image from the response and convert it to a data blob - // Highly unstable API - const base64 = JSON.parse(body)?.data[0]?.[0] as string | undefined; - if (typeof base64 !== "string" || !base64.startsWith("data:image/png;base64,")) throw new Error("Could not read final image result from server response"); + const base64Data = JSON.parse(body)?.images?.[0] as string | undefined; + const base64 = typeof base64Data === "string" && base64Data.length > 0 ? `data:image/png;base64,${base64Data}` : undefined; + if (!base64) throw new Error("Could not read final image result from server response"); const blob = await (await fetch(base64)).blob(); // Send the backend an updated status @@ -156,9 +157,11 @@ function scheduleNextPollingUpdate( try { const [blob, percentComplete] = await pollImage(hostname); + + // After waiting for the polling result back from the server, if during that intervening time the user has terminated the generation, exit so we don't overwrite that terminated status if (terminated) return; - preloadAndSetImaginateBlobURL(editor, blob, documentId, layerPath, resolution.x, resolution.y); + if (blob) preloadAndSetImaginateBlobURL(editor, blob, documentId, layerPath, resolution.x, resolution.y); editor.instance.setImaginateGeneratingStatus(documentId, layerPath, percentComplete, "Generating"); scheduleNextPollingUpdate(interval, nextTimeoutBegan, 0, editor, hostname, documentId, layerPath, resolution); @@ -178,46 +181,30 @@ function scheduleNextPollingUpdate( } // API COMMUNICATION FUNCTIONS -// These are highly unstable APIs that will need to be updated very frequently, so we currently assume usage of this exact commit from the server: -// https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/7d6042b908c064774ee10961309d396eabdc6c4a -function endpoint(hostname: string): string { - // Highly unstable API - return `${hostname}api/predict/`; -} +async function pollImage(hostname: string): Promise<[Blob | undefined, number]> { + // Fetch the percent progress and in-progress image from the API + const result = await fetch(`${hostname}sdapi/v1/progress`, { signal: pollingAbortController.signal, method: "GET" }); + const { current_image, progress } = await result.json(); -async function pollImage(hostname: string): Promise<[Blob, number]> { - // Highly unstable API - const result = await fetch(endpoint(hostname), { - signal: pollingAbortController.signal, - headers: { - accept: "*/*", - "accept-language": "en-US,en;q=0.9", - "content-type": "application/json", - }, - referrer: hostname, - referrerPolicy: "strict-origin-when-cross-origin", - body: stripIndents` - { - "fn_index":3, - "data":[], - "session_hash":"0000000000" - }`, - method: "POST", - mode: "cors", - credentials: "omit", - }); - const json = await result.json(); - // Highly unstable API - const percentComplete = Math.abs(Number(json.data[0].match(/(?<="width:).*?(?=%")/)[0])); // The API sometimes returns negative values presumably due to a bug - // Highly unstable API - const base64 = json.data[2]; + // Convert to a usable format + const progressPercent = progress * 100; + const base64 = typeof current_image === "string" && current_image.length > 0 ? `data:image/png;base64,${current_image}` : undefined; - if (typeof base64 !== "string" || !base64.startsWith("data:image/png;base64,")) return Promise.reject(); + // Deal with a missing image + if (!base64) { + // The image is not ready yet (because it's only had a few samples since generation began), but we do have a progress percentage + if (!Number.isNaN(progressPercent) && progressPercent >= 0 && progressPercent <= 100) { + return [undefined, progressPercent]; + } + // Something else is wrong and the image wasn't provided as expected + return Promise.reject(); + } + + // The image was provided so we turn it into a data blob const blob = await (await fetch(base64)).blob(); - - return [blob, percentComplete]; + return [blob, progressPercent]; } async function generate( @@ -231,138 +218,83 @@ async function generate( xhr?: XMLHttpRequest; }> { let body; + let endpoint; if (image === undefined || parameters.denoisingStrength === undefined) { - // Highly unstable API - body = stripIndents` - { - "fn_index":13, - "data":[ - "${escapeJSON(parameters.prompt)}", - "${escapeJSON(parameters.negativePrompt)}", - "None", - "None", - ${parameters.samples}, - "${parameters.samplingMethod}", - ${parameters.restoreFaces}, - ${parameters.tiling}, - 1, - 1, - ${parameters.cfgScale}, - ${parameters.seed}, - -1, - 0, - 0, - 0, - false, - ${parameters.resolution.y}, - ${parameters.resolution.x}, - false, - 0.7, - 0, - 0, - "None", - false, - false, - null, - "", - "Seed", - "", - "Nothing", - "", - true, - false, - false, - null, - "" - ], - "session_hash":"0000000000" - }`; + endpoint = `${hostname}sdapi/v1/txt2img`; + + body = { + // enable_hr: false, + // denoising_strength: 0, + // firstphase_width: 0, + // firstphase_height: 0, + prompt: parameters.prompt, + // styles: [], + seed: Number(parameters.seed), + // subseed: -1, + // subseed_strength: 0, + // seed_resize_from_h: -1, + // seed_resize_from_w: -1, + // batch_size: 1, + // n_iter: 1, + steps: parameters.samples, + cfg_scale: parameters.cfgScale, + width: parameters.resolution.x, + height: parameters.resolution.y, + restore_faces: parameters.restoreFaces, + tiling: parameters.tiling, + negative_prompt: parameters.negativePrompt, + // eta: 0, + // s_churn: 0, + // s_tmax: 0, + // s_tmin: 0, + // s_noise: 1, + override_settings: { + show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS, + }, + sampler_index: parameters.samplingMethod, + }; } else { const sourceImageBase64 = await blobToBase64(image); - // Highly unstable API - body = stripIndents` - { - "fn_index":33, - "data":[ - 0, - "${escapeJSON(parameters.prompt)}", - "${escapeJSON(parameters.negativePrompt)}", - "None", - "None", - "${sourceImageBase64}", - null, - null, - null, - "Draw mask", - ${parameters.samples}, - "${parameters.samplingMethod}", - 4, - "fill", - ${parameters.restoreFaces}, - ${parameters.tiling}, - 1, - 1, - ${parameters.cfgScale}, - ${parameters.denoisingStrength}, - ${parameters.seed}, - -1, - 0, - 0, - 0, - false, - ${parameters.resolution.y}, - ${parameters.resolution.x}, - "Just resize", - false, - 32, - "Inpaint masked", - "", - "", - "None", - "", - true, - true, - "", - "", - true, - 50, - true, - 1, - 0, - false, - 4, - 1, - "", - 128, - 8, - ["left","right","up","down"], - 1, - 0.05, - 128, - 4, - "fill", - ["left","right","up","down"], - false, - false, - null, - "", - "", - 64, - "None", - "Seed", - "", - "Nothing", - "", - true, - false, - false, - null, - "", - "" - ], - "session_hash":"0000000000" - }`; + endpoint = `${hostname}sdapi/v1/img2img`; + + body = { + init_images: [sourceImageBase64], + // resize_mode: 0, + denoising_strength: parameters.denoisingStrength, + // mask: "", + // mask_blur: 4, + // inpainting_fill: 0, + // inpaint_full_res: true, + // inpaint_full_res_padding: 0, + // inpainting_mask_invert: 0, + prompt: parameters.prompt, + // styles: [], + seed: Number(parameters.seed), + // subseed: -1, + // subseed_strength: 0, + // seed_resize_from_h: -1, + // seed_resize_from_w: -1, + // batch_size: 1, + // n_iter: 1, + steps: parameters.samples, + cfg_scale: parameters.cfgScale, + width: parameters.resolution.x, + height: parameters.resolution.y, + restore_faces: parameters.restoreFaces, + tiling: parameters.tiling, + negative_prompt: parameters.negativePrompt, + // eta: 0, + // s_churn: 0, + // s_tmax: 0, + // s_tmin: 0, + // s_noise: 1, + override_settings: { + show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS, + }, + sampler_index: parameters.samplingMethod, + // include_init_images: false, + }; } // Prepare a promise that will resolve after the outbound request upload is complete @@ -381,7 +313,7 @@ async function generate( uploadedResolve(); } }; - const [result, xhr] = requestWithUploadDownloadProgress(endpoint(hostname), "POST", body, uploadProgress, abortAndResetPolling); + const [result, xhr] = requestWithUploadDownloadProgress(endpoint, "POST", JSON.stringify(body), uploadProgress, abortAndResetPolling); result.catch(() => uploadedReject()); // Return the promise that resolves when the request upload is complete, the promise that resolves when the response download is complete, and the XHR so it can be aborted @@ -389,26 +321,7 @@ async function generate( } async function terminate(hostname: string): Promise { - const body = stripIndents` - { - "fn_index":2, - "data":[], - "session_hash":"0000000000" - }`; - - await fetch(endpoint(hostname), { - headers: { - accept: "*/*", - "accept-language": "en-US,en;q=0.9", - "content-type": "application/json", - }, - referrer: hostname, - referrerPolicy: "strict-origin-when-cross-origin", - body, - method: "POST", - mode: "cors", - credentials: "omit", - }); + await fetch(`${hostname}sdapi/v1/interrupt`, { method: "POST" }); } async function checkConnection(hostname: string): Promise { @@ -417,33 +330,18 @@ async function checkConnection(hostname: string): Promise { const timeout = setTimeout(() => statusAbortController.abort(), SERVER_STATUS_CHECK_TIMEOUT); - const body = stripIndents` - { - "fn_index":100, - "data":[], - "session_hash":"0000000000" - }`; - try { - await fetch(endpoint(hostname), { - signal: statusAbortController.signal, - headers: { - accept: "*/*", - "accept-language": "en-US,en;q=0.9", - "content-type": "application/json", - }, - referrer: hostname, - referrerPolicy: "strict-origin-when-cross-origin", - body, - method: "POST", - mode: "cors", - credentials: "omit", - }); + // Intentionally misuse this API endpoint by using it just to check for a code 200 response, regardless of what the result is + const { status } = await fetch(`${hostname}sdapi/v1/progress?skip_current_image=true`, { signal: statusAbortController.signal, method: "GET" }); - clearTimeout(timeout); - - return true; - } catch (_) { - return false; + // This code means the server has indeed responded and the endpoint exists (otherwise it would be 404) + if (status === 200) { + clearTimeout(timeout); + return true; + } + } catch { + // Do nothing here } + + return false; }