Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion xllm/core/framework/batch/dit_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {

std::vector<torch::Tensor> images;
std::vector<torch::Tensor> mask_images;

std::vector<torch::Tensor> control_images;
std::vector<torch::Tensor> latents;
std::vector<torch::Tensor> masked_image_latents;
for (const auto& request : request_vec_) {
Expand Down Expand Up @@ -96,6 +96,7 @@ DiTForwardInput DiTBatch::prepare_forward_input() {

images.emplace_back(input_params.image);
mask_images.emplace_back(input_params.mask_image);
control_images.emplace_back(input_params.control_image);
}

if (input.prompts.size() != request_vec_.size()) {
Expand All @@ -122,6 +123,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
input.mask_images = torch::stack(mask_images);
}

if (check_tensors_valid(control_images)) {
input.control_image = torch::stack(control_images);
}

if (check_tensors_valid(prompt_embeds)) {
input.prompt_embeds = torch::stack(prompt_embeds);
}
Expand Down
10 changes: 10 additions & 0 deletions xllm/core/framework/request/dit_request_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
}
}

if (input.has_control_image()) {
std::string raw_bytes;
if (!butil::Base64Decode(input.control_image(), &raw_bytes)) {
LOG(ERROR) << "Base64 control_image decode failed";
}
if (!decoder.decode(raw_bytes, input_params.control_image)) {
LOG(ERROR) << "Control_image decode failed.";
}
}

