Distributed Finetuning

#15
by pooyahug - opened

Hi,
Would you please write me how I can do distributed Finetuning with two PCs each has 4090 GPU? I think Data parallelization is good. If you can share the code, I appreciate it.
Thanks.

Google org

Hi @pooyahug

JAX's support for Single-Program Multi-Data (SPMD) programming allows you to train models efficiently across multiple devices.

Resources to Get Started:

  • Fine-tuning PaliGemma with JAX and Flax:
    You can explore a tutorial tailored to fine-tuning the PaliGemma model using JAX and Flax here. This will guide you through the setup and fundamental concepts of fine-tuning using these tools.
  • Introduction to Parallel Programming in JAX:
    To leverage data parallelism, refer to the Introduction to Parallel Programming Guide. This tutorial covers SPMD techniques in JAX for running computations in parallel across multiple GPUs.

Thanks Selam,

I already saw this tutorial but it was complicated to run it on two PCs each has 4090 GPU. I think it is multi node distributed training. I searched about it with jax and found that jax doesn't support this way of distributed training. I have a local network and these two PCs are servers. I can define one of them as master with specific IP and the other as slave. I ran huggingface finetuning code with my custom dataset on 1 PC but it is too slow. I would like to distribute it on two PCs to speedup. I want help to change my code to be suitable for multi node distributed training using huggingface code.
If you want I can share my huggingface finetuning code running on 1 PC.
Thanks.

Sign up or log in to comment