CodeLLM / data /dataset.py
devoppro's picture
Create data/dataset.py
7ffa6dd verified
import torch
import random
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
from typing import Optional, Iterator
LANGUAGES = {
"python": 0.35, "javascript": 0.20, "typescript": 0.15,
"cpp": 0.10, "rust": 0.08, "go": 0.07, "java": 0.05,
}
LANG_TOKEN_MAP = {
"python": "<|python|>", "javascript": "<|javascript|>",
"typescript": "<|typescript|>", "cpp": "<|cpp|>",
"rust": "<|rust|>", "go": "<|go|>", "java": "<|java|>",
}
class TheStackStreamDataset(IterableDataset):
def __init__(self, tokenizer, max_length=2048, languages=None,
split="train", max_samples_per_lang=500_000, fim_rate=0.5):
self.tokenizer = tokenizer
self.max_length = max_length
self.languages = languages or list(LANGUAGES.keys())
self.split = split
self.max_samples = max_samples_per_lang
self.fim_rate = fim_rate
def _get_lang_dataset(self, lang):
return load_dataset(
"bigcode/the-stack", data_dir=f"data/{lang}",
split=self.split, streaming=True, trust_remote_code=True,
)
def _tokenize(self, code, lang):
text = f"{LANG_TOKEN_MAP.get(lang, '')}{code}"
tokens = self.tokenizer(text, max_length=self.max_length, truncation=True)
ids = tokens["input_ids"]
if len(ids) < 64:
return None
return {"input_ids": torch.tensor(ids, dtype=torch.long),
"labels": torch.tensor(ids, dtype=torch.long)}
def _apply_fim(self, code):
if random.random() > self.fim_rate:
return code
lines = code.split("\n")
if len(lines) < 4:
return code
start = random.randint(1, len(lines) - 3)
end = random.randint(start + 1, min(start + 10, len(lines) - 1))
prefix = "\n".join(lines[:start])
middle = "\n".join(lines[start:end])
suffix = "\n".join(lines[end:])
return f"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>{middle}"
def __iter__(self):
datasets = {}
for lang in self.languages:
try:
datasets[lang] = iter(self._get_lang_dataset(lang))
except Exception as e:
print(f"Warning: could not load {lang}: {e}")
lang_list = list(datasets.keys())
weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
counts = {l: 0 for l in lang_list}
while lang_list:
lang = random.choices(lang_list, weights=weights, k=1)[0]
try:
sample = next(datasets[lang])
code = sample.get("content", "")
if not code.strip():
continue
code = self._apply_fim(code)
item = self._tokenize(code, lang)
if item:
counts[lang] += 1
yield item
if counts[lang] >= self.max_samples:
lang_list.remove(lang)
weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
except StopIteration:
lang_list.remove(lang)
weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
class CodeCollator:
def __init__(self, pad_token_id=0, max_length=2048):
self.pad_id = pad_token_id
self.max_length = max_length
def __call__(self, batch):
max_len = min(max(len(x["input_ids"]) for x in batch), self.max_length)
input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
attention_mask = torch.zeros(len(batch), max_len, dtype=torch.long)
for i, item in enumerate(batch):
length = min(len(item["input_ids"]), max_len)
input_ids[i, :length] = item["input_ids"][:length]
labels[i, :length] = item["labels"][:length]
attention_mask[i, :length] = 1
return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}