Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
API Reference — composer-replication-framework
Complete reference for every public symbol in composer_replication. Source-of-truth is the .py files in composer_replication/; docstrings have been pulled verbatim where they exist and supplemented where missing.
Legend
- ⚠️ UNTESTED-CONTRACT — symbol exists and is callable, but its behaviour is not pinned by an automated test in
composer_replication/**/tests/orspikes/**/tests/. - 🟡 SKELETON — class/method body raises
NotImplementedError; ships as design-of-record per ADR-005 / ADR-006.
Module groups (in this document)
composer_replication(top-level re-exports)composer_replication.losscomposer_replication.batchcomposer_replication.opsdcomposer_replication.distillationcomposer_replication.teacher_replaycomposer_replication.replaysimcomposer_replication.ingestion(+.claude_code)composer_replication.hint_generatorcomposer_replication.trainer(+.composer_trainer,.data_collator)composer_replication.dilococomposer_replication.diloco.serverless(+.executor,.allreduce,.modal,.hf_jobs,.replica_entrypoint)composer_replication.recipes.prime_rl.composer_losscomposer_replication.recipes.monarch.actorscomposer_replication.diloco.serverless— cloud executors (.eks,.sagemaker)composer_replication.datagen.docker_sandboxcomposer_replication.safety(+.kill_switch)
1. composer_replication — top-level package
The package re-exports the most common entry points from sub-modules. __all__ is the canonical list of public top-level names.
composer_replication.__version__: str
Package version string. Currently "0.1.0".
import composer_replication
print(composer_replication.__version__) # "0.1.0"
composer_replication._DILOCO_AVAILABLE: bool
True iff torchft is importable in the running Python environment (gates make_diloco_outer_loop). Set to False and make_diloco_outer_loop is set to None when torchft is missing.
from composer_replication import _DILOCO_AVAILABLE
if _DILOCO_AVAILABLE:
from composer_replication import make_diloco_outer_loop
Re-exports
| Name | Source module |
|---|---|
compose_loss |
composer_replication.loss |
LossComponents |
composer_replication.loss |
build_batch |
composer_replication.batch |
generalized_jsd_loss |
composer_replication.opsd |
ClaudeCodeIngester |
composer_replication.ingestion.claude_code |
IngestionStats |
composer_replication.ingestion.claude_code |
SYSTEM_PROMPT |
composer_replication.ingestion.claude_code |
DEFAULT_TEACHERS |
composer_replication.teacher_replay |
DPOPair |
composer_replication.teacher_replay |
TeacherCallResult |
composer_replication.teacher_replay |
TeacherSpec |
composer_replication.teacher_replay |
TraceState |
composer_replication.teacher_replay |
extract_dpo_pairs |
composer_replication.teacher_replay |
replay_trace |
composer_replication.teacher_replay |
ComposerReplicationTrainer |
composer_replication.trainer |
make_diloco_outer_loop |
composer_replication.diloco (or None if torchft missing) |
See each source module below for full signatures.
2. composer_replication.loss
Verification-harness 3-channel loss. Free function, does not depend on trl.
class LossComponents
@dataclass
class LossComponents:
lm_ce: torch.Tensor
sdpo_jsd: torch.Tensor
trace_replay_dpo: torch.Tensor
total: torch.Tensor
def detached(self) -> dict[str, float]: ...
Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar torch.Tensors (shape=()); total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo.
detached() -> dict[str, float] — returns Python-float copies of all four fields with no grad. Useful for W&B logging.
from composer_replication import compose_loss, build_batch
components = compose_loss(model, build_batch(tokenizer))
print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...}
components.total.backward()
compose_loss(model, inputs, *, ...) -> LossComponents
def compose_loss(
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
*,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
lm_ce_label_smoothing: float = 0.0,
dpo_variant: Literal["dpo", "simpo"] = "dpo",
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None,
simpo_beta: float = 2.0,
simpo_gamma: float = 1.0,
entropy_opd_h_max: float | None = None,
) -> LossComponents
Compute total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo.
Required keys in inputs
input_ids:(B, T_s)student rollout token ids.response_mask:(B, T_s)1 on assistant-response tokens, 0 elsewhere.
Optional keys (channel auto-disables if missing OR if its weight = 0):
- SDPO:
ctx_teacher_input_ids(B, T_t),sdpo_loss_mask(B, T_t). - DPO (
dpo_variant="dpo"):dpo_chosen_input_ids,dpo_chosen_response_mask,dpo_rejected_input_ids,dpo_rejected_response_mask,dpo_chosen_ref_logprobs,dpo_rejected_ref_logprobs(precomputed). - SimPO (
dpo_variant="simpo"): same DPO ids/masks; reference logprobs are silently ignored. - TAID (
sdpo_wrapper="taid"): no extrainputskeys needed; the optionalsdpo_loss_maskis reused as the per-token TAID mask. Passtaid_tdirectly (or drive it fromTAIDScheduler).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
model |
torch.nn.Module |
— | HF causal-LM. Must accept input_ids= and return an object with .logits. |
inputs |
dict[str, torch.Tensor] |
— | Batch dict (see required/optional keys above). |
alpha_sdpo |
float |
0.1 |
Weight on SDPO/JSD channel. 0.0 disables. |
beta_replay |
float |
0.05 |
Weight on trace-replay DPO channel. 0.0 disables. |
sdpo_jsd_beta |
float |
0.5 |
β param for generalized_jsd_loss (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when sdpo_wrapper="taid". |
sdpo_temperature |
float |
1.0 |
Softmax temperature in SDPO. Unused when sdpo_wrapper="taid". |
sdpo_token_clip |
float | None |
None |
Per-token JSD clamp. |
replay_dpo_beta |
float |
0.1 |
β in standard DPO logit. |
lm_ce_label_smoothing |
float |
0.0 |
F.cross_entropy(label_smoothing=). |
dpo_variant |
Literal["dpo","simpo"] |
"dpo" |
Channel-3 algorithm. |
sdpo_wrapper |
Literal["none","taid","entropy_opd"] |
"none" |
Channel-2 wrapper. |
taid_t |
float | None |
None |
Current TAID interpolation coefficient in [0, 1]. Required when sdpo_wrapper="taid". Drive from TAIDScheduler or pass a fixed value. |
simpo_beta |
float |
2.0 |
SimPO β (paper default). |
simpo_gamma |
float |
1.0 |
SimPO target margin γ (paper default). |
entropy_opd_h_max |
float | None |
None |
Max-entropy normalizer; None ⇒ log(V). |
Returns LossComponents (see above).
Raises ValueError if dpo_variant or sdpo_wrapper is unknown, if sdpo_wrapper="taid" is requested without taid_t, or if taid_t is outside [0, 1].
from composer_replication import compose_loss, build_batch
batch = build_batch(tokenizer)
out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
out.total.backward()
print(out.detached())
3. composer_replication.batch
Verification-harness batch builder.
build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]
def build_batch(
tokenizer: Any,
*,
device: torch.device | str = "cpu",
seed: int = 42,
variant: str = "factorial",
align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]
Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute).
Returned keys: input_ids, response_mask, ctx_teacher_input_ids, sdpo_loss_mask, dpo_chosen_input_ids, dpo_chosen_response_mask, dpo_rejected_input_ids, dpo_rejected_response_mask, dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
tokenizer |
HF AutoTokenizer (duck-typed) |
— | Must support apply_chat_template and __call__. |
device |
torch.device | str |
"cpu" |
Target device for all returned tensors. |
seed |
int |
42 |
Fixes torch.manual_seed. |
variant |
str |
"factorial" |
One of "factorial", "binary_search". |
align_sdpo_shapes |
bool |
False |
If True, truncate/pad ctx_teacher_input_ids to input_ids length so the SDPO channel actually fires. |
Raises ValueError if variant is unknown.
from transformers import AutoTokenizer
from composer_replication import build_batch
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True)
print({k: v.shape for k, v in batch.items()})
4. composer_replication.opsd
Self-distillation generalized-JSD loss, lifted verbatim from siyan-zhao/OPSD (MIT) per ADR-006.
generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor
def generalized_jsd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor | None = None,
beta: float = 0.5,
temperature: float = 1.0,
reduction: str = "batchmean",
logits_are_probs: bool = False,
top_k: int | None = None,
token_clip: float | None = None,
) -> torch.Tensor
Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
student_logits |
Tensor (B, T, V) |
— | Student logits with grad. |
teacher_logits |
Tensor (B, T, V) |
— | Teacher logits (no grad in SDPO). |
labels |
Tensor (B, T) | None |
None |
Per-token mask. -100 positions are ignored (HF convention). |
beta |
float in [0, 1] |
0.5 |
0=fwd KL, 1=rev KL, 0.5=symmetric JSD. |
temperature |
float |
1.0 |
Softmax temperature. |
reduction |
str |
"batchmean" |
"batchmean", "sum", "mean", "none". |
logits_are_probs |
bool |
False |
Skip softmax if inputs are already probabilities. |
top_k |
int | None |
None |
Restrict KL to teacher's top-k tokens. |
token_clip |
float | None |
None |
Clip per-token JSD for stability. |
Returns scalar tensor (or (B, T) if reduction="none").
Raises ValueError for unknown reduction.
import torch
from composer_replication.opsd import generalized_jsd_loss
s = torch.randn(2, 8, 32, requires_grad=True)
t = torch.randn(2, 8, 32)
loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean")
loss.backward()
5. composer_replication.distillation
Pluggable self-distillation losses (ADR-007). All pure PyTorch.
simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor
def simpo_loss(
chosen_avg_logprobs: torch.Tensor,
rejected_avg_logprobs: torch.Tensor,
*,
beta: float = 2.0,
gamma: float = 1.0,
) -> torch.Tensor
Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
chosen_avg_logprobs |
Tensor (B,) |
— | Per-sequence avg logprob over chosen response tokens. |
rejected_avg_logprobs |
Tensor (B,) |
— | Same for rejected. |
beta |
float |
2.0 |
Scaling factor (paper default). |
gamma |
float |
1.0 |
Target margin (paper default). |
Returns scalar; Raises ValueError if shapes mismatch.
import torch
from composer_replication.distillation import simpo_loss
loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]),
beta=2.0, gamma=1.0)
avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor
⚠️ UNTESTED-CONTRACT (helper exported from simpo.py but not asserted by a test).
def avg_sequence_logprob(
model_logprobs: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor
Convert (B, T) per-token logprobs + (B, T) response mask into (B,) per-sequence average over response tokens.
from composer_replication.distillation.simpo import avg_sequence_logprob
import torch
lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
out = avg_sequence_logprob(lp, m) # shape (2,)
taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor
def taid_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor | None = None,
*,
t: float | torch.Tensor,
) -> torch.Tensor
Faithful port of SakanaAI/TAID (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the current student detached:
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
At t=0 the target collapses to the detached student (no teacher signal in the gradient). At t=1 it reduces to standard forward-KL distillation against the teacher.
Wave 15 breaking change. The previous signature taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction) was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see TAIDScheduler below for the upstream adaptive scheme).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
student_logits |
Tensor (B, T, V) |
— | Current student (with grad). |
teacher_logits |
Tensor (B, T, V) |
— | Teacher logits. |
mask |
Tensor (B, T) | None |
None |
Token mask. None ⇒ all-ones. |
t |
float | Tensor |
— | Interpolation coefficient in [0, 1]. |
Raises ValueError for shape mismatch.
from composer_replication.distillation import taid_loss
loss = taid_loss(s_logits, t_logits, mask, t=0.4)
TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)
Stateful schedule that mirrors upstream TAID.update_t. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as:
from composer_replication.distillation import TAIDScheduler
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
for step in range(num_train_steps):
loss = taid_loss(s, t, mask, t=sched.t)
loss.backward(); optimizer.step()
sched.update_t(loss.detach(), global_step=step)
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
num_train_steps |
int |
— | Total planned training steps; sets the linear floor. |
t_start |
float |
0.4 |
Initial t (paper default). |
t_end |
float |
1.0 |
Terminal t; hard ceiling at every step. |
alpha |
float |
5e-4 |
Adaptive bump magnitude. |
beta |
float |
0.99 |
EMA decay on relative-loss-change momentum. |
disable_adaptive |
bool |
False |
If True, fall back to deterministic linear schedule. |
device |
torch.device | str |
"cpu" |
Where to allocate state buffers. |
Properties / methods
sched.t -> float— currenttas a Python float (zero-arg property).sched.update_t(loss, global_step) -> Tensor | None— update internal state. First finite-loss call only seedsprev_lossand returnsNone; subsequent calls return the (positive)delta_tadded on top of the linear floor.
entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor
def entropy_aware_opd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
*,
labels: torch.Tensor | None = None,
h_max: float | None = None,
temperature: float = 1.0,
reduction: str = "batchmean",
) -> torch.Tensor
Per-token mixture of forward and reverse KL gated by teacher entropy: w(t) = clamp(H_teacher(t)/h_max, 0, 1). High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
student_logits |
Tensor (B,T,V) |
— | Student logits (grad). |
teacher_logits |
Tensor (B,T,V) |
— | Teacher logits (no grad). |
labels |
Tensor (B,T) | None |
None |
0/1 mask, applied multiplicatively after the per-token mix. |
h_max |
float | None |
None ⇒ log(V) |
Max-entropy normalizer. |
temperature |
float |
1.0 |
Softmax temperature on both. |
reduction |
str |
"batchmean" |
"batchmean", "sum", "mean", "none". |
Raises ValueError on shape mismatch (student vs teacher; labels vs per-token loss) or unknown reduction.
from composer_replication.distillation import entropy_aware_opd_loss
loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0)
loss.backward()
teacher_entropy(teacher_logits) -> torch.Tensor
⚠️ UNTESTED-CONTRACT (helper exposed from entropy_aware_opd.py's __all__ but not directly asserted).
Per-token entropy in nats. Input (B,T,V), output (B,T).
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
H = teacher_entropy(teacher_logits) # (B, T)
6. composer_replication.teacher_replay
N-teacher OpenRouter parallel client + DPO-pair extractor. httpx is lazy-imported inside replay_trace; the deterministic local logic is testable without it.
DEFAULT_TEACHERS: list[TeacherSpec]
Three-teacher default set: anthropic/claude-opus-4.7, openai/gpt-5, deepseek/deepseek-v4-pro with paper-baseline OpenRouter pricing.
from composer_replication.teacher_replay import DEFAULT_TEACHERS
print([t["slug"] for t in DEFAULT_TEACHERS])
class TeacherSpec(TypedDict)
class TeacherSpec(TypedDict):
slug: str
input_per_mtok: float
output_per_mtok: float
OpenRouter model slug + per-million-token pricing.
spec: TeacherSpec = {"slug": "openai/gpt-5",
"input_per_mtok": 1.25, "output_per_mtok": 10.0}
class TraceState(TypedDict)
class TraceState(TypedDict):
state_id: str # unique within the trace
messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt
student_action: str # what the student actually did at this step
One step of a frozen agentic trace. student_action is the raw text emitted by the student; teachers are queried with messages and asked to predict the assistant's next action.
state: TraceState = {"state_id": "ex001::0042",
"messages": [{"role": "user", "content": "..."}],
"student_action": "[TOOL_USE] name=Read input={...}"}
class TeacherCallResult(TypedDict)
class TeacherCallResult(TypedDict):
state_id: str
teacher_slug: str
response_text: str | None # None on error
latency_s: float
prompt_tokens: int
completion_tokens: int
cost_usd: float
error: str | None # None on success
One row of N×T results from replay_trace.
r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5",
"response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100,
"completion_tokens": 5, "cost_usd": 0.001, "error": None}
class DPOPair(TypedDict)
class DPOPair(TypedDict):
state_id: str
state_messages: list[dict]
chosen: str # teacher-consensus action
rejected: str # student action
n_teachers_agreeing: int
One preference pair extracted from teacher-vs-student disagreement.
p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...",
"rejected": "...", "n_teachers_agreeing": 2}
async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]
async def replay_trace(
states: Sequence[TraceState],
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
max_total_usd: float = 5.0,
api_key: str | None = None,
) -> list[TeacherCallResult]
For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at max_total_usd (stops after the offending state completes).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
states |
Sequence[TraceState] |
— | Frozen trace, one entry per assistant turn. |
teachers |
Sequence[TeacherSpec] |
DEFAULT_TEACHERS |
Models to query in parallel. |
max_total_usd |
float |
5.0 |
Cumulative spend cap. |
api_key |
str | None |
None |
OpenRouter key; defaults to OPENROUTER_API_KEY env or ~/.hermes/.env. |
Returns flat list of TeacherCallResults (length len(states) * len(teachers) modulo budget cutoff).
Raises RuntimeError if OPENROUTER_API_KEY is not findable; ImportError if httpx is missing at call time.
import asyncio
from composer_replication import replay_trace
results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0))
extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]
def extract_dpo_pairs(
states: Sequence[TraceState],
teacher_actions: Sequence[TeacherCallResult],
agreement_threshold: int = 2,
) -> list[DPOPair]
Group teacher_actions by state_id, normalize whitespace, and emit one DPOPair per state where ≥agreement_threshold teachers agreed on an action that differs from the student's. chosen is the original (un-normalized) teacher response text.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
states |
Sequence[TraceState] |
— | Same as passed to replay_trace. |
teacher_actions |
Sequence[TeacherCallResult] |
— | Output of replay_trace. |
agreement_threshold |
int |
2 |
Min teachers that must agree for a pair to fire. |
Returns list of DPOPair. At most one pair per state (the most-agreed-upon action wins).
from composer_replication import extract_dpo_pairs
pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2)
save_pairs(pairs, path) -> None
⚠️ UNTESTED-CONTRACT.
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None
Write pairs to JSONL (one dict per line). Creates parent dirs.
from composer_replication.teacher_replay import save_pairs
save_pairs(pairs, "/tmp/dpo_pairs.jsonl")
7. composer_replication.replaysim
ADR-004 normalization layer over teacher_replay. Re-exports DPOPair, TeacherCallResult, extract_dpo_pairs, replay_trace from teacher_replay.
class NormalizedDPOPair
@dataclass
class NormalizedDPOPair:
state_id: str
state_messages: list[dict[str, Any]]
chosen_messages: list[dict[str, Any]]
rejected_messages: list[dict[str, Any]]
n_teachers_agreeing: int
metadata: dict[str, Any]
Post-normalization shape. chosen_messages/rejected_messages are chat-format ([{"role": "assistant", "content": ...}]). metadata carries op-graph provenance, including {"skipped": True} when the normalizer was bypassed (skip_dj=True).
from composer_replication.replaysim import NormalizedDPOPair
n = NormalizedDPOPair(state_id="x", state_messages=[],
chosen_messages=[{"role": "assistant", "content": "ok"}],
rejected_messages=[{"role": "assistant", "content": "no"}],
n_teachers_agreeing=2, metadata={})
class DJNormalizer
class DJNormalizer:
DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml
def __init__(
self,
recipe_path: str | os.PathLike[str] | None = None,
*,
skip_dj: bool = False,
) -> None: ...
def normalize(
self,
pairs: Iterable[DPOPair | dict[str, Any]],
) -> list[NormalizedDPOPair]: ...
data-juicer-backed normalizer. Pipeline: each DPOPair → JSONL record → data_juicer.core.DefaultExecutor.run() against the recipe → JSONL → NormalizedDPOPair.
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
recipe_path |
str | PathLike | None |
None ⇒ default recipe |
data-juicer YAML recipe path. |
skip_dj |
bool (kw-only) |
False |
If True: passthrough; records get metadata={"skipped": True} and no ops run. |
normalize(pairs) -> list[NormalizedDPOPair] runs the op-graph. Output may be shorter than input if filter ops drop records.
Raises RuntimeError at construction time if skip_dj=False and data_juicer is not importable. FileNotFoundError if recipe_path (default or explicit) is missing and skip_dj=False.
from composer_replication.replaysim import DJNormalizer
norm = DJNormalizer(skip_dj=True)
out = norm.normalize(my_pairs)
async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
async def replay_and_normalize_trace(
*,
states: Any,
teachers: Any = None,
agreement_threshold: int = 2,
max_total_usd: float = 5.0,
normalizer: DJNormalizer | None = None,
**replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
End-to-end async: replay → extract pairs → normalize.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
states |
Sequence[TraceState] |
— | Frozen trace. |
teachers |
Sequence[TeacherSpec] | None |
None ⇒ defaults |
Forwarded to replay_trace. |
agreement_threshold |
int |
2 |
Forwarded to extract_dpo_pairs. |
max_total_usd |
float |
5.0 |
Spend cap. |
normalizer |
DJNormalizer | None |
None ⇒ DJNormalizer() |
Pass DJNormalizer(skip_dj=True) to bypass. |
**replay_kwargs |
Any |
— | Forwarded to replay_trace (e.g. api_key). |
Returns (raw_teacher_actions, normalized_pairs).
import asyncio
from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer
raw, norm = asyncio.run(replay_and_normalize_trace(
states=my_states, normalizer=DJNormalizer(skip_dj=True)))
replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via asyncio.run).
def replay_and_normalize_trace_sync(*args, **kwargs) -> ...
Sync convenience wrapping asyncio.run(replay_and_normalize_trace(...)).
from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync
raw, norm = replay_and_normalize_trace_sync(states=my_states)
8. composer_replication.ingestion & composer_replication.ingestion.claude_code
Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL.
SYSTEM_PROMPT: str
Default synthetic system prompt injected at messages[0] for ingested traces (most Claude Code sessions don't write one). Truncated head: "You are a senior software engineer working as a coding agent in a terminal environment...".
from composer_replication import SYSTEM_PROMPT
print(SYSTEM_PROMPT[:60])
class IngestionStats
@dataclass
class IngestionStats:
n_records_total: int = 0
n_records_skipped: int = 0
n_states_emitted: int = 0
n_assistant_turns: int = 0
n_tool_use_blocks: int = 0
n_text_blocks: int = 0
skipped_subagent: int = 0
skipped_summary: int = 0
skipped_truncated_lines: int = 0
version_warnings: list[str] | None = None # initialized to [] in __post_init__
Counters populated by ClaudeCodeIngester.ingest() and exposed as ingester.last_stats.
from composer_replication import IngestionStats
s = IngestionStats(n_records_total=5)
print(s.version_warnings) # []
class ClaudeCodeIngester
class ClaudeCodeIngester:
def __init__(
self,
*,
system_prompt: str = SYSTEM_PROMPT,
skip_sidechain: bool = True,
strip_thinking: bool = True,
max_history_tokens: int | None = None,
) -> None: ...
def ingest(self, path: Path) -> Iterator[TraceState]: ...
Convert a Claude Code session JSONL to a stream of TraceStates — one per assistant TURN (not per tool_use block).
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
system_prompt |
str |
SYSTEM_PROMPT |
Synthetic system message injected at history[0]. |
skip_sidechain |
bool |
True |
Skip subagent files (agent-*.jsonl) and records with isSidechain=True. |
strip_thinking |
bool |
True |
Remove [THINKING] blocks from history handed to teachers (kept inside student_action). |
max_history_tokens |
int | None |
None |
⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate. |
ingest(path) -> Iterator[TraceState]: generator over TraceState objects. Each turn's state_id is f"{path.stem}::{idx:04d}". Side effect: replaces self.last_stats with a fresh IngestionStats and updates it as records stream.
from pathlib import Path
from composer_replication import ClaudeCodeIngester
ing = ClaudeCodeIngester()
for state in ing.ingest(Path("session.jsonl")):
print(state["state_id"])
print(ing.last_stats.n_states_emitted)
9. composer_replication.hint_generator
⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test).
Template-based hint registry for SDPO error-site injection.
class HintContext(TypedDict, total=False)
class HintContext(TypedDict, total=False):
error_kind: str
error_message: str
available_tools: list[str]
tool_name: str
tool_schema: dict
intent: str
Per-error context dict consumed by hint templates.
HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]
Default registry keys: "tool_not_found", "json_decode", "type_error", "runtime_error", "repeated_failure".
dispatch(error_kind, ctx=None) -> str | None
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None
Look up error_kind in HINT_TEMPLATES. Returns the template's hint text, or None if the kind is unknown.
from composer_replication.hint_generator import dispatch
hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..."
register(error_kind, fn) -> None
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None
Add or override a custom hint template.
from composer_replication.hint_generator import register
register("my_error", lambda ctx: "Reminder: try X.")
Individual template functions
⚠️ UNTESTED-CONTRACT — exported only via HINT_TEMPLATES, useful as building blocks:
hint_tool_not_found(ctx) -> strhint_json_decode(ctx) -> strhint_type_error(ctx) -> strhint_runtime_error(ctx) -> strhint_repeated_failure(ctx) -> str
Each accepts a HintContext and returns hint text. Signatures are uniform: Callable[[HintContext], str].
from composer_replication.hint_generator import hint_tool_not_found
text = hint_tool_not_found({"available_tools": ["Read", "Write"]})
10. composer_replication.trainer & sub-modules
Production trainer (TRL GRPOTrainer subclass) plus data collator.
class ComposerReplicationTrainer
class ComposerReplicationTrainer(GRPOTrainer):
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
**kwargs: Any,
) -> None: ...
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor: ...
trl.GRPOTrainer subclass that overrides _compute_loss(model, inputs) to compose total = grpo + α·sdpo + β·trace_replay_dpo. When trl is not installed, the parent class falls back to object so the module imports — but instantiation will fail because the parent's GRPO machinery is missing.
Constructor (kw-only beyond GRPOTrainer's own *args, **kwargs)
| Name | Type | Default | Meaning |
|---|---|---|---|
alpha_sdpo |
float |
0.1 |
Channel-2 weight. |
beta_replay |
float |
0.05 |
Channel-3 weight. |
sdpo_jsd_beta |
float |
0.5 |
β for generalized_jsd_loss. |
sdpo_temperature |
float |
1.0 |
SDPO softmax temperature. |
sdpo_token_clip |
float | None |
None |
Per-token JSD clip. |
replay_dpo_beta |
float |
0.1 |
DPO β. |
_compute_loss(model, inputs) -> torch.Tensor — overrides GRPOTrainer._compute_loss. Calls super()._compute_loss for channel 1, then _compute_sdpo_loss and _compute_trace_replay_loss, then composes. Logs per-channel components every args.logging_steps (default 50). Raises whatever super() raises (TRL-shaped errors).
Internal methods (publicly accessible, exercised by spike tests)
- ⚠️ UNTESTED-CONTRACT
_compute_sdpo_loss(model, inputs) -> torch.Tensor— generalized-JSD between student forward andctx_teacher_input_idsforward. Returns0.0(with grad) whenalpha_sdpo == 0, the key is missing, or shapes mismatch. Logs a warning on shape mismatch. - ⚠️ UNTESTED-CONTRACT
_compute_trace_replay_loss(model, inputs) -> torch.Tensor— standard DPO overdpo_chosen_*anddpo_rejected_*, using precomputeddpo_chosen_ref_logprobs/dpo_rejected_ref_logprobs. - ⚠️ UNTESTED-CONTRACT
@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor— sum logprobs over response tokens; standard DPO accounting.
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(
model=my_model, args=my_grpo_args, train_dataset=ds,
data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05,
)
# trainer.train() # uses overridden _compute_loss
make_dr_grpo_config(**overrides) -> trl.GRPOConfig
Builds a trl.GRPOConfig configured to the Dr. GRPO recipe (Composer 2.5's
base objective per the Composer 2 tech report, arXiv:2603.24477; Dr.GRPO =
Liu et al. arXiv:2503.20783). Forces three knobs unless explicitly overridden,
with drift-guard assertions:
loss_type="dr_grpo"— removes GRPO's length-standardization length bias.scale_rewards="none"— NO std-dev advantage normalization (Dr.GRPO requirement).num_iterations=1— single-epoch / strict on-policy.
Any field is overridable via kwargs (learning_rate=, output_dir=, beta=, …).
Honest KL-estimator delta (ADR-012 #1): TRL 1.5.0's GRPOTrainer._compute_loss
uses the k3 estimator exp(ref_logp−logp)−(ref_logp−logp)−1, NOT the k1
estimator −log r the Dr.GRPO/Composer report frames; the delta is small for r≈1
and TRL is not monkeypatched — the delta is documented, not hidden. Exported from
both composer_replication and composer_replication.trainer.
from composer_replication import make_dr_grpo_config
args = make_dr_grpo_config(output_dir="runs/x", learning_rate=1e-6)
make_po_config(objective="dr_grpo", **overrides) -> trl.GRPOConfig
Builds a trl.GRPOConfig for a named policy-optimization objective from the
PO_OBJECTIVES menu (ADR-014). All presets are PURE CONFIG over trl 1.5.0's
GRPOTrainer (verified by introspection) — no custom _compute_loss needed.
**overrides set/override any GRPOConfig field on top.
- Raises
ValueErroron an unknown objective (lists the valid menu). - Raises
AssertionErrorif a requested knob silently failed to apply (drift guard; e.g. GSPO guardsimportance_sampling_level=="sequence").
from composer_replication import make_po_config, PO_OBJECTIVES
args = make_po_config("dapo", output_dir="runs/dapo", learning_rate=2e-6)
PO_OBJECTIVES: dict[str, dict]
The selectable base policy-optimization objectives (named presets over real trl
1.5.0 GRPOConfig knobs). Keys and what each sets:
| Objective | loss_type |
scale_rewards |
Distinguishing knob | Paper |
|---|---|---|---|---|
grpo |
grpo |
group (std-norm) |
IS=token |
DeepSeekMath 2402.03300 |
dr_grpo (default) |
dr_grpo |
none |
length-bias removed | 2503.20783 |
bnpo |
bnpo |
batch |
batch-normalized | trl |
dapo |
dapo |
none |
epsilon_high=0.28 (decoupled clip-higher), mask_truncated_completions, beta=0 |
2503.14476 |
gspo |
grpo |
group |
importance_sampling_level="sequence" |
Qwen 2507.18071 |
cispo |
cispo |
none |
epsilon_high=5.0 (detached IS coef) |
MiniMax-M1 2506.13585 |
Diagnostic gotcha: for any PO-objective ablation, log the distinguishing diagnostic (
clip_ratio/high_meanfor DAPO, the sequence-level ratio for GSPO). A0means the knob never engaged — NOT that the objectives are equal. (This is exactly the inert-knob artifact the A1 DAPO-vs-Dr.GRPO washout hit at lr=1e-6.)
class TraceTurn(TypedDict, total=False) — trainer.data_collator
class TraceTurn(TypedDict, total=False):
role: str # "user" | "assistant" | "tool"
content: str
tool_call: dict | None
tool_error: str | None
error_meta: dict
One turn of an agentic trace as consumed by ComposerDataCollator.
class TraceExample(TypedDict, total=False) — trainer.data_collator
class TraceExample(TypedDict, total=False):
trace_id: str
turns: list[TraceTurn]
final_reward: float
dpo_pairs: list[dict] | None
One training example: (turns, optional dpo_pairs). dpo_pairs shape matches DPOPair.
class TokenizerLike — trainer.data_collator
⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint).
class TokenizerLike:
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ...
def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ...
Minimal protocol the collator needs. Compatible with HF AutoTokenizer.
class CollatorConfig — trainer.data_collator
@dataclass
class CollatorConfig:
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
enable_replay_dpo: bool = True
rlvr_reward_key: str = "final_reward"
Tunables for ComposerDataCollator.
| Field | Default | Meaning |
|---|---|---|
max_seq_len |
4096 |
Truncation cap for student/teacher sequences. |
max_dpo_seq_len |
2048 |
Truncation cap for DPO chosen/rejected sequences. |
pad_token_id |
0 |
Padding token id. |
ignore_index |
-100 |
HF "ignore in loss" sentinel for SDPO mask. |
enable_sdpo |
True |
Toggle channel-2 fields. |
hint_generator |
Callable[[str, dict], str | None] | None (None) |
(error_kind, error_meta) -> hint_text. SDPO is no-op without this. |
enable_replay_dpo |
True |
Toggle channel-3 fields. |
rlvr_reward_key |
"final_reward" |
Key in TraceExample to read scalar reward. |
from composer_replication.trainer.data_collator import CollatorConfig
cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch)
class ComposerDataCollator — trainer.data_collator
@dataclass
class ComposerDataCollator:
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor]: ...
Build trainer-ready batches from raw traces + optional DPO pairs.
Output dict keys (tested in spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py):
- Channel 1 (always):
input_ids,attention_mask,response_mask,rewards. - Channel 2 (when
enable_sdpo=TrueAND batch has at least one error site ANDhint_generatoris set):ctx_teacher_input_ids,sdpo_loss_mask. - Channel 3 (when
enable_replay_dpo=TrueAND batch has at least onedpo_pair):dpo_chosen_input_ids,dpo_chosen_response_mask,dpo_rejected_input_ids,dpo_rejected_response_mask. (Reference logprobs are NOT computed here — the trainer does that pass.)
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig)
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}])
11. composer_replication.diloco
DiLoCo outer-loop wrapper around torchft.local_sgd.DiLoCo. Optional dep — when torchft is missing the package re-export composer_replication.make_diloco_outer_loop is None.
Module-level attributes
DiLoCo: Any—torchft.local_sgd.DiLoCoif importable elseNone.Manager: Any—torchft.manager.Managerif importable elseNone._DummyWork: Any—torchft.work._DummyWorkif importable elseNone._TORCHFT_AVAILABLE: bool— whether the imports succeeded.
from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo
make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo
def make_diloco_outer_loop(
manager: Any,
model_fragments: list[torch.nn.Module],
inner_optimizer: torch.optim.Optimizer,
*,
outer_lr: float = 0.7,
outer_momentum: float = 0.9,
nesterov: bool = True,
sync_every: int = 100,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> Any
Construct a torchft.DiLoCo configured with framework-default hyperparams (DiLoCo paper §3.2: lr=0.7, momentum=0.9, Nesterov).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
manager |
torchft.Manager (or duck-typed MockManager) |
— | Provides allreduce, should_commit, current_step, start_quorum, etc. |
model_fragments |
list[torch.nn.Module] |
— | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. |
inner_optimizer |
torch.optim.Optimizer |
— | Inner-step optimizer (steps every batch). |
outer_lr |
float |
0.7 |
Outer SGD lr. |
outer_momentum |
float |
0.9 |
Outer SGD momentum. |
nesterov |
bool |
True |
Nesterov momentum on outer SGD. |
sync_every |
int |
100 |
Inner steps per outer round. |
fragment_sync_delay |
int |
0 |
0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). |
fragment_update_alpha |
float |
0.0 |
0 = full replacement on sync; >0 = exponential mix. |
Returns a torchft.local_sgd.DiLoCo instance — usable as a context manager.
Raises RuntimeError if torchft is not installed.
import torch
from composer_replication.diloco import make_diloco_outer_loop
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model],
inner_optimizer=opt, sync_every=100)
with outer:
for _ in range(N):
opt.zero_grad(); loss.backward(); opt.step()
12. composer_replication.diloco.serverless
ADR-005 serverless DiLoCo executors + object-store all-reduce.
class ReplicaHandle — serverless.executor
@dataclass
class ReplicaHandle:
rank: int
backend_name: str
metadata: dict[str, Any] = field(default_factory=dict)
Opaque handle returned by ServerlessExecutor.launch_replicas. metadata is backend-specific.
from composer_replication.diloco.serverless import ReplicaHandle
h = ReplicaHandle(rank=0, backend_name="local_process",
metadata={"pid": 12345})
class ServerlessExecutor (Protocol) — serverless.executor
@runtime_checkable
class ServerlessExecutor(Protocol):
backend_name: str
supports_inter_replica_network: bool
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]: ...
def poll(self, handle: ReplicaHandle) -> str: ...
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ...
def cancel(self, handle: ReplicaHandle) -> None: ...
def collect(
self, handles: list[ReplicaHandle], *, timeout: int | None = None,
) -> list[dict[str, Any]]: ...
Structural protocol for serverless backends.
launch_replicas(...)returnslist[ReplicaHandle]of lengthn_replicasin rank order.entrypointis either an importable module path (usesmain()) or amodule.functionpath or aCallable(Local executor only).entrypoint_argsmay includerank_env(default"REPLICA_RANK").poll(handle) -> str: one of"pending","running","succeeded","failed","cancelled".stream_logs(handle, n_lines=200) -> str: best-effort recent stdout/stderr.cancel(handle) -> None: best-effort.collect(handles, timeout=None) -> list[dict]: blocks; each result dict hasrank,status,exit_code,error(andresultfromLocalProcessExecutor).
from composer_replication.diloco.serverless import ServerlessExecutor
def supports(x: ServerlessExecutor) -> bool:
return isinstance(x, ServerlessExecutor) # runtime_checkable
class LocalProcessExecutor — serverless.executor
class LocalProcessExecutor:
backend_name = "local_process"
supports_inter_replica_network = True
def __init__(self) -> None: ...
# implements ServerlessExecutor protocol
Reference implementation using Python multiprocessing (spawn context). Used for tests, CI smokes, and local development with file:// rendezvous.
launch_replicas(...): emits a soft warning on gpu != None (local processes share whatever GPUs are visible). metadata = {"pid": ..., "start_ts": ...}.
from composer_replication.diloco.serverless import LocalProcessExecutor
ex = LocalProcessExecutor()
handles = ex.launch_replicas(
n_replicas=2,
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2,
"trainer_module": "my.trainer"},
)
results = ex.collect(handles, timeout=60)
class ObjectStoreAllReduce — serverless.allreduce
class ObjectStoreAllReduce:
def __init__(
self,
uri: str,
rank: int,
world_size: int,
*,
round_id: int | None = None,
timeout_s: float = 1800.0,
poll_interval_s: float = 1.0,
) -> None: ...
@property
def round_id(self) -> int: ...
def allreduce(
self, tensor: torch.Tensor, *, name: str | None = None,
) -> torch.Tensor: ...
fsspec-backed pseudo-gradient rendezvous. uri accepts s3://, gs://, az://, hf://, file://, or a plain local path.
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
uri |
str |
— | fsspec URI or local path. Trailing / enforced. |
rank |
int |
— | This replica's rank. |
world_size |
int |
— | Total replicas. |
round_id |
int | None (kw-only) |
None ⇒ start at 0 |
Initial round counter. |
timeout_s |
float (kw-only) |
1800.0 |
Per-allreduce timeout. |
poll_interval_s |
float (kw-only) |
1.0 |
Sleep between peer-file existence checks. |
allreduce(tensor, name=None) -> torch.Tensor: serializes tensor.detach().cpu() to round_NNNNNN/rank_RRRR.pt, blocks until all peers post, then averages. Modifies tensor in place AND returns it. Increments the internal _round_counter.
Raises ValueError on invalid rank, RuntimeError if non-local URI is requested without fsspec installed, TimeoutError if peers don't show up before timeout_s.
from composer_replication.diloco.serverless import ObjectStoreAllReduce
import torch
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
g = torch.zeros(10)
store.allreduce(g) # blocks for rank 1
class MockManager — serverless.allreduce
class MockManager:
def __init__(self, store: ObjectStoreAllReduce) -> None: ...
# torchft.Manager-shaped surface:
num_participants: int
rank: int
_use_async_quorum: bool # always False
_step: int
_state_dict_fns: dict[str, tuple[Any, Any]]
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ...
def should_commit(self) -> bool: ...
def start_quorum(self) -> None: ...
def wait_quorum(self) -> int: ...
def current_step(self) -> int: ...
def allow_state_dict_read(self) -> None: ...
def disallow_state_dict_read(self) -> None: ...
def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ...
def is_leader(self) -> bool: ...
Drop-in replacement for torchft.Manager that routes allreduce through ObjectStoreAllReduce. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo.
allreduce(tensor)returns an_ImmediateWorkwhose.wait()is a no-op (the tensor is already averaged).should_commit()alwaysTrue(no fault-tolerance failover).start_quorum()bumps_step.is_leader()returnsrank == 0.
from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
mgr = MockManager(store)
# pass mgr into make_diloco_outer_loop(manager=mgr, ...)
class _ImmediateWork — serverless.allreduce
⚠️ UNTESTED-CONTRACT internal helper exported from __all__. Work-shaped wrapper with .wait() -> True and .get_future() -> torch.futures.Future. Consumed by torchft DiLoCo's perform_sync.
from composer_replication.diloco.serverless.allreduce import _ImmediateWork
class ModalExecutor — serverless.modal
🟡 SKELETON — raises NotImplementedError; see ADR-005. Class body documents the v0 implementation pattern (Modal app.function + function.spawn(rank=...)).
from composer_replication.diloco.serverless.modal import ModalExecutor
# ModalExecutor() # would NotImplementedError when instantiated
class HFJobsExecutor — serverless.hf_jobs
🟡 SKELETON — raises NotImplementedError; see ADR-005. Class body documents the v0 pattern using huggingface_hub.run_job against hf://datasets/.../ rendezvous.
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
# instantiation will fail until v0 implementation lands
replica_entrypoint.main(...) — serverless.replica_entrypoint
def main(
rendezvous_uri: str,
world_size: int,
trainer_module: str,
trainer_fn: str = "train",
trainer_kwargs: dict[str, Any] | None = None,
) -> Any
Script run by every replica. Reads REPLICA_RANK env var, builds ObjectStoreAllReduce + MockManager, imports trainer_module, and calls getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...). Returns whatever the train fn returns.
Raises RuntimeError if REPLICA_RANK env var is missing; ValueError if rank ∉ [0, world_size).
The if __name__ == "__main__" block accepts CLI flags --rendezvous, --world-size, --trainer-module, --trainer-fn, --trainer-kwargs-json.
# In-process invocation
import os
os.environ["REPLICA_RANK"] = "0"
from composer_replication.diloco.serverless.replica_entrypoint import main
result = main(rendezvous_uri="/tmp/run/", world_size=1,
trainer_module="my.trainer", trainer_fn="train")
13. composer_replication.recipes.prime_rl.composer_loss
PRIME-RL adapter (ADR-006). Maps PRIME-RL's LossInputs struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream default_loss_fn at prime_rl/trainer/rl/loss.py lines 116-165). Channel 2 raises NotImplementedError; channel 3 is out of scope.
loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor
def loss_fn(
inputs: Any, # PRIME-RL's LossInputs (duck-typed)
*,
alpha_sdpo: float = 0.0,
beta_dpo: float = 0.0,
dppo_mask_high: float = 0.2,
dppo_mask_low: float = 0.2,
adv_tau: float = 1.0,
kl_tau: float = 1e-3,
) -> Any # torch.Tensor scalar
PRIME-RL passes per-sample 1-D (seq,) tensors (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula:
- Mask gate is on probability-space
probs_diff = exp(trainer_lp) - exp(inference_lp)(NOT on the log-ratio). - A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when
probs_diff > dppo_mask_high, negative-advantage tokens whenprobs_diff < -dppo_mask_low. (PRIME-RL stores both bounds withField(..., ge=0)and applies the sign internally.) - The PG term is
keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)(importance-ratio corrected, not REINFORCE). - A KL penalty
kl_tau * log_importance_ratio**2is added on the fullloss_mask(DPPO masking does not gate it). - Reduction is a plain
sum(); PRIME-RL's outercompute_lossdivides byloss_scale.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
inputs |
PRIME-RL LossInputs (duck-typed) |
— | Must expose trainer_logprobs, inference_logprobs, advantages, loss_mask (all 1-D), and optionally teacher_logprobs. |
alpha_sdpo |
float (kw-only) |
0.0 |
Channel-2 weight. Must be 0 in v0; >0 → NotImplementedError. |
beta_dpo |
float (kw-only) |
0.0 |
Channel-3 weight. Non-zero emits a UserWarning. |
dppo_mask_high |
float (kw-only), >= 0 |
0.2 |
Upper probability-diff threshold. PRIME-RL DefaultLossConfig default. |
dppo_mask_low |
float (kw-only), >= 0 |
0.2 |
Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. |
adv_tau |
float (kw-only), >= 0 |
1.0 |
Advantage temperature. PRIME-RL default. |
kl_tau |
float (kw-only), >= 0 |
1e-3 |
KL term temperature. PRIME-RL default. |
Returns scalar torch.Tensor (PRIME-RL's trainer calls .backward()).
Raises ValueError if any of trainer_logprobs, inference_logprobs, advantages, loss_mask is not 1-D, or any of the four >=0-constrained knobs is negative. NotImplementedError if alpha_sdpo > 0 (channel 2 deferred).
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
# In PRIME-RL config:
# loss:
# custom:
# import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
# kwargs:
# dppo_mask_high: 0.2
# dppo_mask_low: 0.2
# adv_tau: 1.0
# kl_tau: 1.0e-3
14. composer_replication.recipes.monarch.actors
🟡 SKELETON module per ADR-006. Importable; classes raise NotImplementedError on instantiation. Documents the actor signatures so the recipe matrix is complete.
class TrainerActor 🟡
class TrainerActor:
backend = "monarch"
role = "trainer"
def __init__(self) -> None: raise NotImplementedError(...)
async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError
Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+.
class GeneratorActor 🟡
class GeneratorActor:
backend = "monarch"
role = "generator"
def __init__(self) -> None: raise NotImplementedError(...)
async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError
vLLM-backed rollout actor.
class RewarderActor 🟡
class RewarderActor:
backend = "monarch"
role = "rewarder"
def __init__(self) -> None: raise NotImplementedError(...)
async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError
verifiers-protocol rewarder.
class TeacherPoolActor 🟡
class TeacherPoolActor:
backend = "monarch"
role = "teacher_pool"
def __init__(self) -> None: raise NotImplementedError(...)
Channel-3 teacher pool wrapping composer_replication.teacher_replay.
# All Monarch actors raise on instantiation in v0:
from composer_replication.recipes.monarch.actors import TrainerActor
# TrainerActor() # NotImplementedError
15. composer_replication.diloco.serverless — cloud executors
ADR-005 production cloud executors that implement the ServerlessExecutor Protocol (see §12) against real clouds. Both are the loud-success siblings of the ModalExecutor / HFJobsExecutor 🟡 skeletons: cross-replica communication is EXCLUSIVELY the S3 ObjectStoreAllReduce rendezvous, so the trainer / loss / DiLoCo math stay byte-for-byte identical regardless of backend. Both set supports_inter_replica_network = False (replicas rendezvous only through the object store) and both lazy-import their cloud SDK so import composer_replication.diloco.serverless is free of kubernetes / boto3.
Extras: EKSExecutor needs kubernetes>=29 (pip install -e .[eks]); SageMakerExecutor needs boto3>=1.34 (pip install -e .[aws]). The [serverless] extra pulls both (plus fsspec + huggingface_hub for the rendezvous backends). The SDK is required only at adapter-init / method time — when an api/client is injected (tests), a missing top-level package is tolerated.
class EKSExecutor — serverless.eks
class EKSExecutor:
backend_name = "eks"
supports_inter_replica_network = False # S3 rendezvous only
def __init__(
self,
image: str,
*,
namespace: str = "default",
service_account_name: str | None = None,
node_selector: dict[str, str] | None = None,
tolerations: list[Any] | None = None,
runtime_class_name: str | None = None,
command: list[str] | None = None,
cpu_request: str = "4",
memory_request: str = "16Gi",
ttl_seconds_after_finished: int = 3600,
backoff_limit: int = 0,
gpu_resource_key: str = "nvidia.com/gpu",
run_id: str | None = None,
batch_api: Any = None,
core_api: Any = None,
) -> None: ...
# implements ServerlessExecutor protocol (launch_replicas/poll/stream_logs/cancel/collect)
Run N DiLoCo replicas as a single Kubernetes Indexed Job on EKS (completions == parallelism == n_replicas, completionMode="Indexed"). The control plane assigns each pod a JOB_COMPLETION_INDEX 0..N-1 which IS the rank; the executor bridges it to the repo entrypoint's REPLICA_RANK env var via the downward API, so replica_entrypoint works unchanged. launch_replicas creates ONE Job but still returns N ReplicaHandles (handles[i].rank == i) sharing the same job_name/namespace (gang semantics).
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
image |
str |
— | Container image with composer_replication installed; runs the replica entrypoint. |
namespace |
str (kw-only) |
"default" |
k8s namespace for the Job. |
service_account_name |
str | None (kw-only) |
None |
ServiceAccount referenced on the PodSpec for IRSA / EKS Pod Identity S3 access. REFERENCED only — never created. |
node_selector |
dict[str,str] | None (kw-only) |
None |
Extra node selector merged under the GPU node selector. |
tolerations |
list[Any] | None (kw-only) |
None |
PodSpec tolerations; a nvidia.com/gpu NoSchedule toleration is auto-added on GPU requests when none supplied. |
runtime_class_name |
str | None (kw-only) |
None |
Pre-existing RuntimeClass ("gvisor" / "kata"). Combining with gpu is advanced/unverified (see source warning). |
command |
list[str] | None (kw-only) |
None ⇒ ["python","-m","composer_replication.diloco.serverless.replica_entrypoint"] |
Container command. |
cpu_request / memory_request |
str (kw-only) |
"4" / "16Gi" |
PodSpec resource requests. |
ttl_seconds_after_finished |
int (kw-only) |
3600 |
Auto-delete the finished Job (cascading) after this many seconds. |
backoff_limit |
int (kw-only) |
0 |
Job retry budget (fail-fast; NOT the k8s default of 6). |
gpu_resource_key |
str (kw-only) |
"nvidia.com/gpu" |
GPU resource key. |
run_id |
str | None (kw-only) |
None ⇒ "diloco" |
Prefix baked into the generated Job name. |
batch_api / core_api |
Any (kw-only) |
None |
DI'd BatchV1Api / CoreV1Api; lazily built (in-cluster then kube-config) when None. Tests inject mocks. |
ServerlessExecutor Protocol methods (see §12 for the full Protocol)
launch_replicas(n_replicas, entrypoint, entrypoint_args, *, gpu=None, timeout=3600) -> list[ReplicaHandle]— creates ONE Indexed Job;entrypointis ignored (k8s runs the fixed container command), scalarentrypoint_args(e.g.rendezvous_uri) are passed as upper-cased literal env vars,gpumaps to anvidia.com/gpulimit + node selector via the_GPU_SPEC_TABLE("A100"/"H100"/"A10G"/"T4"), andtimeoutbecomes the Job'sactive_deadline_secondshard wall-clock kill.poll(handle) -> str— reads the single Job status and maps THIS rank:rank in completed_indexes⇒"succeeded",rank in failed_indexes⇒"failed", whole-JobFailedcondition ⇒"failed",active > 0⇒"running", else"pending"; a 404 (Job gone) ⇒"cancelled".stream_logs(handle, *, n_lines=200) -> str— finds the pod by completion-index annotation/label (or the<job>-<rank>-name prefix), tails its log; returns a placeholder string (never raises) when the pod has not started.cancel(handle) -> None— deletes the WHOLE shared Indexed Job (propagation_policy="Background"so pods are cascadingly deleted, not orphaned). Intentional gang teardown; idempotent (404 swallowed, unknown handle is a no-op).collect(handles, *, timeout=None) -> list[dict]— polls (sleeping between, status is eventually consistent) until every rank is terminal or the deadline; per-rank dict is{"rank","status","exit_code","error","job_name"}(exit_code0/1/None).
Raises RuntimeError at construction if the kubernetes client is absent AND no api was injected; ValueError if n_replicas < 1.
from composer_replication.diloco.serverless.eks import EKSExecutor
ex = EKSExecutor(image="ghcr.io/me/diloco:latest",
service_account_name="diloco-s3")
handles = ex.launch_replicas(
n_replicas=4,
entrypoint="", # ignored on k8s
entrypoint_args={"rendezvous_uri": "s3://bucket/run/",
"trainer_module": "my.trainer", "world_size": 4},
gpu="H100",
)
results = ex.collect(handles, timeout=3600)
class SageMakerExecutor — serverless.sagemaker
class SageMakerExecutor:
backend_name = "sagemaker"
supports_inter_replica_network = False # S3 rendezvous only
def __init__(
self,
*,
role_arn: str,
image_uri: str,
output_s3_path: str,
instance_type: str = "ml.g5.2xlarge",
cpu_instance_type: str = "ml.m5.xlarge",
volume_size_gb: int = 100,
run_id: str | None = None,
region: str | None = None,
sagemaker_client: Any = None,
logs_client: Any = None,
) -> None: ...
# implements ServerlessExecutor protocol
Run replicas as N independent single-instance SageMaker Training Jobs (NOT one multi-instance job — that would couple replicas through SageMaker's intra-job NCCL/MPI fabric and break the "each replica syncs only through S3" design). Rank goes through the per-job Environment map (REPLICA_RANK=i / WORLD_SIZE=N), so replica_entrypoint reads it unchanged. Pins EnableNetworkIsolation=False (never a knob) — True would sever the container's S3 access and dead-lock the allreduce poll loop.
Constructor parameters (all kw-only)
| Name | Type | Default | Meaning |
|---|---|---|---|
role_arn |
str |
— | IAM execution role SageMaker assumes; must grant S3 to the rendezvous + output buckets (the boto3 analog of EKS IRSA). Caller needs iam:PassRole. |
image_uri |
str |
— | ECR training-container image. The executor also passes ContainerEntrypoint explicitly so a generic image works. |
output_s3_path |
str |
— | s3://... prefix for OutputDataConfig.S3OutputPath. |
instance_type |
str |
"ml.g5.2xlarge" |
Default instance type when gpu is unmapped. |
cpu_instance_type |
str |
"ml.m5.xlarge" |
Instance type used when gpu=None (CPU smoke). |
volume_size_gb |
int |
100 |
ResourceConfig.VolumeSizeInGB per job. |
run_id |
str | None |
None ⇒ random token |
Prefix for generated training-job names. |
region |
str | None |
None |
AWS region for the lazy boto3 clients (None ⇒ ambient default). |
sagemaker_client |
Any |
None |
Inject a pre-built boto3.client("sagemaker") (or a mock); built lazily otherwise. |
logs_client |
Any |
None |
Inject a pre-built boto3.client("logs"); built lazily on first stream_logs. |
ServerlessExecutor Protocol methods
launch_replicas(n_replicas, entrypoint, entrypoint_args, *, gpu=None, timeout=3600) -> list[ReplicaHandle]— submits N independent jobs;entrypointis ignored (container command is baked / passed explicitly).entrypoint_argsMUST includerendezvous_uri(s3://...) andtrainer_module; optionaltrainer_fn(default"train") andtrainer_kwargs(dict, JSON-encoded intoContainerArguments).gpumaps to an instance type via_GPU_INSTANCE_MAP("A100"/"H100"/"H200"/"B200"/"L40S"/"A10G"/"L4"; a literalml.*string is honoured);timeout⇒StoppingCondition.MaxRuntimeInSeconds. On mid-launch failure it best-effort stops already-launched siblings then raises.poll(handle) -> str— mapsdescribe_training_job'sTrainingJobStatusvia_STATUS_MAP(InProgress⇒running,Completed⇒succeeded,Failed⇒failed,Stopping⇒running,Stopped⇒cancelled); refinesInProgress⇒"pending"while still queued (SecondaryStatusin_PENDING_SECONDARY); a vanished job (ResourceNotFound) ⇒"cancelled".stream_logs(handle, *, n_lines=200) -> str— discovers the CloudWatch stream under/aws/sagemaker/TrainingJobsby<job-name>/prefix and tails it; falls back to a CloudWatch console URL on any error.cancel(handle) -> None— best-effortstop_training_job(swallowsResourceNotFound/ already-terminalValidationException).collect(handles, *, timeout=None) -> list[dict]— polls per handle until terminal (Completed/Failed/Stopped) or the shared deadline; results aligned to input order; dict is{"rank","status","exit_code","error","result","training_job_name"}(resultis theS3ModelArtifactspath).
Raises RuntimeError if boto3 is absent and no client was injected; ValueError if n_replicas < 1 or entrypoint_args lacks rendezvous_uri / trainer_module.
from composer_replication.diloco.serverless.sagemaker import SageMakerExecutor
ex = SageMakerExecutor(
role_arn="arn:aws:iam::123:role/sm-exec",
image_uri="123.dkr.ecr.us-east-1.amazonaws.com/diloco:latest",
output_s3_path="s3://bucket/out/", region="us-east-1")
handles = ex.launch_replicas(
n_replicas=2, entrypoint="",
entrypoint_args={"rendezvous_uri": "s3://bucket/run/",
"trainer_module": "my.trainer"},
gpu="A100")
results = ex.collect(handles)
16. composer_replication.datagen.docker_sandbox
ADR-010 §3 hardened container backend for the FeatureDeletion (FD) env. DockerSandbox is a drop-in implementation of the Sandbox Protocol (composer_replication.datagen.sandbox.Sandbox) that runs the agent's tool calls and the verifiable test command inside an ephemeral, locked-down Docker container instead of a raw host subprocess. It is the production execution path for genuinely UNTRUSTED model-generated code (the LocalSubprocessSandbox sibling runs in the host process and is NOT a host-security boundary).
Extra: needs docker>=7 (pip install -e .[datagen] or pip install docker). The SDK is lazy-imported inside methods so the pure-Python core and the FakeSandbox path never require it; a clear RuntimeError is raised if the SDK or the daemon is absent.
The Sandbox Protocol — datagen.sandbox
@runtime_checkable
class Sandbox(Protocol):
def boot(self, image: str) -> None: ...
def exec(self, action: dict) -> str: ...
def run_tests(self, test_command: str, tests: tuple[str, ...]) -> TestRunResult: ...
def trajectory(self) -> list[dict]: ...
def is_command_allowed(self, command: str) -> bool: ...
Structural protocol every FD execution backend implements (DockerSandbox, LocalSubprocessSandbox, FakeSandbox). boot prepares the execution environment, exec runs one tool-call action dict ({"command": ...}) and returns combined stdout/stderr, run_tests runs the verifiable pytest command over the given node ids and returns a TestRunResult, trajectory returns the recorded action list, is_command_allowed is the first-token denylist check.
class DockerSandbox
@dataclass
class DockerSandbox:
workdir: str
runtime: str | None = None
mem_limit: str = "1g"
memswap_limit: str = "1g"
pids_limit: int = 256
nano_cpus: int = 2_000_000_000 # 2 CPUs
user: str = "1000:1000"
container_workdir: str = "/work"
tmpfs_size: str = "64m"
exec_timeout_s: int = 600
keep_root_writable: bool = False
def container_kwargs(self, image: str) -> dict: ...
# Sandbox Protocol:
def boot(self, image: str) -> None: ...
def exec(self, action: dict) -> str: ...
def run_tests(self, test_command: str, tests: tuple[str, ...]) -> TestRunResult: ...
def trajectory(self) -> list[dict]: ...
def is_command_allowed(self, command: str) -> bool: ...
def close(self) -> None: ...
@staticmethod
def reap_leaked(client=None) -> int: ...
Hardened ephemeral-container Sandbox. The lockdown recipe (CIS Docker 5.x + gVisor guidance): network_disabled=True + network_mode="none" (no egress — the reward-hack exfil control), read_only=True root fs + small tmpfs for /tmp, workdir bind-mounted RW at /work, cap_drop=["ALL"] + security_opt=["no-new-privileges:true"], non-root user, pids_limit (fork-bomb guard) + mem_limit==memswap_limit (OOM, no swap) + nano_cpus (CPU quota). The PRIMARY reward-hack control is the host-side scrub_tree(workdir) run in boot() BEFORE the container starts (the bind mount is shared, so host-pre-boot scrubbing == in-container scrubbing); the command denylist is cheap defense-in-depth, not the wall.
Constructor fields
| Field | Type | Default | Meaning |
|---|---|---|---|
workdir |
str |
— | Host path to the materialized repo; bind-mounted RW at /work and scrubbed on the host before boot (PRIMARY control). Must exist by boot() time. |
runtime |
str | None |
None ⇒ daemon default (runc) |
"runsc" (gVisor) for untrusted model code; requires host sudo runsc install + dockerd restart. Gate with runsc_available(). |
mem_limit / memswap_limit |
str |
"1g" / "1g" |
OOM guard; equal values forbid swap. |
pids_limit |
int |
256 |
Fork-bomb guard. |
nano_cpus |
int |
2_000_000_000 |
CPU quota in 1e-9 CPUs (2 CPUs). |
user |
str |
"1000:1000" |
Non-root uid:gid the agent code runs as. |
container_workdir |
str |
"/work" |
Bind-mount target + working dir. |
tmpfs_size |
str |
"64m" |
Size of the /tmp tmpfs scratch. |
exec_timeout_s |
int |
600 |
Wall-clock cap injected via coreutils timeout (exec_run has no timeout param — docker-py #2651). |
keep_root_writable |
bool |
False |
Escape hatch if read-only fs breaks tooling. |
Sandbox Protocol methods
boot(image) -> None— scrubs the host workdir (primary control), reaps leaked siblings, then starts the hardened container. RaisesRuntimeErrorif the image is not found locally (the container is--network none, so it cannot pull) or on a Docker API error.exec(action) -> str— records theaction, denylist-checks its first token, runs the command in the live container (combined stdout/stderr, non-UTF-8 bytes decoded witherrors="replace"). Returns anERROR:string for a denied command.run_tests(test_command, tests) -> TestRunResult—shlex.quotes each node id, runs the verifiable command, parses pass/fail conservatively (PASSED-or-rc0 ⇒ passed;errors during collection⇒collected_ok=False).trajectory() -> list[dict]— copy of the recorded action list.is_command_allowed(command) -> bool— first-token denylist (SANDBOX_DENYLIST); not a boundary on its own.
Lifecycle: close() force-removes the container (idempotent, swallows errors); reap_leaked(client=None) -> int (staticmethod) sweeps containers labelled composer_replication=datagen and returns the count removed. Also a context manager (__enter__ / __exit__ ⇒ close()).
Module helpers: runsc_available(client=None) -> bool (True iff the gVisor runsc runtime is registered with the daemon — gate any runsc behavior on this); _require_docker() / _make_client() (lazy SDK import + daemon ping, each raising a clear RuntimeError).
from composer_replication.datagen.docker_sandbox import DockerSandbox, runsc_available
runtime = "runsc" if runsc_available() else None
with DockerSandbox(workdir="/tmp/repo", runtime=runtime) as sb:
sb.boot("python:3.12-slim")
out = sb.exec({"command": "python -c 'print(1+1)'"})
res = sb.run_tests("python -m pytest -q", ("tests/test_x.py::test_a",))
17. composer_replication.safety & composer_replication.safety.kill_switch
ADR-015 run-level collapse safeguard — the #2 collapse safety net for the self-evolving RL flywheel. The per-task controls live in composer_replication.datagen (4-gate validator, HackMonitor provenance, sandbox denylist); this package adds the missing ACROSS-GENERATION / run-level control that watches in-loop (proxy) reward against a disjoint held-out (real) eval and HALTS the run when collapse / reward-hacking is caught in the act. Pure-Python: no torch, no cloud deps; fully CPU-testable (kl_to_init is just a float the caller computes upstream).
Public surface (composer_replication.safety.__all__): HeldOutGuard, TripwireStatus, CollapseStopError, kl_token_trust_filter.
class TripwireStatus — safety.kill_switch
@dataclass(frozen=True)
class TripwireStatus:
fire: bool
reason: str
step: int
proxy_real_gap: float
in_loop_ema: float
heldout_ema: float
kl_ema: float | None = None
@property
def halt(self) -> bool: ... # documented alias for `fire`
Structured verdict returned by every HeldOutGuard.update(...). fire (alias halt) ⇒ the run should HALT; reason is the human-readable WHY (empty when not firing); proxy_real_gap is the RSI "Hacking Gap" (in-loop gain − held-out gain since baseline); *_ema are the denoised metric EMAs.
class CollapseStopError(RuntimeError) — safety.kill_switch
class CollapseStopError(RuntimeError):
def __init__(self, status: TripwireStatus) -> None: ...
status: TripwireStatus
Typed exception for exception-based trainer control flow; carries the fired TripwireStatus for logging. Raised (optionally) via HeldOutGuard.raise_if_fired(...).
class HeldOutGuard — safety.kill_switch
@dataclass
class HeldOutGuard:
kl_hard_stop: float = 0.08 # nats/token; top of GRPO "good" band
max_proxy_real_gap: float = 0.10 # absolute Hacking-Gap blowout ceiling
min_steps: int = 20 # no fire before this many updates
decline_patience: int = 3 # consecutive held-out declines to fire (a)
ema_alpha: float = 0.9 # EMA weight on the PRIOR (0.9 => slow)
rise_eps: float = 1e-4 # min EMA delta to count as rising/declining
def update(self, round_idx: int, in_loop_reward: float, heldout_score: float,
kl_to_init: float | None = None, entropy: float | None = None,
reward_std: float | None = None) -> TripwireStatus: ...
def should_halt(self) -> bool: ...
@property
def last_status(self) -> TripwireStatus | None: ...
def raise_if_fired(self, status: TripwireStatus | None = None) -> None: ...
def proxy_real_gap(self) -> float: ...
def calibrate_kl_threshold(self, baseline_kls: list[float], factor: float = 3.0) -> float: ...
Stateful across-generation collapse / reward-hacking kill-switch. Call update(...) once per checkpoint (the same cadence as DifficultyCurriculum.update); it folds each metric into a denoised EMA and returns a TripwireStatus.
Fires (fire=True) when ANY of three conditions (none before min_steps, which guards early-run noise):
- (a) collapse-caught-in-the-act — the in-loop reward EMA is RISING while the held-out score EMA has DECLINED for
>= decline_patienceconsecutive checkpoints (the canonical reward-hacking signature: proxy up, real down). - (b) KL hard stop — the
kl_to_initEMA exceedskl_hard_stop(default 0.08 nats/token, the top of the GRPO "good progression" band). Checked first (cheapest unambiguous breach). - (c) proxy-real gap blowout — the Hacking Gap (proxy gain − real gain since baseline) widens beyond
max_proxy_real_gap, catching a fast single-generation divergence even before the full decline window.
Once fired, the verdict is latched — every subsequent update keeps fire=True so a transient recovery cannot silently un-halt the run.
Constructor thresholds
| Field | Type | Default | Meaning |
|---|---|---|---|
kl_hard_stop |
float |
0.08 |
Per-token KL(policy‖init) hard-stop ceiling, nats/token (condition (b)). Must be > 0. |
max_proxy_real_gap |
float |
0.10 |
Absolute Hacking-Gap blowout ceiling (condition (c)). |
min_steps |
int |
20 |
No tripwire fires before this many update calls. |
decline_patience |
int |
3 |
Consecutive held-out declines (while in-loop rises) to fire (a). Must be >= 1. |
ema_alpha |
float |
0.9 |
EMA weight on the PRIOR (0.9 ⇒ slow). Must be in [0, 1). |
rise_eps |
float |
1e-4 |
Min EMA delta to count as "rising" / "declining". |
API
update(round_idx, in_loop_reward, heldout_score, kl_to_init=None, entropy=None, reward_std=None) -> TripwireStatus— fold one checkpoint's metrics and return the verdict.round_idxis for logging only (the internal counter_ndrivesmin_steps).kl_to_initis token-mean KL in nats/token (this repo'stoken_mean_klconvention) — do NOT pass a sequence-summed KL (it will fire instantly).entropy/reward_stdare tracked + exposed but not hard gates.should_halt() -> bool— True iff the most recentupdatefired. Idempotent (does not advance EMA state).last_status -> TripwireStatus | None(property) — the most recent verdict, orNonebefore the firstupdate.raise_if_fired(status=None) -> None— convert a fired verdict (the passed status, orlast_status) into aCollapseStopError; a no-op otherwise. For exception-based loops.proxy_real_gap() -> float— the RSI Hacking Gap (EMA-minus-baseline, both since run start);0.0before the firstupdate.calibrate_kl_threshold(baseline_kls, factor=3.0) -> float— setkl_hard_stopfrom early-run baseline KLs (factor× mean). SAFETY CLAMP: only ever TIGHTENS (min(factor*mean, current)), never loosens past the documented band. RaisesValueErroron emptybaseline_kls.
Raises ValueError at construction if ema_alpha ∉ [0, 1), kl_hard_stop <= 0, or decline_patience < 1.
HeldoutSplit discipline (design-of-record).
heldout_scoremust come from a DISJOINT held-out eval pool — REAL tasks the generator NEVER trains on. If the held-out set drifts with the train set, the proxy-real gap signal is meaningless. See ADR-015 and thesafety.holdoutdesign notes referenced from the module docstring.
from composer_replication.safety import HeldOutGuard
guard = HeldOutGuard(kl_hard_stop=0.08)
for rnd in range(num_generations):
status = guard.update(rnd, in_loop_reward=r_proxy, heldout_score=r_real,
kl_to_init=token_mean_kl_value)
if status.fire: # or: guard.should_halt()
guard.raise_if_fired(status) # -> CollapseStopError
kl_token_trust_filter(logratio_sq_half, threshold=0.08) -> bool — safety.kill_switch
def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool
Per-token KL trust-region mask mirroring torchrl's GRPO "KL-Mask". The caller computes 0.5 * (log π/π_ref)**2 (the Schulman k2 per-token KL estimator, nats); this returns True when the token should be MASKED OUT (its KL contribution exceeds threshold). Pulls no torch into the module — wire it into a loss downstream.
from composer_replication.safety import kl_token_trust_filter
mask_out = kl_token_trust_filter(0.5 * logratio**2, threshold=0.08)
Notes on test coverage
Tested contracts (referenced spike/test paths):
compose_loss+LossComponents+build_batch:composer_replication/tests/test_compose_loss_integration.py,spikes/006-real-hf-model-smoke/tests/.generalized_jsd_loss:spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py.simpo_loss,taid_loss,taid_alpha_schedule,taid_blended_logits,entropy_aware_opd_loss:composer_replication/distillation/tests/test_distillation_losses.py.replay_trace,extract_dpo_pairs,DPOPair,TraceState,TeacherCallResult,TeacherSpec,DEFAULT_TEACHERS:spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py.DJNormalizer,NormalizedDPOPair,replay_and_normalize_trace:composer_replication/replaysim/tests/test_replaysim.py.ClaudeCodeIngester,IngestionStats,SYSTEM_PROMPT:spikes/007-real-trace-ingestion/tests/.ComposerDataCollator,CollatorConfig,TraceTurn,TraceExample:spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py.ComposerReplicationTrainer._compute_loss(composition arithmetic):spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py.make_diloco_outer_loop+ sign convention:spikes/008-streaming-diloco/tests/test_diloco_smoke.py.ObjectStoreAllReduce,MockManager,LocalProcessExecutor,ReplicaHandle,ServerlessExecutor,replica_entrypoint.main:composer_replication/diloco/serverless/tests/test_serverless_local.py,test_serverless_diloco_integration.py.recipes.prime_rl.composer_loss.loss_fn:composer_replication/recipes/prime_rl/tests/test_composer_loss.py.EKSExecutor,SageMakerExecutor(§15):composer_replication/diloco/serverless/tests/(DI'dbatch_api/core_api/sagemaker_clientmocks; no live cloud).DockerSandbox(§16) —container_kwargslockdown config asserted without a live daemon; daemon-gated paths incomposer_replication/datagen/tests/.HeldOutGuard,TripwireStatus,CollapseStopError,kl_token_trust_filter(§17):composer_replication/safety/tests/(pure-Python; CPU-only).
Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above.
Document path: docs/API_REFERENCE.md (repo-relative)