{ "cells": [ { "cell_type": "markdown", "source": [ "Training an LLM Base Model on Custom Dataset" ], "metadata": { "id": "3ss0bmUVAvh2" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S4ltaEGQkF_6" }, "outputs": [], "source": [ "# !pip install transformers\n", "# !pip install tokenizers\n", "# !pip install --upgrade accelerate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "94gs7C92kXAT" }, "outputs": [], "source": [ "from transformers import TextDataset, DataCollatorForLanguageModeling\n", "from transformers import Trainer, TrainingArguments\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", "import torch\n", "torch.random.manual_seed(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7_0iV8aukXDL", "colab": { "base_uri": "https://localhost:8080/", "height": 912, "referenced_widgets": [ "d39964b1dd7d4a718f92f90ce94332dd", "b042b62836054249947010107c0f0c99", "26918b54b15f4448af2a35f8ae2b6bf4", "b6e7f7488ad1430ba0a7c9b90c07c311", "5116779d4df74402b282893d42ee41c4", "d7e182df921a4f389b6b47298942b95d", "d5349f7408c348cdb741d984c69abaa4", "6e55129b390e46bbb5638b55c10b47ca", "ba66f43f624a4fcfa32903d8e5f3a983", "38aeb0a88a1a444ba13fec2b58dd7b73", "055e02a289934bbeaffee2220350cd59", "4128291d55784d1f998ba3347117fb5d", "d423036f5b3047659c003968b104af4e", "af2fe09e5b88431e84540334a85c8fc2", "447ccf81d809484f873daad7c9389894", "14825412ed2b4f4d8b3e6e7f6eb55e0e", "0288733db9df417e8163a408521c9b57", "49faa28041cd441b95f83e1f4ed069db", "e57f48c32dc44ccab64c105885a51a81", "8614d68854774bd2ad50ec6be11ed2a0", "736c4b67778e4a27995f6e5bf1ec3526", "93559149b74b462190b2c5b97dff41cc", "50de155693054877bb0aee4f8cf7a608", "5f986db655e1429c981f21e54d530e4c", "11f74feb8827455fb1a2d0ea8b42f162", "6622f31cc14a4689a22ec3b2b329303e", "a44b0af62c55490cb6b56e8fa1802749", "073f8bc5620a4f0c835ab8522a7fc280", "56e5695326ce47c4beaac8eb4baeac5b", "c6c4f6691ca24f3aacfe232da2ebb309", "07c7c7cca5944caa97aaaa011890feb8", "f50624a3203e4f0b8ff521cbde71663b", "7cf53d8e281e40719866c7eeec2afd36", "50309f466cd04e54b61d54d0fee9d0cf", "04830a6dad02436f9477822ed7a4e23f", "000b0590e5a2451d8f34fa711d262d98", "4cb0d3a3b05d4354917658e273cc9d09", "e90c224cfa944717a12bcae72d0f4e7a", "b7fbecd6c5c04404965bf2db40239de1", "5d9f0aeefdbe4512a8200a02c0362948", "6fe72bd4f3854994817729bcbc493d2b", "fbeaf890e6094963a709ce69de46d242", "a6ba9bf71778454aa42c1795897dfb4e", "8290a8ee65fc4450bbe2c978ac6481bb", "53c13c1db22743b6abafed0ac0454d98", "f805e29e0b404084932bde75f77c03ac", "a1e4d1313a504872ae2c0b241be775a4", "7a447f6843a2415b9505713db83a6756", "915c48bf441e4d5da335c8a5d0502bb6", "2baa04fc61d44f4ea811f4a205c08e61", "58a84bcc664b40ea99fd3c9ec30fcaf1", "26f44639cb0f419582a1505d814a94a1", "3f7a3c54b01948aa9c632cdaa3aca6ea", "09d218dcb90c44edae1f7ddaa75c7e1a", "debfe9fe53584f82b6705c9b56b0c4e7", "5bfbd6fbb75d42fab78a7d12b50edb25", "7e340cf45651483f9c90b7367255d44d", "c01d6632c2a04af8bd4a55548f52ce7e", "ae71ac39dd734228a3746fbd1bf02c65", "dbc04a51b3fb47ada00e931d2d7e0f83", "985a5427be8347959ea01d553ccc7c17", "45bde5b2a4294367bb1b86ea13569cdb", "7ff5a7f74db2405eb3a4a593349a365e", "bd7dcbb232b049caaf8474c5f8d9490c", "4dca7e0901684c1b8d657c704707ce1b", "d0b2639ec57f405da4b31aa00add6eea", "cc69db22289e4ce59337c7c641fd16a1", "313e45ed0d6d4cf7976a833a2448a236", "ba4a19d1417d4ac2a0a3b31bf2fac626", "55dce05daea14b5db581d0f7fdd41ea9", "092635cbd10d45e4a39edfc237d937e4", "79bcebffef0c435b831eb4f92d4f67fc", "0e987db5f542449cb945b87a64505d41", "47f1fdb58d1b4fd9b1b584bc5e7ea433", "6f938a36f9ca4308b4693c10517063e0", "70b48db507754bb9b78615e1b9c3eff3", "e41ff7480563485e912bd3be50faa0e7", "23508f96a9504f079537fd2edbc6bf57", "0f04d52523f84ec1b90ff01f941433f5", "b6dc1b57009b47c4b286e382e5899886", "e182994b9db84cf68df3a0492261cddf", "3a78bbedbea0465c93bd6111328dd6f3", "647ed9ad3ceb4a74b417b4e03829c267", "6c7a3fc9e7514e43acbaaf991a4cf2d4", "709a4747833a466fbd517d32daa102b1", "c887b342b39f494aa8e765a44c22ab38", "633baf65ad274b84ad24e582127c4f19", "a78ab6b418194d948a388d8206a875e5", "e07107c5b3e448d780f0cb4ee1261641", "02ad87e88eea4c798cf16f0b226b3023", "6d386ba8627d4598a6b6d03cea672092", "4219c8634e7c4463b60cbc55e7c1513b", "b5a7e7caab484c6aa407c03ccbbac140", "a61c19e6688646adb8c18b21bd78e022", "d670e4062ce84e46a72c5fecd44d45e8", "8da073ee7e664b79aa576675c17f6ccb", "96f6d7c5ca324fc4af397a516bdc7f44", "e2c6a718440f444b80a9d9822bb4553a", "4bcbe4c07803457db66b340c1ecd3be8", "d5a343c470b7494487be38b4cb203928", "65c89914fb81432c9d083f929a08cf0a", "1204223f7e6d4737bf0dd11fcd77d8b5", "c90b4364014a43ab823b5955300c8511", "b43728af1f9340dc8d3f5f8fe859723c", "7e7e4a74cf594737ac09926e3a218dda", "9cb08d30f0d74d2497486ebba9b2c5a7", "c8e826ff7a0a4792bd9e677e03049325", "e7f4cfafbbe842529866f42b515f42b5", "2139b13ad42c4095b6e039a142af5bdc", "643ea5486a6a4d9c802ddf08d696a9ba" ] }, "outputId": "01941e5e-cc94-48e8-91c5-abcbaa337c65" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "tokenizer_config.json: 0%| | 0.00/5.37k [00:00] 5.18M --.-KB/s in 0.06s \n", "\n", "2024-09-01 18:18:57 (82.5 MB/s) - ‘shakespeare.txt’ saved [5436475/5436475]\n", "\n" ] } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PE_ak0pNkXF0", "outputId": "9fcde11e-44d1-4918-e50b-eceafa70e29a" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/data/datasets/language_modeling.py:53: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py\n", " warnings.warn(\n" ] } ], "source": [ "# prepare the dataset\n", "\n", "train_path = 'shakespeare.txt'\n", "train_dataset = TextDataset(\n", " tokenizer=tokenizer,\n", " file_path=train_path,\n", " block_size=128\n", ")\n", "\n", "data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer,\n", " mlm=False\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AErVFzrakXHK" }, "outputs": [], "source": [ "# set up training args\n", "training_args = TrainingArguments(\n", " output_dir='./results',\n", " overwrite_output_dir=True,\n", " num_train_epochs=3,\n", " per_device_eval_batch_size=64,\n", " eval_steps=100,\n", " #save_steps=1000, # saves checkpoints, takes up disk\n", " warmup_steps=50,\n", " # logging\n", " logging_dir='./logs',\n", " logging_strategy=\"steps\",\n", " logging_steps=25,\n", " logging_first_step=True,\n", " report_to=[\"tensorboard\"],\n", "\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=train_dataset,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "um5prRF4kXJp", "outputId": "36ecf09f-be26-4a6d-8a2f-6166b0b4284d" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [4656/4656 45:09, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
13.289200
253.195100
503.382400
753.459100
1003.567700
1253.568800
1503.557200
1753.581600
2003.425700
2253.491600
2503.420500
2753.451900
3003.556100
3253.466900
3503.411900
3753.390200
4003.424500
4253.387800
4503.360000
4753.460000
5003.429500
5253.458000
5503.334800
5753.415500
6003.310700
6253.371000
6503.319700
6753.395400
7003.344300
7253.235000
7503.310300
7753.324300
8003.296200
8253.280800
8503.279700
8753.234800
9003.230800
9253.295100
9503.159400
9753.240300
10003.187900
10253.231200
10503.197900
10753.112400
11003.190200
11253.183500
11503.240200
11753.131100
12003.168000
12253.202500
12503.171100
12753.043600
13003.066800
13253.129000
13503.123000
13753.169500
14003.148600
14253.119000
14503.039100
14753.088700
15003.143300
15253.094000
15503.089400
15752.503100
16002.387900
16252.394800
16502.363800
16752.353600
17002.404800
17252.370900
17502.460000
17752.388400
18002.373300
18252.413900
18502.307400
18752.338300
19002.283300
19252.395700
19502.405500
19752.336500
20002.331700
20252.373600
20502.350900
20752.449000
21002.363100
21252.436800
21502.377200
21752.346000
22002.386100
22252.341500
22502.355800
22752.381200
23002.357700
23252.324900
23502.366300
23752.309800
24002.333500
24252.353800
24502.379700
24752.337700
25002.311100
25252.355500
25502.287700
25752.296300
26002.367100
26252.320200
26502.324200
26752.322800
27002.333600
27252.257800
27502.274100
27752.292500
28002.308200
28252.209500
28502.316600
28752.307500
29002.288900
29252.242300
29502.230800
29752.346700
30002.311700
30252.276300
30502.272500
30752.187100
31002.289200
31251.300800
31501.073100
31751.006100
32001.013200
32250.994400
32500.975700
32750.978000
33000.975800
33250.904700
33500.946200
33750.956400
34000.947600
34250.968400
34500.984700
34750.914800
35000.929300
35250.946900
35500.925000
35750.959300
36000.958300
36250.929800
36500.937600
36750.942800
37000.925400
37250.926600
37500.933500
37750.872000
38000.892800
38250.899300
38500.894300
38750.924800
39000.907700
39250.903700
39500.876600
39750.895000
40000.922400
40250.877600
40500.872300
40750.885800
41000.863900
41250.879400
41500.866700
41750.862600
42000.910600
42250.871400
42500.862100
42750.864400
43000.893700
43250.877400
43500.858900
43750.839400
44000.852300
44250.844400
44500.851700
44750.852300
45000.845800
45250.854800
45500.856700
45750.854700
46000.846400
46250.836100
46500.848600

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=4656, training_loss=2.180167237763962, metrics={'train_runtime': 2710.0517, 'train_samples_per_second': 13.742, 'train_steps_per_second': 1.718, 'total_flos': 3.3657646372356096e+16, 'train_loss': 2.180167237763962, 'epoch': 3.0})" ] }, "metadata": {}, "execution_count": 7 } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-Jwzlm_GkXND", "outputId": "3704579f-e973-4453-e8f5-eb92e53f9a6e" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('./results/tokenizer_config.json',\n", " './results/special_tokens_map.json',\n", " './results/tokenizer.json')" ] }, "metadata": {}, "execution_count": 8 } ], "source": [ "output_dir = './results'\n", "model.save_pretrained(output_dir)\n", "tokenizer.save_pretrained(output_dir)" ] }, { "cell_type": "code", "source": [ "del model\n", "del tokenizer" ], "metadata": { "id": "o4wHKGZV-kT7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xbYUpT1-mBdi", "colab": { "base_uri": "https://localhost:8080/", "height": 500, "referenced_widgets": [ "5bb18661e23f436d8c9a1c537acbfbc5", "dcf2881689b3410dbdd7018c298540ea", "1a17c356e18c430aa10748b59a06bedc", "d69a0d1051cf4bc89e82846c4ad5c78f", "820b2da1a46e477897940c8c27ad7037", "d0d0f2aab40c4ff79af8832cbde8a377", "a6e968af650245de9b91590ed5efc961", "791ee28c827d4ec88c178b08c62fc7f7", "0d2bc07297bc47d9ac48fb3db3d78107", "2610474d5cf34645bd5445d3d219d17c", "b76c108adbb8461abe5bb984b5f85186" ] }, "outputId": "dc49302b-cc28-48f5-abb2-55d6f67c3560" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00