import torch import random from torch.utils.data import IterableDataset, DataLoader from datasets import load_dataset from transformers import PreTrainedTokenizerFast from typing import Optional, Iterator LANGUAGES = { "python": 0.35, "javascript": 0.20, "typescript": 0.15, "cpp": 0.10, "rust": 0.08, "go": 0.07, "java": 0.05, } LANG_TOKEN_MAP = { "python": "<|python|>", "javascript": "<|javascript|>", "typescript": "<|typescript|>", "cpp": "<|cpp|>", "rust": "<|rust|>", "go": "<|go|>", "java": "<|java|>", } class TheStackStreamDataset(IterableDataset): def __init__(self, tokenizer, max_length=2048, languages=None, split="train", max_samples_per_lang=500_000, fim_rate=0.5): self.tokenizer = tokenizer self.max_length = max_length self.languages = languages or list(LANGUAGES.keys()) self.split = split self.max_samples = max_samples_per_lang self.fim_rate = fim_rate def _get_lang_dataset(self, lang): return load_dataset( "bigcode/the-stack", data_dir=f"data/{lang}", split=self.split, streaming=True, trust_remote_code=True, ) def _tokenize(self, code, lang): text = f"{LANG_TOKEN_MAP.get(lang, '')}{code}" tokens = self.tokenizer(text, max_length=self.max_length, truncation=True) ids = tokens["input_ids"] if len(ids) < 64: return None return {"input_ids": torch.tensor(ids, dtype=torch.long), "labels": torch.tensor(ids, dtype=torch.long)} def _apply_fim(self, code): if random.random() > self.fim_rate: return code lines = code.split("\n") if len(lines) < 4: return code start = random.randint(1, len(lines) - 3) end = random.randint(start + 1, min(start + 10, len(lines) - 1)) prefix = "\n".join(lines[:start]) middle = "\n".join(lines[start:end]) suffix = "\n".join(lines[end:]) return f"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>{middle}" def __iter__(self): datasets = {} for lang in self.languages: try: datasets[lang] = iter(self._get_lang_dataset(lang)) except Exception as e: print(f"Warning: could not load {lang}: {e}") lang_list = list(datasets.keys()) weights = [LANGUAGES.get(l, 1.0) for l in lang_list] counts = {l: 0 for l in lang_list} while lang_list: lang = random.choices(lang_list, weights=weights, k=1)[0] try: sample = next(datasets[lang]) code = sample.get("content", "") if not code.strip(): continue code = self._apply_fim(code) item = self._tokenize(code, lang) if item: counts[lang] += 1 yield item if counts[lang] >= self.max_samples: lang_list.remove(lang) weights = [LANGUAGES.get(l, 1.0) for l in lang_list] except StopIteration: lang_list.remove(lang) weights = [LANGUAGES.get(l, 1.0) for l in lang_list] class CodeCollator: def __init__(self, pad_token_id=0, max_length=2048): self.pad_id = pad_token_id self.max_length = max_length def __call__(self, batch): max_len = min(max(len(x["input_ids"]) for x in batch), self.max_length) input_ids = torch.zeros(len(batch), max_len, dtype=torch.long) labels = torch.full((len(batch), max_len), -100, dtype=torch.long) attention_mask = torch.zeros(len(batch), max_len, dtype=torch.long) for i, item in enumerate(batch): length = min(len(item["input_ids"]), max_len) input_ids[i, :length] = item["input_ids"][:length] labels[i, :length] = item["labels"][:length] attention_mask[i, :length] = 1 return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}