Distributed Finetuning
Hi,
Would you please write me how I can do distributed Finetuning with two PCs each has 4090 GPU? I think Data parallelization is good. If you can share the code, I appreciate it.
Thanks.
Hi @pooyahug
JAX's support for Single-Program Multi-Data (SPMD) programming allows you to train models efficiently across multiple devices.
Resources to Get Started:
- Fine-tuning PaliGemma with JAX and Flax:
You can explore a tutorial tailored to fine-tuning the PaliGemma model using JAX and Flax here. This will guide you through the setup and fundamental concepts of fine-tuning using these tools. - Introduction to Parallel Programming in JAX:
To leverage data parallelism, refer to the Introduction to Parallel Programming Guide. This tutorial covers SPMD techniques in JAX for running computations in parallel across multiple GPUs.
Thanks Selam,
I already saw this tutorial but it was complicated to run it on two PCs each has 4090 GPU. I think it is multi node distributed training. I searched about it with jax and found that jax doesn't support this way of distributed training. I have a local network and these two PCs are servers. I can define one of them as master with specific IP and the other as slave. I ran huggingface finetuning code with my custom dataset on 1 PC but it is too slow. I would like to distribute it on two PCs to speedup. I want help to change my code to be suitable for multi node distributed training using huggingface code.
If you want I can share my huggingface finetuning code running on 1 PC.
Thanks.