“”“Process the WikiText dataset for training the BERT model. Using Hugging Face
datasets library.
““”
import time
import random
from typing import Iterator
import tokenizers
from datasets import load_dataset, Dataset
# path and name of each dataset
DATASETS = {
“wikitext-2”: (“wikitext”, “wikitext-2-raw-v1”),
“wikitext-103”: (“wikitext”, “wikitext-103-raw-v1”),
}
PATH, NAME = DATASETS[“wikitext-103”]
TOKENIZER_PATH = “wikitext-103_wordpiece.json”
def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]:
“”“Load wikitext dataset and extract text as documents”“”
dataset = load_dataset(path, name, split=“train”)
docs: list[list[list[int]]] = []
for line in dataset[“text”]:
line = line.strip()
if not line or line.startswith(“=”):
docs.append([]) # new document encountered
else:
tokens = tokenizer.encode(line).ids
docs[–1].append(tokens)
docs = [doc for doc in docs if doc] # remove empty documents
return docs
def create_dataset(
docs: list[list[list[int]]],
tokenizer: tokenizers.Tokenizer,
max_seq_length: int = 512,
doc_repeat: int = 10,
mask_prob: float = 0.15,
short_seq_prob: float = 0.1,
max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
“”“Generate samples from all documents”“”
doc_indices = list(range(len(docs))) * doc_repeat
for doc_idx in doc_indices:
yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq)
def generate_samples(
doc_idx: int,
all_docs: list[list[list[int]]],
tokenizer: tokenizers.Tokenizer,
max_seq_length: int = 512,
mask_prob: float = 0.15,
short_seq_prob: float = 0.1,
max_predictions_per_seq: int = 20,
) -> Iterator[dict]:
“”“Generate samples from a given document”“”
# number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP]
target_length = max_seq_length – 3
if random.random() < short_seq_prob:
# shorter sequence is used 10% of the time
target_length = random.randint(2, target_length)
# copy the document
chunks = []
for chunk in all_docs[doc_idx]:
chunks.append(chunk)
# exhaust chunks and create samples
while chunks:
# scan until target token length
running_length = 0
end = 1
while end < len(chunks) and running_length < target_length:
running_length += len(chunks[end–1])
end += 1
# randomly separate the chunk into two segments
sep = random.randint(1, end–1) if end > 1 else 1
sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]
sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]
# sentence B: may be from another document
if not sentence_b or random.random() < 0.5:
# find another document (must not be the same as doc_idx)
b_idx = random.randint(0, len(all_docs)–2)
if b_idx >= doc_idx:
b_idx += 1
# sentence B starts from a random position in the new document
sentence_b = []
running_length = len(sentence_a)
i = random.randint(0, len(all_docs[b_idx])–1)
while i < len(all_docs[b_idx]) and running_length < target_length:
sentence_b.extend(all_docs[b_idx][i])
running_length += len(all_docs[b_idx][i])
i += 1
is_random_next = True
chunks = chunks[sep:]
else:
is_random_next = False
chunks = chunks[end:]
# create a sample from the pair
yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq)
def create_sample(
sentence_a: list[list[int]],
sentence_b: list[list[int]],
is_random_next: bool,
tokenizer: tokenizers.Tokenizer,
max_seq_length: int = 512,
mask_prob: float = 0.15,
max_predictions_per_seq: int = 20,
) -> dict:
“”“Create a sample from a pair of sentences”“”
# Collect id of special tokens
cls_id = tokenizer.token_to_id(“[CLS]”)
sep_id = tokenizer.token_to_id(“[SEP]”)
mask_id = tokenizer.token_to_id(“[MASK]”)
pad_id = tokenizer.padding[“pad_id”]
# adjust length to fit the max sequence length
truncate_seq_pair(sentence_a, sentence_b, max_seq_length–3)
num_pad = max_seq_length – len(sentence_a) – len(sentence_b) – 3
# create unmodified tokens sequence
tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)
seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [–1] * num_pad
assert len(tokens) == len(seg_id) == max_seq_length
# create the prediction targets
cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]
random.shuffle(cand_indices)
num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))
num_predictions = min(max_predictions_per_seq, max(1, num_predictions))
mlm_positions = sorted(cand_indices[:num_predictions])
mlm_labels = []
for i in mlm_positions:
mlm_labels.append(tokens[i])
# prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original
if random.random() < 0.8:
tokens[i] = mask_id
elif random.random() < 0.5:
tokens[i] = random.randint(4, tokenizer.get_vocab_size()–1)
# randomly mask some tokens
ret = {
“tokens”: tokens,
“segment_ids”: seg_id,
“is_random_next”: is_random_next,
“masked_positions”: mlm_positions,
“masked_labels”: mlm_labels,
}
return ret
def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None:
“”“Truncate a pair of sequences until below a maximum sequence length.”“”
while len(sentence_a) + len(sentence_b) > max_num_tokens:
# pick the longer sentence to remove tokens from
candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b
# remove one token from either end in equal probabilities
if random.random() < 0.5:
candidate.pop(0)
else:
candidate.pop()
if __name__ == “__main__”:
print(time.time(), “started”)
tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
print(time.time(), “loaded tokenizer”)
docs = create_docs(PATH, NAME, tokenizer)
print(time.time(), “created docs with %d documents” % len(docs))
dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer})
print(time.time(), “created dataset from generator”)
# Save dataset to parquet file
dataset.to_parquet(“wikitext-103_train_data.parquet”)
print(time.time(), “saved dataset to parquet file”)
# Load dataset from parquet file
dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True)
print(time.time(), “loaded dataset from parquet file”)
# Print a few samples
for i, sample in enumerate(dataset):
print(i)
print(sample)
print()
if i >= 3:
break
print(time.time(), “finished”)
