diff --git a/xllm/core/framework/batch/dit_batch.cpp b/xllm/core/framework/batch/dit_batch.cpp index 1c09327a7..6b192b183 100644 --- a/xllm/core/framework/batch/dit_batch.cpp +++ b/xllm/core/framework/batch/dit_batch.cpp @@ -62,7 +62,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() { std::vector images; std::vector mask_images; - + std::vector control_images; std::vector latents; std::vector masked_image_latents; for (const auto& request : request_vec_) { @@ -96,6 +96,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() { images.emplace_back(input_params.image); mask_images.emplace_back(input_params.mask_image); + control_images.emplace_back(input_params.control_image); } if (input.prompts.size() != request_vec_.size()) { @@ -122,6 +123,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.mask_images = torch::stack(mask_images); } + if (check_tensors_valid(control_images)) { + input.control_image = torch::stack(control_images); + } + if (check_tensors_valid(prompt_embeds)) { input.prompt_embeds = torch::stack(prompt_embeds); } diff --git a/xllm/core/framework/request/dit_request_params.cpp b/xllm/core/framework/request/dit_request_params.cpp index 2d01537f6..00b1b5fb3 100644 --- a/xllm/core/framework/request/dit_request_params.cpp +++ b/xllm/core/framework/request/dit_request_params.cpp @@ -270,6 +270,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, } } + if (input.has_control_image()) { + std::string raw_bytes; + if (!butil::Base64Decode(input.control_image(), &raw_bytes)) { + LOG(ERROR) << "Base64 control_image decode failed"; + } + if (!decoder.decode(raw_bytes, input_params.control_image)) { + LOG(ERROR) << "Control_image decode failed."; + } + } + // generation params const auto& params = request.parameters(); if (params.has_size()) { diff --git a/xllm/core/framework/request/dit_request_state.h b/xllm/core/framework/request/dit_request_state.h index fab43cb17..7e69bcb6a 100644 --- a/xllm/core/framework/request/dit_request_state.h +++ b/xllm/core/framework/request/dit_request_state.h @@ -92,6 +92,8 @@ struct DiTInputParams { torch::Tensor image; + torch::Tensor control_image; + torch::Tensor mask_image; torch::Tensor masked_image_latent; diff --git a/xllm/core/runtime/dit_forward_params.h b/xllm/core/runtime/dit_forward_params.h index e52d4bb22..a96a5118a 100644 --- a/xllm/core/runtime/dit_forward_params.h +++ b/xllm/core/runtime/dit_forward_params.h @@ -84,6 +84,8 @@ struct DiTForwardInput { torch::Tensor mask_images; + torch::Tensor control_image; + torch::Tensor masked_image_latents; torch::Tensor prompt_embeds; diff --git a/xllm/models/dit/autoencoder_kl.h b/xllm/models/dit/autoencoder_kl.h index f57acc5b4..e1e948b0e 100644 --- a/xllm/models/dit/autoencoder_kl.h +++ b/xllm/models/dit/autoencoder_kl.h @@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module { bool do_normalize = true, bool do_binarize = false, bool do_convert_rgb = false, - bool do_convert_grayscale = false) { + bool do_convert_grayscale = false, + int64_t latent_channels = 4) { const auto& model_args = context.get_model_args(); + options_ = context.get_tensor_options(); scale_factor_ = 1 << model_args.block_out_channels().size(); - latent_channels_ = 4; + latent_channels_ = latent_channels; do_resize_ = do_resize; do_normalize_ = do_normalize; do_binarize_ = do_binarize; @@ -116,7 +118,6 @@ class VAEImageProcessorImpl : public torch::nn::Module { if (channel == latent_channels_) { return image; } - auto [target_h, target_w] = get_default_height_width(processed, height, width); if (do_resize_) { @@ -129,13 +130,12 @@ class VAEImageProcessorImpl : public torch::nn::Module { if (do_binarize_) { processed = (processed >= 0.5f).to(torch::kFloat32); } - processed = processed.to(image.dtype()); + processed = processed.to(options_); return processed; } torch::Tensor postprocess( const torch::Tensor& tensor, - const std::string& output_type = "pt", std::optional> do_denormalize = std::nullopt) { torch::Tensor processed = tensor.clone(); if (do_normalize_) { @@ -149,9 +149,6 @@ class VAEImageProcessorImpl : public torch::nn::Module { } } } - if (output_type == "np") { - return processed.permute({0, 2, 3, 1}).contiguous(); - } return processed; } @@ -202,6 +199,7 @@ class VAEImageProcessorImpl : public torch::nn::Module { bool do_binarize_ = false; bool do_convert_rgb_ = false; bool do_convert_grayscale_ = false; + torch::TensorOptions options_; }; TORCH_MODULE(VAEImageProcessor); diff --git a/xllm/models/dit/dit.h b/xllm/models/dit/dit.h index e9d9302a5..d333758d5 100644 --- a/xllm/models/dit/dit.h +++ b/xllm/models/dit/dit.h @@ -1436,7 +1436,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module { proj_out_->verify_loaded_weights(prefix + "proj_out."); } - int64_t in_channels() { return out_channels_; } + int64_t in_channels() { return in_channels_; } bool guidance_embeds() { return guidance_embeds_; } private: diff --git a/xllm/models/dit/pipeline_flux.h b/xllm/models/dit/pipeline_flux.h index dfaea1290..2003e1b9a 100644 --- a/xllm/models/dit/pipeline_flux.h +++ b/xllm/models/dit/pipeline_flux.h @@ -30,43 +30,35 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { const auto& model_args = context.get_model_args("vae"); options_ = context.get_tensor_options(); vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); - device_ = options_.device(); - dtype_ = options_.dtype().toScalarType(); vae_shift_factor_ = model_args.shift_factor(); vae_scaling_factor_ = model_args.scale_factor(); - default_sample_size_ = 128; - tokenizer_max_length_ = 77; // TODO: get from config file + tokenizer_max_length_ = + context.get_model_args("text_encoder").max_position_embeddings(); LOG(INFO) << "Initializing Flux pipeline..."; - vae_image_processor_ = VAEImageProcessor( - context.get_model_context("vae"), true, true, false, false, false); + vae_image_processor_ = VAEImageProcessor(context.get_model_context("vae"), + true, + true, + false, + false, + false, + model_args.latent_channels()); vae_ = VAE(context.get_model_context("vae")); - LOG(INFO) << "VAE initialized."; pos_embed_ = register_module( "pos_embed", - FluxPosEmbed(10000, + FluxPosEmbed(ROPE_SCALE_BASE, context.get_model_args("transformer").axes_dims_rope())); transformer_ = FluxDiTModel(context.get_model_context("transformer")); - LOG(INFO) << "DiT transformer initialized."; t5_ = T5EncoderModel(context.get_model_context("text_encoder_2")); - LOG(INFO) << "T5 initialized."; clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder")); - LOG(INFO) << "CLIP text model initialized."; scheduler_ = FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler")); - LOG(INFO) << "Flux pipeline initialized."; register_module("vae", vae_); - LOG(INFO) << "VAE registered."; register_module("vae_image_processor", vae_image_processor_); - LOG(INFO) << "VAE image processor registered."; register_module("transformer", transformer_); - LOG(INFO) << "DiT transformer registered."; register_module("t5", t5_); - LOG(INFO) << "T5 registered."; register_module("scheduler", scheduler_); - LOG(INFO) << "Scheduler registered."; register_module("clip_text_model", clip_text_model_); - LOG(INFO) << "CLIP text model registered."; } DiTForwardOutput forward(const DiTForwardInput& input) { @@ -104,21 +96,21 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { : std::nullopt; std::vector output = forward_( - prompts, // prompt - prompts_2, // prompt_2 - negative_prompts, // negative_prompt - negative_prompts_2, // negative_prompt_2 - generation_params.true_cfg_scale, // cfg scale - std::make_optional(generation_params.height), // height - std::make_optional(generation_params.width), // width - generation_params.num_inference_steps, // num_inference_steps - generation_params.guidance_scale, // guidance_scale - generation_params.num_images_per_prompt, // num_images_per_prompt - seed, // seed - latents, // latents - prompt_embeds, // prompt_embeds - negative_prompt_embeds, // negative_prompt_embeds - pooled_prompt_embeds, // pooled_prompt_embeds + prompts, // prompt + prompts_2, // prompt_2 + negative_prompts, // negative_prompt + negative_prompts_2, // negative_prompt_2 + generation_params.true_cfg_scale, // cfg scale + generation_params.height, // height + generation_params.width, // width + generation_params.num_inference_steps, // num_inference_steps + generation_params.guidance_scale, // guidance_scale + generation_params.num_images_per_prompt, // num_images_per_prompt + seed, // seed + latents, // latents + prompt_embeds, // prompt_embeds + negative_prompt_embeds, // negative_prompt_embeds + pooled_prompt_embeds, // pooled_prompt_embeds negative_pooled_prompt_embeds, // negative_pooled_prompt_embeds generation_params.max_sequence_length // max_sequence_length ); @@ -141,13 +133,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { LOG(INFO) << "Flux model components loaded, start to load weights to sub models"; transformer_->load_model(std::move(transformer_loader)); - transformer_->to(device_); + transformer_->to(options_.device()); vae_->load_model(std::move(vae_loader)); - vae_->to(device_); + vae_->to(options_.device()); t5_->load_model(std::move(t5_loader)); - t5_->to(device_); + t5_->to(options_.device()); clip_text_model_->load_model(std::move(clip_loader)); - clip_text_model_->to(device_); + clip_text_model_->to(options_.device()); tokenizer_ = tokenizer_loader->tokenizer(); tokenizer_2_ = tokenizer_2_loader->tokenizer(); } @@ -186,8 +178,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { std::optional> negative_prompt = std::nullopt, std::optional> negative_prompt_2 = std::nullopt, float true_cfg_scale = 1.0f, - std::optional height = std::nullopt, - std::optional width = std::nullopt, + int64_t height = 512, + int64_t width = 512, int64_t num_inference_steps = 28, float guidance_scale = 3.5f, int64_t num_images_per_prompt = 1, @@ -199,12 +191,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { std::optional negative_pooled_prompt_embeds = std::nullopt, int64_t max_sequence_length = 512) { torch::NoGradGuard no_grad; - int64_t actual_height = height.has_value() - ? height.value() - : default_sample_size_ * vae_scale_factor_; - int64_t actual_width = width.has_value() - ? width.value() - : default_sample_size_ * vae_scale_factor_; int64_t batch_size; if (prompt.has_value()) { batch_size = prompt.value().size(); @@ -244,8 +230,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { auto [prepared_latents, latent_image_ids] = prepare_latents(total_batch_size, num_channels_latents, - actual_height, - actual_width, + height, + width, seed.has_value() ? seed.value() : 42, latents); // prepare timestep @@ -263,7 +249,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { scheduler_->base_shift(), scheduler_->max_shift()); auto [timesteps, num_inference_steps_actual] = retrieve_timesteps( - scheduler_, num_inference_steps, device_, new_sigmas, mu); + scheduler_, num_inference_steps, options_.device(), new_sigmas, mu); int64_t num_warmup_steps = std::max(static_cast(timesteps.numel()) - num_inference_steps_actual * scheduler_->order(), @@ -272,7 +258,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { torch::Tensor guidance; if (transformer_->guidance_embeds()) { torch::TensorOptions options = - torch::dtype(torch::kFloat32).device(device_); + torch::dtype(torch::kFloat32).device(options_.device()); guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options); guidance = guidance.expand({prepared_latents.size(0)}); @@ -284,8 +270,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { auto [rot_emb1, rot_emb2] = pos_embed_->forward_cache(text_ids, latent_image_ids, - height.value() / (vae_scale_factor_ * 2), - width.value() / (vae_scale_factor_ * 2)); + height / (vae_scale_factor_ * 2), + width / (vae_scale_factor_ * 2)); torch::Tensor image_rotary_emb = torch::stack({rot_emb1, rot_emb2}, 0); for (int64_t i = 0; i < timesteps.numel(); ++i) { torch::Tensor t = timesteps[i].unsqueeze(0); @@ -326,13 +312,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { } torch::Tensor image; // Unpack latents - torch::Tensor unpacked_latents = unpack_latents( - prepared_latents, actual_height, actual_width, vae_scale_factor_); + torch::Tensor unpacked_latents = + unpack_latents(prepared_latents, height, width, vae_scale_factor_); unpacked_latents = (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_; - unpacked_latents = unpacked_latents.to(dtype_); + unpacked_latents = unpacked_latents.to(options_.dtype()); image = vae_->decode(unpacked_latents); - image = vae_image_processor_->postprocess(image, "pil"); + image = vae_image_processor_->postprocess(image); return std::vector{{image}}; } @@ -343,7 +329,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl { FluxDiTModel transformer_{nullptr}; float vae_scaling_factor_; float vae_shift_factor_; - int default_sample_size_; FluxPosEmbed pos_embed_{nullptr}; }; TORCH_MODULE(FluxPipeline); diff --git a/xllm/models/dit/pipeline_flux_base.h b/xllm/models/dit/pipeline_flux_base.h index b9f8442fe..4b3895586 100644 --- a/xllm/models/dit/pipeline_flux_base.h +++ b/xllm/models/dit/pipeline_flux_base.h @@ -36,6 +36,8 @@ limitations under the License. namespace xllm { +constexpr int64_t ROPE_SCALE_BASE = 10000; + float calculate_shift(int64_t image_seq_len, int64_t base_seq_len = 256, int64_t max_seq_len = 4096, @@ -213,9 +215,9 @@ class FluxPipelineBaseImpl : public torch::nn::Module { auto input_ids = torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong)) .view({batch_size, max_sequence_length}) - .to(device_); + .to(options_.device()); torch::Tensor prompt_embeds = t5_->forward(input_ids); - prompt_embeds = prompt_embeds.to(device_).to(dtype_); + prompt_embeds = prompt_embeds.to(options_); int64_t seq_len = prompt_embeds.size(1); prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt, 1}); prompt_embeds = @@ -244,10 +246,10 @@ class FluxPipelineBaseImpl : public torch::nn::Module { auto input_ids = torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong)) .view({batch_size, tokenizer_max_length_}) - .to(device_); + .to(options_.device()); auto encoder_output = clip_text_model_->forward(input_ids); torch::Tensor prompt_embeds = encoder_output; - prompt_embeds = prompt_embeds.to(device_).to(dtype_); + prompt_embeds = prompt_embeds.to(options_); prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt}); prompt_embeds = prompt_embeds.view({batch_size * num_images_per_prompt, -1}); @@ -281,8 +283,8 @@ class FluxPipelineBaseImpl : public torch::nn::Module { prompt_embeds = get_t5_prompt_embeds( prompt_2_list, num_images_per_prompt, max_sequence_length); } - torch::Tensor text_ids = torch::zeros({prompt_embeds.value().size(1), 3}, - torch::device(device_).dtype(dtype_)); + torch::Tensor text_ids = + torch::zeros({prompt_embeds.value().size(1), 3}, options_); return std::make_tuple(prompt_embeds.value(), pooled_prompt_embeds.has_value() diff --git a/xllm/models/dit/pipeline_flux_control.h b/xllm/models/dit/pipeline_flux_control.h new file mode 100644 index 000000000..2521478aa --- /dev/null +++ b/xllm/models/dit/pipeline_flux_control.h @@ -0,0 +1,356 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include "core/layers/pos_embedding.h" +#include "core/layers/rotary_embedding.h" +#include "dit.h" +#include "pipeline_flux_base.h" +// pipeline_flux_control compatible with huggingface weights +// ref to: +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_control.py + +namespace xllm { + +class FluxControlPipelineImpl : public FluxPipelineBaseImpl { + public: + FluxControlPipelineImpl(const DiTModelContext& context) { + auto model_args = context.get_model_args("vae"); + options_ = context.get_tensor_options(); + vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); + vae_shift_factor_ = model_args.shift_factor(); + vae_scaling_factor_ = model_args.scale_factor(); + latent_channels_ = model_args.latent_channels(); + tokenizer_max_length_ = + context.get_model_args("text_encoder").max_position_embeddings(); + LOG(INFO) << "Initializing FluxControl pipeline..."; + image_processor_ = VAEImageProcessor(context.get_model_context("vae"), + true, + true, + false, + false, + false, + latent_channels_); + vae_ = VAE(context.get_model_context("vae")); + pos_embed_ = register_module( + "pos_embed", + FluxPosEmbed(ROPE_SCALE_BASE, + context.get_model_args("transformer").axes_dims_rope())); + transformer_ = FluxDiTModel(context.get_model_context("transformer")); + t5_ = T5EncoderModel(context.get_model_context("text_encoder_2")); + clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder")); + scheduler_ = + FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler")); + register_module("vae", vae_); + register_module("vae_image_processor", image_processor_); + register_module("transformer", transformer_); + register_module("t5", t5_); + register_module("scheduler", scheduler_); + register_module("clip_text_model", clip_text_model_); + } + + DiTForwardOutput forward(const DiTForwardInput& input) { + const auto& generation_params = input.generation_params; + int64_t height = generation_params.height; + int64_t width = generation_params.width; + auto seed = generation_params.seed > 0 ? generation_params.seed : 42; + auto prompts = std::make_optional(input.prompts); + auto prompts_2 = input.prompts_2.empty() + ? std::nullopt + : std::make_optional(input.prompts_2); + + auto control_image = input.control_image; + + auto latents = input.latents.defined() ? std::make_optional(input.latents) + : std::nullopt; + auto prompt_embeds = input.prompt_embeds.defined() + ? std::make_optional(input.prompt_embeds) + : std::nullopt; + auto pooled_prompt_embeds = + input.pooled_prompt_embeds.defined() + ? std::make_optional(input.pooled_prompt_embeds) + : std::nullopt; + + std::vector output = + forward_(prompts, + prompts_2, + control_image, + height, + width, + generation_params.strength, + generation_params.num_inference_steps, + generation_params.guidance_scale, + generation_params.num_images_per_prompt, + seed, + latents, + prompt_embeds, + pooled_prompt_embeds, + generation_params.max_sequence_length); + + DiTForwardOutput out; + out.tensors = torch::chunk(output[0], output[0].size(0), 0); + LOG(INFO) << "Output tensor chunks size: " << out.tensors.size(); + return out; + } + + void load_model(std::unique_ptr loader) { + LOG(INFO) << "FluxControlPipeline loading model from" + << loader->model_root_path(); + std::string model_path = loader->model_root_path(); + auto transformer_loader = loader->take_component_loader("transformer"); + auto vae_loader = loader->take_component_loader("vae"); + auto t5_loader = loader->take_component_loader("text_encoder_2"); + auto clip_loader = loader->take_component_loader("text_encoder"); + auto tokenizer_loader = loader->take_component_loader("tokenizer"); + auto tokenizer_2_loader = loader->take_component_loader("tokenizer_2"); + LOG(INFO) + << "FluxControl model components loaded, start to load weights to " + "sub models"; + transformer_->load_model(std::move(transformer_loader)); + transformer_->to(options_.device()); + vae_->load_model(std::move(vae_loader)); + vae_->to(options_.device()); + t5_->load_model(std::move(t5_loader)); + t5_->to(options_.device()); + clip_text_model_->load_model(std::move(clip_loader)); + clip_text_model_->to(options_.device()); + tokenizer_ = tokenizer_loader->tokenizer(); + tokenizer_2_ = tokenizer_2_loader->tokenizer(); + } + + private: + torch::Tensor encode_vae_image(const torch::Tensor& image, int64_t seed) { + torch::Tensor latents = vae_->encode(image, seed); + latents = (latents - vae_shift_factor_) * vae_scaling_factor_; + return latents; + } + + std::pair get_timesteps(int64_t num_inference_steps, + float strength) { + int64_t init_timestep = + std::min(static_cast(num_inference_steps * strength), + num_inference_steps); + + int64_t t_start = std::max(num_inference_steps - init_timestep, int64_t(0)); + int64_t start_idx = t_start * scheduler_->order(); + auto timesteps = scheduler_->timesteps().slice(0, start_idx).to(options_); + scheduler_->set_begin_index(start_idx); + return {timesteps, num_inference_steps - t_start}; + } + + std::pair prepare_latents( + int64_t batch_size, + int64_t num_channels_latents, + int64_t height, + int64_t width, + int64_t seed, + std::optional latents = std::nullopt) { + int64_t adjusted_height = 2 * (height / (vae_scale_factor_ * 2)); + int64_t adjusted_width = 2 * (width / (vae_scale_factor_ * 2)); + std::vector shape = { + batch_size, num_channels_latents, adjusted_height, adjusted_width}; + if (latents.has_value()) { + torch::Tensor latent_image_ids = prepare_latent_image_ids( + batch_size, adjusted_height / 2, adjusted_width / 2); + return {latents.value(), latent_image_ids}; + } + torch::Tensor latents_tensor = randn_tensor(shape, seed, options_); + torch::Tensor packed_latents = pack_latents(latents_tensor, + batch_size, + num_channels_latents, + adjusted_height, + adjusted_width); + torch::Tensor latent_image_ids = prepare_latent_image_ids( + batch_size, adjusted_height / 2, adjusted_width / 2); + return {packed_latents, latent_image_ids}; + } + + torch::Tensor prepare_image(torch::Tensor image, + int64_t width, + int64_t height, + int64_t batch_size, + int64_t num_images_per_prompt) { + int image_batch_size = image.size(0); + int repeat_times; + image = image_processor_->preprocess( + image, height, width, "default", std::nullopt); + if (image_batch_size == 1) { + repeat_times = batch_size; + } else { + repeat_times = num_images_per_prompt; + } + const auto B = image.size(0); + const auto C = image.size(1); + const auto H = image.size(2); + const auto W = image.size(3); + image = image.unsqueeze(1) + .repeat({1, repeat_times, 1, 1, 1}) + .reshape({B * repeat_times, C, H, W}) + .to(options_); + return image; + } + + std::vector forward_( + std::optional> prompt = std::nullopt, + std::optional> prompt_2 = std::nullopt, + torch::Tensor control_image = torch::Tensor(), + int64_t height = 512, + int64_t width = 512, + float strength = 1.0f, + int64_t num_inference_steps = 50, + float guidance_scale = 30.0f, + int64_t num_images_per_prompt = 1, + int64_t seed = 42, + std::optional latents = std::nullopt, + std::optional prompt_embeds = std::nullopt, + std::optional pooled_prompt_embeds = std::nullopt, + int64_t max_sequence_length = 512) { + torch::NoGradGuard no_grad; + int64_t actual_height = height; + int64_t actual_width = width; + int64_t batch_size; + if (prompt.has_value()) { + batch_size = prompt.value().size(); + } else { + batch_size = prompt_embeds.value().size(0); + } + int64_t total_batch_size = batch_size * num_images_per_prompt; + // encode prompt + auto [encoded_prompt_embeds, encoded_pooled_embeds, text_ids] = + encode_prompt(prompt, + prompt_2, + prompt_embeds, + pooled_prompt_embeds, + num_images_per_prompt, + max_sequence_length); + + // prepare latent + int64_t num_channels_latents = transformer_->in_channels() / 8; + // control image to latents + control_image = prepare_image(control_image, + width, + height, + batch_size * num_images_per_prompt, + num_images_per_prompt); + if (control_image.dim() == 4) { + auto enc = vae_->encode(control_image, seed); + control_image = (enc - vae_shift_factor_) * vae_scaling_factor_; + control_image = control_image.to(options_); + auto shape = control_image.sizes(); + auto height_control_image = shape[2]; + auto width_control_image = shape[3]; + control_image = pack_latents(control_image, + total_batch_size, + num_channels_latents, + height_control_image, + width_control_image); + } + auto [prepared_latents, latent_image_ids] = + prepare_latents(total_batch_size, + num_channels_latents, + actual_height, + actual_width, + seed, + latents); + // prepare timestep + std::vector new_sigmas; + for (int64_t i = 0; i < num_inference_steps; ++i) { + new_sigmas.push_back(1.0f - static_cast(i) / + (num_inference_steps - 1) * + (1.0f - 1.0f / num_inference_steps)); + } + + int64_t image_seq_len = prepared_latents.size(1); + float mu = calculate_shift(image_seq_len, + scheduler_->base_image_seq_len(), + scheduler_->max_image_seq_len(), + scheduler_->base_shift(), + scheduler_->max_shift()); + auto [timesteps, num_inference_steps_actual] = retrieve_timesteps( + scheduler_, num_inference_steps, options_.device(), new_sigmas, mu); + int64_t num_warmup_steps = + std::max(static_cast(timesteps.numel()) - + num_inference_steps_actual * scheduler_->order(), + static_cast(0LL)); + // prepare guidance + torch::Tensor guidance; + if (transformer_->guidance_embeds()) { + torch::TensorOptions options = options_.dtype(torch::kFloat32); + guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options); + guidance = guidance.expand({prepared_latents.size(0)}); + } + scheduler_->set_begin_index(0); + torch::Tensor timestep = + torch::empty({prepared_latents.size(0)}, prepared_latents.options()); + // image rotary positional embeddings outplace computation + auto [rot_emb1, rot_emb2] = + pos_embed_->forward_cache(text_ids, + latent_image_ids, + height / (vae_scale_factor_ * 2), + width / (vae_scale_factor_ * 2)); + torch::Tensor image_rotary_emb = torch::stack({rot_emb1, rot_emb2}, 0); + for (int64_t i = 0; i < timesteps.numel(); ++i) { + torch::Tensor t = timesteps[i].unsqueeze(0); + timestep.fill_(t.item()) + .to(prepared_latents.dtype()) + .div_(1000.0f); + int64_t step_id = i + 1; + auto controlled_latents = + torch::cat({prepared_latents, control_image}, 2); + torch::Tensor noise_pred = transformer_->forward(controlled_latents, + encoded_prompt_embeds, + encoded_pooled_embeds, + timestep, + image_rotary_emb, + guidance, + step_id); + auto prev_latents = scheduler_->step(noise_pred, t, prepared_latents); + prepared_latents = prev_latents.detach(); + std::vector tensors = {prepared_latents, noise_pred}; + noise_pred.reset(); + prev_latents = torch::Tensor(); + + if (latents.has_value() && + prepared_latents.dtype() != latents.value().dtype()) { + prepared_latents = prepared_latents.to(latents.value().dtype()); + } + } + torch::Tensor image; + // Unpack latents + torch::Tensor unpacked_latents = unpack_latents( + prepared_latents, actual_height, actual_width, vae_scale_factor_); + unpacked_latents = + (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_; + unpacked_latents = unpacked_latents.to(options_.dtype()); + image = vae_->decode(unpacked_latents); + image = image_processor_->postprocess(image); + return std::vector{{image}}; + } + + private: + FlowMatchEulerDiscreteScheduler scheduler_{nullptr}; + VAE vae_{nullptr}; + VAEImageProcessor image_processor_{nullptr}; + FluxDiTModel transformer_{nullptr}; + float vae_scaling_factor_; + float vae_shift_factor_; + int64_t vae_latent_channels_; + int64_t latent_channels_; + FluxPosEmbed pos_embed_{nullptr}; +}; +TORCH_MODULE(FluxControlPipeline); + +REGISTER_DIT_MODEL(flux_control, FluxControlPipeline); +} // namespace xllm diff --git a/xllm/models/dit/pipeline_flux_fill.h b/xllm/models/dit/pipeline_flux_fill.h index 73e1579fc..cfd6d51d6 100644 --- a/xllm/models/dit/pipeline_flux_fill.h +++ b/xllm/models/dit/pipeline_flux_fill.h @@ -29,49 +29,45 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { FluxFillPipelineImpl(const DiTModelContext& context) { auto model_args = context.get_model_args("vae"); options_ = context.get_tensor_options(); - device_ = options_.device(); - dtype_ = options_.dtype().toScalarType(); vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1); vae_shift_factor_ = model_args.shift_factor(); vae_scaling_factor_ = model_args.scale_factor(); latent_channels_ = model_args.latent_channels(); - - default_sample_size_ = 128; - tokenizer_max_length_ = 77; // TODO: get from config file + tokenizer_max_length_ = + context.get_model_args("text_encoder").max_position_embeddings(); LOG(INFO) << "Initializing FluxFill pipeline..."; - image_processor_ = VAEImageProcessor( - context.get_model_context("vae"), true, true, false, false, false); - mask_processor_ = VAEImageProcessor( - context.get_model_context("vae"), true, false, true, false, true); + image_processor_ = VAEImageProcessor(context.get_model_context("vae"), + true, + true, + false, + false, + false, + latent_channels_); + mask_processor_ = VAEImageProcessor(context.get_model_context("vae"), + true, + false, + true, + false, + true, + latent_channels_); vae_ = VAE(context.get_model_context("vae")); LOG(INFO) << "VAE initialized."; pos_embed_ = register_module( "pos_embed", - FluxPosEmbed(10000, + FluxPosEmbed(ROPE_SCALE_BASE, context.get_model_args("transformer").axes_dims_rope())); transformer_ = FluxDiTModel(context.get_model_context("transformer")); - LOG(INFO) << "DiT transformer initialized."; t5_ = T5EncoderModel(context.get_model_context("text_encoder_2")); - LOG(INFO) << "T5 initialized."; clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder")); - LOG(INFO) << "CLIP text model initialized."; scheduler_ = FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler")); - LOG(INFO) << "FluxFill pipeline initialized."; register_module("vae", vae_); - LOG(INFO) << "VAE registered."; register_module("vae_image_processor", image_processor_); - LOG(INFO) << "VAE image processor registered."; register_module("mask_processor", mask_processor_); - LOG(INFO) << "mask processor registered."; register_module("transformer", transformer_); - LOG(INFO) << "DiT transformer registered."; register_module("t5", t5_); - LOG(INFO) << "T5 registered."; register_module("scheduler", scheduler_); - LOG(INFO) << "Scheduler registered."; register_module("clip_text_model", clip_text_model_); - LOG(INFO) << "CLIP text model registered."; } DiTForwardOutput forward(const DiTForwardInput& input) { @@ -141,13 +137,13 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { LOG(INFO) << "FluxFill model components loaded, start to load weights to " "sub models"; transformer_->load_model(std::move(transformer_loader)); - transformer_->to(device_); + transformer_->to(options_.device()); vae_->load_model(std::move(vae_loader)); - vae_->to(device_); + vae_->to(options_.device()); t5_->load_model(std::move(t5_loader)); - t5_->to(device_); + t5_->to(options_.device()); clip_text_model_->load_model(std::move(clip_loader)); - clip_text_model_->to(device_); + clip_text_model_->to(options_.device()); tokenizer_ = tokenizer_loader->tokenizer(); tokenizer_2_ = tokenizer_2_loader->tokenizer(); } @@ -174,7 +170,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { masked_image_latents = (masked_image_latents - vae_shift_factor_) * vae_scaling_factor_; - masked_image_latents = masked_image_latents.to(device_).to(dtype_); + masked_image_latents = masked_image_latents.to(options_); batch_size = batch_size * num_images_per_prompt; if (mask.size(0) < batch_size) { @@ -203,7 +199,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { {batch_size, vae_scale_factor_ * vae_scale_factor_, height, width}); mask = pack_latents( mask, batch_size, vae_scale_factor_ * vae_scale_factor_, height, width); - mask = mask.to(device_).to(dtype_); + mask = mask.to(options_); return {mask, masked_image_latents}; } @@ -222,8 +218,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { int64_t t_start = std::max(num_inference_steps - init_timestep, int64_t(0)); int64_t start_idx = t_start * scheduler_->order(); - auto timesteps = - scheduler_->timesteps().slice(0, start_idx).to(device_).to(dtype_); + auto timesteps = scheduler_->timesteps().slice(0, start_idx).to(options_); scheduler_->set_begin_index(start_idx); return {timesteps, num_inference_steps - t_start}; } @@ -245,7 +240,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { torch::Tensor latent_image_ids = prepare_latent_image_ids(batch_size, height / 2, width / 2); if (latents.has_value()) { - return {latents.value().to(device_).to(dtype_), latent_image_ids}; + return {latents.value().to(options_), latent_image_ids}; } torch::Tensor image_latents; @@ -325,7 +320,8 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { scheduler_->base_shift(), scheduler_->max_shift()); - retrieve_timesteps(scheduler_, num_inference_steps, device_, sigmas, mu); + retrieve_timesteps( + scheduler_, num_inference_steps, options_.device(), sigmas, mu); torch::Tensor timesteps; std::tie(timesteps, num_inference_steps) = get_timesteps(num_inference_steps, strength); @@ -348,8 +344,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { latents); if (masked_image_latents.has_value()) { - masked_image_latents = - masked_image_latents.value().to(device_).to(dtype_); + masked_image_latents = masked_image_latents.value().to(options_); } else { mask_image = mask_processor_->preprocess(mask_image.value(), height, width); @@ -385,11 +380,12 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { width / (vae_scale_factor_ * 2)); torch::Tensor image_rotary_emb = - torch::stack({rot_emb1, rot_emb2}, 0).to(device_); + torch::stack({rot_emb1, rot_emb2}, 0).to(options_.device()); for (int64_t i = 0; i < timesteps.size(0); ++i) { torch::Tensor t = timesteps[i]; - torch::Tensor timestep = t.expand({latents->size(0)}).to(device_); + torch::Tensor timestep = + t.expand({latents->size(0)}).to(options_.device()); int64_t step_id = i + 1; torch::Tensor input_latents = @@ -404,7 +400,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { guidance, step_id); auto prev_latents = scheduler_->step(noise_pred, t, latents.value()); - latents = prev_latents.detach().to(device_); + latents = prev_latents.detach().to(options_.device()); } torch::Tensor output_image; @@ -414,7 +410,7 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { (unpacked_latents / vae_scaling_factor_) + vae_shift_factor_; output_image = vae_->decode(unpacked_latents); - output_image = image_processor_->postprocess(output_image, "pil"); + output_image = image_processor_->postprocess(output_image); return std::vector{{output_image}}; } @@ -426,7 +422,6 @@ class FluxFillPipelineImpl : public FluxPipelineBaseImpl { FluxDiTModel transformer_{nullptr}; float vae_scaling_factor_; float vae_shift_factor_; - int default_sample_size_; int64_t latent_channels_; FluxPosEmbed pos_embed_{nullptr}; }; diff --git a/xllm/models/models.h b/xllm/models/models.h old mode 100755 new mode 100644 index 12e8e2d5d..161a4c318 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -16,23 +16,24 @@ limitations under the License. #pragma once #if defined(USE_NPU) -#include "dit/pipeline_flux.h" // IWYU pragma: keep -#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep -#include "llm/deepseek_v2.h" // IWYU pragma: keep -#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep -#include "llm/deepseek_v3.h" // IWYU pragma: keep -#include "llm/glm4_moe.h" // IWYU pragma: keep -#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep -#include "llm/kimi_k2.h" // IWYU pragma: keep -#include "llm/llama.h" // IWYU pragma: keep -#include "llm/llama3.h" // IWYU pragma: keep -#include "llm/llm_model_base.h" // IWYU pragma: keep -#include "llm/qwen2.h" // IWYU pragma: keep -#include "llm/qwen3_embedding.h" // IWYU pragma: keep -#include "vlm/minicpmv.h" // IWYU pragma: keep -#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep -#include "vlm/qwen3_vl.h" // IWYU pragma: keep -#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep +#include "dit/pipeline_flux.h" // IWYU pragma: keep +#include "dit/pipeline_flux_control.h" // IWYU pragma: keep +#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep +#include "llm/deepseek_v2.h" // IWYU pragma: keep +#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep +#include "llm/deepseek_v3.h" // IWYU pragma: keep +#include "llm/glm4_moe.h" // IWYU pragma: keep +#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep +#include "llm/kimi_k2.h" // IWYU pragma: keep +#include "llm/llama.h" // IWYU pragma: keep +#include "llm/llama3.h" // IWYU pragma: keep +#include "llm/llm_model_base.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3_embedding.h" // IWYU pragma: keep +#include "vlm/minicpmv.h" // IWYU pragma: keep +#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #endif #include "llm/llm_model_base.h" // IWYU pragma: keep diff --git a/xllm/proto/image_generation.proto b/xllm/proto/image_generation.proto index f16b9f178..cf8eb69bb 100644 --- a/xllm/proto/image_generation.proto +++ b/xllm/proto/image_generation.proto @@ -43,6 +43,9 @@ message Input { // An image batch of mask images generated by the VAE optional Tensor masked_image_latent = 12; + + // Control Image + optional string control_image = 13; } // Generation parameters container