import os
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tokenizers
import tqdm
# Download novels from Project Gutenberg
DATASOURCE = {
“moby_dick”: “https://www.gutenberg.org/ebooks/2701.txt.utf-8”,
“frankenstein”: “https://www.gutenberg.org/ebooks/84.txt.utf-8”,
“dracula”: “https://www.gutenberg.org/ebooks/345.txt.utf-8”,
“little_women”: “https://www.gutenberg.org/ebooks/37106.txt.utf-8”,
“pride_and_prejudice”: “https://www.gutenberg.org/ebooks/1342.txt.utf-8”,
“alice_in_wonderland”: “https://www.gutenberg.org/ebooks/11.txt.utf-8”,
“crime_and_punishment”: “https://www.gutenberg.org/ebooks/2554.txt.utf-8”,
“tom_sawyer”: “https://www.gutenberg.org/ebooks/74.txt.utf-8”,
“tale_of_two_cities”: “https://www.gutenberg.org/ebooks/98.txt.utf-8”,
“sherlock_holmes”: “https://www.gutenberg.org/ebooks/1661.txt.utf-8”,
“war_and_peace”: “https://www.gutenberg.org/ebooks/2600.txt.utf-8”,
}
for filename, url in DATASOURCE.items():
if not os.path.exists(f“{filename}.txt”):
response = requests.get(url)
with open(f“{filename}.txt”, “wb”) as f:
f.write(response.content)
# Read and preprocess the text
def preprocess_gutenberg(filename):
with open(filename, “r”, encoding=“utf-8”) as f:
text = f.read()
# Find the start and end of the actual content
start = text.find(“*** START OF THE PROJECT GUTENBERG EBOOK”)
start = text.find(“\n”, start) + 1
end = text.find(“*** END OF THE PROJECT GUTENBERG EBOOK”)
# Extract the main content
text = text[start:end].strip()
# Basic preprocessing
# Remove multiple newlines and spaces
text = “\n”.join(line.strip() for line in text.split(“\n”) if line.strip())
return text
def get_dataset_text():
all_text = []
for filename in DATASOURCE:
text = preprocess_gutenberg(f“{filename}.txt”)
all_text.append(text)
return all_text
# Tokenization with BPE
if os.path.exists(“gutenberg_tokenizer.json”):
tokenizer = tokenizers.Tokenizer.from_file(“gutenberg_tokenizer.json”)
else:
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())
# Configure pre-tokenizer add space at beginning of the sentence
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True)
# Configure decoder so that would boundary symbol will be removed
tokenizer.decoder = tokenizers.decoders.ByteLevel()
# Train BPE
VOCAB_SIZE = 10000
trainer = tokenizers.trainers.BpeTrainer(
vocab_size=VOCAB_SIZE,
special_tokens=[“[pad]”, “[eos]”],
show_progress=True
)
text = get_dataset_text()
tokenizer.train_from_iterator(text, trainer=trainer)
tokenizer.enable_padding(pad_id=tokenizer.token_to_id(“[pad]”), pad_token=“[pad]”)
# Save the trained tokenizer
tokenizer.save(“gutenberg_tokenizer.json”, pretty=True)
# Create PyTorch dataset
class GutenbergDataset(torch.utils.data.Dataset):
def __init__(self, text, tokenizer, seq_len=512):
self.seq_len = seq_len
# Encode the entire text
self.encoded = tokenizer.encode(text).ids
def __len__(self):
return len(self.encoded) – self.seq_len
def __getitem__(self, idx):
chunk = self.encoded[idx:idx + self.seq_len + 1] # +1 for target
x = torch.tensor(chunk[:–1])
y = torch.tensor(chunk[1:])
return x, y
def rotate_half(x):
x1, x2 = x.chunk(2, dim=–1)
return torch.cat((–x2, x1), dim=–1)
def apply_rotary_pos_emb(x, cos, sin):
return (x * cos) + (rotate_half(x) * sin)
class RotaryPositionalEncoding(nn.Module):
def __init__(self, dim, max_seq_len=1024):
super().__init__()
N = 10000
inv_freq = 1. / (N ** (torch.arange(0, dim, 2).float() / dim))
position = torch.arange(max_seq_len).float()
inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)
sinusoid_inp = torch.outer(position, inv_freq)
self.register_buffer(“cos”, sinusoid_inp.cos())
self.register_buffer(“sin”, sinusoid_inp.sin())
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.size(1)
cos = self.cos[:seq_len].view(1, seq_len, 1, –1)
sin = self.sin[:seq_len].view(1, seq_len, 1, –1)
return apply_rotary_pos_emb(x, cos, sin)
class SwiGLU(nn.Module):
def __init__(self, hidden_dim, intermediate_dim):
super().__init__()
self.gate = nn.Linear(hidden_dim, intermediate_dim)
self.up = nn.Linear(hidden_dim, intermediate_dim)
self.down = nn.Linear(intermediate_dim, hidden_dim)
self.act = nn.SiLU()
def forward(self, x):
x = self.act(self.gate(x)) * self.up(x)
x = self.down(x)
return x
class GQA(nn.Module):
def __init__(self, hidden_dim, num_heads, num_kv_heads=None, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.head_dim = hidden_dim // num_heads
self.num_groups = num_heads // num_kv_heads
self.dropout = dropout
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, q, k, v, mask=None, rope=None):
q_batch_size, q_seq_len, hidden_dim = q.shape
k_batch_size, k_seq_len, hidden_dim = k.shape
v_batch_size, v_seq_len, hidden_dim = v.shape
# projection
q = self.q_proj(q).view(q_batch_size, q_seq_len, –1, self.head_dim).transpose(1, 2)
k = self.k_proj(k).view(k_batch_size, k_seq_len, –1, self.head_dim).transpose(1, 2)
v = self.v_proj(v).view(v_batch_size, v_seq_len, –1, self.head_dim).transpose(1, 2)
# apply rotary positional encoding
if rope:
q = rope(q)
k = rope(k)
# compute grouped query attention
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
output = F.scaled_dot_product_attention(q, k, v,
attn_mask=mask,
dropout_p=self.dropout,
enable_gqa=True)
output = output.transpose(1, 2).reshape(q_batch_size, q_seq_len, hidden_dim).contiguous()
output = self.out_proj(output)
return output
class DecoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, num_kv_heads, dropout=0.1):
super().__init__()
self.self_attn = GQA(hidden_dim, num_heads, num_kv_heads, dropout)
self.mlp = SwiGLU(hidden_dim, 4 * hidden_dim)
self.norm1 = nn.RMSNorm(hidden_dim)
self.norm2 = nn.RMSNorm(hidden_dim)
def forward(self, x, mask=None, rope=None):
# self-attention sublayer
out = self.norm1(x)
out = self.self_attn(out, out, out, mask, rope)
x = out + x
# MLP sublayer
out = self.norm2(x)
out = self.mlp(out)
return out + x
class TextGenerationModel(nn.Module):
def __init__(self, num_layers, num_heads, num_kv_heads, hidden_dim,
max_seq_len, vocab_size, dropout=0.1):
super().__init__()
self.rope = RotaryPositionalEncoding(hidden_dim // num_heads, max_seq_len)
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.decoders = nn.ModuleList([
DecoderLayer(hidden_dim, num_heads, num_kv_heads, dropout)
for _ in range(num_layers)
])
self.norm = nn.RMSNorm(hidden_dim)
self.out = nn.Linear(hidden_dim, vocab_size)
def forward(self, ids, mask=None):
x = self.embedding(ids)
for decoder in self.decoders:
x = decoder(x, mask, self.rope)
x = self.norm(x)
return self.out(x)
def create_causal_mask(seq_len, device):
“”“Create a causal mask for autoregressive attention.”“”
mask = torch.triu(torch.full((seq_len, seq_len), float(‘-inf’), device=device), diagonal=1)
return mask
# Training configuration
model_config = {
“num_layers”: 8,
“num_heads”: 8,
“num_kv_heads”: 4,
“hidden_dim”: 768,
“max_seq_len”: 512,
“vocab_size”: len(tokenizer.get_vocab()),
“dropout”: 0.1,
}
# Initialize model, optimizer, etc.
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
model = TextGenerationModel(**model_config).to(device)
# Create dataset and dataloader
BATCH_SIZE = 32
text = “\n”.join(get_dataset_text())
dataset = GutenbergDataset(text, tokenizer, seq_len=model_config[“max_seq_len”])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Training loop
if os.path.exists(“textgen_model.pth”):
model.load_state_dict(torch.load(“textgen_model.pth”))
else:
N_EPOCHS = 2
LR = 0.0005
WARMUP_STEPS = 2000
CLIP_NORM = 6.0
optimizer = optim.AdamW(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id(“[pad]”))
# Learning rate scheduling
warmup_scheduler = optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, end_factor=1.0, total_iters=WARMUP_STEPS)
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=N_EPOCHS * len(dataloader) – WARMUP_STEPS, eta_min=0)
scheduler = optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[WARMUP_STEPS])
print(f“Training for {N_EPOCHS} epochs with {len(dataloader)} steps per epoch”)
best_loss = float(‘inf’)
for epoch in range(N_EPOCHS):
model.train()
epoch_loss = 0
progress_bar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{N_EPOCHS}”)
for x, y in progress_bar:
x = x.to(device)
y = y.to(device)
# Create causal mask
mask = create_causal_mask(x.shape[1], device)
# Forward pass
optimizer.zero_grad()
outputs = model(x, mask.unsqueeze(0))
# Compute loss
loss = loss_fn(outputs.view(–1, outputs.shape[–1]), y.view(–1))
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(
model.parameters(), CLIP_NORM, error_if_nonfinite=True
)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
# Show loss in tqdm
progress_bar.set_postfix(loss=loss.item())
avg_loss = epoch_loss / len(dataloader)
print(f“Epoch {epoch+1}/{N_EPOCHS}; Avg loss: {avg_loss:.4f}”)
# Save checkpoint if loss improved
if avg_loss < best_loss:
best_loss = avg_loss
torch.save(model.state_dict(), “textgen_model.pth”)
# Generation function
def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7):
model.eval()
device = next(model.parameters()).device
# Encode the prompt
input_ids = torch.tensor(tokenizer.encode(prompt).ids).unsqueeze(0).to(device)
with torch.no_grad():
for _ in range(max_length):
# Get model predictions for the next token as the last element of the output
outputs = model(input_ids)
next_token_logits = outputs[:, –1, :] / temperature
# Sample from the distribution
probs = F.softmax(next_token_logits, dim=–1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to input_ids
input_ids = torch.cat([input_ids, next_token], dim=1)
# Stop if we predict the end token
if next_token[0].item() == tokenizer.token_to_id(“[eos]”):
break
return tokenizer.decode(input_ids[0].tolist())
# Test the model with some prompts
test_prompts = [
“Once upon a time,”,
“We the people of the”,
“In the beginning was the”,
]
print(“\nGenerating sample texts:”)
for prompt in test_prompts:
generated = generate_text(model, tokenizer, prompt)
print(f“\nPrompt: {prompt}”)
print(f“Generated: {generated}”)
print(“-“ * 80)