OLA-VLM / ola_vlm /train /probe_dsg_train_mem.py
praeclarumjj3's picture
:zap: add code
9fa3d89
raw
history blame
231 Bytes
from ola_vlm.train.probe_dsg_train import train
import torch.multiprocessing as mp
if __name__ == "__main__":
# try:
# train(attn_implementation="flash_attention_2")
# except:
train(attn_implementation="eager")