File size: 1,117 Bytes
0b7b08a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import datasets
import os
from tqdm import tqdm
import webdataset as wds
import json

DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/the_pile/all/train"
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile"
SAMPLE_PER_SHARD = 100000

if __name__ == "__main__":
    os.makedirs(OUT_DIR)
    print("load dataset...")
    pile = datasets.load_from_disk(DATASET_ROOT)
    total_num = pile.num_rows
    print("total num:", total_num)
    num = 0
    pbar = tqdm(total=total_num)
    with wds.ShardWriter(OUT_DIR+"/%05d.tar", maxcount=SAMPLE_PER_SHARD, encoder=False) as sink:
        for sample in pile.iter(4096):
            for text, meta in zip(sample["text"], sample["meta"]):
                pbar.update(1)
                if meta.get("pile_set_name", None) == "Github":
                    continue
                num += 1
                sink.write({
                    '__key__': str(num),
                    'txt': text.encode("utf-8"),
                    'json': json.dumps(meta, indent=4).encode("utf-8"),
                })
    print(f"{num} out of {total_num} is written")