// generation params
const auto& params = request.parameters();
if (params.has_size()) {
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/request/dit_request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ struct DiTInputParams {

torch::Tensor image;

torch::Tensor control_image;

torch::Tensor mask_image;

torch::Tensor masked_image_latent;
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/runtime/dit_forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ struct DiTForwardInput {

torch::Tensor mask_images;

torch::Tensor control_image;

torch::Tensor masked_image_latents;

torch::Tensor prompt_embeds;
Expand Down
14 changes: 6 additions & 8 deletions xllm/models/dit/autoencoder_kl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
bool do_normalize = true,
bool do_binarize = false,
bool do_convert_rgb = false,
bool do_convert_grayscale = false) {
bool do_convert_grayscale = false,
int64_t latent_channels = 4) {
const auto& model_args = context.get_model_args();
options_ = context.get_tensor_options();
scale_factor_ = 1 << model_args.block_out_channels().size();
latent_channels_ = 4;
latent_channels_ = latent_channels;
do_resize_ = do_resize;
do_normalize_ = do_normalize;
do_binarize_ = do_binarize;
Expand Down Expand Up @@ -116,7 +118,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
if (channel == latent_channels_) {
return image;
}

auto [target_h, target_w] =
get_default_height_width(processed, height, width);
if (do_resize_) {
Expand All @@ -129,13 +130,12 @@ class VAEImageProcessorImpl : public torch::nn::Module {
if (do_binarize_) {
processed = (processed >= 0.5f).to(torch::kFloat32);
}
processed = processed.to(image.dtype());
processed = processed.to(options_);
return processed;
}

torch::Tensor postprocess(
const torch::Tensor& tensor,
const std::string& output_type = "pt",
std::optional<std::vector<bool>> do_denormalize = std::nullopt) {
torch::Tensor processed = tensor.clone();
if (do_normalize_) {
Expand All @@ -149,9 +149,6 @@ class VAEImageProcessorImpl : public torch::nn::Module {
}
}
}
if (output_type == "np") {
return processed.permute({0, 2, 3, 1}).contiguous();
}
return processed;
}

Expand Down Expand Up @@ -202,6 +199,7 @@ class VAEImageProcessorImpl : public torch::nn::Module {
bool do_binarize_ = false;
bool do_convert_rgb_ = false;
bool do_convert_grayscale_ = false;
torch::TensorOptions options_;
};
TORCH_MODULE(VAEImageProcessor);

Expand Down
2 changes: 1 addition & 1 deletion xllm/models/dit/dit.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ class FluxTransformer2DModelImpl : public torch::nn::Module {
proj_out_->verify_loaded_weights(prefix + "proj_out.");
}

int64_t in_channels() { return out_channels_; }
int64_t in_channels() { return in_channels_; }
bool guidance_embeds() { return guidance_embeds_; }

private:
Expand Down
97 changes: 41 additions & 56 deletions xllm/models/dit/pipeline_flux.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,43 +30,35 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
const auto& model_args = context.get_model_args("vae");
options_ = context.get_tensor_options();
vae_scale_factor_ = 1 << (model_args.block_out_channels().size() - 1);
device_ = options_.device();
dtype_ = options_.dtype().toScalarType();

vae_shift_factor_ = model_args.shift_factor();
vae_scaling_factor_ = model_args.scale_factor();
default_sample_size_ = 128;
tokenizer_max_length_ = 77; // TODO: get from config file
tokenizer_max_length_ =
context.get_model_args("text_encoder").max_position_embeddings();
LOG(INFO) << "Initializing Flux pipeline...";
vae_image_processor_ = VAEImageProcessor(
context.get_model_context("vae"), true, true, false, false, false);
vae_image_processor_ = VAEImageProcessor(context.get_model_context("vae"),
true,
true,
false,
false,
false,
model_args.latent_channels());
vae_ = VAE(context.get_model_context("vae"));
LOG(INFO) << "VAE initialized.";
pos_embed_ = register_module(
"pos_embed",
FluxPosEmbed(10000,
FluxPosEmbed(ROPE_SCALE_BASE,
context.get_model_args("transformer").axes_dims_rope()));
transformer_ = FluxDiTModel(context.get_model_context("transformer"));
LOG(INFO) << "DiT transformer initialized.";
t5_ = T5EncoderModel(context.get_model_context("text_encoder_2"));
LOG(INFO) << "T5 initialized.";
clip_text_model_ = CLIPTextModel(context.get_model_context("text_encoder"));
LOG(INFO) << "CLIP text model initialized.";
scheduler_ =
FlowMatchEulerDiscreteScheduler(context.get_model_context("scheduler"));
LOG(INFO) << "Flux pipeline initialized.";
register_module("vae", vae_);
LOG(INFO) << "VAE registered.";
register_module("vae_image_processor", vae_image_processor_);
LOG(INFO) << "VAE image processor registered.";
register_module("transformer", transformer_);
LOG(INFO) << "DiT transformer registered.";
register_module("t5", t5_);
LOG(INFO) << "T5 registered.";
register_module("scheduler", scheduler_);
LOG(INFO) << "Scheduler registered.";
register_module("clip_text_model", clip_text_model_);
LOG(INFO) << "CLIP text model registered.";
}

DiTForwardOutput forward(const DiTForwardInput& input) {
Expand Down Expand Up @@ -104,21 +96,21 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
: std::nullopt;

std::vector<torch::Tensor> output = forward_(
prompts, // prompt
prompts_2, // prompt_2
negative_prompts, // negative_prompt
negative_prompts_2, // negative_prompt_2
generation_params.true_cfg_scale, // cfg scale
std::make_optional(generation_params.height), // height
std::make_optional(generation_params.width), // width
generation_params.num_inference_steps, // num_inference_steps
generation_params.guidance_scale, // guidance_scale
generation_params.num_images_per_prompt, // num_images_per_prompt
seed, // seed
latents, // latents
prompt_embeds, // prompt_embeds
negative_prompt_embeds, // negative_prompt_embeds
pooled_prompt_embeds, // pooled_prompt_embeds
prompts, // prompt
prompts_2, // prompt_2
negative_prompts, // negative_prompt
negative_prompts_2, // negative_prompt_2
generation_params.true_cfg_scale, // cfg scale
generation_params.height, // height
generation_params.width, // width
generation_params.num_inference_steps, // num_inference_steps
generation_params.guidance_scale, // guidance_scale
generation_params.num_images_per_prompt, // num_images_per_prompt
seed, // seed
latents, // latents
prompt_embeds, // prompt_embeds
negative_prompt_embeds, // negative_prompt_embeds
pooled_prompt_embeds, // pooled_prompt_embeds
negative_pooled_prompt_embeds, // negative_pooled_prompt_embeds
generation_params.max_sequence_length // max_sequence_length
);
Expand All @@ -141,13 +133,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
LOG(INFO)
<< "Flux model components loaded, start to load weights to sub models";
transformer_->load_model(std::move(transformer_loader));
transformer_->to(device_);
transformer_->to(options_.device());
vae_->load_model(std::move(vae_loader));
vae_->to(device_);
vae_->to(options_.device());
t5_->load_model(std::move(t5_loader));
t5_->to(device_);
t5_->to(options_.device());
clip_text_model_->load_model(std::move(clip_loader));
clip_text_model_->to(device_);
clip_text_model_->to(options_.device());
tokenizer_ = tokenizer_loader->tokenizer();
tokenizer_2_ = tokenizer_2_loader->tokenizer();
}
Expand Down Expand Up @@ -186,8 +178,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
std::optional<std::vector<std::string>> negative_prompt = std::nullopt,
std::optional<std::vector<std::string>> negative_prompt_2 = std::nullopt,
float true_cfg_scale = 1.0f,
std::optional<int64_t> height = std::nullopt,
std::optional<int64_t> width = std::nullopt,
int64_t height = 512,
int64_t width = 512,
int64_t num_inference_steps = 28,
float guidance_scale = 3.5f,
int64_t num_images_per_prompt = 1,
Expand All @@ -199,12 +191,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
std::optional<torch::Tensor> negative_pooled_prompt_embeds = std::nullopt,
int64_t max_sequence_length = 512) {
torch::NoGradGuard no_grad;
int64_t actual_height = height.has_value()
? height.value()
: default_sample_size_ * vae_scale_factor_;
int64_t actual_width = width.has_value()
? width.value()
: default_sample_size_ * vae_scale_factor_;
int64_t batch_size;
if (prompt.has_value()) {
batch_size = prompt.value().size();
Expand Down Expand Up @@ -244,8 +230,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
auto [prepared_latents, latent_image_ids] =
prepare_latents(total_batch_size,
num_channels_latents,
actual_height,
actual_width,
height,
width,
seed.has_value() ? seed.value() : 42,
latents);
// prepare timestep
Expand All @@ -263,7 +249,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
scheduler_->base_shift(),
scheduler_->max_shift());
auto [timesteps, num_inference_steps_actual] = retrieve_timesteps(
scheduler_, num_inference_steps, device_, new_sigmas, mu);
scheduler_, num_inference_steps, options_.device(), new_sigmas, mu);
int64_t num_warmup_steps =
std::max(static_cast<int64_t>(timesteps.numel()) -
num_inference_steps_actual * scheduler_->order(),
Expand All @@ -272,7 +258,7 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
torch::Tensor guidance;
if (transformer_->guidance_embeds()) {
torch::TensorOptions options =
torch::dtype(torch::kFloat32).device(device_);
torch::dtype(torch::kFloat32).device(options_.device());

guidance = torch::full(at::IntArrayRef({1}), guidance_scale, options);
guidance = guidance.expand({prepared_latents.size(0)});
Expand All @@ -284,8 +270,8 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
auto [rot_emb1, rot_emb2] =
pos_embed_->forward_cache(text_ids,
latent_image_ids,
height.value() / (vae_scale_factor_ * 2),
width.value() / (vae_scale_factor_ * 2));
height / (vae_scale_factor_ * 2),
width / (vae_scale_factor_ * 2));
torch::Tensor image_rotary_emb = torch::stack({rot_emb1, rot_emb2}, 0);
for (int64_t i = 0; i < timesteps.numel(); ++i) {
torch::Tensor t = timesteps[i].unsqueeze(0);
Expand Down Expand Up @@ -326,13 +312,13 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
}
torch::Tensor image;
// Unpack latents
torch::Tensor unpacked_latents = unpack_latents(
prepared_latents, actual_height, actual_width, vae_scale_factor_);
torch::Tensor unpacked_latents =
unpack_latents(prepared_latents, height, width, vae_scale_factor_);
unpacked_latents =
(unpacked_latents / vae_scaling_factor_) + vae_shift_factor_;
unpacked_latents = unpacked_latents.to(dtype_);
unpacked_latents = unpacked_latents.to(options_.dtype());
image = vae_->decode(unpacked_latents);
image = vae_image_processor_->postprocess(image, "pil");
image = vae_image_processor_->postprocess(image);
return std::vector<torch::Tensor>{{image}};
}

Expand All @@ -343,7 +329,6 @@ class FluxPipelineImpl : public FluxPipelineBaseImpl {
FluxDiTModel transformer_{nullptr};
float vae_scaling_factor_;
float vae_shift_factor_;
int default_sample_size_;
FluxPosEmbed pos_embed_{nullptr};
};
TORCH_MODULE(FluxPipeline);
Expand Down
14 changes: 8 additions & 6 deletions xllm/models/dit/pipeline_flux_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ limitations under the License.

namespace xllm {

constexpr int64_t ROPE_SCALE_BASE = 10000;

float calculate_shift(int64_t image_seq_len,
int64_t base_seq_len = 256,
int64_t max_seq_len = 4096,
Expand Down Expand Up @@ -213,9 +215,9 @@ class FluxPipelineBaseImpl : public torch::nn::Module {
auto input_ids =
torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong))
.view({batch_size, max_sequence_length})
.to(device_);
.to(options_.device());
torch::Tensor prompt_embeds = t5_->forward(input_ids);
prompt_embeds = prompt_embeds.to(device_).to(dtype_);
prompt_embeds = prompt_embeds.to(options_);
int64_t seq_len = prompt_embeds.size(1);
prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt, 1});
prompt_embeds =
Expand Down Expand Up @@ -244,10 +246,10 @@ class FluxPipelineBaseImpl : public torch::nn::Module {
auto input_ids =
torch::tensor(text_input_ids_flat, torch::dtype(torch::kLong))
.view({batch_size, tokenizer_max_length_})
.to(device_);
.to(options_.device());
auto encoder_output = clip_text_model_->forward(input_ids);
torch::Tensor prompt_embeds = encoder_output;
prompt_embeds = prompt_embeds.to(device_).to(dtype_);
prompt_embeds = prompt_embeds.to(options_);
prompt_embeds = prompt_embeds.repeat({1, num_images_per_prompt});
prompt_embeds =
prompt_embeds.view({batch_size * num_images_per_prompt, -1});
Expand Down Expand Up @@ -281,8 +283,8 @@ class FluxPipelineBaseImpl : public torch::nn::Module {
prompt_embeds = get_t5_prompt_embeds(
prompt_2_list, num_images_per_prompt, max_sequence_length);
}
torch::Tensor text_ids = torch::zeros({prompt_embeds.value().size(1), 3},
torch::device(device_).dtype(dtype_));
torch::Tensor text_ids =
torch::zeros({prompt_embeds.value().size(1), 3}, options_);

return std::make_tuple(prompt_embeds.value(),
pooled_prompt_embeds.has_value()
Expand Down
Loading