| import torch |
| import random |
| from torch.utils.data import IterableDataset, DataLoader |
| from datasets import load_dataset |
| from transformers import PreTrainedTokenizerFast |
| from typing import Optional, Iterator |
|
|
| LANGUAGES = { |
| "python": 0.35, "javascript": 0.20, "typescript": 0.15, |
| "cpp": 0.10, "rust": 0.08, "go": 0.07, "java": 0.05, |
| } |
| LANG_TOKEN_MAP = { |
| "python": "<|python|>", "javascript": "<|javascript|>", |
| "typescript": "<|typescript|>", "cpp": "<|cpp|>", |
| "rust": "<|rust|>", "go": "<|go|>", "java": "<|java|>", |
| } |
|
|
| class TheStackStreamDataset(IterableDataset): |
| def __init__(self, tokenizer, max_length=2048, languages=None, |
| split="train", max_samples_per_lang=500_000, fim_rate=0.5): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.languages = languages or list(LANGUAGES.keys()) |
| self.split = split |
| self.max_samples = max_samples_per_lang |
| self.fim_rate = fim_rate |
|
|
| def _get_lang_dataset(self, lang): |
| return load_dataset( |
| "bigcode/the-stack", data_dir=f"data/{lang}", |
| split=self.split, streaming=True, trust_remote_code=True, |
| ) |
|
|
| def _tokenize(self, code, lang): |
| text = f"{LANG_TOKEN_MAP.get(lang, '')}{code}" |
| tokens = self.tokenizer(text, max_length=self.max_length, truncation=True) |
| ids = tokens["input_ids"] |
| if len(ids) < 64: |
| return None |
| return {"input_ids": torch.tensor(ids, dtype=torch.long), |
| "labels": torch.tensor(ids, dtype=torch.long)} |
|
|
| def _apply_fim(self, code): |
| if random.random() > self.fim_rate: |
| return code |
| lines = code.split("\n") |
| if len(lines) < 4: |
| return code |
| start = random.randint(1, len(lines) - 3) |
| end = random.randint(start + 1, min(start + 10, len(lines) - 1)) |
| prefix = "\n".join(lines[:start]) |
| middle = "\n".join(lines[start:end]) |
| suffix = "\n".join(lines[end:]) |
| return f"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>{middle}" |
|
|
| def __iter__(self): |
| datasets = {} |
| for lang in self.languages: |
| try: |
| datasets[lang] = iter(self._get_lang_dataset(lang)) |
| except Exception as e: |
| print(f"Warning: could not load {lang}: {e}") |
| lang_list = list(datasets.keys()) |
| weights = [LANGUAGES.get(l, 1.0) for l in lang_list] |
| counts = {l: 0 for l in lang_list} |
| while lang_list: |
| lang = random.choices(lang_list, weights=weights, k=1)[0] |
| try: |
| sample = next(datasets[lang]) |
| code = sample.get("content", "") |
| if not code.strip(): |
| continue |
| code = self._apply_fim(code) |
| item = self._tokenize(code, lang) |
| if item: |
| counts[lang] += 1 |
| yield item |
| if counts[lang] >= self.max_samples: |
| lang_list.remove(lang) |
| weights = [LANGUAGES.get(l, 1.0) for l in lang_list] |
| except StopIteration: |
| lang_list.remove(lang) |
| weights = [LANGUAGES.get(l, 1.0) for l in lang_list] |
|
|
| class CodeCollator: |
| def __init__(self, pad_token_id=0, max_length=2048): |
| self.pad_id = pad_token_id |
| self.max_length = max_length |
|
|
| def __call__(self, batch): |
| max_len = min(max(len(x["input_ids"]) for x in batch), self.max_length) |
| input_ids = torch.zeros(len(batch), max_len, dtype=torch.long) |
| labels = torch.full((len(batch), max_len), -100, dtype=torch.long) |
| attention_mask = torch.zeros(len(batch), max_len, dtype=torch.long) |
| for i, item in enumerate(batch): |
| length = min(len(item["input_ids"]), max_len) |
| input_ids[i, :length] = item["input_ids"][:length] |
| labels[i, :length] = item["labels"][:length] |
| attention_mask[i, :length] = 1 |
| return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask} |