Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
# API Reference — composer-replication-framework
Complete reference for every public symbol in `composer_replication`. Source-of-truth is the `.py` files in `composer_replication/`; docstrings have been pulled verbatim where they exist and supplemented where missing.
**Legend**
- ⚠️ **UNTESTED-CONTRACT** — symbol exists and is callable, but its behaviour is not pinned by an automated test in `composer_replication/**/tests/` or `spikes/**/tests/`.
- 🟡 **SKELETON** — class/method body raises `NotImplementedError`; ships as design-of-record per ADR-005 / ADR-006.
**Module groups (in this document)**
1. `composer_replication` (top-level re-exports)
2. `composer_replication.loss`
3. `composer_replication.batch`
4. `composer_replication.opsd`
5. `composer_replication.distillation`
6. `composer_replication.teacher_replay`
7. `composer_replication.replaysim`
8. `composer_replication.ingestion` (+ `.claude_code`)
9. `composer_replication.hint_generator`
10. `composer_replication.trainer` (+ `.composer_trainer`, `.data_collator`)
11. `composer_replication.diloco`
12. `composer_replication.diloco.serverless` (+ `.executor`, `.allreduce`, `.modal`, `.hf_jobs`, `.replica_entrypoint`)
13. `composer_replication.recipes.prime_rl.composer_loss`
14. `composer_replication.recipes.monarch.actors`
15. `composer_replication.diloco.serverless` — cloud executors (`.eks`, `.sagemaker`)
16. `composer_replication.datagen.docker_sandbox`
17. `composer_replication.safety` (+ `.kill_switch`)
---
## 1. `composer_replication` — top-level package
The package re-exports the most common entry points from sub-modules. `__all__` is the canonical list of public top-level names.
### `composer_replication.__version__: str`
Package version string. Currently `"0.1.0"`.
```python
import composer_replication
print(composer_replication.__version__) # "0.1.0"
```
### `composer_replication._DILOCO_AVAILABLE: bool`
`True` iff `torchft` is importable in the running Python environment (gates `make_diloco_outer_loop`). Set to `False` and `make_diloco_outer_loop` is set to `None` when `torchft` is missing.
```python
from composer_replication import _DILOCO_AVAILABLE
if _DILOCO_AVAILABLE:
from composer_replication import make_diloco_outer_loop
```
### Re-exports
| Name | Source module |
|---|---|
| `compose_loss` | `composer_replication.loss` |
| `LossComponents` | `composer_replication.loss` |
| `build_batch` | `composer_replication.batch` |
| `generalized_jsd_loss` | `composer_replication.opsd` |
| `ClaudeCodeIngester` | `composer_replication.ingestion.claude_code` |
| `IngestionStats` | `composer_replication.ingestion.claude_code` |
| `SYSTEM_PROMPT` | `composer_replication.ingestion.claude_code` |
| `DEFAULT_TEACHERS` | `composer_replication.teacher_replay` |
| `DPOPair` | `composer_replication.teacher_replay` |
| `TeacherCallResult` | `composer_replication.teacher_replay` |
| `TeacherSpec` | `composer_replication.teacher_replay` |
| `TraceState` | `composer_replication.teacher_replay` |
| `extract_dpo_pairs` | `composer_replication.teacher_replay` |
| `replay_trace` | `composer_replication.teacher_replay` |
| `ComposerReplicationTrainer` | `composer_replication.trainer` |
| `make_diloco_outer_loop` | `composer_replication.diloco` (or `None` if `torchft` missing) |
See each source module below for full signatures.
---
## 2. `composer_replication.loss`
Verification-harness 3-channel loss. Free function, does not depend on `trl`.
### `class LossComponents`
```python
@dataclass
class LossComponents:
lm_ce: torch.Tensor
sdpo_jsd: torch.Tensor
trace_replay_dpo: torch.Tensor
total: torch.Tensor
def detached(self) -> dict[str, float]: ...
```
Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar `torch.Tensor`s (`shape=()`); `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`.
**`detached() -> dict[str, float]`** — returns Python-float copies of all four fields with no grad. Useful for W&B logging.
```python
from composer_replication import compose_loss, build_batch
components = compose_loss(model, build_batch(tokenizer))
print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...}
components.total.backward()
```
### `compose_loss(model, inputs, *, ...) -> LossComponents`
<a id="compose_loss"></a>
```python
def compose_loss(
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
*,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
lm_ce_label_smoothing: float = 0.0,
dpo_variant: Literal["dpo", "simpo"] = "dpo",
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None,
simpo_beta: float = 2.0,
simpo_gamma: float = 1.0,
entropy_opd_h_max: float | None = None,
) -> LossComponents
```
Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`.
**Required keys in `inputs`**
- `input_ids`: `(B, T_s)` student rollout token ids.
- `response_mask`: `(B, T_s)` 1 on assistant-response tokens, 0 elsewhere.
**Optional keys** (channel auto-disables if missing OR if its weight = 0):
- SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`.
- DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed).
- SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored.
- TAID (`sdpo_wrapper="taid"`): no extra `inputs` keys needed; the optional `sdpo_loss_mask` is reused as the per-token TAID mask. Pass `taid_t` directly (or drive it from `TAIDScheduler`).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `model` | `torch.nn.Module` | — | HF causal-LM. Must accept `input_ids=` and return an object with `.logits`. |
| `inputs` | `dict[str, torch.Tensor]` | — | Batch dict (see required/optional keys above). |
| `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. |
| `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. |
| `sdpo_jsd_beta` | `float` | `0.5` | β param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when `sdpo_wrapper="taid"`. |
| `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. Unused when `sdpo_wrapper="taid"`. |
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. |
| `replay_dpo_beta` | `float` | `0.1` | β in standard DPO logit. |
| `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. |
| `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. |
| `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. |
| `taid_t` | `float \| None` | `None` | Current TAID interpolation coefficient in `[0, 1]`. Required when `sdpo_wrapper="taid"`. Drive from `TAIDScheduler` or pass a fixed value. |
| `simpo_beta` | `float` | `2.0` | SimPO β (paper default). |
| `simpo_gamma` | `float` | `1.0` | SimPO target margin γ (paper default). |
| `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None``log(V)`. |
**Returns** `LossComponents` (see above).
**Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without `taid_t`, or if `taid_t` is outside `[0, 1]`.
```python
from composer_replication import compose_loss, build_batch
batch = build_batch(tokenizer)
out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
out.total.backward()
print(out.detached())
```
---
## 3. `composer_replication.batch`
Verification-harness batch builder.
### `build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]`
```python
def build_batch(
tokenizer: Any,
*,
device: torch.device | str = "cpu",
seed: int = 42,
variant: str = "factorial",
align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]
```
Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute).
**Returned keys**: `input_ids`, `response_mask`, `ctx_teacher_input_ids`, `sdpo_loss_mask`, `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs`.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `tokenizer` | HF `AutoTokenizer` (duck-typed) | — | Must support `apply_chat_template` and `__call__`. |
| `device` | `torch.device \| str` | `"cpu"` | Target device for all returned tensors. |
| `seed` | `int` | `42` | Fixes `torch.manual_seed`. |
| `variant` | `str` | `"factorial"` | One of `"factorial"`, `"binary_search"`. |
| `align_sdpo_shapes` | `bool` | `False` | If True, truncate/pad `ctx_teacher_input_ids` to `input_ids` length so the SDPO channel actually fires. |
**Raises** `ValueError` if `variant` is unknown.
```python
from transformers import AutoTokenizer
from composer_replication import build_batch
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True)
print({k: v.shape for k, v in batch.items()})
```
---
## 4. `composer_replication.opsd`
Self-distillation generalized-JSD loss, lifted verbatim from `siyan-zhao/OPSD` (MIT) per ADR-006.
### `generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor`
```python
def generalized_jsd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor | None = None,
beta: float = 0.5,
temperature: float = 1.0,
reduction: str = "batchmean",
logits_are_probs: bool = False,
top_k: int | None = None,
token_clip: float | None = None,
) -> torch.Tensor
```
Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `student_logits` | `Tensor (B, T, V)` | — | Student logits with grad. |
| `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits (no grad in SDPO). |
| `labels` | `Tensor (B, T) \| None` | `None` | Per-token mask. `-100` positions are ignored (HF convention). |
| `beta` | `float` in [0, 1] | `0.5` | 0=fwd KL, 1=rev KL, 0.5=symmetric JSD. |
| `temperature` | `float` | `1.0` | Softmax temperature. |
| `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. |
| `logits_are_probs` | `bool` | `False` | Skip softmax if inputs are already probabilities. |
| `top_k` | `int \| None` | `None` | Restrict KL to teacher's top-k tokens. |
| `token_clip` | `float \| None` | `None` | Clip per-token JSD for stability. |
**Returns** scalar tensor (or `(B, T)` if `reduction="none"`).
**Raises** `ValueError` for unknown `reduction`.
```python
import torch
from composer_replication.opsd import generalized_jsd_loss
s = torch.randn(2, 8, 32, requires_grad=True)
t = torch.randn(2, 8, 32)
loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean")
loss.backward()
```
---
## 5. `composer_replication.distillation`
Pluggable self-distillation losses (ADR-007). All pure PyTorch.
### `simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor`
```python
def simpo_loss(
chosen_avg_logprobs: torch.Tensor,
rejected_avg_logprobs: torch.Tensor,
*,
beta: float = 2.0,
gamma: float = 1.0,
) -> torch.Tensor
```
Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). `L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ)`.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `chosen_avg_logprobs` | `Tensor (B,)` | — | Per-sequence avg logprob over chosen response tokens. |
| `rejected_avg_logprobs` | `Tensor (B,)` | — | Same for rejected. |
| `beta` | `float` | `2.0` | Scaling factor (paper default). |
| `gamma` | `float` | `1.0` | Target margin (paper default). |
**Returns** scalar; **Raises** `ValueError` if shapes mismatch.
```python
import torch
from composer_replication.distillation import simpo_loss
loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]),
beta=2.0, gamma=1.0)
```
### `avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor`
⚠️ UNTESTED-CONTRACT (helper exported from `simpo.py` but not asserted by a test).
```python
def avg_sequence_logprob(
model_logprobs: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor
```
Convert `(B, T)` per-token logprobs + `(B, T)` response mask into `(B,)` per-sequence average over response tokens.
```python
from composer_replication.distillation.simpo import avg_sequence_logprob
import torch
lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
out = avg_sequence_logprob(lp, m) # shape (2,)
```
### `taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor`
```python
def taid_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor | None = None,
*,
t: float | torch.Tensor,
) -> torch.Tensor
```
Faithful port of `SakanaAI/TAID` (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the **current student detached**:
```
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
```
At `t=0` the target collapses to the detached student (no teacher signal in the gradient). At `t=1` it reduces to standard forward-KL distillation against the teacher.
**Wave 15 breaking change.** The previous signature `taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction)` was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see `TAIDScheduler` below for the upstream adaptive scheme).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `student_logits` | `Tensor (B, T, V)` | — | Current student (with grad). |
| `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits. |
| `mask` | `Tensor (B, T) \| None` | `None` | Token mask. `None` ⇒ all-ones. |
| `t` | `float \| Tensor` | — | Interpolation coefficient in `[0, 1]`. |
**Raises** `ValueError` for shape mismatch.
```python
from composer_replication.distillation import taid_loss
loss = taid_loss(s_logits, t_logits, mask, t=0.4)
```
### `TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)`
Stateful schedule that mirrors upstream `TAID.update_t`. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as:
```python
from composer_replication.distillation import TAIDScheduler
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
for step in range(num_train_steps):
loss = taid_loss(s, t, mask, t=sched.t)
loss.backward(); optimizer.step()
sched.update_t(loss.detach(), global_step=step)
```
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `num_train_steps` | `int` | — | Total planned training steps; sets the linear floor. |
| `t_start` | `float` | `0.4` | Initial `t` (paper default). |
| `t_end` | `float` | `1.0` | Terminal `t`; hard ceiling at every step. |
| `alpha` | `float` | `5e-4` | Adaptive bump magnitude. |
| `beta` | `float` | `0.99` | EMA decay on relative-loss-change momentum. |
| `disable_adaptive` | `bool` | `False` | If True, fall back to deterministic linear schedule. |
| `device` | `torch.device \| str` | `"cpu"` | Where to allocate state buffers. |
**Properties / methods**
- `sched.t -> float` — current `t` as a Python float (zero-arg property).
- `sched.update_t(loss, global_step) -> Tensor | None` — update internal state. First finite-loss call only seeds `prev_loss` and returns `None`; subsequent calls return the (positive) `delta_t` added on top of the linear floor.
### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor`
```python
def entropy_aware_opd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
*,
labels: torch.Tensor | None = None,
h_max: float | None = None,
temperature: float = 1.0,
reduction: str = "batchmean",
) -> torch.Tensor
```
Per-token mixture of forward and reverse KL gated by teacher entropy: `w(t) = clamp(H_teacher(t)/h_max, 0, 1)`. High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `student_logits` | `Tensor (B,T,V)` | — | Student logits (grad). |
| `teacher_logits` | `Tensor (B,T,V)` | — | Teacher logits (no grad). |
| `labels` | `Tensor (B,T) \| None` | `None` | 0/1 mask, applied multiplicatively after the per-token mix. |
| `h_max` | `float \| None` | `None``log(V)` | Max-entropy normalizer. |
| `temperature` | `float` | `1.0` | Softmax temperature on both. |
| `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. |
**Raises** `ValueError` on shape mismatch (student vs teacher; labels vs per-token loss) or unknown `reduction`.
```python
from composer_replication.distillation import entropy_aware_opd_loss
loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0)
loss.backward()
```
### `teacher_entropy(teacher_logits) -> torch.Tensor`
⚠️ UNTESTED-CONTRACT (helper exposed from `entropy_aware_opd.py`'s `__all__` but not directly asserted).
Per-token entropy in nats. Input `(B,T,V)`, output `(B,T)`.
```python
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
H = teacher_entropy(teacher_logits) # (B, T)
```
---
## 6. `composer_replication.teacher_replay`
N-teacher OpenRouter parallel client + DPO-pair extractor. `httpx` is lazy-imported inside `replay_trace`; the deterministic local logic is testable without it.
### `DEFAULT_TEACHERS: list[TeacherSpec]`
Three-teacher default set: `anthropic/claude-opus-4.7`, `openai/gpt-5`, `deepseek/deepseek-v4-pro` with paper-baseline OpenRouter pricing.
```python
from composer_replication.teacher_replay import DEFAULT_TEACHERS
print([t["slug"] for t in DEFAULT_TEACHERS])
```
### `class TeacherSpec(TypedDict)`
```python
class TeacherSpec(TypedDict):
slug: str
input_per_mtok: float
output_per_mtok: float
```
OpenRouter model slug + per-million-token pricing.
```python
spec: TeacherSpec = {"slug": "openai/gpt-5",
"input_per_mtok": 1.25, "output_per_mtok": 10.0}
```
### `class TraceState(TypedDict)`
```python
class TraceState(TypedDict):
state_id: str # unique within the trace
messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt
student_action: str # what the student actually did at this step
```
One step of a frozen agentic trace. `student_action` is the raw text emitted by the student; teachers are queried with `messages` and asked to predict the assistant's next action.
```python
state: TraceState = {"state_id": "ex001::0042",
"messages": [{"role": "user", "content": "..."}],
"student_action": "[TOOL_USE] name=Read input={...}"}
```
### `class TeacherCallResult(TypedDict)`
```python
class TeacherCallResult(TypedDict):
state_id: str
teacher_slug: str
response_text: str | None # None on error
latency_s: float
prompt_tokens: int
completion_tokens: int
cost_usd: float
error: str | None # None on success
```
One row of N×T results from `replay_trace`.
```python
r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5",
"response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100,
"completion_tokens": 5, "cost_usd": 0.001, "error": None}
```
### `class DPOPair(TypedDict)`
```python
class DPOPair(TypedDict):
state_id: str
state_messages: list[dict]
chosen: str # teacher-consensus action
rejected: str # student action
n_teachers_agreeing: int
```
One preference pair extracted from teacher-vs-student disagreement.
```python
p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...",
"rejected": "...", "n_teachers_agreeing": 2}
```
### `async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]`
```python
async def replay_trace(
states: Sequence[TraceState],
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
max_total_usd: float = 5.0,
api_key: str | None = None,
) -> list[TeacherCallResult]
```
For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at `max_total_usd` (stops after the offending state completes).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `states` | `Sequence[TraceState]` | — | Frozen trace, one entry per assistant turn. |
| `teachers` | `Sequence[TeacherSpec]` | `DEFAULT_TEACHERS` | Models to query in parallel. |
| `max_total_usd` | `float` | `5.0` | Cumulative spend cap. |
| `api_key` | `str \| None` | `None` | OpenRouter key; defaults to `OPENROUTER_API_KEY` env or `~/.hermes/.env`. |
**Returns** flat list of `TeacherCallResult`s (length `len(states) * len(teachers)` modulo budget cutoff).
**Raises** `RuntimeError` if `OPENROUTER_API_KEY` is not findable; `ImportError` if `httpx` is missing at call time.
```python
import asyncio
from composer_replication import replay_trace
results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0))
```
### `extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]`
```python
def extract_dpo_pairs(
states: Sequence[TraceState],
teacher_actions: Sequence[TeacherCallResult],
agreement_threshold: int = 2,
) -> list[DPOPair]
```
Group teacher_actions by `state_id`, normalize whitespace, and emit one `DPOPair` per state where ≥`agreement_threshold` teachers agreed on an action that differs from the student's. `chosen` is the original (un-normalized) teacher response text.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `states` | `Sequence[TraceState]` | — | Same as passed to `replay_trace`. |
| `teacher_actions` | `Sequence[TeacherCallResult]` | — | Output of `replay_trace`. |
| `agreement_threshold` | `int` | `2` | Min teachers that must agree for a pair to fire. |
**Returns** list of `DPOPair`. At most one pair per state (the most-agreed-upon action wins).
```python
from composer_replication import extract_dpo_pairs
pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2)
```
### `save_pairs(pairs, path) -> None`
⚠️ UNTESTED-CONTRACT.
```python
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None
```
Write pairs to JSONL (one dict per line). Creates parent dirs.
```python
from composer_replication.teacher_replay import save_pairs
save_pairs(pairs, "/tmp/dpo_pairs.jsonl")
```
---
## 7. `composer_replication.replaysim`
ADR-004 normalization layer over `teacher_replay`. Re-exports `DPOPair`, `TeacherCallResult`, `extract_dpo_pairs`, `replay_trace` from `teacher_replay`.
### `class NormalizedDPOPair`
```python
@dataclass
class NormalizedDPOPair:
state_id: str
state_messages: list[dict[str, Any]]
chosen_messages: list[dict[str, Any]]
rejected_messages: list[dict[str, Any]]
n_teachers_agreeing: int
metadata: dict[str, Any]
```
Post-normalization shape. `chosen_messages`/`rejected_messages` are chat-format (`[{"role": "assistant", "content": ...}]`). `metadata` carries op-graph provenance, including `{"skipped": True}` when the normalizer was bypassed (`skip_dj=True`).
```python
from composer_replication.replaysim import NormalizedDPOPair
n = NormalizedDPOPair(state_id="x", state_messages=[],
chosen_messages=[{"role": "assistant", "content": "ok"}],
rejected_messages=[{"role": "assistant", "content": "no"}],
n_teachers_agreeing=2, metadata={})
```
### `class DJNormalizer`
```python
class DJNormalizer:
DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml
def __init__(
self,
recipe_path: str | os.PathLike[str] | None = None,
*,
skip_dj: bool = False,
) -> None: ...
def normalize(
self,
pairs: Iterable[DPOPair | dict[str, Any]],
) -> list[NormalizedDPOPair]: ...
```
`data-juicer`-backed normalizer. Pipeline: each `DPOPair` → JSONL record → `data_juicer.core.DefaultExecutor.run()` against the recipe → JSONL → `NormalizedDPOPair`.
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `recipe_path` | `str \| PathLike \| None` | `None` ⇒ default recipe | data-juicer YAML recipe path. |
| `skip_dj` | `bool` (kw-only) | `False` | If True: passthrough; records get `metadata={"skipped": True}` and no ops run. |
**`normalize(pairs) -> list[NormalizedDPOPair]`** runs the op-graph. Output may be shorter than input if filter ops drop records.
**Raises** `RuntimeError` at construction time if `skip_dj=False` and `data_juicer` is not importable. `FileNotFoundError` if `recipe_path` (default or explicit) is missing and `skip_dj=False`.
```python
from composer_replication.replaysim import DJNormalizer
norm = DJNormalizer(skip_dj=True)
out = norm.normalize(my_pairs)
```
### `async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]`
```python
async def replay_and_normalize_trace(
*,
states: Any,
teachers: Any = None,
agreement_threshold: int = 2,
max_total_usd: float = 5.0,
normalizer: DJNormalizer | None = None,
**replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
```
End-to-end async: replay → extract pairs → normalize.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `states` | `Sequence[TraceState]` | — | Frozen trace. |
| `teachers` | `Sequence[TeacherSpec] \| None` | `None` ⇒ defaults | Forwarded to `replay_trace`. |
| `agreement_threshold` | `int` | `2` | Forwarded to `extract_dpo_pairs`. |
| `max_total_usd` | `float` | `5.0` | Spend cap. |
| `normalizer` | `DJNormalizer \| None` | `None` ⇒ `DJNormalizer()` | Pass `DJNormalizer(skip_dj=True)` to bypass. |
| `**replay_kwargs` | `Any` | — | Forwarded to `replay_trace` (e.g. `api_key`). |
**Returns** `(raw_teacher_actions, normalized_pairs)`.
```python
import asyncio
from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer
raw, norm = asyncio.run(replay_and_normalize_trace(
states=my_states, normalizer=DJNormalizer(skip_dj=True)))
```
### `replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]`
⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via `asyncio.run`).
```python
def replay_and_normalize_trace_sync(*args, **kwargs) -> ...
```
Sync convenience wrapping `asyncio.run(replay_and_normalize_trace(...))`.
```python
from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync
raw, norm = replay_and_normalize_trace_sync(states=my_states)
```
---
## 8. `composer_replication.ingestion` & `composer_replication.ingestion.claude_code`
Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL.
### `SYSTEM_PROMPT: str`
Default synthetic system prompt injected at `messages[0]` for ingested traces (most Claude Code sessions don't write one). Truncated head: `"You are a senior software engineer working as a coding agent in a terminal environment..."`.
```python
from composer_replication import SYSTEM_PROMPT
print(SYSTEM_PROMPT[:60])
```
### `class IngestionStats`
```python
@dataclass
class IngestionStats:
n_records_total: int = 0
n_records_skipped: int = 0
n_states_emitted: int = 0
n_assistant_turns: int = 0
n_tool_use_blocks: int = 0
n_text_blocks: int = 0
skipped_subagent: int = 0
skipped_summary: int = 0
skipped_truncated_lines: int = 0
version_warnings: list[str] | None = None # initialized to [] in __post_init__
```
Counters populated by `ClaudeCodeIngester.ingest()` and exposed as `ingester.last_stats`.
```python
from composer_replication import IngestionStats
s = IngestionStats(n_records_total=5)
print(s.version_warnings) # []
```
### `class ClaudeCodeIngester`
```python
class ClaudeCodeIngester:
def __init__(
self,
*,
system_prompt: str = SYSTEM_PROMPT,
skip_sidechain: bool = True,
strip_thinking: bool = True,
max_history_tokens: int | None = None,
) -> None: ...
def ingest(self, path: Path) -> Iterator[TraceState]: ...
```
Convert a Claude Code session JSONL to a stream of `TraceState`s — one per assistant TURN (not per `tool_use` block).
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `system_prompt` | `str` | `SYSTEM_PROMPT` | Synthetic system message injected at history[0]. |
| `skip_sidechain` | `bool` | `True` | Skip subagent files (`agent-*.jsonl`) and records with `isSidechain=True`. |
| `strip_thinking` | `bool` | `True` | Remove `[THINKING]` blocks from history handed to teachers (kept inside `student_action`). |
| `max_history_tokens` | `int \| None` | `None` | ⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate. |
**`ingest(path) -> Iterator[TraceState]`**: generator over `TraceState` objects. Each turn's `state_id` is `f"{path.stem}::{idx:04d}"`. Side effect: replaces `self.last_stats` with a fresh `IngestionStats` and updates it as records stream.
```python
from pathlib import Path
from composer_replication import ClaudeCodeIngester
ing = ClaudeCodeIngester()
for state in ing.ingest(Path("session.jsonl")):
print(state["state_id"])
print(ing.last_stats.n_states_emitted)
```
---
## 9. `composer_replication.hint_generator`
⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test).
Template-based hint registry for SDPO error-site injection.
### `class HintContext(TypedDict, total=False)`
```python
class HintContext(TypedDict, total=False):
error_kind: str
error_message: str
available_tools: list[str]
tool_name: str
tool_schema: dict
intent: str
```
Per-error context dict consumed by hint templates.
### `HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]`
Default registry keys: `"tool_not_found"`, `"json_decode"`, `"type_error"`, `"runtime_error"`, `"repeated_failure"`.
### `dispatch(error_kind, ctx=None) -> str | None`
```python
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None
```
Look up `error_kind` in `HINT_TEMPLATES`. Returns the template's hint text, or `None` if the kind is unknown.
```python
from composer_replication.hint_generator import dispatch
hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..."
```
### `register(error_kind, fn) -> None`
```python
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None
```
Add or override a custom hint template.
```python
from composer_replication.hint_generator import register
register("my_error", lambda ctx: "Reminder: try X.")
```
### Individual template functions
⚠️ UNTESTED-CONTRACT — exported only via `HINT_TEMPLATES`, useful as building blocks:
- `hint_tool_not_found(ctx) -> str`
- `hint_json_decode(ctx) -> str`
- `hint_type_error(ctx) -> str`
- `hint_runtime_error(ctx) -> str`
- `hint_repeated_failure(ctx) -> str`
Each accepts a `HintContext` and returns hint text. Signatures are uniform: `Callable[[HintContext], str]`.
```python
from composer_replication.hint_generator import hint_tool_not_found
text = hint_tool_not_found({"available_tools": ["Read", "Write"]})
```
---
## 10. `composer_replication.trainer` & sub-modules
Production trainer (TRL `GRPOTrainer` subclass) plus data collator.
### `class ComposerReplicationTrainer`
```python
class ComposerReplicationTrainer(GRPOTrainer):
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
**kwargs: Any,
) -> None: ...
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor: ...
```
`trl.GRPOTrainer` subclass that overrides `_compute_loss(model, inputs)` to compose `total = grpo + α·sdpo + β·trace_replay_dpo`. When `trl` is not installed, the parent class falls back to `object` so the module imports — but instantiation will fail because the parent's GRPO machinery is missing.
**Constructor (kw-only beyond GRPOTrainer's own `*args, **kwargs`)**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `alpha_sdpo` | `float` | `0.1` | Channel-2 weight. |
| `beta_replay` | `float` | `0.05` | Channel-3 weight. |
| `sdpo_jsd_beta` | `float` | `0.5` | β for `generalized_jsd_loss`. |
| `sdpo_temperature` | `float` | `1.0` | SDPO softmax temperature. |
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clip. |
| `replay_dpo_beta` | `float` | `0.1` | DPO β. |
**`_compute_loss(model, inputs) -> torch.Tensor`** — overrides `GRPOTrainer._compute_loss`. Calls `super()._compute_loss` for channel 1, then `_compute_sdpo_loss` and `_compute_trace_replay_loss`, then composes. Logs per-channel components every `args.logging_steps` (default 50). **Raises** whatever `super()` raises (TRL-shaped errors).
**Internal methods (publicly accessible, exercised by spike tests)**
- ⚠️ UNTESTED-CONTRACT `_compute_sdpo_loss(model, inputs) -> torch.Tensor` — generalized-JSD between student forward and `ctx_teacher_input_ids` forward. Returns `0.0` (with grad) when `alpha_sdpo == 0`, the key is missing, or shapes mismatch. Logs a warning on shape mismatch.
- ⚠️ UNTESTED-CONTRACT `_compute_trace_replay_loss(model, inputs) -> torch.Tensor` — standard DPO over `dpo_chosen_*` and `dpo_rejected_*`, using precomputed `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`.
- ⚠️ UNTESTED-CONTRACT `@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor` — sum logprobs over response tokens; standard DPO accounting.
```python
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(
model=my_model, args=my_grpo_args, train_dataset=ds,
data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05,
)
# trainer.train() # uses overridden _compute_loss
```
### `make_dr_grpo_config(**overrides) -> trl.GRPOConfig`
Builds a `trl.GRPOConfig` configured to the **Dr. GRPO** recipe (Composer 2.5's
base objective per the Composer 2 tech report, arXiv:2603.24477; Dr.GRPO =
Liu et al. arXiv:2503.20783). Forces three knobs unless explicitly overridden,
with drift-guard assertions:
- `loss_type="dr_grpo"` — removes GRPO's length-standardization length bias.
- `scale_rewards="none"` — NO std-dev advantage normalization (Dr.GRPO requirement).
- `num_iterations=1` — single-epoch / strict on-policy.
Any field is overridable via kwargs (`learning_rate=`, `output_dir=`, `beta=`, …).
**Honest KL-estimator delta** (ADR-012 #1): TRL 1.5.0's `GRPOTrainer._compute_loss`
uses the **k3** estimator `exp(ref_logp−logp)−(ref_logp−logp)−1`, NOT the k1
estimator `−log r` the Dr.GRPO/Composer report frames; the delta is small for r≈1
and TRL is not monkeypatched — the delta is documented, not hidden. Exported from
both `composer_replication` and `composer_replication.trainer`.
```python
from composer_replication import make_dr_grpo_config
args = make_dr_grpo_config(output_dir="runs/x", learning_rate=1e-6)
```
### `make_po_config(objective="dr_grpo", **overrides) -> trl.GRPOConfig`
Builds a `trl.GRPOConfig` for a **named policy-optimization objective** from the
`PO_OBJECTIVES` menu (ADR-014). All presets are PURE CONFIG over trl 1.5.0's
`GRPOTrainer` (verified by introspection) — no custom `_compute_loss` needed.
`**overrides` set/override any `GRPOConfig` field on top.
- Raises `ValueError` on an unknown objective (lists the valid menu).
- Raises `AssertionError` if a requested knob silently failed to apply (drift guard;
e.g. GSPO guards `importance_sampling_level=="sequence"`).
```python
from composer_replication import make_po_config, PO_OBJECTIVES
args = make_po_config("dapo", output_dir="runs/dapo", learning_rate=2e-6)
```
### `PO_OBJECTIVES: dict[str, dict]`
The selectable base policy-optimization objectives (named presets over real trl
1.5.0 `GRPOConfig` knobs). Keys and what each sets:
| Objective | `loss_type` | `scale_rewards` | Distinguishing knob | Paper |
|---|---|---|---|---|
| `grpo` | `grpo` | `group` (std-norm) | IS=`token` | DeepSeekMath 2402.03300 |
| `dr_grpo` *(default)* | `dr_grpo` | `none` | length-bias removed | 2503.20783 |
| `bnpo` | `bnpo` | `batch` | batch-normalized | trl |
| `dapo` | `dapo` | `none` | `epsilon_high=0.28` (decoupled clip-higher), `mask_truncated_completions`, `beta=0` | 2503.14476 |
| `gspo` | `grpo` | `group` | `importance_sampling_level="sequence"` | Qwen 2507.18071 |
| `cispo` | `cispo` | `none` | `epsilon_high=5.0` (detached IS coef) | MiniMax-M1 2506.13585 |
> **Diagnostic gotcha:** for any PO-objective ablation, log the *distinguishing*
> diagnostic (`clip_ratio/high_mean` for DAPO, the sequence-level ratio for GSPO).
> A `0` means the knob never engaged — NOT that the objectives are equal. (This is
> exactly the inert-knob artifact the A1 DAPO-vs-Dr.GRPO washout hit at lr=1e-6.)
### `class TraceTurn(TypedDict, total=False)` — `trainer.data_collator`
```python
class TraceTurn(TypedDict, total=False):
role: str # "user" | "assistant" | "tool"
content: str
tool_call: dict | None
tool_error: str | None
error_meta: dict
```
One turn of an agentic trace as consumed by `ComposerDataCollator`.
### `class TraceExample(TypedDict, total=False)` — `trainer.data_collator`
```python
class TraceExample(TypedDict, total=False):
trace_id: str
turns: list[TraceTurn]
final_reward: float
dpo_pairs: list[dict] | None
```
One training example: `(turns, optional dpo_pairs)`. `dpo_pairs` shape matches `DPOPair`.
### `class TokenizerLike` — `trainer.data_collator`
⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint).
```python
class TokenizerLike:
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ...
def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ...
```
Minimal protocol the collator needs. Compatible with HF `AutoTokenizer`.
### `class CollatorConfig` — `trainer.data_collator`
```python
@dataclass
class CollatorConfig:
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
enable_replay_dpo: bool = True
rlvr_reward_key: str = "final_reward"
```
Tunables for `ComposerDataCollator`.
| Field | Default | Meaning |
|---|---|---|
| `max_seq_len` | `4096` | Truncation cap for student/teacher sequences. |
| `max_dpo_seq_len` | `2048` | Truncation cap for DPO chosen/rejected sequences. |
| `pad_token_id` | `0` | Padding token id. |
| `ignore_index` | `-100` | HF "ignore in loss" sentinel for SDPO mask. |
| `enable_sdpo` | `True` | Toggle channel-2 fields. |
| `hint_generator` | `Callable[[str, dict], str \| None] \| None` (`None`) | `(error_kind, error_meta) -> hint_text`. SDPO is no-op without this. |
| `enable_replay_dpo` | `True` | Toggle channel-3 fields. |
| `rlvr_reward_key` | `"final_reward"` | Key in `TraceExample` to read scalar reward. |
```python
from composer_replication.trainer.data_collator import CollatorConfig
cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch)
```
### `class ComposerDataCollator` — `trainer.data_collator`
```python
@dataclass
class ComposerDataCollator:
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor]: ...
```
Build trainer-ready batches from raw traces + optional DPO pairs.
**Output dict keys** (tested in `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`):
- Channel 1 (always): `input_ids`, `attention_mask`, `response_mask`, `rewards`.
- Channel 2 (when `enable_sdpo=True` AND batch has at least one error site AND `hint_generator` is set): `ctx_teacher_input_ids`, `sdpo_loss_mask`.
- Channel 3 (when `enable_replay_dpo=True` AND batch has at least one `dpo_pair`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`. (Reference logprobs are NOT computed here — the trainer does that pass.)
```python
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig)
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}])
```
---
## 11. `composer_replication.diloco`
DiLoCo outer-loop wrapper around `torchft.local_sgd.DiLoCo`. Optional dep — when `torchft` is missing the package re-export `composer_replication.make_diloco_outer_loop` is `None`.
### Module-level attributes
- `DiLoCo: Any``torchft.local_sgd.DiLoCo` if importable else `None`.
- `Manager: Any``torchft.manager.Manager` if importable else `None`.
- `_DummyWork: Any``torchft.work._DummyWork` if importable else `None`.
- `_TORCHFT_AVAILABLE: bool` — whether the imports succeeded.
```python
from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo
```
### `make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo`
```python
def make_diloco_outer_loop(
manager: Any,
model_fragments: list[torch.nn.Module],
inner_optimizer: torch.optim.Optimizer,
*,
outer_lr: float = 0.7,
outer_momentum: float = 0.9,
nesterov: bool = True,
sync_every: int = 100,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> Any
```
Construct a `torchft.DiLoCo` configured with framework-default hyperparams (DiLoCo paper §3.2: `lr=0.7, momentum=0.9, Nesterov`).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `manager` | `torchft.Manager` (or duck-typed `MockManager`) | — | Provides `allreduce`, `should_commit`, `current_step`, `start_quorum`, etc. |
| `model_fragments` | `list[torch.nn.Module]` | — | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. |
| `inner_optimizer` | `torch.optim.Optimizer` | — | Inner-step optimizer (steps every batch). |
| `outer_lr` | `float` | `0.7` | Outer SGD lr. |
| `outer_momentum` | `float` | `0.9` | Outer SGD momentum. |
| `nesterov` | `bool` | `True` | Nesterov momentum on outer SGD. |
| `sync_every` | `int` | `100` | Inner steps per outer round. |
| `fragment_sync_delay` | `int` | `0` | 0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). |
| `fragment_update_alpha` | `float` | `0.0` | 0 = full replacement on sync; >0 = exponential mix. |
**Returns** a `torchft.local_sgd.DiLoCo` instance — usable as a context manager.
**Raises** `RuntimeError` if `torchft` is not installed.
```python
import torch
from composer_replication.diloco import make_diloco_outer_loop
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model],
inner_optimizer=opt, sync_every=100)
with outer:
for _ in range(N):
opt.zero_grad(); loss.backward(); opt.step()
```
---
## 12. `composer_replication.diloco.serverless`
ADR-005 serverless DiLoCo executors + object-store all-reduce.
### `class ReplicaHandle` — `serverless.executor`
```python
@dataclass
class ReplicaHandle:
rank: int
backend_name: str
metadata: dict[str, Any] = field(default_factory=dict)
```
Opaque handle returned by `ServerlessExecutor.launch_replicas`. `metadata` is backend-specific.
```python
from composer_replication.diloco.serverless import ReplicaHandle
h = ReplicaHandle(rank=0, backend_name="local_process",
metadata={"pid": 12345})
```
### `class ServerlessExecutor` (Protocol) — `serverless.executor`
```python
@runtime_checkable
class ServerlessExecutor(Protocol):
backend_name: str
supports_inter_replica_network: bool
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]: ...
def poll(self, handle: ReplicaHandle) -> str: ...
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ...
def cancel(self, handle: ReplicaHandle) -> None: ...
def collect(
self, handles: list[ReplicaHandle], *, timeout: int | None = None,
) -> list[dict[str, Any]]: ...
```
Structural protocol for serverless backends.
- `launch_replicas(...)` returns `list[ReplicaHandle]` of length `n_replicas` in rank order. `entrypoint` is either an importable module path (uses `main()`) or a `module.function` path or a `Callable` (Local executor only). `entrypoint_args` may include `rank_env` (default `"REPLICA_RANK"`).
- `poll(handle) -> str`: one of `"pending"`, `"running"`, `"succeeded"`, `"failed"`, `"cancelled"`.
- `stream_logs(handle, n_lines=200) -> str`: best-effort recent stdout/stderr.
- `cancel(handle) -> None`: best-effort.
- `collect(handles, timeout=None) -> list[dict]`: blocks; each result dict has `rank`, `status`, `exit_code`, `error` (and `result` from `LocalProcessExecutor`).
```python
from composer_replication.diloco.serverless import ServerlessExecutor
def supports(x: ServerlessExecutor) -> bool:
return isinstance(x, ServerlessExecutor) # runtime_checkable
```
### `class LocalProcessExecutor` — `serverless.executor`
```python
class LocalProcessExecutor:
backend_name = "local_process"
supports_inter_replica_network = True
def __init__(self) -> None: ...
# implements ServerlessExecutor protocol
```
Reference implementation using Python `multiprocessing` (`spawn` context). Used for tests, CI smokes, and local development with `file://` rendezvous.
`launch_replicas(...)`: emits a soft warning on `gpu != None` (local processes share whatever GPUs are visible). `metadata = {"pid": ..., "start_ts": ...}`.
```python
from composer_replication.diloco.serverless import LocalProcessExecutor
ex = LocalProcessExecutor()
handles = ex.launch_replicas(
n_replicas=2,
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2,
"trainer_module": "my.trainer"},
)
results = ex.collect(handles, timeout=60)
```
### `class ObjectStoreAllReduce` — `serverless.allreduce`
```python
class ObjectStoreAllReduce:
def __init__(
self,
uri: str,
rank: int,
world_size: int,
*,
round_id: int | None = None,
timeout_s: float = 1800.0,
poll_interval_s: float = 1.0,
) -> None: ...
@property
def round_id(self) -> int: ...
def allreduce(
self, tensor: torch.Tensor, *, name: str | None = None,
) -> torch.Tensor: ...
```
fsspec-backed pseudo-gradient rendezvous. `uri` accepts `s3://`, `gs://`, `az://`, `hf://`, `file://`, or a plain local path.
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `uri` | `str` | — | fsspec URI or local path. Trailing `/` enforced. |
| `rank` | `int` | — | This replica's rank. |
| `world_size` | `int` | — | Total replicas. |
| `round_id` | `int \| None` (kw-only) | `None` ⇒ start at 0 | Initial round counter. |
| `timeout_s` | `float` (kw-only) | `1800.0` | Per-`allreduce` timeout. |
| `poll_interval_s` | `float` (kw-only) | `1.0` | Sleep between peer-file existence checks. |
**`allreduce(tensor, name=None) -> torch.Tensor`**: serializes `tensor.detach().cpu()` to `round_NNNNNN/rank_RRRR.pt`, blocks until all peers post, then averages. **Modifies `tensor` in place** AND returns it. Increments the internal `_round_counter`.
**Raises** `ValueError` on invalid `rank`, `RuntimeError` if non-local URI is requested without `fsspec` installed, `TimeoutError` if peers don't show up before `timeout_s`.
```python
from composer_replication.diloco.serverless import ObjectStoreAllReduce
import torch
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
g = torch.zeros(10)
store.allreduce(g) # blocks for rank 1
```
### `class MockManager` — `serverless.allreduce`
```python
class MockManager:
def __init__(self, store: ObjectStoreAllReduce) -> None: ...
# torchft.Manager-shaped surface:
num_participants: int
rank: int
_use_async_quorum: bool # always False
_step: int
_state_dict_fns: dict[str, tuple[Any, Any]]
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ...
def should_commit(self) -> bool: ...
def start_quorum(self) -> None: ...
def wait_quorum(self) -> int: ...
def current_step(self) -> int: ...
def allow_state_dict_read(self) -> None: ...
def disallow_state_dict_read(self) -> None: ...
def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ...
def is_leader(self) -> bool: ...
```
Drop-in replacement for `torchft.Manager` that routes `allreduce` through `ObjectStoreAllReduce`. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo.
- `allreduce(tensor)` returns an `_ImmediateWork` whose `.wait()` is a no-op (the tensor is already averaged).
- `should_commit()` always `True` (no fault-tolerance failover).
- `start_quorum()` bumps `_step`.
- `is_leader()` returns `rank == 0`.
```python
from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
mgr = MockManager(store)
# pass mgr into make_diloco_outer_loop(manager=mgr, ...)
```
### `class _ImmediateWork` — `serverless.allreduce`
⚠️ UNTESTED-CONTRACT internal helper exported from `__all__`. `Work`-shaped wrapper with `.wait() -> True` and `.get_future() -> torch.futures.Future`. Consumed by torchft DiLoCo's `perform_sync`.
```python
from composer_replication.diloco.serverless.allreduce import _ImmediateWork
```
### `class ModalExecutor` — `serverless.modal`
🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 implementation pattern (Modal `app.function` + `function.spawn(rank=...)`).
```python
from composer_replication.diloco.serverless.modal import ModalExecutor
# ModalExecutor() # would NotImplementedError when instantiated
```
### `class HFJobsExecutor` — `serverless.hf_jobs`
🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 pattern using `huggingface_hub.run_job` against `hf://datasets/.../` rendezvous.
```python
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
# instantiation will fail until v0 implementation lands
```
### `replica_entrypoint.main(...)` — `serverless.replica_entrypoint`
```python
def main(
rendezvous_uri: str,
world_size: int,
trainer_module: str,
trainer_fn: str = "train",
trainer_kwargs: dict[str, Any] | None = None,
) -> Any
```
Script run by every replica. Reads `REPLICA_RANK` env var, builds `ObjectStoreAllReduce` + `MockManager`, imports `trainer_module`, and calls `getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...)`. Returns whatever the train fn returns.
**Raises** `RuntimeError` if `REPLICA_RANK` env var is missing; `ValueError` if rank ∉ `[0, world_size)`.
The `if __name__ == "__main__"` block accepts CLI flags `--rendezvous`, `--world-size`, `--trainer-module`, `--trainer-fn`, `--trainer-kwargs-json`.
```python
# In-process invocation
import os
os.environ["REPLICA_RANK"] = "0"
from composer_replication.diloco.serverless.replica_entrypoint import main
result = main(rendezvous_uri="/tmp/run/", world_size=1,
trainer_module="my.trainer", trainer_fn="train")
```
---
## 13. `composer_replication.recipes.prime_rl.composer_loss`
PRIME-RL adapter (ADR-006). Maps PRIME-RL's `LossInputs` struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream `default_loss_fn` at `prime_rl/trainer/rl/loss.py` lines 116-165). Channel 2 raises `NotImplementedError`; channel 3 is out of scope.
### `loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor`
```python
def loss_fn(
inputs: Any, # PRIME-RL's LossInputs (duck-typed)
*,
alpha_sdpo: float = 0.0,
beta_dpo: float = 0.0,
dppo_mask_high: float = 0.2,
dppo_mask_low: float = 0.2,
adv_tau: float = 1.0,
kl_tau: float = 1e-3,
) -> Any # torch.Tensor scalar
```
PRIME-RL passes per-sample **1-D `(seq,)` tensors** (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula:
- Mask gate is on **probability-space** `probs_diff = exp(trainer_lp) - exp(inference_lp)` (NOT on the log-ratio).
- A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when `probs_diff > dppo_mask_high`, negative-advantage tokens when `probs_diff < -dppo_mask_low`. (PRIME-RL stores both bounds with `Field(..., ge=0)` and applies the sign internally.)
- The PG term is `keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)` (importance-ratio corrected, not REINFORCE).
- A KL penalty `kl_tau * log_importance_ratio**2` is added on the full `loss_mask` (DPPO masking does not gate it).
- Reduction is a plain `sum()`; PRIME-RL's outer `compute_loss` divides by `loss_scale`.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `inputs` | PRIME-RL `LossInputs` (duck-typed) | — | Must expose `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` (all 1-D), and optionally `teacher_logprobs`. |
| `alpha_sdpo` | `float` (kw-only) | `0.0` | Channel-2 weight. Must be `0` in v0; >0 → `NotImplementedError`. |
| `beta_dpo` | `float` (kw-only) | `0.0` | Channel-3 weight. Non-zero emits a `UserWarning`. |
| `dppo_mask_high` | `float` (kw-only), `>= 0` | `0.2` | Upper probability-diff threshold. PRIME-RL `DefaultLossConfig` default. |
| `dppo_mask_low` | `float` (kw-only), `>= 0` | `0.2` | Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. |
| `adv_tau` | `float` (kw-only), `>= 0` | `1.0` | Advantage temperature. PRIME-RL default. |
| `kl_tau` | `float` (kw-only), `>= 0` | `1e-3` | KL term temperature. PRIME-RL default. |
**Returns** scalar `torch.Tensor` (PRIME-RL's trainer calls `.backward()`).
**Raises** `ValueError` if any of `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` is not 1-D, or any of the four `>=0`-constrained knobs is negative. `NotImplementedError` if `alpha_sdpo > 0` (channel 2 deferred).
```python
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
# In PRIME-RL config:
# loss:
# custom:
# import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
# kwargs:
# dppo_mask_high: 0.2
# dppo_mask_low: 0.2
# adv_tau: 1.0
# kl_tau: 1.0e-3
```
---
## 14. `composer_replication.recipes.monarch.actors`
🟡 SKELETON module per ADR-006. Importable; classes raise `NotImplementedError` on instantiation. Documents the actor signatures so the recipe matrix is complete.
### `class TrainerActor` 🟡
```python
class TrainerActor:
backend = "monarch"
role = "trainer"
def __init__(self) -> None: raise NotImplementedError(...)
async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError
```
Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+.
### `class GeneratorActor` 🟡
```python
class GeneratorActor:
backend = "monarch"
role = "generator"
def __init__(self) -> None: raise NotImplementedError(...)
async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError
```
vLLM-backed rollout actor.
### `class RewarderActor` 🟡
```python
class RewarderActor:
backend = "monarch"
role = "rewarder"
def __init__(self) -> None: raise NotImplementedError(...)
async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError
```
verifiers-protocol rewarder.
### `class TeacherPoolActor` 🟡
```python
class TeacherPoolActor:
backend = "monarch"
role = "teacher_pool"
def __init__(self) -> None: raise NotImplementedError(...)
```
Channel-3 teacher pool wrapping `composer_replication.teacher_replay`.
```python
# All Monarch actors raise on instantiation in v0:
from composer_replication.recipes.monarch.actors import TrainerActor
# TrainerActor() # NotImplementedError
```
---
## 15. `composer_replication.diloco.serverless` — cloud executors
ADR-005 production cloud executors that implement the `ServerlessExecutor` Protocol (see §12) against real clouds. Both are the loud-success siblings of the `ModalExecutor` / `HFJobsExecutor` 🟡 skeletons: cross-replica communication is EXCLUSIVELY the S3 `ObjectStoreAllReduce` rendezvous, so the trainer / loss / DiLoCo math stay byte-for-byte identical regardless of backend. Both set `supports_inter_replica_network = False` (replicas rendezvous only through the object store) and both lazy-import their cloud SDK so `import composer_replication.diloco.serverless` is free of `kubernetes` / `boto3`.
**Extras**: `EKSExecutor` needs `kubernetes>=29` (`pip install -e .[eks]`); `SageMakerExecutor` needs `boto3>=1.34` (`pip install -e .[aws]`). The `[serverless]` extra pulls both (plus `fsspec` + `huggingface_hub` for the rendezvous backends). The SDK is required only at adapter-init / method time — when an api/client is injected (tests), a missing top-level package is tolerated.
### `class EKSExecutor` — `serverless.eks`
```python
class EKSExecutor:
backend_name = "eks"
supports_inter_replica_network = False # S3 rendezvous only
def __init__(
self,
image: str,
*,
namespace: str = "default",
service_account_name: str | None = None,
node_selector: dict[str, str] | None = None,
tolerations: list[Any] | None = None,
runtime_class_name: str | None = None,
command: list[str] | None = None,
cpu_request: str = "4",
memory_request: str = "16Gi",
ttl_seconds_after_finished: int = 3600,
backoff_limit: int = 0,
gpu_resource_key: str = "nvidia.com/gpu",
run_id: str | None = None,
batch_api: Any = None,
core_api: Any = None,
) -> None: ...
# implements ServerlessExecutor protocol (launch_replicas/poll/stream_logs/cancel/collect)
```
Run N DiLoCo replicas as a **single Kubernetes Indexed Job** on EKS (`completions == parallelism == n_replicas`, `completionMode="Indexed"`). The control plane assigns each pod a `JOB_COMPLETION_INDEX` 0..N-1 which IS the rank; the executor bridges it to the repo entrypoint's `REPLICA_RANK` env var via the downward API, so `replica_entrypoint` works unchanged. `launch_replicas` creates ONE Job but still returns N `ReplicaHandle`s (`handles[i].rank == i`) sharing the same `job_name`/`namespace` (gang semantics).
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `image` | `str` | — | Container image with `composer_replication` installed; runs the replica entrypoint. |
| `namespace` | `str` (kw-only) | `"default"` | k8s namespace for the Job. |
| `service_account_name` | `str \| None` (kw-only) | `None` | ServiceAccount referenced on the PodSpec for IRSA / EKS Pod Identity S3 access. REFERENCED only — never created. |
| `node_selector` | `dict[str,str] \| None` (kw-only) | `None` | Extra node selector merged under the GPU node selector. |
| `tolerations` | `list[Any] \| None` (kw-only) | `None` | PodSpec tolerations; a `nvidia.com/gpu` NoSchedule toleration is auto-added on GPU requests when none supplied. |
| `runtime_class_name` | `str \| None` (kw-only) | `None` | Pre-existing RuntimeClass (`"gvisor"` / `"kata"`). Combining with `gpu` is advanced/unverified (see source warning). |
| `command` | `list[str] \| None` (kw-only) | `None``["python","-m","composer_replication.diloco.serverless.replica_entrypoint"]` | Container command. |
| `cpu_request` / `memory_request` | `str` (kw-only) | `"4"` / `"16Gi"` | PodSpec resource requests. |
| `ttl_seconds_after_finished` | `int` (kw-only) | `3600` | Auto-delete the finished Job (cascading) after this many seconds. |
| `backoff_limit` | `int` (kw-only) | `0` | Job retry budget (fail-fast; NOT the k8s default of 6). |
| `gpu_resource_key` | `str` (kw-only) | `"nvidia.com/gpu"` | GPU resource key. |
| `run_id` | `str \| None` (kw-only) | `None``"diloco"` | Prefix baked into the generated Job name. |
| `batch_api` / `core_api` | `Any` (kw-only) | `None` | DI'd `BatchV1Api` / `CoreV1Api`; lazily built (in-cluster then kube-config) when `None`. Tests inject mocks. |
**`ServerlessExecutor` Protocol methods** (see §12 for the full Protocol)
- `launch_replicas(n_replicas, entrypoint, entrypoint_args, *, gpu=None, timeout=3600) -> list[ReplicaHandle]` — creates ONE Indexed Job; `entrypoint` is ignored (k8s runs the fixed container command), scalar `entrypoint_args` (e.g. `rendezvous_uri`) are passed as upper-cased literal env vars, `gpu` maps to a `nvidia.com/gpu` limit + node selector via the `_GPU_SPEC_TABLE` (`"A100"`/`"H100"`/`"A10G"`/`"T4"`), and `timeout` becomes the Job's `active_deadline_seconds` hard wall-clock kill.
- `poll(handle) -> str` — reads the single Job status and maps THIS rank: `rank in completed_indexes``"succeeded"`, `rank in failed_indexes``"failed"`, whole-Job `Failed` condition ⇒ `"failed"`, `active > 0``"running"`, else `"pending"`; a 404 (Job gone) ⇒ `"cancelled"`.
- `stream_logs(handle, *, n_lines=200) -> str` — finds the pod by completion-index annotation/label (or the `<job>-<rank>-` name prefix), tails its log; returns a placeholder string (never raises) when the pod has not started.
- `cancel(handle) -> None` — deletes the WHOLE shared Indexed Job (`propagation_policy="Background"` so pods are cascadingly deleted, not orphaned). Intentional gang teardown; idempotent (404 swallowed, unknown handle is a no-op).
- `collect(handles, *, timeout=None) -> list[dict]` — polls (sleeping between, status is eventually consistent) until every rank is terminal or the deadline; per-rank dict is `{"rank","status","exit_code","error","job_name"}` (`exit_code` 0/1/`None`).
**Raises** `RuntimeError` at construction if the `kubernetes` client is absent AND no api was injected; `ValueError` if `n_replicas < 1`.
```python
from composer_replication.diloco.serverless.eks import EKSExecutor
ex = EKSExecutor(image="ghcr.io/me/diloco:latest",
service_account_name="diloco-s3")
handles = ex.launch_replicas(
n_replicas=4,
entrypoint="", # ignored on k8s
entrypoint_args={"rendezvous_uri": "s3://bucket/run/",
"trainer_module": "my.trainer", "world_size": 4},
gpu="H100",
)
results = ex.collect(handles, timeout=3600)
```
### `class SageMakerExecutor` — `serverless.sagemaker`
```python
class SageMakerExecutor:
backend_name = "sagemaker"
supports_inter_replica_network = False # S3 rendezvous only
def __init__(
self,
*,
role_arn: str,
image_uri: str,
output_s3_path: str,
instance_type: str = "ml.g5.2xlarge",
cpu_instance_type: str = "ml.m5.xlarge",
volume_size_gb: int = 100,
run_id: str | None = None,
region: str | None = None,
sagemaker_client: Any = None,
logs_client: Any = None,
) -> None: ...
# implements ServerlessExecutor protocol
```
Run replicas as **N independent single-instance SageMaker Training Jobs** (NOT one multi-instance job — that would couple replicas through SageMaker's intra-job NCCL/MPI fabric and break the "each replica syncs only through S3" design). Rank goes through the per-job `Environment` map (`REPLICA_RANK=i` / `WORLD_SIZE=N`), so `replica_entrypoint` reads it unchanged. Pins `EnableNetworkIsolation=False` (never a knob) — `True` would sever the container's S3 access and dead-lock the allreduce poll loop.
**Constructor parameters** (all kw-only)
| Name | Type | Default | Meaning |
|---|---|---|---|
| `role_arn` | `str` | — | IAM execution role SageMaker assumes; must grant S3 to the rendezvous + output buckets (the boto3 analog of EKS IRSA). Caller needs `iam:PassRole`. |
| `image_uri` | `str` | — | ECR training-container image. The executor also passes `ContainerEntrypoint` explicitly so a generic image works. |
| `output_s3_path` | `str` | — | `s3://...` prefix for `OutputDataConfig.S3OutputPath`. |
| `instance_type` | `str` | `"ml.g5.2xlarge"` | Default instance type when `gpu` is unmapped. |
| `cpu_instance_type` | `str` | `"ml.m5.xlarge"` | Instance type used when `gpu=None` (CPU smoke). |
| `volume_size_gb` | `int` | `100` | `ResourceConfig.VolumeSizeInGB` per job. |
| `run_id` | `str \| None` | `None` ⇒ random token | Prefix for generated training-job names. |
| `region` | `str \| None` | `None` | AWS region for the lazy boto3 clients (`None` ⇒ ambient default). |
| `sagemaker_client` | `Any` | `None` | Inject a pre-built `boto3.client("sagemaker")` (or a mock); built lazily otherwise. |
| `logs_client` | `Any` | `None` | Inject a pre-built `boto3.client("logs")`; built lazily on first `stream_logs`. |
**`ServerlessExecutor` Protocol methods**
- `launch_replicas(n_replicas, entrypoint, entrypoint_args, *, gpu=None, timeout=3600) -> list[ReplicaHandle]` — submits N independent jobs; `entrypoint` is ignored (container command is baked / passed explicitly). `entrypoint_args` MUST include `rendezvous_uri` (`s3://...`) and `trainer_module`; optional `trainer_fn` (default `"train"`) and `trainer_kwargs` (dict, JSON-encoded into `ContainerArguments`). `gpu` maps to an instance type via `_GPU_INSTANCE_MAP` (`"A100"`/`"H100"`/`"H200"`/`"B200"`/`"L40S"`/`"A10G"`/`"L4"`; a literal `ml.*` string is honoured); `timeout``StoppingCondition.MaxRuntimeInSeconds`. On mid-launch failure it best-effort stops already-launched siblings then raises.
- `poll(handle) -> str` — maps `describe_training_job`'s `TrainingJobStatus` via `_STATUS_MAP` (`InProgress``running`, `Completed``succeeded`, `Failed``failed`, `Stopping``running`, `Stopped``cancelled`); refines `InProgress``"pending"` while still queued (`SecondaryStatus` in `_PENDING_SECONDARY`); a vanished job (`ResourceNotFound`) ⇒ `"cancelled"`.
- `stream_logs(handle, *, n_lines=200) -> str` — discovers the CloudWatch stream under `/aws/sagemaker/TrainingJobs` by `<job-name>/` prefix and tails it; falls back to a CloudWatch console URL on any error.
- `cancel(handle) -> None` — best-effort `stop_training_job` (swallows `ResourceNotFound` / already-terminal `ValidationException`).
- `collect(handles, *, timeout=None) -> list[dict]` — polls per handle until terminal (`Completed`/`Failed`/`Stopped`) or the shared deadline; results aligned to input order; dict is `{"rank","status","exit_code","error","result","training_job_name"}` (`result` is the `S3ModelArtifacts` path).
**Raises** `RuntimeError` if boto3 is absent and no client was injected; `ValueError` if `n_replicas < 1` or `entrypoint_args` lacks `rendezvous_uri` / `trainer_module`.
```python
from composer_replication.diloco.serverless.sagemaker import SageMakerExecutor
ex = SageMakerExecutor(
role_arn="arn:aws:iam::123:role/sm-exec",
image_uri="123.dkr.ecr.us-east-1.amazonaws.com/diloco:latest",
output_s3_path="s3://bucket/out/", region="us-east-1")
handles = ex.launch_replicas(
n_replicas=2, entrypoint="",
entrypoint_args={"rendezvous_uri": "s3://bucket/run/",
"trainer_module": "my.trainer"},
gpu="A100")
results = ex.collect(handles)
```
---
## 16. `composer_replication.datagen.docker_sandbox`
ADR-010 §3 hardened container backend for the FeatureDeletion (FD) env. `DockerSandbox` is a drop-in implementation of the `Sandbox` Protocol (`composer_replication.datagen.sandbox.Sandbox`) that runs the agent's tool calls and the verifiable test command inside an ephemeral, locked-down Docker container instead of a raw host subprocess. It is the production execution path for genuinely UNTRUSTED model-generated code (the `LocalSubprocessSandbox` sibling runs in the host process and is NOT a host-security boundary).
**Extra**: needs `docker>=7` (`pip install -e .[datagen]` or `pip install docker`). The SDK is **lazy-imported inside methods** so the pure-Python core and the `FakeSandbox` path never require it; a clear `RuntimeError` is raised if the SDK or the daemon is absent.
### The `Sandbox` Protocol — `datagen.sandbox`
```python
@runtime_checkable
class Sandbox(Protocol):
def boot(self, image: str) -> None: ...
def exec(self, action: dict) -> str: ...
def run_tests(self, test_command: str, tests: tuple[str, ...]) -> TestRunResult: ...
def trajectory(self) -> list[dict]: ...
def is_command_allowed(self, command: str) -> bool: ...
```
Structural protocol every FD execution backend implements (`DockerSandbox`, `LocalSubprocessSandbox`, `FakeSandbox`). `boot` prepares the execution environment, `exec` runs one tool-call `action` dict (`{"command": ...}`) and returns combined stdout/stderr, `run_tests` runs the verifiable pytest command over the given node ids and returns a `TestRunResult`, `trajectory` returns the recorded action list, `is_command_allowed` is the first-token denylist check.
### `class DockerSandbox`
```python
@dataclass
class DockerSandbox:
workdir: str
runtime: str | None = None
mem_limit: str = "1g"
memswap_limit: str = "1g"
pids_limit: int = 256
nano_cpus: int = 2_000_000_000 # 2 CPUs
user: str = "1000:1000"
container_workdir: str = "/work"
tmpfs_size: str = "64m"
exec_timeout_s: int = 600
keep_root_writable: bool = False
def container_kwargs(self, image: str) -> dict: ...
# Sandbox Protocol:
def boot(self, image: str) -> None: ...
def exec(self, action: dict) -> str: ...
def run_tests(self, test_command: str, tests: tuple[str, ...]) -> TestRunResult: ...
def trajectory(self) -> list[dict]: ...
def is_command_allowed(self, command: str) -> bool: ...
def close(self) -> None: ...
@staticmethod
def reap_leaked(client=None) -> int: ...
```
Hardened ephemeral-container `Sandbox`. The lockdown recipe (CIS Docker 5.x + gVisor guidance): `network_disabled=True` + `network_mode="none"` (no egress — the reward-hack exfil control), `read_only=True` root fs + small `tmpfs` for `/tmp`, workdir bind-mounted RW at `/work`, `cap_drop=["ALL"]` + `security_opt=["no-new-privileges:true"]`, non-root `user`, `pids_limit` (fork-bomb guard) + `mem_limit==memswap_limit` (OOM, no swap) + `nano_cpus` (CPU quota). The PRIMARY reward-hack control is the host-side `scrub_tree(workdir)` run in `boot()` BEFORE the container starts (the bind mount is shared, so host-pre-boot scrubbing == in-container scrubbing); the command denylist is cheap defense-in-depth, not the wall.
**Constructor fields**
| Field | Type | Default | Meaning |
|---|---|---|---|
| `workdir` | `str` | — | Host path to the materialized repo; bind-mounted RW at `/work` and scrubbed on the host before boot (PRIMARY control). Must exist by `boot()` time. |
| `runtime` | `str \| None` | `None` ⇒ daemon default (runc) | `"runsc"` (gVisor) for untrusted model code; requires host `sudo runsc install` + dockerd restart. Gate with `runsc_available()`. |
| `mem_limit` / `memswap_limit` | `str` | `"1g"` / `"1g"` | OOM guard; equal values forbid swap. |
| `pids_limit` | `int` | `256` | Fork-bomb guard. |
| `nano_cpus` | `int` | `2_000_000_000` | CPU quota in 1e-9 CPUs (2 CPUs). |
| `user` | `str` | `"1000:1000"` | Non-root uid:gid the agent code runs as. |
| `container_workdir` | `str` | `"/work"` | Bind-mount target + working dir. |
| `tmpfs_size` | `str` | `"64m"` | Size of the `/tmp` tmpfs scratch. |
| `exec_timeout_s` | `int` | `600` | Wall-clock cap injected via coreutils `timeout` (exec_run has no timeout param — docker-py #2651). |
| `keep_root_writable` | `bool` | `False` | Escape hatch if read-only fs breaks tooling. |
**`Sandbox` Protocol methods**
- `boot(image) -> None` — scrubs the host workdir (primary control), reaps leaked siblings, then starts the hardened container. **Raises** `RuntimeError` if the image is not found locally (the container is `--network none`, so it cannot pull) or on a Docker API error.
- `exec(action) -> str` — records the `action`, denylist-checks its first token, runs the command in the live container (combined stdout/stderr, non-UTF-8 bytes decoded with `errors="replace"`). Returns an `ERROR:` string for a denied command.
- `run_tests(test_command, tests) -> TestRunResult` — `shlex.quote`s each node id, runs the verifiable command, parses pass/fail conservatively (PASSED-or-rc0 ⇒ passed; `errors during collection` ⇒ `collected_ok=False`).
- `trajectory() -> list[dict]` — copy of the recorded action list.
- `is_command_allowed(command) -> bool` — first-token denylist (`SANDBOX_DENYLIST`); not a boundary on its own.
**Lifecycle**: `close()` force-removes the container (idempotent, swallows errors); `reap_leaked(client=None) -> int` (staticmethod) sweeps containers labelled `composer_replication=datagen` and returns the count removed. Also a context manager (`__enter__` / `__exit__``close()`).
**Module helpers**: `runsc_available(client=None) -> bool` (True iff the gVisor `runsc` runtime is registered with the daemon — gate any runsc behavior on this); `_require_docker()` / `_make_client()` (lazy SDK import + daemon ping, each raising a clear `RuntimeError`).
```python
from composer_replication.datagen.docker_sandbox import DockerSandbox, runsc_available
runtime = "runsc" if runsc_available() else None
with DockerSandbox(workdir="/tmp/repo", runtime=runtime) as sb:
sb.boot("python:3.12-slim")
out = sb.exec({"command": "python -c 'print(1+1)'"})
res = sb.run_tests("python -m pytest -q", ("tests/test_x.py::test_a",))
```
---
## 17. `composer_replication.safety` & `composer_replication.safety.kill_switch`
ADR-015 run-level collapse safeguard — the #2 collapse safety net for the self-evolving RL flywheel. The per-task controls live in `composer_replication.datagen` (4-gate validator, `HackMonitor` provenance, sandbox denylist); this package adds the missing ACROSS-GENERATION / run-level control that watches in-loop (proxy) reward against a disjoint held-out (real) eval and HALTS the run when collapse / reward-hacking is caught in the act. Pure-Python: no torch, no cloud deps; fully CPU-testable (`kl_to_init` is just a float the caller computes upstream).
Public surface (`composer_replication.safety.__all__`): `HeldOutGuard`, `TripwireStatus`, `CollapseStopError`, `kl_token_trust_filter`.
### `class TripwireStatus` — `safety.kill_switch`
```python
@dataclass(frozen=True)
class TripwireStatus:
fire: bool
reason: str
step: int
proxy_real_gap: float
in_loop_ema: float
heldout_ema: float
kl_ema: float | None = None
@property
def halt(self) -> bool: ... # documented alias for `fire`
```
Structured verdict returned by every `HeldOutGuard.update(...)`. `fire` (alias `halt`) ⇒ the run should HALT; `reason` is the human-readable WHY (empty when not firing); `proxy_real_gap` is the RSI "Hacking Gap" (in-loop gain − held-out gain since baseline); `*_ema` are the denoised metric EMAs.
### `class CollapseStopError(RuntimeError)` — `safety.kill_switch`
```python
class CollapseStopError(RuntimeError):
def __init__(self, status: TripwireStatus) -> None: ...
status: TripwireStatus
```
Typed exception for exception-based trainer control flow; carries the fired `TripwireStatus` for logging. Raised (optionally) via `HeldOutGuard.raise_if_fired(...)`.
### `class HeldOutGuard` — `safety.kill_switch`
```python
@dataclass
class HeldOutGuard:
kl_hard_stop: float = 0.08 # nats/token; top of GRPO "good" band
max_proxy_real_gap: float = 0.10 # absolute Hacking-Gap blowout ceiling
min_steps: int = 20 # no fire before this many updates
decline_patience: int = 3 # consecutive held-out declines to fire (a)
ema_alpha: float = 0.9 # EMA weight on the PRIOR (0.9 => slow)
rise_eps: float = 1e-4 # min EMA delta to count as rising/declining
def update(self, round_idx: int, in_loop_reward: float, heldout_score: float,
kl_to_init: float | None = None, entropy: float | None = None,
reward_std: float | None = None) -> TripwireStatus: ...
def should_halt(self) -> bool: ...
@property
def last_status(self) -> TripwireStatus | None: ...
def raise_if_fired(self, status: TripwireStatus | None = None) -> None: ...
def proxy_real_gap(self) -> float: ...
def calibrate_kl_threshold(self, baseline_kls: list[float], factor: float = 3.0) -> float: ...
```
Stateful across-generation collapse / reward-hacking kill-switch. Call `update(...)` once per checkpoint (the same cadence as `DifficultyCurriculum.update`); it folds each metric into a denoised EMA and returns a `TripwireStatus`.
**Fires (`fire=True`) when ANY of three conditions** (none before `min_steps`, which guards early-run noise):
- **(a) collapse-caught-in-the-act** — the in-loop reward EMA is RISING while the held-out score EMA has DECLINED for `>= decline_patience` consecutive checkpoints (the canonical reward-hacking signature: proxy up, real down).
- **(b) KL hard stop** — the `kl_to_init` EMA exceeds `kl_hard_stop` (default **0.08 nats/token**, the top of the GRPO "good progression" band). Checked first (cheapest unambiguous breach).
- **(c) proxy-real gap blowout** — the Hacking Gap (proxy gain − real gain since baseline) widens beyond `max_proxy_real_gap`, catching a fast single-generation divergence even before the full decline window.
Once fired, the verdict is **latched** — every subsequent `update` keeps `fire=True` so a transient recovery cannot silently un-halt the run.
**Constructor thresholds**
| Field | Type | Default | Meaning |
|---|---|---|---|
| `kl_hard_stop` | `float` | `0.08` | Per-token KL(policy‖init) hard-stop ceiling, nats/token (condition (b)). Must be > 0. |
| `max_proxy_real_gap` | `float` | `0.10` | Absolute Hacking-Gap blowout ceiling (condition (c)). |
| `min_steps` | `int` | `20` | No tripwire fires before this many `update` calls. |
| `decline_patience` | `int` | `3` | Consecutive held-out declines (while in-loop rises) to fire (a). Must be >= 1. |
| `ema_alpha` | `float` | `0.9` | EMA weight on the PRIOR (0.9 ⇒ slow). Must be in `[0, 1)`. |
| `rise_eps` | `float` | `1e-4` | Min EMA delta to count as "rising" / "declining". |
**API**
- `update(round_idx, in_loop_reward, heldout_score, kl_to_init=None, entropy=None, reward_std=None) -> TripwireStatus` — fold one checkpoint's metrics and return the verdict. `round_idx` is for logging only (the internal counter `_n` drives `min_steps`). `kl_to_init` is **token-mean KL in nats/token** (this repo's `token_mean_kl` convention) — do NOT pass a sequence-summed KL (it will fire instantly). `entropy` / `reward_std` are tracked + exposed but not hard gates.
- `should_halt() -> bool` — True iff the most recent `update` fired. **Idempotent** (does not advance EMA state).
- `last_status -> TripwireStatus | None` (property) — the most recent verdict, or `None` before the first `update`.
- `raise_if_fired(status=None) -> None` — convert a fired verdict (the passed status, or `last_status`) into a `CollapseStopError`; a no-op otherwise. For exception-based loops.
- `proxy_real_gap() -> float` — the RSI Hacking Gap (EMA-minus-baseline, both since run start); `0.0` before the first `update`.
- `calibrate_kl_threshold(baseline_kls, factor=3.0) -> float` — set `kl_hard_stop` from early-run baseline KLs (`factor` × mean). SAFETY CLAMP: only ever TIGHTENS (`min(factor*mean, current)`), never loosens past the documented band. **Raises** `ValueError` on empty `baseline_kls`.
**Raises** `ValueError` at construction if `ema_alpha``[0, 1)`, `kl_hard_stop <= 0`, or `decline_patience < 1`.
> **HeldoutSplit discipline (design-of-record).** `heldout_score` must come from a DISJOINT held-out eval pool — REAL tasks the generator NEVER trains on. If the held-out set drifts with the train set, the proxy-real gap signal is meaningless. See ADR-015 and the `safety.holdout` design notes referenced from the module docstring.
```python
from composer_replication.safety import HeldOutGuard
guard = HeldOutGuard(kl_hard_stop=0.08)
for rnd in range(num_generations):
status = guard.update(rnd, in_loop_reward=r_proxy, heldout_score=r_real,
kl_to_init=token_mean_kl_value)
if status.fire: # or: guard.should_halt()
guard.raise_if_fired(status) # -> CollapseStopError
```
### `kl_token_trust_filter(logratio_sq_half, threshold=0.08) -> bool` — `safety.kill_switch`
```python
def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool
```
Per-token KL trust-region mask mirroring torchrl's GRPO "KL-Mask". The caller computes `0.5 * (log π/π_ref)**2` (the Schulman k2 per-token KL estimator, nats); this returns `True` when the token should be MASKED OUT (its KL contribution exceeds `threshold`). Pulls no torch into the module — wire it into a loss downstream.
```python
from composer_replication.safety import kl_token_trust_filter
mask_out = kl_token_trust_filter(0.5 * logratio**2, threshold=0.08)
```
---
## Notes on test coverage
Tested contracts (referenced spike/test paths):
- `compose_loss` + `LossComponents` + `build_batch`: `composer_replication/tests/test_compose_loss_integration.py`, `spikes/006-real-hf-model-smoke/tests/`.
- `generalized_jsd_loss`: `spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py`.
- `simpo_loss`, `taid_loss`, `taid_alpha_schedule`, `taid_blended_logits`, `entropy_aware_opd_loss`: `composer_replication/distillation/tests/test_distillation_losses.py`.
- `replay_trace`, `extract_dpo_pairs`, `DPOPair`, `TraceState`, `TeacherCallResult`, `TeacherSpec`, `DEFAULT_TEACHERS`: `spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py`.
- `DJNormalizer`, `NormalizedDPOPair`, `replay_and_normalize_trace`: `composer_replication/replaysim/tests/test_replaysim.py`.
- `ClaudeCodeIngester`, `IngestionStats`, `SYSTEM_PROMPT`: `spikes/007-real-trace-ingestion/tests/`.
- `ComposerDataCollator`, `CollatorConfig`, `TraceTurn`, `TraceExample`: `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`.
- `ComposerReplicationTrainer._compute_loss` (composition arithmetic): `spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py`.
- `make_diloco_outer_loop` + sign convention: `spikes/008-streaming-diloco/tests/test_diloco_smoke.py`.
- `ObjectStoreAllReduce`, `MockManager`, `LocalProcessExecutor`, `ReplicaHandle`, `ServerlessExecutor`, `replica_entrypoint.main`: `composer_replication/diloco/serverless/tests/test_serverless_local.py`, `test_serverless_diloco_integration.py`.
- `recipes.prime_rl.composer_loss.loss_fn`: `composer_replication/recipes/prime_rl/tests/test_composer_loss.py`.
- `EKSExecutor`, `SageMakerExecutor` (§15): `composer_replication/diloco/serverless/tests/` (DI'd `batch_api`/`core_api` / `sagemaker_client` mocks; no live cloud).
- `DockerSandbox` (§16) — `container_kwargs` lockdown config asserted without a live daemon; daemon-gated paths in `composer_replication/datagen/tests/`.
- `HeldOutGuard`, `TripwireStatus`, `CollapseStopError`, `kl_token_trust_filter` (§17): `composer_replication/safety/tests/` (pure-Python; CPU-only).
Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above.
---
**Document path**: `docs/API_REFERENCE.md` (repo-relative)