|
| 1 | +import os |
| 2 | +import logging |
| 3 | +import subprocess |
| 4 | +import sys |
| 5 | +from argparse import ArgumentParser |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +import datasets |
| 9 | +import lxml.html |
| 10 | +from datasets import config, load_from_disk |
| 11 | +from datasets.utils.logging import set_verbosity_info |
| 12 | + |
| 13 | +set_verbosity_info() |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | +# For `soup.decode_content` that can hit the limit |
| 17 | +sys.setrecursionlimit(10000) |
| 18 | + |
| 19 | + |
| 20 | +def get_args(): |
| 21 | + parser = ArgumentParser() |
| 22 | + parser.add_argument( |
| 23 | + "--dataset-path", |
| 24 | + type=str, |
| 25 | + required=True, |
| 26 | + help="path to the parquet dataset folder", |
| 27 | + ) |
| 28 | + parser.add_argument("--save-path", type=str, help="Where to save the datasets.") |
| 29 | + parser.add_argument("--use-datasets-caching", action="store_true") |
| 30 | + parser.add_argument( |
| 31 | + "--num-proc", type=int, default=1, help="Number of procs use for preprocessing." |
| 32 | + ) |
| 33 | + parser.add_argument( |
| 34 | + "--num-examples", |
| 35 | + type=int, |
| 36 | + default=None, |
| 37 | + help="Optional argument to select a subset (used for debugging purposes). Example `10`.", |
| 38 | + ) |
| 39 | + args = parser.parse_args() |
| 40 | + |
| 41 | + return args |
| 42 | + |
| 43 | + |
| 44 | +def main(): |
| 45 | + # Setup logging |
| 46 | + logging.basicConfig( |
| 47 | + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 48 | + datefmt="%m/%d/%Y %H:%M:%S", |
| 49 | + level=logging.INFO, |
| 50 | + ) |
| 51 | + args = get_args() |
| 52 | + logger.info( |
| 53 | + f"** The job is runned with the following arguments: **\n{args}\n **** " |
| 54 | + ) |
| 55 | + |
| 56 | + if os.path.isdir(args.save_path): |
| 57 | + logger.info(f"Seed id {args.save_path.split('/')[-1]} already processed") |
| 58 | + return |
| 59 | + |
| 60 | + if not args.use_datasets_caching: |
| 61 | + datasets.set_caching_enabled(False) |
| 62 | + else: |
| 63 | + logger.info( |
| 64 | + f"the datasets results will be cached at {config.HF_DATASETS_CACHE}." |
| 65 | + ) |
| 66 | + |
| 67 | + ds = load_from_disk(args.dataset_path) |
| 68 | + logger.info(f"the dataset is {ds}") |
| 69 | + |
| 70 | + if args.num_examples: |
| 71 | + ds = ds.select([i for i in range(args.num_examples)]) |
| 72 | + |
| 73 | + def detect_lang(example): |
| 74 | + if example["text"] is None or len(example["text"]) == 0: |
| 75 | + example["html_lang_attr"] = None |
| 76 | + else: |
| 77 | + root = lxml.html.fromstring(example["html_str"]) |
| 78 | + root_lang = root.attrib.get("lang") |
| 79 | + example["html_lang_attr"] = root_lang |
| 80 | + return example |
| 81 | + |
| 82 | + ds = ds.map( |
| 83 | + detect_lang, |
| 84 | + batched=False, |
| 85 | + num_proc=args.num_proc, |
| 86 | + ) |
| 87 | + |
| 88 | + if args.save_path: |
| 89 | + save_path = Path(args.save_path) |
| 90 | + else: |
| 91 | + save_path = Path(args.dataset_path) |
| 92 | + |
| 93 | + logger.info( |
| 94 | + f"Lang attribute detected for {len([e for e in ds['train']['html_lang_attr'] if e is not None])} rows out of {len(ds['train'])} rows." |
| 95 | + ) |
| 96 | + |
| 97 | + save_path_tmp = f"{str(save_path.absolute())}.tmp" |
| 98 | + logger.info(f"Saving the dataset at {save_path_tmp}") |
| 99 | + ds.save_to_disk(save_path_tmp) |
| 100 | + logger.info(f"Moving the saved dataset to {str(save_path.absolute())}") |
| 101 | + subprocess.run(["mv", save_path_tmp, str(save_path.absolute())]) |
| 102 | + |
| 103 | + |
| 104 | +if __name__ == "__main__": |
| 105 | + main() |
0 commit comments