SaladCloud Blog

INSIDE SALAD

Cost-effective Stable Diffusion fine tuning on Salad

Shawn Rushefsky

Stable Diffusion XL (SDXL) fine tuning as a service

I recently wrote a blog about fine tuning Stable Diffusion XL (SDXL) on interruptible GPUs at low cost, starring my dog Timber. The strong results and exceptional cost performance got me wondering: What would it take to turn that into a fully managed Stable Diffusion training platform?

I built one on SaladCloud to find out. The result? Costs as low as $0.00016179 per training step while successfully completing 1000/1000 training jobs.

The many faces of Timber, my Siberian Husky from a Stable Diffusion XL fine tuning

Challenges in developing a Stable Diffusion API

There are a number of challenges involved when developing and deploying what’s essentially a Stable Diffusion XL training API on a distributed cloud like Salad. Salad is a distributed cloud with a Million+ individual PCs around the world connected to our network. The GPUs on Salad are latent Nvidia RTX/GTX series. Our goal is for the service to be resilient at any kind of scale.

  1. Training nodes can be interrupted at any time with no warning
  2. We expect a number of submitted training jobs that dramatically exceeds our capacity to simultaneously run them.
  3. Training jobs are long, from minutes to hours.
  4. Training nodes are too expensive to run while idle, whenever we can help it.
  5. Training nodes can have a very long cold start time.
  6. It is possible, though not common, for a training node to be “bad” for a variety of reasons.
  7. We want to handle downloading training data and uploading model checkpoints as securely as possible.
  8. We want observability into training jobs as they run.
  9. We want the platform as a whole to be very easy to operate, meaning doable for one person. In this case, that one person is me.

Architecture

To handle node interruptions and concurrent training, we built a simple orchestration API, with training compute handled by GPU worker nodes. Additionally, we setup a simple autoscaler using a scheduled Cloudflare Worker. Except for the pool of training nodes, the entire platform uses Cloudflare serverless services.

Cloudflare Dreambooth training architecture

Heavily leveraging serverless technologies for the platform layer greatly reduces operational labor, makes the platform nearly free at rest, and will comfortably scale to handle significantly more load. Given sufficient continuous load, serverless applications do tend to be more expensive than alternatives, so feel free to swap out components as desired. This design doesn’t rely on any provider-specific features, so any SQL database and any Key-Value store would work just as well.

API components

GPU worker node components

  • 🤗 Accelerate to run the training script
  • Dreambooth LoRA SDXL training script from 🤗diffusers repository
  • Stable Diffusion XL as the base image generation model
  • SDXL VAE FP16 Fix as a replacement VAE to enable decoding in fp16
  • 3 Threads:
    • One thread to run the training script in a subprocess
    • One thread to periodically heartbeat the API, letting it know the job is still running
    • One thread to monitor the filesystem for checkpoint creation, and upload those checkpoints

Distributing work

To get work, worker nodes make a GET request to the API, including their machine id as a query parameter. The API prioritizes handing out that are in the running state, but stalled as measured by heartbeat timeout. It also will never hand a job out to a node where that job has previously failed.

Marking a job failed

There are many reasons a job can fail, and some of those reasons are specific to a particular node, such as out of date drivers, inadequate available vRAM, and more. When this occurs, and the worker process is still alive, it makes a POST request to a fail endpoint, indicating the job has failed.

When this happens, the API does a few things:

  1. This machine is banned from processing this job again.
  2. If the job has received too many failure events, the job will be marked as failed. Otherwise, it stays running so that it will be handed out again to a different node.

Handling bad nodes

If a particular node has failed too many jobs, we want to reallocate it. Our first implementation did not take this into account, and one bad node marked 85% of the benchmark failed, just pulling and failing one job after another. We now run a scheduled Cloudflare Worker every 5 minutes to handle reallocating any nodes with more than the allowed number of failures.

Autoscaling the worker pool

Our scheduled Cloudflare Worker also handles scaling the worker cluster. It essentially attempts to keep the number of replicas equal to the number of running and pending jobs, with configurable limits.

Dreambooth autoscaling

Observing a training run

The training script we used from diffusers has a built-in integration with Weights and Biases, a popular ML/AI training dashboard platform. It lets you qualitatively observe the training progress, tracks your training arguments, monitors system usage, and more.

