Instructions to use ainz/tiny-recursive-model with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ainz/tiny-recursive-model with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ainz/tiny-recursive-model", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from transformers import PreTrainedModel, PretrainedConfig | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP | |
| from transformers.generation import GenerationMixin | |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
| import torch | |
| import torch.nn as nn | |
| class TRMConfig(PretrainedConfig): | |
| model_type = "recursive_gpt" | |
| def __init__( | |
| self, | |
| vocab_size=50257, | |
| n_positions=1024, | |
| n_embd=512, | |
| n_physical_layers=3, | |
| n_loops=8, | |
| n_head=8, | |
| activation_function="gelu_new", | |
| resid_pdrop=0.1, | |
| embd_pdrop=0.1, | |
| attn_pdrop=0.1, | |
| layer_norm_epsilon=1e-5, | |
| scale_attn_weights=True, | |
| scale_attn_by_inverse_layer_idx=False, | |
| reorder_and_upcast_attn=False, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.n_positions = n_positions | |
| self.n_embd = n_embd | |
| self.n_physical_layers = n_physical_layers | |
| self.n_loops = n_loops | |
| self.n_head = n_head | |
| self.activation_function = activation_function | |
| self.resid_pdrop = resid_pdrop | |
| self.embd_pdrop = embd_pdrop | |
| self.attn_pdrop = attn_pdrop | |
| self.layer_norm_epsilon = layer_norm_epsilon | |
| self.scale_attn_weights = scale_attn_weights | |
| self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx | |
| self.reorder_and_upcast_attn = reorder_and_upcast_attn | |
| # Required for transformers compatibility | |
| self.hidden_size = n_embd | |
| self.num_attention_heads = n_head | |
| self.num_hidden_layers = n_physical_layers | |
| self.n_inner = None | |
| self.is_encoder_decoder = False | |
| class TinyRecursiveModel(PreTrainedModel, GenerationMixin): | |
| config_class = TRMConfig | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # 1. Embeddings | |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.wpe = nn.Embedding(config.n_positions, config.n_embd) | |
| self.drop = nn.Dropout(config.embd_pdrop) | |
| # 2. Physical blocks - matching your saved model structure | |
| self.physical_blocks = nn.ModuleList([ | |
| nn.ModuleDict({ | |
| "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), | |
| "attn": GPT2Attention(config, layer_idx=i), | |
| "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), | |
| "mlp": GPT2MLP(4 * config.n_embd, config) | |
| }) for i in range(config.n_physical_layers) | |
| ]) | |
| # 3. Final layer norm | |
| self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | |
| # 4. Language modeling head | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| # Initialize weights | |
| self.post_init() | |
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
| if input_ids is None: | |
| return None | |
| batch_size, seq_len = input_ids.shape | |
| device = input_ids.device | |
| # Get embeddings | |
| token_embeds = self.wte(input_ids) | |
| pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) | |
| pos_embeds = self.wpe(pos_ids) | |
| hidden_states = self.drop(token_embeds + pos_embeds) | |
| # Apply recursive loops through physical blocks | |
| for loop in range(self.config.n_loops): | |
| block_idx = loop % self.config.n_physical_layers | |
| block = self.physical_blocks[block_idx] | |
| # Attention | |
| ln_output = block["ln_1"](hidden_states) | |
| attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0] | |
| hidden_states = hidden_states + attn_output | |
| # MLP | |
| ln_output = block["ln_2"](hidden_states) | |
| mlp_output = block["mlp"](ln_output) | |
| hidden_states = hidden_states + mlp_output | |
| # Final layer norm and projection | |
| hidden_states = self.ln_f(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| return CausalLMOutputWithCrossAttentions( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=hidden_states, | |
| attentions=None, | |
| cross_attentions=None | |
| ) | |
| def prepare_inputs_for_generation(self, input_ids, **kwargs): | |
| return {"input_ids": input_ids} | |
| def _reorder_cache(self, past, beam_idx): | |
| return past | |