File size: 4,103 Bytes
7ffa6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import random
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
from typing import Optional, Iterator

LANGUAGES = {
    "python": 0.35, "javascript": 0.20, "typescript": 0.15,
    "cpp": 0.10, "rust": 0.08, "go": 0.07, "java": 0.05,
}
LANG_TOKEN_MAP = {
    "python": "<|python|>", "javascript": "<|javascript|>",
    "typescript": "<|typescript|>", "cpp": "<|cpp|>",
    "rust": "<|rust|>", "go": "<|go|>", "java": "<|java|>",
}

class TheStackStreamDataset(IterableDataset):
    def __init__(self, tokenizer, max_length=2048, languages=None,
                 split="train", max_samples_per_lang=500_000, fim_rate=0.5):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.languages = languages or list(LANGUAGES.keys())
        self.split = split
        self.max_samples = max_samples_per_lang
        self.fim_rate = fim_rate

    def _get_lang_dataset(self, lang):
        return load_dataset(
            "bigcode/the-stack", data_dir=f"data/{lang}",
            split=self.split, streaming=True, trust_remote_code=True,
        )

    def _tokenize(self, code, lang):
        text = f"{LANG_TOKEN_MAP.get(lang, '')}{code}"
        tokens = self.tokenizer(text, max_length=self.max_length, truncation=True)
        ids = tokens["input_ids"]
        if len(ids) < 64:
            return None
        return {"input_ids": torch.tensor(ids, dtype=torch.long),
                "labels": torch.tensor(ids, dtype=torch.long)}

    def _apply_fim(self, code):
        if random.random() > self.fim_rate:
            return code
        lines = code.split("\n")
        if len(lines) < 4:
            return code
        start = random.randint(1, len(lines) - 3)
        end = random.randint(start + 1, min(start + 10, len(lines) - 1))
        prefix = "\n".join(lines[:start])
        middle = "\n".join(lines[start:end])
        suffix = "\n".join(lines[end:])
        return f"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>{middle}"

    def __iter__(self):
        datasets = {}
        for lang in self.languages:
            try:
                datasets[lang] = iter(self._get_lang_dataset(lang))
            except Exception as e:
                print(f"Warning: could not load {lang}: {e}")
        lang_list = list(datasets.keys())
        weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
        counts = {l: 0 for l in lang_list}
        while lang_list:
            lang = random.choices(lang_list, weights=weights, k=1)[0]
            try:
                sample = next(datasets[lang])
                code = sample.get("content", "")
                if not code.strip():
                    continue
                code = self._apply_fim(code)
                item = self._tokenize(code, lang)
                if item:
                    counts[lang] += 1
                    yield item
                if counts[lang] >= self.max_samples:
                    lang_list.remove(lang)
                    weights = [LANGUAGES.get(l, 1.0) for l in lang_list]
            except StopIteration:
                lang_list.remove(lang)
                weights = [LANGUAGES.get(l, 1.0) for l in lang_list]

class CodeCollator:
    def __init__(self, pad_token_id=0, max_length=2048):
        self.pad_id = pad_token_id
        self.max_length = max_length

    def __call__(self, batch):
        max_len = min(max(len(x["input_ids"]) for x in batch), self.max_length)
        input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
        labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
        attention_mask = torch.zeros(len(batch), max_len, dtype=torch.long)
        for i, item in enumerate(batch):
            length = min(len(item["input_ids"]), max_len)
            input_ids[i, :length] = item["input_ids"][:length]
            labels[i, :length] = item["labels"][:length]
            attention_mask[i, :length] = 1
        return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}