Spaces:
Running
Running
Upload 8 files
Browse files- __init__.py +36 -0
- donor_audit_v3.py +245 -0
- eos_scanner.py +17 -1
- llama.py +59 -0
- model_tools.md +14 -1
- moe_defs.py +197 -0
- tokeninspector.py +135 -0
- tokensurgeon.py +867 -0
__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from mergekit.moe.arch import MoEOutputArchitecture
|
| 4 |
+
from mergekit.moe.deepseek import DeepseekMoE
|
| 5 |
+
from mergekit.moe.mixtral import MixtralMoE
|
| 6 |
+
|
| 7 |
+
ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()]
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from mergekit.moe.qwen import QwenMoE
|
| 11 |
+
except ImportError:
|
| 12 |
+
pass
|
| 13 |
+
else:
|
| 14 |
+
ALL_OUTPUT_ARCHITECTURES.append(QwenMoE())
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from mergekit.moe.qwen3 import Qwen3MoE
|
| 18 |
+
except ImportError:
|
| 19 |
+
pass
|
| 20 |
+
else:
|
| 21 |
+
ALL_OUTPUT_ARCHITECTURES.append(Qwen3MoE())
|
| 22 |
+
|
| 23 |
+
# --- ADD THIS SECTION START ---
|
| 24 |
+
try:
|
| 25 |
+
from mergekit.moe.llama import LlamaMoE
|
| 26 |
+
except ImportError:
|
| 27 |
+
# This will trigger if llama.py is missing or has a syntax error
|
| 28 |
+
pass
|
| 29 |
+
else:
|
| 30 |
+
ALL_OUTPUT_ARCHITECTURES.append(LlamaMoE())
|
| 31 |
+
# --- ADD THIS SECTION END ---
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"ALL_OUTPUT_ARCHITECTURES",
|
| 35 |
+
"MoEOutputArchitecture",
|
| 36 |
+
]
|
donor_audit_v3.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Arcee AI & Kraken Architect
|
| 2 |
+
# SPDX-License-Identifier: BUSL-1.1
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import click
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from mergekit.common import ModelReference
|
| 15 |
+
from mergekit.config import MergeConfiguration
|
| 16 |
+
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
|
| 17 |
+
from mergekit.merge_methods.easy_define import merge_method
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 20 |
+
LOG = logging.getLogger("donor_audit")
|
| 21 |
+
|
| 22 |
+
@merge_method(
|
| 23 |
+
name="donor_audit",
|
| 24 |
+
pretty_name="Donor Audit",
|
| 25 |
+
reference_url="https://arxiv.org/abs/2408.07990",
|
| 26 |
+
)
|
| 27 |
+
def _donor_audit_registration(tensors: List[torch.Tensor]) -> torch.Tensor:
|
| 28 |
+
"""Placeholder to register the method name."""
|
| 29 |
+
return tensors[0]
|
| 30 |
+
|
| 31 |
+
def rsce_weight(tvs: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Calculates matrix-level weights based on the energy of the task vectors.
|
| 34 |
+
(Copied from RSCE v3)
|
| 35 |
+
"""
|
| 36 |
+
# Mean square energy
|
| 37 |
+
weights = torch.mean(tvs**2, dim=list(range(1, tvs.dim())))
|
| 38 |
+
weight_sum = torch.sum(weights).item()
|
| 39 |
+
if abs(weight_sum) < 1e-8:
|
| 40 |
+
return torch.ones_like(weights) / weights.shape[0]
|
| 41 |
+
return weights / weight_sum
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def log_rsce_audit(layer_name: str, weights: torch.Tensor, names: List[str]):
|
| 45 |
+
"""Prints and saves a bar chart of donor influence."""
|
| 46 |
+
w_list = weights.tolist()
|
| 47 |
+
bar_char = "█"
|
| 48 |
+
|
| 49 |
+
# Header
|
| 50 |
+
print(f"\n{'='*60}")
|
| 51 |
+
print(f"RSCE DONOR AUDIT REPORT")
|
| 52 |
+
print(f"Target Tensor: {layer_name}")
|
| 53 |
+
print(f"{'='*60}")
|
| 54 |
+
|
| 55 |
+
lines = []
|
| 56 |
+
for name, w in zip(names, w_list):
|
| 57 |
+
pct = w * 100
|
| 58 |
+
# Scale bar: 50 chars = 100% influence (which is huge/impossible usually)
|
| 59 |
+
# Let's scale it so the max value fills the bar for better visibility
|
| 60 |
+
max_val = max(w_list) if max(w_list) > 0 else 1.0
|
| 61 |
+
|
| 62 |
+
# Relative bar length (relative to the loudest model)
|
| 63 |
+
bar_len = int((w / max_val) * 40)
|
| 64 |
+
bar = bar_char * bar_len
|
| 65 |
+
|
| 66 |
+
# Truncate name for clean display
|
| 67 |
+
clean_name = os.path.basename(name)
|
| 68 |
+
if len(clean_name) > 30:
|
| 69 |
+
clean_name = clean_name[:27] + "..."
|
| 70 |
+
|
| 71 |
+
lines.append(f"{clean_name:<30} | {bar:<40} | {pct:6.2f}% (Raw: {w:.4f})")
|
| 72 |
+
|
| 73 |
+
log_entry = "\n".join(lines)
|
| 74 |
+
print(log_entry)
|
| 75 |
+
print(f"{'='*60}\n")
|
| 76 |
+
|
| 77 |
+
# Append to file
|
| 78 |
+
with open("rsce_audit.log", "a", encoding="utf-8") as f:
|
| 79 |
+
f.write(f"\n[Audit {layer_name}]\n" + log_entry + "\n")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def find_layer0_tensor(loader: LazyTensorLoader) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Scans a model loader to find a suitable Layer 0 tensor for auditing.
|
| 85 |
+
Prioritizes self_attn projections as they are usually dense and representative.
|
| 86 |
+
"""
|
| 87 |
+
candidates = []
|
| 88 |
+
for key in loader.index.tensor_paths.keys():
|
| 89 |
+
# Look for Layer 0
|
| 90 |
+
if ".layers.0." in key or ".h.0." in key or ".blocks.0." in key:
|
| 91 |
+
# Look for weights (not bias)
|
| 92 |
+
if key.endswith(".weight"):
|
| 93 |
+
candidates.append(key)
|
| 94 |
+
|
| 95 |
+
# Priority sort: q_proj > gate_proj > dense > others
|
| 96 |
+
for c in candidates:
|
| 97 |
+
if "down_proj" in c: return c
|
| 98 |
+
for c in candidates:
|
| 99 |
+
if "gate_proj" in c: return c
|
| 100 |
+
for c in candidates:
|
| 101 |
+
if "c_attn" in c: return c # GPT-NeoX / Qwen
|
| 102 |
+
|
| 103 |
+
if not candidates:
|
| 104 |
+
raise RuntimeError("Could not find any Layer 0 weights in the base model.")
|
| 105 |
+
|
| 106 |
+
return candidates[0]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_tensor_safe(model_path: str, tensor_name: str, device="cpu") -> torch.Tensor:
|
| 110 |
+
"""Loads a single tensor from a model path."""
|
| 111 |
+
try:
|
| 112 |
+
# We use ShardedTensorIndex directly to avoid caching overhead of LoaderCache for this simple script
|
| 113 |
+
if os.path.isfile(model_path):
|
| 114 |
+
index = ShardedTensorIndex.from_file(model_path)
|
| 115 |
+
else:
|
| 116 |
+
index = ShardedTensorIndex.from_disk(model_path)
|
| 117 |
+
loader = LazyTensorLoader(index, lazy_unpickle=True)
|
| 118 |
+
|
| 119 |
+
# Handle potential naming mismatches (simple check)
|
| 120 |
+
if tensor_name not in index.tensor_paths:
|
| 121 |
+
# Try to find a fuzzy match if exact name fails (e.g. if models have slightly different archs)
|
| 122 |
+
# This is a basic fallback
|
| 123 |
+
suffix = tensor_name.split("layers.0.")[-1]
|
| 124 |
+
for k in index.tensor_paths.keys():
|
| 125 |
+
if k.endswith(suffix) and ("layers.0." in k or "h.0." in k):
|
| 126 |
+
tensor_name = k
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
t = loader.get_tensor(tensor_name, device=device)
|
| 130 |
+
return t.float() # Convert to float32 for math
|
| 131 |
+
except Exception as e:
|
| 132 |
+
LOG.error(f"Failed to load {tensor_name} from {model_path}: {e}")
|
| 133 |
+
sys.exit(1)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@click.command()
|
| 137 |
+
@click.argument("config_file", type=click.Path(exists=True))
|
| 138 |
+
@click.option("--lora-merge-cache", default=None, help="Cache directory for merged LoRAs")
|
| 139 |
+
@click.option("--cuda/--no-cuda", default=False, help="Use GPU for calculation (faster math, higher VRAM)")
|
| 140 |
+
def main(config_file, lora_merge_cache, cuda):
|
| 141 |
+
"""
|
| 142 |
+
RSCE Donor Audit Tool V3.
|
| 143 |
+
|
| 144 |
+
Loads Layer 0 from all models in the config and calculates their
|
| 145 |
+
Task Vector magnitude/energy contribution relative to the base model.
|
| 146 |
+
"""
|
| 147 |
+
device = "cuda" if cuda and torch.cuda.is_available() else "cpu"
|
| 148 |
+
LOG.info(f"Running audit on {device}...")
|
| 149 |
+
|
| 150 |
+
# 1. Parse Config
|
| 151 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 152 |
+
config_data = yaml.safe_load(f)
|
| 153 |
+
config = MergeConfiguration.model_validate(config_data)
|
| 154 |
+
|
| 155 |
+
# 2. Identify Models
|
| 156 |
+
base_model_ref = config.base_model
|
| 157 |
+
if not base_model_ref:
|
| 158 |
+
LOG.error("Config must specify a `base_model` for RSCE auditing.")
|
| 159 |
+
sys.exit(1)
|
| 160 |
+
|
| 161 |
+
# Extract donor models from slices or models list
|
| 162 |
+
donor_refs = []
|
| 163 |
+
if config.models:
|
| 164 |
+
donor_refs = [m.model for m in config.models]
|
| 165 |
+
elif config.slices:
|
| 166 |
+
# Flatten slices to get unique models
|
| 167 |
+
seen = set()
|
| 168 |
+
for s in config.slices:
|
| 169 |
+
for source in s.sources:
|
| 170 |
+
if source.model != base_model_ref and source.model not in seen:
|
| 171 |
+
donor_refs.append(source.model)
|
| 172 |
+
seen.add(source.model)
|
| 173 |
+
|
| 174 |
+
# Filter out base model if it appeared in donors
|
| 175 |
+
donor_refs = [d for d in donor_refs if d != base_model_ref]
|
| 176 |
+
|
| 177 |
+
LOG.info(f"Base Model: {base_model_ref.model.path}")
|
| 178 |
+
LOG.info(f"Found {len(donor_refs)} donor models.")
|
| 179 |
+
|
| 180 |
+
# 3. Resolve Paths (Handle LoRAs if necessary)
|
| 181 |
+
def resolve_path(ref: ModelReference):
|
| 182 |
+
if ref.lora:
|
| 183 |
+
if not lora_merge_cache:
|
| 184 |
+
LOG.warning("LoRA detected but --lora-merge-cache not set. This might fail.")
|
| 185 |
+
return ref.merged(cache_dir=lora_merge_cache).model.path
|
| 186 |
+
|
| 187 |
+
if not os.path.exists(ref.model.path):
|
| 188 |
+
try:
|
| 189 |
+
from huggingface_hub import snapshot_download
|
| 190 |
+
return snapshot_download(ref.model.path, allow_patterns=["*.safetensors", "*.bin", "*.json"])
|
| 191 |
+
except:
|
| 192 |
+
return ref.model.path
|
| 193 |
+
return ref.model.path
|
| 194 |
+
|
| 195 |
+
base_path = resolve_path(base_model_ref)
|
| 196 |
+
donor_paths = [resolve_path(d) for d in donor_refs]
|
| 197 |
+
|
| 198 |
+
# 4. Identify Target Tensor (Layer 0)
|
| 199 |
+
base_index = ShardedTensorIndex.from_disk(base_path)
|
| 200 |
+
base_loader = LazyTensorLoader(base_index, lazy_unpickle=True)
|
| 201 |
+
target_tensor_name = find_layer0_tensor(base_loader)
|
| 202 |
+
|
| 203 |
+
LOG.info(f"Selected audit tensor: {target_tensor_name}")
|
| 204 |
+
LOG.info("Loading tensors into memory...")
|
| 205 |
+
|
| 206 |
+
# 5. Load All Tensors
|
| 207 |
+
base_tensor = load_tensor_safe(base_path, target_tensor_name, device)
|
| 208 |
+
|
| 209 |
+
donor_tensors = []
|
| 210 |
+
valid_donor_refs = []
|
| 211 |
+
|
| 212 |
+
for d_path, d_ref in zip(tqdm(donor_paths, desc="Loading Donors"), donor_refs):
|
| 213 |
+
dt = load_tensor_safe(d_path, target_tensor_name, device)
|
| 214 |
+
|
| 215 |
+
# V3: Catch shape mismatches (e.g. a 7B model mixed into a 12B merge)
|
| 216 |
+
if dt.shape != base_tensor.shape:
|
| 217 |
+
LOG.warning(f"\n[!] Shape mismatch for {d_ref.model.path}: expected {base_tensor.shape}, got {dt.shape}. Skipping this model.")
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
donor_tensors.append(dt)
|
| 221 |
+
valid_donor_refs.append(d_ref)
|
| 222 |
+
|
| 223 |
+
if not donor_tensors:
|
| 224 |
+
LOG.error("No valid donor tensors found with matching shapes. Exiting.")
|
| 225 |
+
sys.exit(1)
|
| 226 |
+
|
| 227 |
+
# 6. Perform RSCE Audit Math
|
| 228 |
+
LOG.info("Calculating Task Vector Energy...")
|
| 229 |
+
|
| 230 |
+
base_tv = torch.zeros_like(base_tensor)
|
| 231 |
+
donor_tvs = [dt - base_tensor for dt in donor_tensors]
|
| 232 |
+
|
| 233 |
+
all_tvs = torch.stack([base_tv] + donor_tvs, dim=0)
|
| 234 |
+
raw_weights = rsce_weight(all_tvs)
|
| 235 |
+
|
| 236 |
+
display_names = ["Base Model (Anchor)"] + [d.model.path for d in valid_donor_refs]
|
| 237 |
+
|
| 238 |
+
# 7. Output
|
| 239 |
+
log_rsce_audit(target_tensor_name, raw_weights, display_names)
|
| 240 |
+
|
| 241 |
+
LOG.info("Audit complete.")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
eos_scanner.py
CHANGED
|
@@ -31,9 +31,25 @@ def load_json(path):
|
|
| 31 |
return None
|
| 32 |
|
| 33 |
def get_model_metadata(model_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
data = {
|
| 35 |
"path": model_path,
|
| 36 |
-
"name":
|
| 37 |
"gen_eos_id": "MISSING", # From generation_config.json
|
| 38 |
"tok_eos_str": "MISSING", # From tokenizer_config.json
|
| 39 |
"vocab_eos_id": "MISSING", # The actual ID of the string in tokenizer.json
|
|
|
|
| 31 |
return None
|
| 32 |
|
| 33 |
def get_model_metadata(model_path):
|
| 34 |
+
# --- NAME FIX LOGIC START ---
|
| 35 |
+
# Normalize path to handle trailing slashes or mixed separators
|
| 36 |
+
norm_path = os.path.normpath(model_path)
|
| 37 |
+
base_name = os.path.basename(norm_path)
|
| 38 |
+
|
| 39 |
+
# If the folder is named "fixed", grab the parent folder name instead
|
| 40 |
+
if base_name == "fixed":
|
| 41 |
+
parent_name = os.path.basename(os.path.dirname(norm_path))
|
| 42 |
+
display_name = f"{parent_name}/fixed"
|
| 43 |
+
else:
|
| 44 |
+
display_name = base_name
|
| 45 |
+
|
| 46 |
+
# Clean up the huggingface cache prefix
|
| 47 |
+
display_name = display_name.replace("!models--", "")
|
| 48 |
+
# --- NAME FIX LOGIC END ---
|
| 49 |
+
|
| 50 |
data = {
|
| 51 |
"path": model_path,
|
| 52 |
+
"name": display_name,
|
| 53 |
"gen_eos_id": "MISSING", # From generation_config.json
|
| 54 |
"tok_eos_str": "MISSING", # From tokenizer_config.json
|
| 55 |
"vocab_eos_id": "MISSING", # The actual ID of the string in tokenizer.json
|
llama.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tqdm
|
| 3 |
+
import transformers
|
| 4 |
+
from mergekit.moe.arch import MoEOutputArchitecture
|
| 5 |
+
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
|
| 6 |
+
from mergekit.moe.config import MoEMergeConfig
|
| 7 |
+
from mergekit.options import MergeOptions
|
| 8 |
+
from mergekit.architecture import arch_info_for_config
|
| 9 |
+
|
| 10 |
+
class LlamaMoE(MoEOutputArchitecture):
|
| 11 |
+
def name(self) -> str:
|
| 12 |
+
return "LlamaMoE"
|
| 13 |
+
|
| 14 |
+
def supports_config(self, config: MoEMergeConfig, explain: bool = False, trust_remote_code: bool = False) -> bool:
|
| 15 |
+
# Ensure the base model is a Llama model
|
| 16 |
+
model_cfg = config.base_model.config(trust_remote_code=trust_remote_code)
|
| 17 |
+
if model_cfg.model_type != "llama":
|
| 18 |
+
if explain:
|
| 19 |
+
print("LlamaMoE only supports Llama base models")
|
| 20 |
+
return False
|
| 21 |
+
return True
|
| 22 |
+
|
| 23 |
+
def write_model(self, out_path: str, config: MoEMergeConfig, merge_options: MergeOptions, router_weights: list[torch.Tensor], shared_router_weights=None):
|
| 24 |
+
base_model = config.base_model
|
| 25 |
+
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
|
| 26 |
+
|
| 27 |
+
# 1. Generate the config.json
|
| 28 |
+
out_cfg = base_cfg.to_dict()
|
| 29 |
+
# Note: Most Llama MoEs use the Mixtral architecture name for compatibility with loaders
|
| 30 |
+
out_cfg["architectures"] = ["MixtralForCausalLM"]
|
| 31 |
+
out_cfg["num_local_experts"] = len(config.experts)
|
| 32 |
+
out_cfg["num_experts_per_tok"] = config.experts_per_token
|
| 33 |
+
|
| 34 |
+
out_dtype = select_dtype(config, base_cfg)
|
| 35 |
+
|
| 36 |
+
# 2. Initialize IO
|
| 37 |
+
loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
|
| 38 |
+
|
| 39 |
+
# 3. Map Tensors
|
| 40 |
+
for weight_info in tqdm.tqdm(arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights"):
|
| 41 |
+
tensor_name = weight_info.name
|
| 42 |
+
if ".mlp." in tensor_name:
|
| 43 |
+
for expert_idx, expert in enumerate(config.experts):
|
| 44 |
+
# Map Llama's gate_proj/up_proj/down_proj to Mixtral's w1/w3/w2
|
| 45 |
+
expert_name = tensor_name.replace(".mlp.gate_proj", f".block_sparse_moe.experts.{expert_idx}.w1")
|
| 46 |
+
expert_name = expert_name.replace(".mlp.down_proj", f".block_sparse_moe.experts.{expert_idx}.w2")
|
| 47 |
+
expert_name = expert_name.replace(".mlp.up_proj", f".block_sparse_moe.experts.{expert_idx}.w3")
|
| 48 |
+
|
| 49 |
+
expert_loader = loaders.get(expert.source_model)
|
| 50 |
+
copy_tensor_out(weight_info, expert_loader, writer, expert=expert, output_name=expert_name, out_dtype=out_dtype)
|
| 51 |
+
else:
|
| 52 |
+
# Copy Attention and Norms from base model
|
| 53 |
+
copy_tensor_out(weight_info, base_loader, writer, out_dtype=out_dtype)
|
| 54 |
+
|
| 55 |
+
# 4. Write Router Weights
|
| 56 |
+
for layer_idx, weight in enumerate(router_weights):
|
| 57 |
+
writer.save_tensor(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", weight.to(dtype=out_dtype))
|
| 58 |
+
|
| 59 |
+
writer.finalize()
|
model_tools.md
CHANGED
|
@@ -32,8 +32,21 @@ Tools to enhance LLM quantizations and merging
|
|
| 32 |
# [metadata_audit.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/metadata_audit.py)
|
| 33 |
- Checks multiple models within subdirectories for vocab or rope mismatch (useful for large merges). Calibrated for Mistral Nemo 12B by default.
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# [eos_scanner.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner.py)
|
| 36 |
-
- This tool scans the tokenizer jsons to detect any mismatches with EOS tokens, which cause early termination bugs. You can then use the [gen_id_patcher.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/gen_id_patcher.py) to patch missing `generation_config.json` files for EOS token. See [this post](https://huggingface.co/Naphula/Q0_Bench/discussions/1?not-for-all-audiences=true#6987717c762f0a45f672e250) as well as the [EOS Scanner ReadMe](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner_readme.md) for more info.
|
| 37 |
|
| 38 |
# [weight_counter.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/weight_counter.py)
|
| 39 |
- This counts the number of models in a yaml and adds up the total weight values. Useful for large della/ties merges.
|
|
|
|
| 32 |
# [metadata_audit.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/metadata_audit.py)
|
| 33 |
- Checks multiple models within subdirectories for vocab or rope mismatch (useful for large merges). Calibrated for Mistral Nemo 12B by default.
|
| 34 |
|
| 35 |
+
# llama moe
|
| 36 |
+
- Add support for Llama Mixture of Experts. If you want to merge custom Llama MoE you can add these scripts to your mergekit environment:
|
| 37 |
+
- [mergekit-main\mergekit\architecture\moe_defs.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/moe_defs.py)
|
| 38 |
+
- [mergekit-main\mergekit\__init__.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/__init__.py)
|
| 39 |
+
- [mergekit-main\mergekit\moe\llama.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/llama.py)
|
| 40 |
+
- Then assign the num_experts_per_tok in config.json (or the config.yaml)
|
| 41 |
+
|
| 42 |
+
# [tokensurgeon.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/tokensurgeon.py)
|
| 43 |
+
- Uses adaptive VRAM from Grim Jim's `measure.py` like `graph_v18` to prevent OOM. Use recommended [batch file](https://huggingface.co/spaces/Naphula/model_tools/blob/main/fix_tokenizers.bat) here or modify sh. This supposedly avoids 'cardboard town' fake patches.
|
| 44 |
+
|
| 45 |
+
# [tokeninspector.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/tokensurgeon.py)
|
| 46 |
+
- Audit your tokensurgeon results.
|
| 47 |
+
|
| 48 |
# [eos_scanner.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner.py)
|
| 49 |
+
- Updated! This tool scans the tokenizer jsons to detect any mismatches with EOS tokens, which cause early termination bugs. You can then use the [gen_id_patcher.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/gen_id_patcher.py) to patch missing `generation_config.json` files for EOS token. See [this post](https://huggingface.co/Naphula/Q0_Bench/discussions/1?not-for-all-audiences=true#6987717c762f0a45f672e250) as well as the [EOS Scanner ReadMe](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner_readme.md) for more info.
|
| 50 |
|
| 51 |
# [weight_counter.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/weight_counter.py)
|
| 52 |
- This counts the number of models in a yaml and adds up the total weight values. Useful for large della/ties merges.
|
moe_defs.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Arcee AI
|
| 2 |
+
# SPDX-License-Identifier: LGPL-3.0-only
|
| 3 |
+
|
| 4 |
+
from typing import ClassVar, List, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from transformers import PretrainedConfig
|
| 8 |
+
|
| 9 |
+
from mergekit.architecture.base import (
|
| 10 |
+
ModuleArchitecture,
|
| 11 |
+
WeightInfo,
|
| 12 |
+
)
|
| 13 |
+
from mergekit.architecture.json_definitions import NAME_TO_ARCH
|
| 14 |
+
|
| 15 |
+
MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0]
|
| 16 |
+
MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MixtralModuleArchitecture(ModuleArchitecture, BaseModel):
|
| 20 |
+
ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
|
| 21 |
+
num_local_experts: int
|
| 22 |
+
|
| 23 |
+
def name(self) -> str:
|
| 24 |
+
return "mixtral"
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def from_config(cls, config: PretrainedConfig):
|
| 28 |
+
return MixtralModuleArchitecture(num_local_experts=config.num_local_experts)
|
| 29 |
+
|
| 30 |
+
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 31 |
+
return MISTRAL_MODULE_ARCH.pre_weights(config)
|
| 32 |
+
|
| 33 |
+
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 34 |
+
return MISTRAL_MODULE_ARCH.post_weights(config)
|
| 35 |
+
|
| 36 |
+
def num_layers_config_key(self) -> str:
|
| 37 |
+
return MISTRAL_MODULE_ARCH.num_layers_config_key()
|
| 38 |
+
|
| 39 |
+
def layer_weights(
|
| 40 |
+
self, index: int, config: PretrainedConfig
|
| 41 |
+
) -> Optional[List[WeightInfo]]:
|
| 42 |
+
num_experts = self.num_local_experts
|
| 43 |
+
prefix = f"model.layers.{index}"
|
| 44 |
+
tensor_names = []
|
| 45 |
+
for expert_idx in range(num_experts):
|
| 46 |
+
for param in ("w1", "w2", "w3"):
|
| 47 |
+
tensor_names.append(
|
| 48 |
+
prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
|
| 49 |
+
)
|
| 50 |
+
tensor_names.append(prefix + ".block_sparse_moe.gate.weight")
|
| 51 |
+
res = []
|
| 52 |
+
for name in tensor_names:
|
| 53 |
+
res.append(WeightInfo(name=name))
|
| 54 |
+
for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config):
|
| 55 |
+
if ".mlp." in weight_info.name:
|
| 56 |
+
continue
|
| 57 |
+
res.append(weight_info)
|
| 58 |
+
return res
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0]
|
| 62 |
+
QWEN3_MODULE_ARCH = QWEN3_INFO.modules["default"].architecture
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Qwen3MoeModuleArchitecture(ModuleArchitecture, BaseModel):
|
| 66 |
+
ARCHITECTURE_NAME: ClassVar[str] = "Qwen3MoeForCausalLM"
|
| 67 |
+
num_experts: int
|
| 68 |
+
|
| 69 |
+
def name(self) -> str:
|
| 70 |
+
return "qwen3_moe"
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def from_config(cls, config: PretrainedConfig):
|
| 74 |
+
return Qwen3MoeModuleArchitecture(num_experts=config.num_experts)
|
| 75 |
+
|
| 76 |
+
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 77 |
+
return QWEN3_MODULE_ARCH.pre_weights(config)
|
| 78 |
+
|
| 79 |
+
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 80 |
+
return QWEN3_MODULE_ARCH.post_weights(config)
|
| 81 |
+
|
| 82 |
+
def num_layers_config_key(self) -> str:
|
| 83 |
+
return QWEN3_MODULE_ARCH.num_layers_config_key()
|
| 84 |
+
|
| 85 |
+
def layer_weights(
|
| 86 |
+
self, index: int, config: PretrainedConfig
|
| 87 |
+
) -> Optional[List[WeightInfo]]:
|
| 88 |
+
prefix = f"model.layers.{index}"
|
| 89 |
+
tensor_names = []
|
| 90 |
+
for expert_idx in range(self.num_experts):
|
| 91 |
+
for param in ("up_proj", "gate_proj", "down_proj"):
|
| 92 |
+
tensor_names.append(
|
| 93 |
+
prefix + f".mlp.experts.{expert_idx}.{param}.weight"
|
| 94 |
+
)
|
| 95 |
+
tensor_names.append(prefix + ".mlp.gate.weight")
|
| 96 |
+
res = []
|
| 97 |
+
for name in tensor_names:
|
| 98 |
+
res.append(WeightInfo(name=name))
|
| 99 |
+
for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config):
|
| 100 |
+
if ".mlp." in weight_info.name:
|
| 101 |
+
continue
|
| 102 |
+
res.append(weight_info)
|
| 103 |
+
return res
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
AFMOE_PARTIAL_INFO = NAME_TO_ARCH["_AfmoePartialForCausalLM"][0]
|
| 107 |
+
AFMOE_PARTIAL_MODULE_ARCH = AFMOE_PARTIAL_INFO.modules["default"].architecture
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class AfmoeModuleArchitecture(ModuleArchitecture, BaseModel):
|
| 111 |
+
ARCHITECTURE_NAME: ClassVar[str] = "AfmoeForCausalLM"
|
| 112 |
+
num_experts: int
|
| 113 |
+
|
| 114 |
+
def name(self) -> str:
|
| 115 |
+
return "afmoe"
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def from_config(cls, config: PretrainedConfig):
|
| 119 |
+
return AfmoeModuleArchitecture(num_experts=config.num_experts)
|
| 120 |
+
|
| 121 |
+
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 122 |
+
return AFMOE_PARTIAL_MODULE_ARCH.pre_weights(config)
|
| 123 |
+
|
| 124 |
+
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 125 |
+
return AFMOE_PARTIAL_MODULE_ARCH.post_weights(config)
|
| 126 |
+
|
| 127 |
+
def num_layers_config_key(self) -> str:
|
| 128 |
+
return AFMOE_PARTIAL_MODULE_ARCH.num_layers_config_key()
|
| 129 |
+
|
| 130 |
+
def layer_weights(
|
| 131 |
+
self, index: int, config: PretrainedConfig
|
| 132 |
+
) -> Optional[List[WeightInfo]]:
|
| 133 |
+
res = AFMOE_PARTIAL_MODULE_ARCH.layer_weights(index, config) or []
|
| 134 |
+
prefix = f"model.layers.{index}"
|
| 135 |
+
for expert_idx in range(self.num_experts):
|
| 136 |
+
for param in ("up_proj", "gate_proj", "down_proj"):
|
| 137 |
+
res.append(
|
| 138 |
+
WeightInfo(
|
| 139 |
+
name=prefix + f".mlp.experts.{expert_idx}.{param}.weight",
|
| 140 |
+
optional=True,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
return res
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Add this to moe_defs.py
|
| 147 |
+
|
| 148 |
+
# 1. Get the base Llama info from the registry
|
| 149 |
+
LLAMA_INFO = NAME_TO_ARCH["LlamaForCausalLM"][0]
|
| 150 |
+
LLAMA_MODULE_ARCH = LLAMA_INFO.modules["default"].architecture
|
| 151 |
+
|
| 152 |
+
class LlamaMoeModuleArchitecture(ModuleArchitecture, BaseModel):
|
| 153 |
+
# This is the name that will appear in the output config.json
|
| 154 |
+
ARCHITECTURE_NAME: ClassVar[str] = "LlamaMoeForCausalLM"
|
| 155 |
+
num_experts: int
|
| 156 |
+
|
| 157 |
+
def name(self) -> str:
|
| 158 |
+
return "llama_moe"
|
| 159 |
+
|
| 160 |
+
@classmethod
|
| 161 |
+
def from_config(cls, config: PretrainedConfig):
|
| 162 |
+
# This looks for the 'num_experts' key in the model's config
|
| 163 |
+
return LlamaMoeModuleArchitecture(num_experts=getattr(config, "num_experts", 8))
|
| 164 |
+
|
| 165 |
+
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 166 |
+
# Uses standard Llama embeddings/norms
|
| 167 |
+
return LLAMA_MODULE_ARCH.pre_weights(config)
|
| 168 |
+
|
| 169 |
+
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
|
| 170 |
+
# Uses standard Llama final norm/head
|
| 171 |
+
return LLAMA_MODULE_ARCH.post_weights(config)
|
| 172 |
+
|
| 173 |
+
def num_layers_config_key(self) -> str:
|
| 174 |
+
return LLAMA_MODULE_ARCH.num_layers_config_key()
|
| 175 |
+
|
| 176 |
+
def layer_weights(self, index: int, config: PretrainedConfig) -> Optional[List[WeightInfo]]:
|
| 177 |
+
prefix = f"model.layers.{index}"
|
| 178 |
+
res = []
|
| 179 |
+
|
| 180 |
+
# 2. Define the Expert weights
|
| 181 |
+
# We map the dense MLP layers into an expert array
|
| 182 |
+
for expert_idx in range(self.num_experts):
|
| 183 |
+
for param in ("gate_proj", "up_proj", "down_proj"):
|
| 184 |
+
res.append(
|
| 185 |
+
WeightInfo(name=prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight")
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# 3. Define the Router (Gate) weight
|
| 189 |
+
res.append(WeightInfo(name=prefix + ".block_sparse_moe.gate.weight"))
|
| 190 |
+
|
| 191 |
+
# 4. Add the non-MLP weights (Attention layers, Input Norms)
|
| 192 |
+
# We skip the original .mlp. weights because we replaced them with experts
|
| 193 |
+
for weight_info in LLAMA_MODULE_ARCH.layer_weights(index, config):
|
| 194 |
+
if ".mlp." not in weight_info.name:
|
| 195 |
+
res.append(weight_info)
|
| 196 |
+
|
| 197 |
+
return res
|
tokeninspector.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python tokeninspector.py "B:\12B\models--mistralai--Mistral-Nemo-Instruct-2407" "B:\12B\models--aixonlab--Aether-12b.backup" "B:\12B\models--aixonlab--Aether-12b"
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import click
|
| 5 |
+
import torch
|
| 6 |
+
import transformers
|
| 7 |
+
from mergekit.io.lazy_tensor_loader import LazyTensorLoader
|
| 8 |
+
|
| 9 |
+
def get_embed_tensor(model_path):
|
| 10 |
+
"""Lazily loads the embedding tensor from a model directory."""
|
| 11 |
+
try:
|
| 12 |
+
loader = LazyTensorLoader.from_disk(model_path)
|
| 13 |
+
for key in loader.index.tensor_paths.keys():
|
| 14 |
+
if "embed_tokens.weight" in key or "wte.weight" in key:
|
| 15 |
+
return loader.get_tensor(key)
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f" [!] Error loading tensors from {model_path}: {e}")
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
@click.command()
|
| 21 |
+
@click.argument("base_model", type=click.Path(exists=True))
|
| 22 |
+
@click.argument("donor_model", type=click.Path(exists=True))
|
| 23 |
+
@click.argument("output_model", type=click.Path(exists=True))
|
| 24 |
+
def main(base_model, donor_model, output_model):
|
| 25 |
+
print("="*60)
|
| 26 |
+
print("🔍 TOKEN SURGEON AUDIT TOOL")
|
| 27 |
+
print("="*60)
|
| 28 |
+
|
| 29 |
+
print("\n[1] Loading Tokenizers...")
|
| 30 |
+
tok_base = transformers.AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 31 |
+
tok_donor = transformers.AutoTokenizer.from_pretrained(donor_model, trust_remote_code=True)
|
| 32 |
+
tok_out = transformers.AutoTokenizer.from_pretrained(output_model, trust_remote_code=True)
|
| 33 |
+
|
| 34 |
+
print(f" Base: {len(tok_base)} tokens")
|
| 35 |
+
print(f" Donor: {len(tok_donor)} tokens")
|
| 36 |
+
print(f" Output: {len(tok_out)} tokens")
|
| 37 |
+
|
| 38 |
+
if len(tok_out) != len(tok_donor):
|
| 39 |
+
print(" ❌ FAIL: Output vocab size does not match Donor vocab size!")
|
| 40 |
+
else:
|
| 41 |
+
print(" ✅ PASS: Output vocab size matches Donor.")
|
| 42 |
+
|
| 43 |
+
print("\n[2] Loading Embedding Tensors (Lazy Load)...")
|
| 44 |
+
emb_base = get_embed_tensor(base_model)
|
| 45 |
+
emb_donor = get_embed_tensor(donor_model)
|
| 46 |
+
emb_out = get_embed_tensor(output_model)
|
| 47 |
+
|
| 48 |
+
print(f" Base Matrix: {emb_base.shape if emb_base is not None else 'Not found'}")
|
| 49 |
+
print(f" Donor Matrix: {emb_donor.shape if emb_donor is not None else 'Not found'}")
|
| 50 |
+
print(f" Output Matrix: {emb_out.shape if emb_out is not None else 'Not found'}")
|
| 51 |
+
|
| 52 |
+
if emb_out is not None and emb_donor is not None:
|
| 53 |
+
if emb_out.shape[0] >= len(tok_donor):
|
| 54 |
+
print(" ✅ PASS: Output embedding matrix size is sufficient for Donor vocab.")
|
| 55 |
+
else:
|
| 56 |
+
print(" ❌ FAIL: Output embedding matrix is smaller than Donor vocab!")
|
| 57 |
+
|
| 58 |
+
vocab_base = tok_base.get_vocab()
|
| 59 |
+
vocab_donor = tok_donor.get_vocab()
|
| 60 |
+
|
| 61 |
+
shared_tokens = set(vocab_base.keys()).intersection(set(vocab_donor.keys()))
|
| 62 |
+
donor_only_tokens = set(vocab_donor.keys()) - set(vocab_base.keys())
|
| 63 |
+
|
| 64 |
+
print("\n[3] Testing a Shared Token (Verifying exact transfer)...")
|
| 65 |
+
if shared_tokens:
|
| 66 |
+
# Pick a common word that is likely to exist in both
|
| 67 |
+
test_shared = None
|
| 68 |
+
for candidate in [" the", " hello", "The", "Hello", "Ġthe", "Ġhello", "the", "hello"]:
|
| 69 |
+
if candidate in shared_tokens:
|
| 70 |
+
test_shared = candidate
|
| 71 |
+
break
|
| 72 |
+
if not test_shared:
|
| 73 |
+
test_shared = list(shared_tokens)[len(shared_tokens)//2]
|
| 74 |
+
|
| 75 |
+
id_base = vocab_base[test_shared]
|
| 76 |
+
id_out = vocab_donor[test_shared] # output uses donor vocab
|
| 77 |
+
|
| 78 |
+
print(f" Token: '{test_shared}'")
|
| 79 |
+
print(f" ID in Base: {id_base} | ID in Output: {id_out}")
|
| 80 |
+
|
| 81 |
+
if emb_base is not None and emb_out is not None:
|
| 82 |
+
vec_base = emb_base[id_base].float()
|
| 83 |
+
vec_out = emb_out[id_out].float()
|
| 84 |
+
|
| 85 |
+
cos_sim = torch.nn.functional.cosine_similarity(vec_base, vec_out, dim=0).item()
|
| 86 |
+
print(f" Cosine similarity between Base and Output vectors: {cos_sim:.6f}")
|
| 87 |
+
if cos_sim > 0.999:
|
| 88 |
+
print(" ✅ PASS: Embeddings match perfectly. The vector was successfully moved to the new ID.")
|
| 89 |
+
else:
|
| 90 |
+
print(" ❌ FAIL: Embeddings for shared token do not match!")
|
| 91 |
+
else:
|
| 92 |
+
print(" ⚠️ No shared tokens found between vocabularies.")
|
| 93 |
+
|
| 94 |
+
print("\n[4] Testing a New Token (Verifying OMP approximation)...")
|
| 95 |
+
if donor_only_tokens:
|
| 96 |
+
# Try to find a special token or a distinct word
|
| 97 |
+
test_new = list(donor_only_tokens)[0]
|
| 98 |
+
for t in donor_only_tokens:
|
| 99 |
+
if "<" in t or "[" in t or "im_start" in t:
|
| 100 |
+
test_new = t
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
id_out = vocab_donor[test_new]
|
| 104 |
+
print(f" Token: '{test_new}' (Only exists in Donor)")
|
| 105 |
+
print(f" ID in Output: {id_out}")
|
| 106 |
+
|
| 107 |
+
if emb_out is not None:
|
| 108 |
+
vec_out = emb_out[id_out].float()
|
| 109 |
+
norm = vec_out.norm().item()
|
| 110 |
+
print(f" Vector L2 Norm: {norm:.4f}")
|
| 111 |
+
if norm > 0.01:
|
| 112 |
+
print(" ✅ PASS: Vector is non-zero. OMP successfully approximated a new embedding.")
|
| 113 |
+
else:
|
| 114 |
+
print(" ⚠️ WARN: Vector is zero or very close to zero. It may have been treated as a junk token.")
|
| 115 |
+
else:
|
| 116 |
+
print(" ⚠️ No donor-only tokens found. Vocabularies are identical.")
|
| 117 |
+
|
| 118 |
+
print("\n[5] Testing Tokenizer Encoding Behavior...")
|
| 119 |
+
test_text = "Hello world! This is a test of the new tokenizer. <|im_start|>system\n12345<|im_end|>"
|
| 120 |
+
enc_donor = tok_donor.encode(test_text)
|
| 121 |
+
enc_out = tok_out.encode(test_text)
|
| 122 |
+
|
| 123 |
+
if enc_donor == enc_out:
|
| 124 |
+
print(" ✅ PASS: Output model encodes text exactly identically to the Donor model.")
|
| 125 |
+
else:
|
| 126 |
+
print(" ❌ FAIL: Output model encoding differs from Donor model!")
|
| 127 |
+
print(f" Donor: {enc_donor[:10]}...")
|
| 128 |
+
print(f" Output: {enc_out[:10]}...")
|
| 129 |
+
|
| 130 |
+
print("\n" + "="*60)
|
| 131 |
+
print("Audit Complete.")
|
| 132 |
+
print("="*60)
|
| 133 |
+
|
| 134 |
+
if __name__ == '__main__':
|
| 135 |
+
main()
|
tokensurgeon.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Arcee AI
|
| 2 |
+
# SPDX-License-Identifier: LGPL-3.0-only
|
| 3 |
+
|
| 4 |
+
import enum
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import click
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributions.constraints
|
| 11 |
+
import tqdm
|
| 12 |
+
import transformers
|
| 13 |
+
from pydantic import BaseModel
|
| 14 |
+
|
| 15 |
+
from mergekit.architecture import (
|
| 16 |
+
ConfiguredModelArchitecture,
|
| 17 |
+
WeightInfo,
|
| 18 |
+
arch_info_for_config,
|
| 19 |
+
)
|
| 20 |
+
from mergekit.common import ModelReference, set_config_value
|
| 21 |
+
from mergekit.io.tasks import (
|
| 22 |
+
LoaderCache,
|
| 23 |
+
)
|
| 24 |
+
from mergekit.io.tensor_writer import TensorWriter
|
| 25 |
+
from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options
|
| 26 |
+
from mergekit.tokenizer.normalization import (
|
| 27 |
+
NormalizedToken,
|
| 28 |
+
normalized_vocabulary,
|
| 29 |
+
token_prefixes,
|
| 30 |
+
)
|
| 31 |
+
from mergekit.tokensurgeon import (
|
| 32 |
+
SubwordMethod,
|
| 33 |
+
WeightingScheme,
|
| 34 |
+
batch_mp_rope,
|
| 35 |
+
batch_omp,
|
| 36 |
+
common_interp_approximate,
|
| 37 |
+
compute_token_basis,
|
| 38 |
+
landmark_pca_approximate,
|
| 39 |
+
subword_approximate,
|
| 40 |
+
well_trained_tokens,
|
| 41 |
+
)
|
| 42 |
+
from mergekit.tokensurgeon.common_interpolation import DistanceMetric
|
| 43 |
+
|
| 44 |
+
LOG = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TokenAssignmentStats(BaseModel):
|
| 48 |
+
exact_match: int = 0
|
| 49 |
+
byte_match: int = 0
|
| 50 |
+
prefix_match: int = 0
|
| 51 |
+
to_approximate: int = 0
|
| 52 |
+
|
| 53 |
+
def pretty_print(self):
|
| 54 |
+
chunks = ["Token Breakdown:"]
|
| 55 |
+
if self.exact_match:
|
| 56 |
+
chunks.append(f" Exact matches: {self.exact_match}")
|
| 57 |
+
if self.byte_match:
|
| 58 |
+
chunks.append(f" Byte matches: {self.byte_match}")
|
| 59 |
+
if self.prefix_match:
|
| 60 |
+
chunks.append(f" Prefix matches: {self.prefix_match}")
|
| 61 |
+
if self.to_approximate:
|
| 62 |
+
chunks.append(f" Tokens to approximate: {self.to_approximate}")
|
| 63 |
+
chunks.append(
|
| 64 |
+
f" Total: {self.exact_match + self.byte_match + self.prefix_match + self.to_approximate}"
|
| 65 |
+
)
|
| 66 |
+
return "\n".join(chunks)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ApproximationMethod(enum.Enum):
|
| 70 |
+
COMMON_INTERPOLATION = "common_interpolation"
|
| 71 |
+
SUBWORD = "subword"
|
| 72 |
+
MEAN = "mean"
|
| 73 |
+
ZERO = "zero"
|
| 74 |
+
RANDN = "randn"
|
| 75 |
+
JOHN_HEWITT = "john_hewitt"
|
| 76 |
+
ORTHOGONAL_MATCHING_PURSUIT = "omp"
|
| 77 |
+
LANDMARK_PCA = "landmark_pca"
|
| 78 |
+
SPARSE_TOKEN_BASIS = "stb"
|
| 79 |
+
MATCHING_PURSUIT_ROPE = "mp_rope"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TokenSurgeonOptions(BaseModel):
|
| 83 |
+
model: ModelReference
|
| 84 |
+
donor: ModelReference
|
| 85 |
+
out_path: str
|
| 86 |
+
method: ApproximationMethod = ApproximationMethod.COMMON_INTERPOLATION
|
| 87 |
+
weight_scheme: WeightingScheme = WeightingScheme.DISTANCE_PROPORTIONAL
|
| 88 |
+
k: int = 64
|
| 89 |
+
cosine_similarity: bool = False
|
| 90 |
+
subword_method: SubwordMethod = SubwordMethod.MEAN
|
| 91 |
+
batch_size: Optional[int] = None
|
| 92 |
+
new_vocab_noise: Optional[float] = None
|
| 93 |
+
new_vocab_scale: Optional[float] = None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_arch_info(
|
| 97 |
+
model: ModelReference, options: MergeOptions
|
| 98 |
+
) -> ConfiguredModelArchitecture:
|
| 99 |
+
cfg = model.config(trust_remote_code=options.trust_remote_code)
|
| 100 |
+
arch_info = arch_info_for_config(cfg)
|
| 101 |
+
return ConfiguredModelArchitecture(info=arch_info, config=cfg)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_embedding_info(
|
| 105 |
+
arch_info: ConfiguredModelArchitecture,
|
| 106 |
+
) -> Tuple[WeightInfo, WeightInfo]:
|
| 107 |
+
"""Get WeightInfo for the input and output embeddings of a model."""
|
| 108 |
+
|
| 109 |
+
if len(arch_info.info.modules) != 1:
|
| 110 |
+
raise RuntimeError("Model has multiple modules - not supported by tokensurgeon")
|
| 111 |
+
name = next(iter(arch_info.info.modules.keys()))
|
| 112 |
+
module_def = arch_info.get_module(name)
|
| 113 |
+
|
| 114 |
+
embed, lm_head = None, None
|
| 115 |
+
for weight_info in module_def.pre_weights():
|
| 116 |
+
if weight_info.is_embed:
|
| 117 |
+
if embed is not None:
|
| 118 |
+
raise RuntimeError("Multiple input embeddings found")
|
| 119 |
+
embed = weight_info
|
| 120 |
+
|
| 121 |
+
for weight_info in module_def.post_weights():
|
| 122 |
+
if weight_info.is_embed:
|
| 123 |
+
if lm_head is not None:
|
| 124 |
+
raise RuntimeError("Multiple output embeddings found")
|
| 125 |
+
lm_head = weight_info
|
| 126 |
+
return embed, lm_head
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def maybe_aliases(weight_info: WeightInfo, tied: bool) -> Tuple[str, ...]:
|
| 130 |
+
return tuple(
|
| 131 |
+
list(weight_info.aliases or [])
|
| 132 |
+
+ list((weight_info.tied_names or []) if tied else [])
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_stuff(
|
| 137 |
+
model: ModelReference,
|
| 138 |
+
options: MergeOptions,
|
| 139 |
+
arch_info: Optional[ConfiguredModelArchitecture] = None,
|
| 140 |
+
get_tied: bool = False,
|
| 141 |
+
device: str = "cpu",
|
| 142 |
+
) -> Tuple[Dict[NormalizedToken, int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 143 |
+
if arch_info is None:
|
| 144 |
+
arch_info = get_arch_info(model, options)
|
| 145 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 146 |
+
model.model.path,
|
| 147 |
+
revision=model.model.revision,
|
| 148 |
+
trust_remote_code=options.trust_remote_code,
|
| 149 |
+
)
|
| 150 |
+
vocab = normalized_vocabulary(tokenizer)
|
| 151 |
+
embed_wi, lm_head_wi = get_embedding_info(arch_info)
|
| 152 |
+
loader = LoaderCache().get(model)
|
| 153 |
+
embed = loader.get_tensor(
|
| 154 |
+
embed_wi.name,
|
| 155 |
+
device=device,
|
| 156 |
+
aliases=maybe_aliases(embed_wi, get_tied),
|
| 157 |
+
raise_on_missing=not embed_wi.optional,
|
| 158 |
+
)
|
| 159 |
+
lm_head = loader.get_tensor(
|
| 160 |
+
lm_head_wi.name,
|
| 161 |
+
device=device,
|
| 162 |
+
aliases=maybe_aliases(lm_head_wi, get_tied),
|
| 163 |
+
raise_on_missing=not lm_head_wi.optional,
|
| 164 |
+
)
|
| 165 |
+
return vocab, embed, lm_head
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def match_byte_token(
|
| 169 |
+
token: NormalizedToken, original_vocab: Dict[NormalizedToken, int]
|
| 170 |
+
) -> Optional[int]:
|
| 171 |
+
if not isinstance(token, str):
|
| 172 |
+
return None
|
| 173 |
+
if len(token) == 1 and ord(token) < 256:
|
| 174 |
+
# check for matching byte tokens
|
| 175 |
+
byte_tok = f"<0x{ord(token):02X}>"
|
| 176 |
+
if byte_tok in original_vocab:
|
| 177 |
+
return original_vocab[byte_tok]
|
| 178 |
+
elif token.startswith("<0x") and token.endswith(">") and len(token) == 6:
|
| 179 |
+
# check for character tokens matching byte tokens
|
| 180 |
+
try:
|
| 181 |
+
byte = int(token[3:-1], 16)
|
| 182 |
+
except ValueError:
|
| 183 |
+
pass
|
| 184 |
+
else:
|
| 185 |
+
if chr(byte) in original_vocab:
|
| 186 |
+
return original_vocab[chr(byte)]
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def match_prefix(
|
| 191 |
+
token: NormalizedToken, original_vocab: Dict[NormalizedToken, int]
|
| 192 |
+
) -> Optional[int]:
|
| 193 |
+
for prefix in token_prefixes(token):
|
| 194 |
+
if prefix in original_vocab:
|
| 195 |
+
return original_vocab[prefix]
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def get_out_arch_info(
|
| 200 |
+
model: ModelReference,
|
| 201 |
+
donor: ModelReference,
|
| 202 |
+
new_vocab_size: int,
|
| 203 |
+
common_options: MergeOptions,
|
| 204 |
+
) -> ConfiguredModelArchitecture:
|
| 205 |
+
cfg_donor = donor.config(trust_remote_code=common_options.trust_remote_code)
|
| 206 |
+
cfg_out = model.config(trust_remote_code=common_options.trust_remote_code)
|
| 207 |
+
arch_info_out = arch_info_for_config(cfg_out)
|
| 208 |
+
set_config_value(
|
| 209 |
+
cfg_out, arch_info_out.vocab_size_config_key or "vocab_size", new_vocab_size
|
| 210 |
+
)
|
| 211 |
+
for key in [
|
| 212 |
+
"pad_token_id",
|
| 213 |
+
"eos_token_id",
|
| 214 |
+
"bos_token_id",
|
| 215 |
+
"unk_token_id",
|
| 216 |
+
"mask_token_id",
|
| 217 |
+
"padding_side",
|
| 218 |
+
]:
|
| 219 |
+
if hasattr(cfg_donor, key):
|
| 220 |
+
set_config_value(cfg_out, key, getattr(cfg_donor, key))
|
| 221 |
+
return ConfiguredModelArchitecture(info=arch_info_out, config=cfg_out)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def john_hewitt_init(orig_embed: torch.Tensor, num_new_tokens: int) -> torch.Tensor:
|
| 225 |
+
orig_embed_f32 = orig_embed.to(torch.float32)
|
| 226 |
+
mean = orig_embed_f32.mean(dim=0)
|
| 227 |
+
centered = orig_embed_f32 - mean
|
| 228 |
+
covariance = centered.T @ centered / orig_embed_f32.shape[0]
|
| 229 |
+
is_pd = torch.distributions.constraints.positive_definite.check(covariance).all()
|
| 230 |
+
if not is_pd:
|
| 231 |
+
LOG.warning(
|
| 232 |
+
"Covariance matrix is not positive definite - falling back to small randn"
|
| 233 |
+
)
|
| 234 |
+
return (
|
| 235 |
+
torch.randn(
|
| 236 |
+
len(num_new_tokens),
|
| 237 |
+
orig_embed.shape[1],
|
| 238 |
+
device=orig_embed.device,
|
| 239 |
+
dtype=orig_embed.dtype,
|
| 240 |
+
)
|
| 241 |
+
* 0.02
|
| 242 |
+
)
|
| 243 |
+
dist = torch.distributions.multivariate_normal.MultivariateNormal(
|
| 244 |
+
loc=mean,
|
| 245 |
+
covariance_matrix=covariance,
|
| 246 |
+
)
|
| 247 |
+
new_embeds = dist.sample((num_new_tokens,))
|
| 248 |
+
return new_embeds.to(orig_embed.dtype)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def compute_new_embeddings(
|
| 252 |
+
orig_embed: torch.Tensor,
|
| 253 |
+
donor_embed: torch.Tensor,
|
| 254 |
+
orig_vocab: Dict[NormalizedToken, int],
|
| 255 |
+
donor_vocab: Dict[NormalizedToken, int],
|
| 256 |
+
target_tokens: List[NormalizedToken],
|
| 257 |
+
is_lm_head: bool,
|
| 258 |
+
token_basis: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 259 |
+
orig_tokenizer: transformers.PreTrainedTokenizerBase,
|
| 260 |
+
options: TokenSurgeonOptions,
|
| 261 |
+
shared_data: Optional[Dict] = None,
|
| 262 |
+
compute_device: torch.device = torch.device("cpu"),
|
| 263 |
+
) -> torch.Tensor:
|
| 264 |
+
assert all(t in donor_vocab for t in target_tokens)
|
| 265 |
+
if options.method == ApproximationMethod.MEAN:
|
| 266 |
+
mean = orig_embed.mean(dim=0).to(compute_device)
|
| 267 |
+
return mean.unsqueeze(0).expand(len(target_tokens), -1)
|
| 268 |
+
elif options.method == ApproximationMethod.ZERO:
|
| 269 |
+
return torch.zeros(
|
| 270 |
+
len(target_tokens),
|
| 271 |
+
orig_embed.shape[1],
|
| 272 |
+
device=compute_device,
|
| 273 |
+
dtype=orig_embed.dtype,
|
| 274 |
+
)
|
| 275 |
+
elif options.method == ApproximationMethod.RANDN:
|
| 276 |
+
return torch.randn(
|
| 277 |
+
len(target_tokens),
|
| 278 |
+
orig_embed.shape[1],
|
| 279 |
+
device=compute_device,
|
| 280 |
+
dtype=orig_embed.dtype,
|
| 281 |
+
)
|
| 282 |
+
elif options.method == ApproximationMethod.JOHN_HEWITT:
|
| 283 |
+
return john_hewitt_init(orig_embed.to(compute_device), len(target_tokens))
|
| 284 |
+
elif options.method in (
|
| 285 |
+
ApproximationMethod.COMMON_INTERPOLATION,
|
| 286 |
+
ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT,
|
| 287 |
+
ApproximationMethod.LANDMARK_PCA,
|
| 288 |
+
ApproximationMethod.MATCHING_PURSUIT_ROPE,
|
| 289 |
+
):
|
| 290 |
+
if shared_data is not None:
|
| 291 |
+
donor_shared_embeds = shared_data["donor_shared_embeds"].to(compute_device)
|
| 292 |
+
orig_shared_embeds = shared_data["orig_shared_embeds"].to(compute_device)
|
| 293 |
+
else:
|
| 294 |
+
shared_vocab = list(
|
| 295 |
+
sorted(
|
| 296 |
+
set(orig_vocab.keys()) & set(donor_vocab.keys()),
|
| 297 |
+
key=lambda x: donor_vocab[x],
|
| 298 |
+
)
|
| 299 |
+
)
|
| 300 |
+
donor_shared_embeds = donor_embed[
|
| 301 |
+
torch.tensor([donor_vocab[t] for t in shared_vocab])
|
| 302 |
+
].to(compute_device)
|
| 303 |
+
|
| 304 |
+
orig_shared_embeds = orig_embed[
|
| 305 |
+
torch.tensor([orig_vocab[t] for t in shared_vocab])
|
| 306 |
+
].to(compute_device)
|
| 307 |
+
|
| 308 |
+
res = None
|
| 309 |
+
in_donor = None
|
| 310 |
+
targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])].to(compute_device)
|
| 311 |
+
|
| 312 |
+
if options.method == ApproximationMethod.LANDMARK_PCA:
|
| 313 |
+
return landmark_pca_approximate(
|
| 314 |
+
targets,
|
| 315 |
+
donor_shared_embeds,
|
| 316 |
+
orig_shared_embeds,
|
| 317 |
+
)
|
| 318 |
+
elif options.method == ApproximationMethod.COMMON_INTERPOLATION:
|
| 319 |
+
indices, coeffs = common_interp_approximate(
|
| 320 |
+
targets,
|
| 321 |
+
donor_shared_embeds,
|
| 322 |
+
k=options.k,
|
| 323 |
+
metric=(
|
| 324 |
+
DistanceMetric.COSINE
|
| 325 |
+
if options.cosine_similarity
|
| 326 |
+
else DistanceMetric.EUCLIDEAN
|
| 327 |
+
),
|
| 328 |
+
weight_scheme=options.weight_scheme,
|
| 329 |
+
)
|
| 330 |
+
elif options.method == ApproximationMethod.MATCHING_PURSUIT_ROPE:
|
| 331 |
+
model_config = options.model.config(trust_remote_code=False)
|
| 332 |
+
donor_config = options.donor.config(trust_remote_code=False)
|
| 333 |
+
indices, coeffs, res, in_donor = batch_mp_rope(
|
| 334 |
+
targets,
|
| 335 |
+
donor_shared_embeds,
|
| 336 |
+
orig_shared_embeds,
|
| 337 |
+
k=options.k,
|
| 338 |
+
num_heads_a=donor_config.num_attention_heads,
|
| 339 |
+
num_heads_b=model_config.num_attention_heads,
|
| 340 |
+
a_rope_base=donor_config.rope_theta,
|
| 341 |
+
b_rope_base=model_config.rope_theta,
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
|
| 345 |
+
|
| 346 |
+
if res is None:
|
| 347 |
+
res = (
|
| 348 |
+
torch.bmm(
|
| 349 |
+
coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float)
|
| 350 |
+
)
|
| 351 |
+
.squeeze(1)
|
| 352 |
+
.to(orig_embed.dtype)
|
| 353 |
+
)
|
| 354 |
+
return res
|
| 355 |
+
elif options.method == ApproximationMethod.SUBWORD:
|
| 356 |
+
return subword_approximate(
|
| 357 |
+
orig_embed.to(compute_device),
|
| 358 |
+
target_tokens,
|
| 359 |
+
is_lm_head,
|
| 360 |
+
orig_tokenizer,
|
| 361 |
+
options.subword_method,
|
| 362 |
+
)
|
| 363 |
+
elif options.method == ApproximationMethod.SPARSE_TOKEN_BASIS:
|
| 364 |
+
assert token_basis is not None, "Token basis must be provided for STB"
|
| 365 |
+
donor_basis, orig_basis = token_basis
|
| 366 |
+
donor_basis = donor_basis.to(compute_device).to(torch.float32)
|
| 367 |
+
orig_basis = orig_basis.to(compute_device).to(torch.float32)
|
| 368 |
+
|
| 369 |
+
target_donor_embeds = donor_embed[
|
| 370 |
+
torch.tensor([donor_vocab[t] for t in target_tokens])
|
| 371 |
+
].to(compute_device).to(torch.float32) - donor_embed.mean(dim=0).to(compute_device)
|
| 372 |
+
|
| 373 |
+
coeffs = torch.linalg.lstsq(
|
| 374 |
+
donor_basis.T,
|
| 375 |
+
target_donor_embeds.T,
|
| 376 |
+
).solution.T
|
| 377 |
+
|
| 378 |
+
if LOG.isEnabledFor(logging.DEBUG):
|
| 379 |
+
donor_rt = coeffs @ donor_basis
|
| 380 |
+
err = (donor_rt - target_donor_embeds).norm(dim=1)
|
| 381 |
+
err_rel = err / target_donor_embeds.norm(dim=1).clamp_min(1e-6)
|
| 382 |
+
sim = torch.nn.functional.cosine_similarity(
|
| 383 |
+
donor_rt, target_donor_embeds, dim=1
|
| 384 |
+
)
|
| 385 |
+
LOG.debug(f"Reconstruction error: {err.mean().item():.4f}")
|
| 386 |
+
LOG.debug(f"Relative reconstruction error: {err_rel.mean().item():.4f}")
|
| 387 |
+
LOG.debug(f"Cosine similarity: {sim.mean().item():.4f}")
|
| 388 |
+
|
| 389 |
+
return coeffs @ orig_basis + orig_embed.mean(dim=0).to(compute_device)
|
| 390 |
+
else:
|
| 391 |
+
raise ValueError(f"Unknown approximation method: {options.method}")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def build_embedding_matrix(
|
| 395 |
+
weight_info: WeightInfo,
|
| 396 |
+
orig_embed: torch.Tensor,
|
| 397 |
+
donor_embed: torch.Tensor,
|
| 398 |
+
orig_vocab: Dict[NormalizedToken, int],
|
| 399 |
+
donor_vocab: Dict[NormalizedToken, int],
|
| 400 |
+
junk_tokens: List[int],
|
| 401 |
+
allow_prefix: bool,
|
| 402 |
+
allow_byte: bool,
|
| 403 |
+
is_lm_head: bool,
|
| 404 |
+
options: TokenSurgeonOptions,
|
| 405 |
+
compute_device: torch.device,
|
| 406 |
+
) -> torch.Tensor:
|
| 407 |
+
LOG.info(f"Building new tensor for {weight_info.name}")
|
| 408 |
+
stats = TokenAssignmentStats()
|
| 409 |
+
out_vocab_size = max(len(donor_vocab), max(donor_vocab.values()) + 1)
|
| 410 |
+
|
| 411 |
+
if options.method == ApproximationMethod.SPARSE_TOKEN_BASIS:
|
| 412 |
+
token_basis = compute_token_basis(
|
| 413 |
+
orig_embed,
|
| 414 |
+
donor_embed,
|
| 415 |
+
orig_vocab,
|
| 416 |
+
donor_vocab,
|
| 417 |
+
junk_tokens,
|
| 418 |
+
options,
|
| 419 |
+
)
|
| 420 |
+
else:
|
| 421 |
+
token_basis = None
|
| 422 |
+
|
| 423 |
+
res = torch.zeros(
|
| 424 |
+
out_vocab_size,
|
| 425 |
+
orig_embed.shape[1],
|
| 426 |
+
device=orig_embed.device,
|
| 427 |
+
dtype=orig_embed.dtype,
|
| 428 |
+
)
|
| 429 |
+
new_tokens = []
|
| 430 |
+
for token, donor_idx in donor_vocab.items():
|
| 431 |
+
if token in orig_vocab:
|
| 432 |
+
orig_idx = orig_vocab[token]
|
| 433 |
+
res[donor_idx] = orig_embed[orig_idx]
|
| 434 |
+
stats.exact_match += 1
|
| 435 |
+
elif (
|
| 436 |
+
allow_byte and (orig_idx := match_byte_token(token, orig_vocab)) is not None
|
| 437 |
+
):
|
| 438 |
+
res[donor_idx] = orig_embed[orig_idx]
|
| 439 |
+
stats.byte_match += 1
|
| 440 |
+
elif allow_prefix and (orig_idx := match_prefix(token, orig_vocab)) is not None:
|
| 441 |
+
res[donor_idx] = orig_embed[orig_idx]
|
| 442 |
+
stats.prefix_match += 1
|
| 443 |
+
else:
|
| 444 |
+
new_tokens.append(token)
|
| 445 |
+
stats.to_approximate += 1
|
| 446 |
+
|
| 447 |
+
donor_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 448 |
+
options.donor.model.path,
|
| 449 |
+
revision=options.donor.model.revision,
|
| 450 |
+
trust_remote_code=True,
|
| 451 |
+
)
|
| 452 |
+
orig_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 453 |
+
options.model.model.path,
|
| 454 |
+
revision=options.model.model.revision,
|
| 455 |
+
trust_remote_code=True,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
LOG.info(stats.pretty_print())
|
| 459 |
+
if new_tokens:
|
| 460 |
+
LOG.info(f"Approximating {len(new_tokens)} tokens")
|
| 461 |
+
|
| 462 |
+
# Precompute shared embeds to avoid doing it in every batch
|
| 463 |
+
shared_vocab = list(
|
| 464 |
+
sorted(
|
| 465 |
+
set(orig_vocab.keys()) & set(donor_vocab.keys()),
|
| 466 |
+
key=lambda x: donor_vocab[x],
|
| 467 |
+
)
|
| 468 |
+
)
|
| 469 |
+
donor_shared_embeds = donor_embed[
|
| 470 |
+
torch.tensor([donor_vocab[t] for t in shared_vocab])
|
| 471 |
+
]
|
| 472 |
+
orig_shared_embeds = orig_embed[
|
| 473 |
+
torch.tensor([orig_vocab[t] for t in shared_vocab])
|
| 474 |
+
]
|
| 475 |
+
shared_data = {
|
| 476 |
+
"donor_shared_embeds": donor_shared_embeds,
|
| 477 |
+
"orig_shared_embeds": orig_shared_embeds,
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
batch_size = options.batch_size
|
| 481 |
+
if batch_size is None or batch_size <= 0:
|
| 482 |
+
batch_size = 512
|
| 483 |
+
|
| 484 |
+
# Adaptive batching logic
|
| 485 |
+
i = 0
|
| 486 |
+
total_tokens = len(new_tokens)
|
| 487 |
+
oom_count = 0
|
| 488 |
+
|
| 489 |
+
pbar = tqdm.tqdm(total=total_tokens, desc="Approximating tokens")
|
| 490 |
+
|
| 491 |
+
while i < total_tokens:
|
| 492 |
+
end = min(i + batch_size, total_tokens)
|
| 493 |
+
current_batch = new_tokens[i:end]
|
| 494 |
+
|
| 495 |
+
try:
|
| 496 |
+
new_embeds = compute_new_embeddings(
|
| 497 |
+
orig_embed,
|
| 498 |
+
donor_embed,
|
| 499 |
+
orig_vocab,
|
| 500 |
+
donor_vocab,
|
| 501 |
+
target_tokens=current_batch,
|
| 502 |
+
is_lm_head=is_lm_head,
|
| 503 |
+
token_basis=token_basis,
|
| 504 |
+
orig_tokenizer=orig_tokenizer,
|
| 505 |
+
options=options,
|
| 506 |
+
shared_data=shared_data,
|
| 507 |
+
compute_device=compute_device,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if options.new_vocab_noise:
|
| 511 |
+
new_embeds += torch.randn_like(new_embeds) * options.new_vocab_noise
|
| 512 |
+
if options.new_vocab_scale:
|
| 513 |
+
new_embeds *= options.new_vocab_scale
|
| 514 |
+
|
| 515 |
+
for ne_idx, token in enumerate(current_batch):
|
| 516 |
+
res[donor_vocab[token]] = new_embeds[ne_idx].to(res.device)
|
| 517 |
+
|
| 518 |
+
# Success, move to next batch
|
| 519 |
+
pbar.update(end - i)
|
| 520 |
+
i = end
|
| 521 |
+
oom_count = 0
|
| 522 |
+
|
| 523 |
+
# Optional cleanup
|
| 524 |
+
if compute_device.type == "cuda":
|
| 525 |
+
torch.cuda.empty_cache()
|
| 526 |
+
|
| 527 |
+
except torch.OutOfMemoryError:
|
| 528 |
+
oom_count += 1
|
| 529 |
+
if compute_device.type == "cuda":
|
| 530 |
+
torch.cuda.empty_cache()
|
| 531 |
+
import gc
|
| 532 |
+
gc.collect()
|
| 533 |
+
|
| 534 |
+
old_batch = batch_size
|
| 535 |
+
batch_size = max(1, int(batch_size * 0.75))
|
| 536 |
+
|
| 537 |
+
if batch_size == old_batch and batch_size == 1:
|
| 538 |
+
LOG.error("OOM even with batch size 1. Cannot continue.")
|
| 539 |
+
raise
|
| 540 |
+
|
| 541 |
+
LOG.warning(f"OOM error. Reducing batch size from {old_batch} to {batch_size} (attempt {oom_count})")
|
| 542 |
+
|
| 543 |
+
if oom_count > 10:
|
| 544 |
+
LOG.error("Too many OOM errors, giving up.")
|
| 545 |
+
raise
|
| 546 |
+
|
| 547 |
+
pbar.close()
|
| 548 |
+
|
| 549 |
+
if junk_tokens:
|
| 550 |
+
LOG.info(f"Zero-initializing {len(junk_tokens)} junk tokens")
|
| 551 |
+
for token_id in junk_tokens:
|
| 552 |
+
res[token_id] = torch.zeros(
|
| 553 |
+
orig_embed.shape[1],
|
| 554 |
+
device=orig_embed.device,
|
| 555 |
+
dtype=orig_embed.dtype,
|
| 556 |
+
)
|
| 557 |
+
return res
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class AllowMatch(enum.Enum):
|
| 561 |
+
LM_HEAD_ONLY = "lm_head"
|
| 562 |
+
EMBED_ONLY = "embed"
|
| 563 |
+
YES = "yes"
|
| 564 |
+
NO = "no"
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@click.command("mergekit-tokensurgeon", cls=PrettyPrintHelp)
|
| 568 |
+
@click.argument("model", type=str)
|
| 569 |
+
@click.argument("donor", type=str)
|
| 570 |
+
@click.argument("out_path", type=str)
|
| 571 |
+
@click.option(
|
| 572 |
+
"--k",
|
| 573 |
+
"-k",
|
| 574 |
+
type=int,
|
| 575 |
+
default=64,
|
| 576 |
+
help="Number of nearest neighbours to use for embedding interpolation",
|
| 577 |
+
show_default=True,
|
| 578 |
+
)
|
| 579 |
+
@click.option(
|
| 580 |
+
"--cosine-similarity/--no-cosine-similarity",
|
| 581 |
+
"-c/-nc",
|
| 582 |
+
is_flag=True,
|
| 583 |
+
default=False,
|
| 584 |
+
help="Use cosine similarity for nearest neighbour search",
|
| 585 |
+
show_default=True,
|
| 586 |
+
)
|
| 587 |
+
@click.option(
|
| 588 |
+
"--approximation-method",
|
| 589 |
+
"-a",
|
| 590 |
+
type=click.Choice([m.value for m in ApproximationMethod]),
|
| 591 |
+
default=ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT.value,
|
| 592 |
+
help="Method for approximating missing tokens",
|
| 593 |
+
show_default=True,
|
| 594 |
+
)
|
| 595 |
+
@click.option(
|
| 596 |
+
"--weight-scheme",
|
| 597 |
+
"-w",
|
| 598 |
+
type=click.Choice([w.value for w in WeightingScheme]),
|
| 599 |
+
default=WeightingScheme.DISTANCE_PROPORTIONAL.value,
|
| 600 |
+
help="Weighting scheme for common-vocabulary interpolation",
|
| 601 |
+
show_default=True,
|
| 602 |
+
)
|
| 603 |
+
@click.option(
|
| 604 |
+
"--subword-method",
|
| 605 |
+
"-s",
|
| 606 |
+
type=click.Choice([m.value for m in SubwordMethod]),
|
| 607 |
+
default=SubwordMethod.MEAN.value,
|
| 608 |
+
help="Method for approximating embeddings with subword tokens",
|
| 609 |
+
show_default=True,
|
| 610 |
+
)
|
| 611 |
+
@click.option(
|
| 612 |
+
"--batch-size",
|
| 613 |
+
type=int,
|
| 614 |
+
default=512,
|
| 615 |
+
help="Number of tokens to process in each batch. -1 for no batching.",
|
| 616 |
+
show_default=True,
|
| 617 |
+
)
|
| 618 |
+
@click.option(
|
| 619 |
+
"--prefix-match",
|
| 620 |
+
"-pm",
|
| 621 |
+
type=click.Choice([m.value for m in AllowMatch]),
|
| 622 |
+
default=AllowMatch.NO.value,
|
| 623 |
+
help="Allow prefix match for tokens",
|
| 624 |
+
show_default=True,
|
| 625 |
+
)
|
| 626 |
+
@click.option(
|
| 627 |
+
"--byte-match",
|
| 628 |
+
"-bm",
|
| 629 |
+
type=click.Choice([m.value for m in AllowMatch]),
|
| 630 |
+
default=AllowMatch.NO.value,
|
| 631 |
+
help="Allow byte match for tokens",
|
| 632 |
+
show_default=True,
|
| 633 |
+
)
|
| 634 |
+
@click.option(
|
| 635 |
+
"--magikarp/--no-magikarp",
|
| 636 |
+
is_flag=True,
|
| 637 |
+
default=False,
|
| 638 |
+
help="Filter out poorly trained tokens",
|
| 639 |
+
show_default=True,
|
| 640 |
+
)
|
| 641 |
+
@click.option(
|
| 642 |
+
"--new-vocab-noise",
|
| 643 |
+
"-nvn",
|
| 644 |
+
type=float,
|
| 645 |
+
default=None,
|
| 646 |
+
help="Add gaussian noise to new vocab embeddings",
|
| 647 |
+
show_default=True,
|
| 648 |
+
)
|
| 649 |
+
@click.option(
|
| 650 |
+
"--new-vocab-scale",
|
| 651 |
+
"-nvs",
|
| 652 |
+
type=float,
|
| 653 |
+
default=None,
|
| 654 |
+
help="Scale computed new vocab embeddings by this factor",
|
| 655 |
+
show_default=True,
|
| 656 |
+
)
|
| 657 |
+
@add_merge_options
|
| 658 |
+
def main(
|
| 659 |
+
model: str,
|
| 660 |
+
donor: str,
|
| 661 |
+
out_path: str,
|
| 662 |
+
k: int,
|
| 663 |
+
cosine_similarity: bool,
|
| 664 |
+
approximation_method: str,
|
| 665 |
+
weight_scheme: str,
|
| 666 |
+
subword_method: str,
|
| 667 |
+
batch_size: Optional[int],
|
| 668 |
+
prefix_match: str,
|
| 669 |
+
byte_match: str,
|
| 670 |
+
magikarp: bool,
|
| 671 |
+
new_vocab_noise: Optional[float],
|
| 672 |
+
new_vocab_scale: Optional[float],
|
| 673 |
+
merge_options: MergeOptions,
|
| 674 |
+
):
|
| 675 |
+
merge_options.apply_global_options()
|
| 676 |
+
logging.warning("This script is experimental and may produce unexpected results.")
|
| 677 |
+
options = TokenSurgeonOptions(
|
| 678 |
+
model=ModelReference.model_validate(model),
|
| 679 |
+
donor=ModelReference.model_validate(donor),
|
| 680 |
+
out_path=out_path,
|
| 681 |
+
k=k,
|
| 682 |
+
cosine_similarity=cosine_similarity,
|
| 683 |
+
method=ApproximationMethod(approximation_method),
|
| 684 |
+
weight_scheme=WeightingScheme(weight_scheme),
|
| 685 |
+
subword_method=SubwordMethod(subword_method),
|
| 686 |
+
batch_size=batch_size,
|
| 687 |
+
new_vocab_noise=new_vocab_noise,
|
| 688 |
+
new_vocab_scale=new_vocab_scale,
|
| 689 |
+
)
|
| 690 |
+
prefix_match = AllowMatch(prefix_match)
|
| 691 |
+
byte_match = AllowMatch(byte_match)
|
| 692 |
+
|
| 693 |
+
cache = LoaderCache()
|
| 694 |
+
cache.setup(options=merge_options)
|
| 695 |
+
|
| 696 |
+
compute_device = torch.device(merge_options.device if merge_options.device else "cuda" if torch.cuda.is_available() else "cpu")
|
| 697 |
+
storage_device = "cpu"
|
| 698 |
+
|
| 699 |
+
arch_info = get_arch_info(options.model, merge_options)
|
| 700 |
+
embed_wi, lm_head_wi = get_embedding_info(arch_info)
|
| 701 |
+
orig_vocab, orig_embed, orig_lm_head = get_stuff(
|
| 702 |
+
options.model, merge_options, arch_info=arch_info, device=storage_device
|
| 703 |
+
)
|
| 704 |
+
donor_vocab, donor_embed, donor_lm_head = get_stuff(
|
| 705 |
+
options.donor, merge_options, arch_info=None, get_tied=True, device=storage_device
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
if magikarp:
|
| 709 |
+
LOG.debug("Finding well-trained tokens in original model")
|
| 710 |
+
well_trained_orig_tokens = set(
|
| 711 |
+
well_trained_tokens(
|
| 712 |
+
orig_vocab,
|
| 713 |
+
orig_embed,
|
| 714 |
+
orig_lm_head,
|
| 715 |
+
)
|
| 716 |
+
)
|
| 717 |
+
LOG.debug("Finding well-trained tokens in donor model")
|
| 718 |
+
well_trained_donor_tokens = set(
|
| 719 |
+
well_trained_tokens(
|
| 720 |
+
donor_vocab,
|
| 721 |
+
donor_embed,
|
| 722 |
+
donor_lm_head,
|
| 723 |
+
)
|
| 724 |
+
)
|
| 725 |
+
common_well_trained_tokens = (
|
| 726 |
+
well_trained_orig_tokens & well_trained_donor_tokens
|
| 727 |
+
)
|
| 728 |
+
LOG.info(f"Found {len(common_well_trained_tokens)} common well-trained tokens")
|
| 729 |
+
orig_vocab = {
|
| 730 |
+
tok: idx
|
| 731 |
+
for tok, idx in orig_vocab.items()
|
| 732 |
+
if tok in common_well_trained_tokens
|
| 733 |
+
}
|
| 734 |
+
junk_tokens = [
|
| 735 |
+
idx
|
| 736 |
+
for tok, idx in donor_vocab.items()
|
| 737 |
+
if (tok not in well_trained_donor_tokens)
|
| 738 |
+
and (tok not in well_trained_orig_tokens)
|
| 739 |
+
]
|
| 740 |
+
else:
|
| 741 |
+
junk_tokens = []
|
| 742 |
+
|
| 743 |
+
if orig_embed is not None:
|
| 744 |
+
if donor_embed is None:
|
| 745 |
+
raise RuntimeError(
|
| 746 |
+
f"Missing tensor {embed_wi.name} in model {options.donor}"
|
| 747 |
+
)
|
| 748 |
+
new_embed = build_embedding_matrix(
|
| 749 |
+
embed_wi,
|
| 750 |
+
orig_embed,
|
| 751 |
+
donor_embed,
|
| 752 |
+
orig_vocab=orig_vocab,
|
| 753 |
+
donor_vocab=donor_vocab,
|
| 754 |
+
junk_tokens=junk_tokens,
|
| 755 |
+
allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY),
|
| 756 |
+
allow_byte=byte_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY),
|
| 757 |
+
is_lm_head=False,
|
| 758 |
+
options=options,
|
| 759 |
+
compute_device=compute_device,
|
| 760 |
+
)
|
| 761 |
+
else:
|
| 762 |
+
if not embed_wi.optional:
|
| 763 |
+
raise RuntimeError(
|
| 764 |
+
f"Missing tensor {embed_wi.name} in model {options.model}"
|
| 765 |
+
)
|
| 766 |
+
new_embed = None
|
| 767 |
+
|
| 768 |
+
if orig_lm_head is not None:
|
| 769 |
+
if donor_lm_head is None:
|
| 770 |
+
raise RuntimeError(
|
| 771 |
+
f"Missing tensor {lm_head_wi.name} in model {options.donor}"
|
| 772 |
+
)
|
| 773 |
+
new_lm_head = build_embedding_matrix(
|
| 774 |
+
lm_head_wi,
|
| 775 |
+
orig_lm_head,
|
| 776 |
+
donor_lm_head,
|
| 777 |
+
orig_vocab=orig_vocab,
|
| 778 |
+
donor_vocab=donor_vocab,
|
| 779 |
+
junk_tokens=junk_tokens,
|
| 780 |
+
allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY),
|
| 781 |
+
allow_byte=byte_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY),
|
| 782 |
+
is_lm_head=True,
|
| 783 |
+
options=options,
|
| 784 |
+
compute_device=compute_device,
|
| 785 |
+
)
|
| 786 |
+
else:
|
| 787 |
+
if not lm_head_wi.optional:
|
| 788 |
+
raise RuntimeError(
|
| 789 |
+
f"Missing tensor {lm_head_wi.name} in model {options.model}"
|
| 790 |
+
)
|
| 791 |
+
new_lm_head = None
|
| 792 |
+
|
| 793 |
+
new_vocab_size = None
|
| 794 |
+
if new_embed is not None:
|
| 795 |
+
new_vocab_size = new_embed.shape[0]
|
| 796 |
+
elif new_lm_head is not None:
|
| 797 |
+
new_vocab_size = new_lm_head.shape[0]
|
| 798 |
+
LOG.info(f"Saving new model to {out_path}")
|
| 799 |
+
out_arch_info = get_out_arch_info(
|
| 800 |
+
options.model, options.donor, new_vocab_size, merge_options
|
| 801 |
+
)
|
| 802 |
+
writer = TensorWriter(
|
| 803 |
+
out_path,
|
| 804 |
+
max_shard_size=merge_options.out_shard_size,
|
| 805 |
+
safe_serialization=merge_options.safe_serialization,
|
| 806 |
+
use_async=merge_options.async_write,
|
| 807 |
+
max_write_threads=merge_options.write_threads,
|
| 808 |
+
)
|
| 809 |
+
for weight_info in tqdm.tqdm(out_arch_info.all_weights(), desc="Saving weights"):
|
| 810 |
+
if weight_info.name == embed_wi.name:
|
| 811 |
+
tensor = new_embed
|
| 812 |
+
elif lm_head_wi is not None and weight_info.name == lm_head_wi.name:
|
| 813 |
+
tensor = new_lm_head
|
| 814 |
+
else:
|
| 815 |
+
tensor = cache.get(options.model).get_tensor(
|
| 816 |
+
weight_info.name, aliases=weight_info.aliases, raise_on_missing=False
|
| 817 |
+
)
|
| 818 |
+
if tensor is None:
|
| 819 |
+
if weight_info.optional:
|
| 820 |
+
continue
|
| 821 |
+
raise RuntimeError(
|
| 822 |
+
f"Missing tensor {weight_info.name} in model {options.model}"
|
| 823 |
+
)
|
| 824 |
+
writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors)
|
| 825 |
+
|
| 826 |
+
# Force close lazy loader file handles so Windows allows deletion/renaming
|
| 827 |
+
cache.flush_all()
|
| 828 |
+
import gc
|
| 829 |
+
gc.collect()
|
| 830 |
+
|
| 831 |
+
# Delete original safetensors files to prevent FileExistsError during rename
|
| 832 |
+
import os
|
| 833 |
+
import re
|
| 834 |
+
temp_pattern = re.compile(r"^.*-\d+\.safetensors$")
|
| 835 |
+
for fname in os.listdir(out_path):
|
| 836 |
+
if fname.endswith(".safetensors") and not temp_pattern.match(fname):
|
| 837 |
+
try:
|
| 838 |
+
os.remove(os.path.join(out_path, fname))
|
| 839 |
+
except Exception as e:
|
| 840 |
+
LOG.warning(f"Could not remove old file {fname}: {e}")
|
| 841 |
+
elif fname == "model.safetensors.index.json":
|
| 842 |
+
try:
|
| 843 |
+
os.remove(os.path.join(out_path, fname))
|
| 844 |
+
except Exception:
|
| 845 |
+
pass
|
| 846 |
+
|
| 847 |
+
writer.finalize()
|
| 848 |
+
out_arch_info.config.save_pretrained(out_path)
|
| 849 |
+
|
| 850 |
+
tokenizer_out = transformers.AutoTokenizer.from_pretrained(
|
| 851 |
+
options.donor.model.path,
|
| 852 |
+
revision=options.donor.model.revision,
|
| 853 |
+
trust_remote_code=merge_options.trust_remote_code,
|
| 854 |
+
)
|
| 855 |
+
tokenizer_out.save_pretrained(out_path)
|
| 856 |
+
|
| 857 |
+
# Also copy generation_config.json if it exists in the donor
|
| 858 |
+
donor_gen_config = os.path.join(options.donor.model.path, "generation_config.json")
|
| 859 |
+
if os.path.exists(donor_gen_config):
|
| 860 |
+
import shutil
|
| 861 |
+
shutil.copy(donor_gen_config, os.path.join(out_path, "generation_config.json"))
|
| 862 |
+
|
| 863 |
+
LOG.info("Done!")
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
if __name__ == "__main__":
|
| 867 |
+
main()
|