Running PixArt-Σ/Flux.1 Image Generation on Lower VRAM GPUs: A Short Tutorial in Python

Diffusers and Quanto giving hope to the GPU-challengedGenerated locally by PixArt-Σ with less than 8Gb of VRamImage generation tools are hotter than ever, and they’ve never been more powerful. Models like PixArt Sigma and Flux.1 are leading the charge, thanks to their open weight models and permissive licenses. This setup allows for creative tinkering, including training LoRAs without sharing data outside your computer.However, working with these models can be challenging if you’re using older or less VRAM-rich GPUs. Typically, there’s a trade-off between quality, speed, and VRAM usage. In this blog post, we’ll focus on optimizing for speed and lower VRAM usage while maintaining as much quality as possible. This approach works exceptionally well for PixArt due to its smaller size, but results might vary with Flux.1. I’ll share some alternative solutions for Flux.1 at the end of this post.Both PixArt Sigma and Flux.1 are transformer-based, which means they benefit from the same quantization techniques used by large language models (LLMs). Quantization involves compressing the model’s components to use less memory. It allows you to keep all model components in GPU VRAM simultaneously, leading to faster generation speeds compared to methods that move weights between the GPU and CPU, which can slow things down.Let’s dive into the setup!Setting Up Your Local EnvironmentFirst, ensure you have Nvidia drivers and Anaconda installed.Next, create a python environment and install all the main requirements:conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidiaThen the Diffusers and Quanto libs:pip install pillow==10.3.0 loguru~=0.7.2 optimum-quanto==0.2.4 diffusers==0.30.0 transformers==4.44.2 accelerate==0.33.0 sentencepiece==0.2.0Quantization CodeHere’s a simple script to get you started for PixArt-Sigma:from optimum.quanto import qint8, qint4, quantize, freezefrom diffusers import PixArtSigmaPipelineimport torchpipeline = PixArtSigmaPipeline.from_pretrained( "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16)quantize(pipeline.transformer, weights=qint8)freeze(pipeline.transformer)quantize(pipeline.text_encoder, weights=qint4, exclude="proj_out")freeze(pipeline.text_encoder)pipe = pipeline.to("cuda")for i in range(2): generator = torch.Generator(device="cpu").manual_seed(i) prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed" image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator).images[0] image.save(f"Sigma_{i}.png")Understanding the Script: Here are the major steps of the implementationImport Necessary Libraries: We import libraries for quantization, model loading, and GPU handling.Load the Model: We load the PixArt Sigma model in half-precision (float16) to CPU first.Quantize the Model: We apply quantization to the transformer and text encoder components of the model. Here we apply different levels of quantizations: The Text encoder part is quantized at qint4 given that it is quite large. The vision part, if quantized at qint8, would make the full pipeline use up 7.5 G VRAM, if not quantized at all would use around 8.5 G VRAM.Move to GPU: We move the pipeline to the GPU .to("cuda")for faster processing.Generate Images: We use the pipe to generate images based on a given prompt and save the output.Running the ScriptSave the script and run it in your environment. You should see an image generated based on the prompt “Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed” saved as sigma_1.png. Generation takes 6 seconds on a RTX 3080 GPU.Generated locally by PixArt-ΣYou can achieve similar results with Flux.1 Schnell, despite its additional components, but it would necessitate more aggressive quantization, which would negatively lower quality (Unless you have access to more VRAM, say 16 or 25 Gigs)import torchfrom optimum.quanto import qint2, qint4, quantize, freezefrom diffusers.pipelines.flux.pipeline_flux import FluxPipelinepipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)quantize(pipe.text_encoder, weights=qint4, exclude="proj_out")freeze(pipe.text_encoder)quantize(pipe.text_encoder_2, weights=qint2, exclude="proj_out")freeze(pipe.text_encoder_2)quantize(pipe.transformer, weights=qint4, exclude="proj_out")freeze(pipe.transformer)pipe = pipe.to("cuda")for i in range(10): generator = torch.Generator(device="cpu").manual_seed(i) prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed" image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator, num_inference_steps=4).images[0] image.save(f"Schnell_{i}.png")Generated locally by Flux.1 Schnell: Lower quality and poor prompt adherence due to excessive quantizationWe can see that quantization of the text encoder to qint2 and vision transformer to qint8 might be too aggressive, which had a significant impact on the quality for Flux.1 SchnellHere are some alternatives for running Flux.1 Schnell:If PixArt-Sigma is not sufficient for your needs and you don’t have enough VRAM to run Flux.1 at sufficient quality you have two main options:ComfyUI or Forge: Those are GUI tools that enthusiasts use, they mostly sacrifice speed for quality.Replicate API: It costs 0.003 per image generation for Schnell.DeploymentI had a little fun deploying PixArt Sigma on an older machine I have. Here is a brief summary of how I went about it:First the list of component:HTMX and Tailwind: These are like the face of the project. HTMX helps make the website interactive without a lot of extra code, and Tailwind gives it a nice look.FastAPI: It takes requests from the website and decides what to do with them.Celery Worker: Think of this as the hard worker. It takes the orders from FastAPI and actually creates the images.Redis Cache/Pub-Sub: This is like the communication center. It helps different parts of the project talk to each other and remember important stuff.GCS (Google Cloud Storage): This is where we keep the finished images.Now, how do they all work together? Here’s a simple rundown:When you visit the website and make a request, HTMX and Tailwind make sure it looks good.FastAPI gets the request and tells the Celery Worker what kind of image to make through Redis.The Celery Worker goes to work, creating the image.Once the image is ready, it gets stored in GCS, so it’s easy to access.Service URL: https://image-generation-app-340387183829.europe-west1.run.appDemo of the appConclusionBy quantizing the model components, we can significantly reduce VRAM usage while maintaining good image quality and improving generation speed. This method is particularly effective for models like PixArt Sigma. For Flux.1, while the results might be mixed, the principles of quantization remain applicable.References:https://huggingface.co/blog/quanto-diffusershttps://lightning.ai/lightning-ai/studios/deploy-an-image-generation-api-with-fluxRunning PixArt-Σ/Flux.1 Image Generation on Lower VRAM GPUs: A Short Tutorial in Python was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Welcome to Billionaire Club Co LLC, your gateway to a brand-new social media experience! Sign up today and dive into over 10,000 fresh daily articles and videos curated just for your enjoyment. Enjoy the ad free experience, unlimited content interactions, and get that coveted blue check verification—all for just $1 a month!
Account Frozen
Your account is frozen. You can still view content but cannot interact with it.
Please go to your settings to update your account status.
Open Profile Settings