Observing a stable diffusion XL training run

Deployment on Salad

Deploying on Salad is simple. The worker pattern means we don’t need to enable inbound networking or configure any probes. The only environment configuration needed is a URL for the orchestration API, a key for the orchestration API, and an API key for Weights and Biases (optional).

Dreambooth lora training – container setup in Salad

Seeding the benchmark

To get a baseline idea of performance, we ran 1000 identical training jobs, each 1400 steps, with text encoder training. We skipped reporting samples to Weights and Balances for this benchmark. We let the auto-scaler run between 0 and 40 nodes each with 2 vCPU, 16GB RAM, and an RTX 4090 GPU.

if [[ -z $API_KEY ]]; then
  echo "API_KEY is not set"
  exit 1
fi

num_jobs=1000

for i in $(seq 1 $num_jobs); do
    curl  -X POST \
    'https://sdxl.dreambooth.run/job' \
    --header "x-api-key: $API_KEY" \
    --header 'Content-Type: application/json' \
    --data-raw '{
    "instance_prompt": "a photo of timber dog",
    "instance_data_prefix": "timber/",
    "max_train_steps": 1400,
    "learning_rate": 0.0000001,
    "checkpointing_steps": 100,
    "train_text_encoder": true,
    "text_encoder_lr": 0.000005
    }'
done

Visualizing a training run

