#!/usr/bin/env python3 """ Apply generated SwiGLU MLP weights to a Gemma 4 31B safetensors model. Layer files contain gate_proj.weight / up_proj.weight / down_proj.weight as pre-computed delta tensors — fused via Shape-Contoured Fusion (SCF). SCF replaces the old naive additive delta approach: - down_proj : contoured multiplicative delta (dynamic_alpha * delta * W_existing) - gate_proj : multiplicative gamma scaling (W * (1 + clamp(delta, +/-gamma_cap))) - up_proj : intentionally unchanged (linear path, as in fuzer.py) Gemma 4 31B interleaved attention: 5 SWA + 1 global per period (60 layers total). Global layers (5, 11, 17, 23, 29, 35, 41, 47, 53, 59) may carry double-wide MLP tensors; partial coverage is handled transparently via row/col clamping. """ import argparse import json import shutil from pathlib import Path import numpy as np import torch from safetensors.torch import load, load_file, save_file PROJ_KEYS = ("gate_proj.weight", "up_proj.weight", "down_proj.weight") INTERLEAVE_PERIOD = 6 GLOBAL_LAYER_OFFSET = 5 def is_global_attention_layer(layer_idx: int) -> bool: return ( layer_idx >= GLOBAL_LAYER_OFFSET and (layer_idx - GLOBAL_LAYER_OFFSET) % INTERLEAVE_PERIOD == 0 ) def detect_key_prefix(tensor_keys, layer_idx: int, proj: str) -> str: """Dynamically locate the exact key prefix in the target file. Gemma 4 is a VLM: always prefer language_model matches over vision tower. """ suffix = f"layers.{layer_idx}.mlp.{proj}" matches = [k for k in tensor_keys if k.endswith(suffix)] for k in matches: if "language_model" in k: return k[: -len(suffix)] if matches: return matches[0][: -len(suffix)] return "model.language_model.model." def discover_generated_layers(weights_dir: Path) -> dict: layers = {} for f in sorted(weights_dir.glob("layer_*.safetensors")): try: idx = int(f.stem.split("_")[1]) layers[idx] = f except (IndexError, ValueError): continue return layers # --------------------------------------------------------------------------- # Shape-Contoured Fusion applied to pre-computed delta tensors # --------------------------------------------------------------------------- def fuse_layer_deltas( layer_idx: int, gate_w: torch.Tensor, # float32, modified in-place up_w: torch.Tensor, # float32, intentionally NOT modified down_w: torch.Tensor, # float32, modified in-place new_weights: dict, args: argparse.Namespace, ) -> None: """ Apply SCF to one layer using pre-computed delta tensors. down_proj -- contoured additive: delta is scaled by the existing weight profile so the update respects the model's learned contour. dynamic_alpha is variance-normalised so scale stays consistent across layers regardless of initialisation. gate_proj -- multiplicative gamma: gamma = 1 + clamp(delta, +-gamma_cap) Matches fuzer's W*gamma pattern without needing raw adapter weights. up_proj -- unchanged: Linear value path in SwiGLU must not receive non-linear scaling. Intentional, mirrors fuzer's explicit decision. """ # down_proj: contoured multiplicative delta if "down_proj.weight" in new_weights: delta_down = new_weights["down_proj.weight"].float() nr = min(delta_down.shape[0], down_w.shape[0]) nc = min(delta_down.shape[1], down_w.shape[1]) fan_in = down_w.shape[1] expected_var = 1.0 / fan_in down_var = down_w[:nr, :nc].var().item() dynamic_alpha = float(np.clip( args.alpha * (down_var / (expected_var + 1e-8)), args.alpha * 0.1, args.alpha * 10.0, )) contoured = dynamic_alpha * delta_down[:nr, :nc] * down_w[:nr, :nc] down_w[:nr, :nc] = down_w[:nr, :nc] + contoured if nr < down_w.shape[0] or nc < down_w.shape[1]: print(f" [warn] Layer {layer_idx}: down_proj delta covers " f"{nr}x{nc} of {down_w.shape[0]}x{down_w.shape[1]} -- partial fusion") # gate_proj: multiplicative gamma if "gate_proj.weight" in new_weights: delta_gate = new_weights["gate_proj.weight"].float() nr = min(delta_gate.shape[0], gate_w.shape[0]) nc = min(delta_gate.shape[1], gate_w.shape[1]) gamma = 1.0 + delta_gate[:nr, :nc].clamp(-args.gamma_cap, args.gamma_cap) gate_w[:nr, :nc] = gate_w[:nr, :nc] * gamma # up_proj: intentionally untouched -- linear path must stay unchanged # --------------------------------------------------------------------------- # Single-file apply # --------------------------------------------------------------------------- def apply_single_file(model_path: Path, output_dir: Path, layer_files: dict, args) -> int: dry_run = args.dry_run print(f"\n[model] Processing file: {model_path.name}") # load_file uses memory-mapping — avoids reading the whole file into RAM twice tensors = load_file(str(model_path)) fused = 0 skipped = 0 for layer_idx, layer_path in sorted(layer_files.items()): layer_type = "global" if is_global_attention_layer(layer_idx) else "swa" new_weights = load_file(str(layer_path)) if not any(k in new_weights for k in PROJ_KEYS): print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. " f"Got: {list(new_weights.keys())}") skipped += 1 continue proj_model_keys = {} all_found = True for proj in PROJ_KEYS: prefix = detect_key_prefix(tensors.keys(), layer_idx, proj) model_key = f"{prefix}layers.{layer_idx}.mlp.{proj}" if model_key not in tensors: print(f" [skip] Key not found in model: {model_key!r}") all_found = False break proj_model_keys[proj] = model_key if not all_found: skipped += 1 continue gate_key = proj_model_keys["gate_proj.weight"] up_key = proj_model_keys["up_proj.weight"] down_key = proj_model_keys["down_proj.weight"] orig_gate_dtype = tensors[gate_key].dtype orig_down_dtype = tensors[down_key].dtype gate_w = tensors[gate_key].clone().float() up_w = tensors[up_key].clone().float() down_w = tensors[down_key].clone().float() if not dry_run: fuse_layer_deltas(layer_idx, gate_w, up_w, down_w, new_weights, args) tensors[gate_key] = gate_w.to(orig_gate_dtype) # up_w unchanged by SCF -- no write-back needed tensors[down_key] = down_w.to(orig_down_dtype) fused += 1 print(f" {'[dry]' if dry_run else '[ok]'} Fused layer {layer_idx:02d} [{layer_type}]" f" gate*gamma + down contoured (up unchanged)") if skipped > 0 and fused == 0: raise RuntimeError( f"No layers were fused -- all {skipped} layer(s) were skipped.\n" f"Sample model keys: {list(tensors.keys())[:4]}" ) if skipped > 0: print(f" [warn] {skipped} layer(s) skipped, {fused} fused.") if not dry_run: out_path = output_dir / model_path.name save_file(tensors, str(out_path)) print(f" Saved -> {out_path.resolve()}") return fused # --------------------------------------------------------------------------- # Sharded apply # --------------------------------------------------------------------------- def apply_sharded(model_dir: Path, output_dir: Path, layer_files: dict, args) -> int: dry_run = args.dry_run index_path = model_dir / "model.safetensors.index.json" if not index_path.exists(): raise FileNotFoundError(f"Sharded index missing: {index_path}") with open(index_path) as f: index = json.load(f) weight_map = index["weight_map"] # Per-projection fusion plan keyed by shard. # Each entry: (layer_idx, proj, model_key, delta_tensor, layer_type). # A layer whose projections span multiple shards will appear in several # shard buckets — one entry per projection — instead of being skipped. fusion_plan: dict = {} skipped = 0 for layer_idx, layer_path in sorted(layer_files.items()): layer_type = "global" if is_global_attention_layer(layer_idx) else "swa" new_weights = load_file(str(layer_path)) if not any(k in new_weights for k in PROJ_KEYS): print(f" [skip] Layer {layer_idx}: none of {PROJ_KEYS} found. " f"Got: {list(new_weights.keys())}") skipped += 1 continue proj_registered = 0 for proj in PROJ_KEYS: if proj not in new_weights: continue prefix = detect_key_prefix(weight_map.keys(), layer_idx, proj) model_key = f"{prefix}layers.{layer_idx}.mlp.{proj}" if model_key not in weight_map: print(f" [skip] Layer {layer_idx}: {model_key!r} not in weight_map") continue shard_name = weight_map[model_key] fusion_plan.setdefault(shard_name, []).append( (layer_idx, proj, model_key, new_weights[proj], layer_type) ) proj_registered += 1 if proj_registered == 0: skipped += 1 if not fusion_plan: sample = list(weight_map.keys())[:6] raise RuntimeError( f"No layers matched in weight_map. Sample keys: {sample}" ) # Identify which shards will be modified so we can copy non-modified files lazily. modified_shards = set(fusion_plan.keys()) if not dry_run: output_dir.mkdir(parents=True, exist_ok=True) # Copy all non-shard files (config, tokenizer, index, etc.) eagerly. # Shard files are copied individually just before they are modified, # avoiding a full model copy upfront that can exhaust RAM and disk I/O. for src_file in model_dir.iterdir(): dst_file = output_dir / src_file.name if src_file.name not in modified_shards: if src_file.is_dir(): shutil.copytree(src_file, dst_file, dirs_exist_ok=True) else: shutil.copy2(src_file, dst_file) # Copy unmodified shards (they just need to be present in the output). all_shards = {v for v in weight_map.values()} for shard_name in all_shards - modified_shards: src = model_dir / shard_name dst = output_dir / shard_name if src.exists() and not dst.exists(): shutil.copy2(src, dst) fused_layer_idxs: set = set() for shard_name, ops in sorted(fusion_plan.items()): shard_src = model_dir / shard_name shard_dst = output_dir / shard_name # load_file uses memory-mapped I/O — no full f.read() into RAM tensors = load_file(str(shard_src)) # Re-group by layer so fuse_layer_deltas is called once per layer per shard. by_layer: dict = {} for layer_idx, proj, model_key, delta, layer_type in ops: by_layer.setdefault(layer_idx, []).append((proj, model_key, delta, layer_type)) for layer_idx, proj_ops in sorted(by_layer.items()): layer_type = proj_ops[0][3] # Deltas restricted to projections whose tensors live in this shard. # fuse_layer_deltas gates every block on presence in new_weights, so # absent projections are never touched regardless of the tensor passed. partial_new_weights = {proj: delta for proj, _, delta, _ in proj_ops} # Build weight tensors for projections present in this shard; supply # an empty sentinel for absent slots — they are never accessed because # their keys are absent from partial_new_weights. proj_tensors = { proj: (model_key, tensors[model_key].clone().float()) for proj, model_key, _, _ in proj_ops } gate_w = proj_tensors.get("gate_proj.weight", (None, torch.empty(0)))[1] up_w = proj_tensors.get("up_proj.weight", (None, torch.empty(0)))[1] down_w = proj_tensors.get("down_proj.weight", (None, torch.empty(0)))[1] orig_dtypes = { proj: tensors[model_key].dtype for proj, model_key, _, _ in proj_ops } if not dry_run: fuse_layer_deltas(layer_idx, gate_w, up_w, down_w, partial_new_weights, args) for proj, model_key, _, _ in proj_ops: if proj == "gate_proj.weight": tensors[model_key] = gate_w.to(orig_dtypes[proj]) elif proj == "down_proj.weight": tensors[model_key] = down_w.to(orig_dtypes[proj]) # up_proj: SCF intentionally leaves it unchanged fused_layer_idxs.add(layer_idx) proj_names = [p.split(".")[0] for p, *_ in proj_ops] print(f" {'[dry]' if dry_run else '[ok]'} Fused layer {layer_idx:02d} [{layer_type}]" f" ({', '.join(proj_names)} in this shard)") if not dry_run: save_file(tensors, str(shard_dst)) print(f" [ok] Saved shard {shard_name} ({len(by_layer)} layer(s))") del tensors # free RAM before loading next shard if skipped > 0: print(f" [warn] {skipped} layer(s) fully skipped, " f"{len(fused_layer_idxs)} unique layer(s) fused.") return len(fused_layer_idxs) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Apply delta weights to a model via Shape-Contoured Fusion." ) parser.add_argument("--model", required=True) parser.add_argument("--weights", required=True) parser.add_argument("--output", required=True) parser.add_argument("--layers", type=int, nargs="+", default=None) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--alpha", type=float, default=0.02, help="down-proj variance scale multiplier (default: 0.02)") parser.add_argument("--gamma-cap", type=float, default=0.05, help="max fractional gate_proj adjustment (default: 0.05)") args = parser.parse_args() model_path = Path(args.model) weights_dir = Path(args.weights) output_dir = Path(args.output) layer_files = discover_generated_layers(weights_dir) if not layer_files: raise FileNotFoundError( f"No layer_*.safetensors files found in: {weights_dir.resolve()}" ) if args.layers is not None: layer_files = {i: layer_files[i] for i in args.layers if i in layer_files} if not layer_files: available = sorted(discover_generated_layers(weights_dir).keys()) raise ValueError(f"--layers filter empty. Available: {available}") print(f"[info] Found {len(layer_files)} layer file(s): indices {sorted(layer_files.keys())}") print(f"[info] SCF params: alpha={args.alpha}, gamma_cap={args.gamma_cap}") if not args.dry_run: output_dir.mkdir(parents=True, exist_ok=True) if model_path.is_file() and model_path.suffix == ".safetensors": apply_single_file(model_path, output_dir, layer_files, args) elif model_path.is_dir(): single = model_path / "model.safetensors" index = model_path / "model.safetensors.index.json" if single.exists() and not index.exists(): if not args.dry_run: for f in model_path.iterdir(): if f.name != "model.safetensors": dst = output_dir / f.name if f.is_dir(): shutil.copytree(f, dst, dirs_exist_ok=True) else: shutil.copy2(f, dst) apply_single_file(single, output_dir, layer_files, args) elif index.exists(): apply_sharded(model_path, output_dir, layer_files, args) else: raise FileNotFoundError( f"No model.safetensors or model.safetensors.index.json in {model_path}" ) else: raise FileNotFoundError(f"--model not found: {model_path}") config_path = ( model_path / "config.json" if model_path.is_dir() else model_path.parent / "config.json" ) if config_path.exists() and not args.dry_run: shutil.copy2(config_path, output_dir / "config.json") print(" [ok] Copied config.json (activation unchanged).") if __name__ == "__main__": main()