Here’s an example training job that got interrupted twice, and was able to resume and complete training on a different node each time. The smaller marks are heartbeat events emitted by the worker every 30s, color coded by machine id. We can see for this run that it sat in the queue for 5.4 hours before a worker picked it up, and ran for 54:00 of billable time, calculated as number of heartbeats * 30s. Plugging that into the Pricing Calculator, we see a cost of $0.324/hour, so a total cost of $0.2916 to train the model and the text encoder for 1400 steps. This comes out to $**0.000208/**step. The amount of time taken, and therefore the cost, varies greatly based on the parameters you use for training. Training the text encoder slows down training. Using prior preservation also slows down training. More steps takes longer. It’s interesting to note that although the run was interrupted multiple times, these interruptions cost less than 4 minutes of clock time, and the run still finished in the median amount of time.

Results from the Stable Diffusion XL fine tuning

  • All 1000 jobs completed successfully!
  • Average Time Taken Per Training Job: 62.01 minutes (2.66s per training step)
  • Minimum Time Taken: 41.95 minutes (1.80s per training step)
  • Average Cost Per Training Job: $0.3348 ($0.00023918 per training step)
  • Minimum Cost: $0.2265 ($0.00016179 per training step)
  • 90.1% of jobs completed on one node with zero interruptions.
  • 9% of jobs were interrupted once, requiring 2 nodes to complete training.
  • The remaining 0.9% of jobs required 3 or more nodes to complete training.

Tips and Observations

  • Train the text encoder. It makes a huge difference.
  • More steps with a lower learning rate typically yielded better results.
  • Including the type of animal (e.g. “dog”) in the instance prompt and in validation prompts yielded much better results than just using the name alone. E.g. an instance prompt like “a photo of timber dog” performed much better than just “timber”.
  • As with image generation, there is a considerable amount of skill involved in getting consistently good results from training. You’ll need to experiment, trying different prompts and parameters until you are getting results you like.
  • We checkpointed every 100 steps for this, and that seems to have worked well, with interruptions costing relatively little time.
  • I found the Cloudflare developer platform very easy to work with, and very reasonably priced. Their wrangler cli lets you do all kinds of things, and the local development experience is very good, letting you use all of the Cloudflare services locally without interfering with your production resources.
  • There was some variance in how long jobs took to complete, with a minimum time of 41.95 minutes, a maximum time of 737.92 minutes, and a median time of 53.82 minutes. You can see in the charts below that the vast majority of jobs finish within 1 hour, with the 737.92 minute run representing a substantial outlier. You can also see that the mean time is pulled up from the median by the long tail of slower runs.

Future Improvements

  • One very time consuming part of the training process, as implemented in the training script we pulled from diffusers, is that to generate periodic validation images, it must reload all the model components for inference, generate the images, and then reload all the model components again for training. These interruptions can double (or worse) the amount of time needed to train a LoRA, depending on how often you want to generate validation images. Performing validation inference asynchronously would significantly improve training throughput. This would be implemented by bypassing the original script’s validation logic, and instead queueing validation inference jobs as checkpoints are created and uploaded. A separate pool of workers would then pull these inference jobs, and handle uploading the results to Weights And Balances, or whatever tracking platform you decide to use. In this way, the training process would not need to be interrupted for validation.
  • While we did manage to automatically purge nodes which proved incapable of completing the training task, we did not filter for unreasonably slow nodes. For instance, while the mean training time for this task was 62 minutes, the absolute worst performance came in 12x worse at 738 minutes. This case was sufficiently rare that removing such outliers from the results makes very little difference in the mean cost of the entire batch (see charts above). Still, in the future, armed with this baseline performance information, we could monitor for machines which are working too slowly (e.g. 2 standard deviations slower), and add them to the banned worker records.
  • In the event there is 1 worker node, and 1 available job, it is possible for the worker to be banned for that job, but unable to accumulate enough failures to be automatically reallocated. This will result in the final job not completing until the node is reallocated manually, or naturally. This could be accounted for in the autoscaling code. Our test required 1 manual reallocation due to this scenario.
  • This implementation used a pre-written script, unmodified, to actually perform the training. It’s very likely there are details which could be made more efficient by more deeply integrating the training logic into the worker code. For instance, we don’t really need to save checkpoints to the file system at all, but could instead directly upload them as they’re created.
  • The solution as implemented assumes a single user, and notably lacks any sort of permission management. To build this into a commercial service, you’d likely need to integrate ownership concepts and permission management.
  • Since training jobs may complete many hours after they are submitted, a notification system would be nice, to alert a user when their model is ready to download. Currently, a user would need to periodically poll the API to check on the status.

Conclusions

Our exploration into fine-tuning Stable Diffusion XL on interruptible GPUs has demonstrated the feasibility and efficiency of our approach, despite the significant challenges posed by training interruptions, capacity limitations, and cost management. Leveraging Cloudflare’s serverless technologies alongside our custom orchestration and autoscaling solutions, we’ve created a resilient and manageable system capable of handling large-scale operations with notable cost efficiency and operational simplicity.

The successes of our deployment, underscored by the seamless completion of 1000/1000 benchmark jobs, highlight the system’s robustness and the potential for further improvements. Future enhancements, such as asynchronous validation and refined node performance assessments, promise to elevate the performance and cost-effectiveness of our service.

Given the extensive amount of experimentation required to get good results, a platform like this can be useful for individuals as well as those seeking to build commercial offerings. Once deployed, a person could submit many different combinations of parameters, prompts, and training data, and run many experiments in parallel.

Resources

Have questions about SaladCloud for your workload?

Book a 15 min call with our team. Get $50 in testing credits.

Related Blog Posts

AI transcription - Parakeet TRT 1.1B batch transription compared against APIs

AI Transcription Benchmark: 1 Million Hours of Youtube Videos with Parakeet TDT 1.1B for Just $1260, a 1000-fold cost reduction 

Building upon the inference benchmark of Parakeet TDT 1.1B on SaladCloud and with our ongoing efforts to enhance the system architecture and implementation for batch jobs, we have achieved a 1000-fold...
Read More
Self-managed Openvoice vs Metavoice comparison: A Text to speech API alternative

Text-to-Speech (TTS) API Alternative: Self-Managed OpenVoice vs MetaVoice Comparison

A cost-effective alternative to Text-to-speech APIs In the realm of text-to-speech (TTS) technology, two open-source models have recently garnered everyone's attention: OpenVoice and MetaVoice. Each model has unique capabilities in...
Read More
MetaVoice Text-to-Speech gpu benchmark on SaladCloud

MetaVoice AI Text-to-Speech (TTS) Benchmark: Narrate 100,000 words for only $4.29 on Salad

Note: Do not miss out on listening to voice clones of 10 different celebrities reading Harry Potter and the Sorcerer's Stone towards the end of the blog. Introduction to MetaVoice-1B...
Read More

Don’t miss anything!

Subscribe To SaladCloud Newsletter & Stay Updated.