Naphula commited on
Commit
5f463e1
·
verified ·
1 Parent(s): 7080631

Upload 8 files

Browse files
Files changed (8) hide show
  1. __init__.py +36 -0
  2. donor_audit_v3.py +245 -0
  3. eos_scanner.py +17 -1
  4. llama.py +59 -0
  5. model_tools.md +14 -1
  6. moe_defs.py +197 -0
  7. tokeninspector.py +135 -0
  8. tokensurgeon.py +867 -0
__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from mergekit.moe.arch import MoEOutputArchitecture
4
+ from mergekit.moe.deepseek import DeepseekMoE
5
+ from mergekit.moe.mixtral import MixtralMoE
6
+
7
+ ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()]
8
+
9
+ try:
10
+ from mergekit.moe.qwen import QwenMoE
11
+ except ImportError:
12
+ pass
13
+ else:
14
+ ALL_OUTPUT_ARCHITECTURES.append(QwenMoE())
15
+
16
+ try:
17
+ from mergekit.moe.qwen3 import Qwen3MoE
18
+ except ImportError:
19
+ pass
20
+ else:
21
+ ALL_OUTPUT_ARCHITECTURES.append(Qwen3MoE())
22
+
23
+ # --- ADD THIS SECTION START ---
24
+ try:
25
+ from mergekit.moe.llama import LlamaMoE
26
+ except ImportError:
27
+ # This will trigger if llama.py is missing or has a syntax error
28
+ pass
29
+ else:
30
+ ALL_OUTPUT_ARCHITECTURES.append(LlamaMoE())
31
+ # --- ADD THIS SECTION END ---
32
+
33
+ __all__ = [
34
+ "ALL_OUTPUT_ARCHITECTURES",
35
+ "MoEOutputArchitecture",
36
+ ]
donor_audit_v3.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Arcee AI & Kraken Architect
2
+ # SPDX-License-Identifier: BUSL-1.1
3
+
4
+ import logging
5
+ import os
6
+ import sys
7
+ from typing import List, Optional
8
+
9
+ import click
10
+ import torch
11
+ import yaml
12
+ from tqdm import tqdm
13
+
14
+ from mergekit.common import ModelReference
15
+ from mergekit.config import MergeConfiguration
16
+ from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
17
+ from mergekit.merge_methods.easy_define import merge_method
18
+
19
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
20
+ LOG = logging.getLogger("donor_audit")
21
+
22
+ @merge_method(
23
+ name="donor_audit",
24
+ pretty_name="Donor Audit",
25
+ reference_url="https://arxiv.org/abs/2408.07990",
26
+ )
27
+ def _donor_audit_registration(tensors: List[torch.Tensor]) -> torch.Tensor:
28
+ """Placeholder to register the method name."""
29
+ return tensors[0]
30
+
31
+ def rsce_weight(tvs: torch.Tensor) -> torch.Tensor:
32
+ """
33
+ Calculates matrix-level weights based on the energy of the task vectors.
34
+ (Copied from RSCE v3)
35
+ """
36
+ # Mean square energy
37
+ weights = torch.mean(tvs**2, dim=list(range(1, tvs.dim())))
38
+ weight_sum = torch.sum(weights).item()
39
+ if abs(weight_sum) < 1e-8:
40
+ return torch.ones_like(weights) / weights.shape[0]
41
+ return weights / weight_sum
42
+
43
+
44
+ def log_rsce_audit(layer_name: str, weights: torch.Tensor, names: List[str]):
45
+ """Prints and saves a bar chart of donor influence."""
46
+ w_list = weights.tolist()
47
+ bar_char = "█"
48
+
49
+ # Header
50
+ print(f"\n{'='*60}")
51
+ print(f"RSCE DONOR AUDIT REPORT")
52
+ print(f"Target Tensor: {layer_name}")
53
+ print(f"{'='*60}")
54
+
55
+ lines = []
56
+ for name, w in zip(names, w_list):
57
+ pct = w * 100
58
+ # Scale bar: 50 chars = 100% influence (which is huge/impossible usually)
59
+ # Let's scale it so the max value fills the bar for better visibility
60
+ max_val = max(w_list) if max(w_list) > 0 else 1.0
61
+
62
+ # Relative bar length (relative to the loudest model)
63
+ bar_len = int((w / max_val) * 40)
64
+ bar = bar_char * bar_len
65
+
66
+ # Truncate name for clean display
67
+ clean_name = os.path.basename(name)
68
+ if len(clean_name) > 30:
69
+ clean_name = clean_name[:27] + "..."
70
+
71
+ lines.append(f"{clean_name:<30} | {bar:<40} | {pct:6.2f}% (Raw: {w:.4f})")
72
+
73
+ log_entry = "\n".join(lines)
74
+ print(log_entry)
75
+ print(f"{'='*60}\n")
76
+
77
+ # Append to file
78
+ with open("rsce_audit.log", "a", encoding="utf-8") as f:
79
+ f.write(f"\n[Audit {layer_name}]\n" + log_entry + "\n")
80
+
81
+
82
+ def find_layer0_tensor(loader: LazyTensorLoader) -> str:
83
+ """
84
+ Scans a model loader to find a suitable Layer 0 tensor for auditing.
85
+ Prioritizes self_attn projections as they are usually dense and representative.
86
+ """
87
+ candidates = []
88
+ for key in loader.index.tensor_paths.keys():
89
+ # Look for Layer 0
90
+ if ".layers.0." in key or ".h.0." in key or ".blocks.0." in key:
91
+ # Look for weights (not bias)
92
+ if key.endswith(".weight"):
93
+ candidates.append(key)
94
+
95
+ # Priority sort: q_proj > gate_proj > dense > others
96
+ for c in candidates:
97
+ if "down_proj" in c: return c
98
+ for c in candidates:
99
+ if "gate_proj" in c: return c
100
+ for c in candidates:
101
+ if "c_attn" in c: return c # GPT-NeoX / Qwen
102
+
103
+ if not candidates:
104
+ raise RuntimeError("Could not find any Layer 0 weights in the base model.")
105
+
106
+ return candidates[0]
107
+
108
+
109
+ def load_tensor_safe(model_path: str, tensor_name: str, device="cpu") -> torch.Tensor:
110
+ """Loads a single tensor from a model path."""
111
+ try:
112
+ # We use ShardedTensorIndex directly to avoid caching overhead of LoaderCache for this simple script
113
+ if os.path.isfile(model_path):
114
+ index = ShardedTensorIndex.from_file(model_path)
115
+ else:
116
+ index = ShardedTensorIndex.from_disk(model_path)
117
+ loader = LazyTensorLoader(index, lazy_unpickle=True)
118
+
119
+ # Handle potential naming mismatches (simple check)
120
+ if tensor_name not in index.tensor_paths:
121
+ # Try to find a fuzzy match if exact name fails (e.g. if models have slightly different archs)
122
+ # This is a basic fallback
123
+ suffix = tensor_name.split("layers.0.")[-1]
124
+ for k in index.tensor_paths.keys():
125
+ if k.endswith(suffix) and ("layers.0." in k or "h.0." in k):
126
+ tensor_name = k
127
+ break
128
+
129
+ t = loader.get_tensor(tensor_name, device=device)
130
+ return t.float() # Convert to float32 for math
131
+ except Exception as e:
132
+ LOG.error(f"Failed to load {tensor_name} from {model_path}: {e}")
133
+ sys.exit(1)
134
+
135
+
136
+ @click.command()
137
+ @click.argument("config_file", type=click.Path(exists=True))
138
+ @click.option("--lora-merge-cache", default=None, help="Cache directory for merged LoRAs")
139
+ @click.option("--cuda/--no-cuda", default=False, help="Use GPU for calculation (faster math, higher VRAM)")
140
+ def main(config_file, lora_merge_cache, cuda):
141
+ """
142
+ RSCE Donor Audit Tool V3.
143
+
144
+ Loads Layer 0 from all models in the config and calculates their
145
+ Task Vector magnitude/energy contribution relative to the base model.
146
+ """
147
+ device = "cuda" if cuda and torch.cuda.is_available() else "cpu"
148
+ LOG.info(f"Running audit on {device}...")
149
+
150
+ # 1. Parse Config
151
+ with open(config_file, "r", encoding="utf-8") as f:
152
+ config_data = yaml.safe_load(f)
153
+ config = MergeConfiguration.model_validate(config_data)
154
+
155
+ # 2. Identify Models
156
+ base_model_ref = config.base_model
157
+ if not base_model_ref:
158
+ LOG.error("Config must specify a `base_model` for RSCE auditing.")
159
+ sys.exit(1)
160
+
161
+ # Extract donor models from slices or models list
162
+ donor_refs = []
163
+ if config.models:
164
+ donor_refs = [m.model for m in config.models]
165
+ elif config.slices:
166
+ # Flatten slices to get unique models
167
+ seen = set()
168
+ for s in config.slices:
169
+ for source in s.sources:
170
+ if source.model != base_model_ref and source.model not in seen:
171
+ donor_refs.append(source.model)
172
+ seen.add(source.model)
173
+
174
+ # Filter out base model if it appeared in donors
175
+ donor_refs = [d for d in donor_refs if d != base_model_ref]
176
+
177
+ LOG.info(f"Base Model: {base_model_ref.model.path}")
178
+ LOG.info(f"Found {len(donor_refs)} donor models.")
179
+
180
+ # 3. Resolve Paths (Handle LoRAs if necessary)
181
+ def resolve_path(ref: ModelReference):
182
+ if ref.lora:
183
+ if not lora_merge_cache:
184
+ LOG.warning("LoRA detected but --lora-merge-cache not set. This might fail.")
185
+ return ref.merged(cache_dir=lora_merge_cache).model.path
186
+
187
+ if not os.path.exists(ref.model.path):
188
+ try:
189
+ from huggingface_hub import snapshot_download
190
+ return snapshot_download(ref.model.path, allow_patterns=["*.safetensors", "*.bin", "*.json"])
191
+ except:
192
+ return ref.model.path
193
+ return ref.model.path
194
+
195
+ base_path = resolve_path(base_model_ref)
196
+ donor_paths = [resolve_path(d) for d in donor_refs]
197
+
198
+ # 4. Identify Target Tensor (Layer 0)
199
+ base_index = ShardedTensorIndex.from_disk(base_path)
200
+ base_loader = LazyTensorLoader(base_index, lazy_unpickle=True)
201
+ target_tensor_name = find_layer0_tensor(base_loader)
202
+
203
+ LOG.info(f"Selected audit tensor: {target_tensor_name}")
204
+ LOG.info("Loading tensors into memory...")
205
+
206
+ # 5. Load All Tensors
207
+ base_tensor = load_tensor_safe(base_path, target_tensor_name, device)
208
+
209
+ donor_tensors = []
210
+ valid_donor_refs = []
211
+
212
+ for d_path, d_ref in zip(tqdm(donor_paths, desc="Loading Donors"), donor_refs):
213
+ dt = load_tensor_safe(d_path, target_tensor_name, device)
214
+
215
+ # V3: Catch shape mismatches (e.g. a 7B model mixed into a 12B merge)
216
+ if dt.shape != base_tensor.shape:
217
+ LOG.warning(f"\n[!] Shape mismatch for {d_ref.model.path}: expected {base_tensor.shape}, got {dt.shape}. Skipping this model.")
218
+ continue
219
+
220
+ donor_tensors.append(dt)
221
+ valid_donor_refs.append(d_ref)
222
+
223
+ if not donor_tensors:
224
+ LOG.error("No valid donor tensors found with matching shapes. Exiting.")
225
+ sys.exit(1)
226
+
227
+ # 6. Perform RSCE Audit Math
228
+ LOG.info("Calculating Task Vector Energy...")
229
+
230
+ base_tv = torch.zeros_like(base_tensor)
231
+ donor_tvs = [dt - base_tensor for dt in donor_tensors]
232
+
233
+ all_tvs = torch.stack([base_tv] + donor_tvs, dim=0)
234
+ raw_weights = rsce_weight(all_tvs)
235
+
236
+ display_names = ["Base Model (Anchor)"] + [d.model.path for d in valid_donor_refs]
237
+
238
+ # 7. Output
239
+ log_rsce_audit(target_tensor_name, raw_weights, display_names)
240
+
241
+ LOG.info("Audit complete.")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
eos_scanner.py CHANGED
@@ -31,9 +31,25 @@ def load_json(path):
31
  return None
32
 
33
  def get_model_metadata(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  data = {
35
  "path": model_path,
36
- "name": os.path.basename(model_path).replace("!models--", ""),
37
  "gen_eos_id": "MISSING", # From generation_config.json
38
  "tok_eos_str": "MISSING", # From tokenizer_config.json
39
  "vocab_eos_id": "MISSING", # The actual ID of the string in tokenizer.json
 
31
  return None
32
 
33
  def get_model_metadata(model_path):
34
+ # --- NAME FIX LOGIC START ---
35
+ # Normalize path to handle trailing slashes or mixed separators
36
+ norm_path = os.path.normpath(model_path)
37
+ base_name = os.path.basename(norm_path)
38
+
39
+ # If the folder is named "fixed", grab the parent folder name instead
40
+ if base_name == "fixed":
41
+ parent_name = os.path.basename(os.path.dirname(norm_path))
42
+ display_name = f"{parent_name}/fixed"
43
+ else:
44
+ display_name = base_name
45
+
46
+ # Clean up the huggingface cache prefix
47
+ display_name = display_name.replace("!models--", "")
48
+ # --- NAME FIX LOGIC END ---
49
+
50
  data = {
51
  "path": model_path,
52
+ "name": display_name,
53
  "gen_eos_id": "MISSING", # From generation_config.json
54
  "tok_eos_str": "MISSING", # From tokenizer_config.json
55
  "vocab_eos_id": "MISSING", # The actual ID of the string in tokenizer.json
llama.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tqdm
3
+ import transformers
4
+ from mergekit.moe.arch import MoEOutputArchitecture
5
+ from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
6
+ from mergekit.moe.config import MoEMergeConfig
7
+ from mergekit.options import MergeOptions
8
+ from mergekit.architecture import arch_info_for_config
9
+
10
+ class LlamaMoE(MoEOutputArchitecture):
11
+ def name(self) -> str:
12
+ return "LlamaMoE"
13
+
14
+ def supports_config(self, config: MoEMergeConfig, explain: bool = False, trust_remote_code: bool = False) -> bool:
15
+ # Ensure the base model is a Llama model
16
+ model_cfg = config.base_model.config(trust_remote_code=trust_remote_code)
17
+ if model_cfg.model_type != "llama":
18
+ if explain:
19
+ print("LlamaMoE only supports Llama base models")
20
+ return False
21
+ return True
22
+
23
+ def write_model(self, out_path: str, config: MoEMergeConfig, merge_options: MergeOptions, router_weights: list[torch.Tensor], shared_router_weights=None):
24
+ base_model = config.base_model
25
+ base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
26
+
27
+ # 1. Generate the config.json
28
+ out_cfg = base_cfg.to_dict()
29
+ # Note: Most Llama MoEs use the Mixtral architecture name for compatibility with loaders
30
+ out_cfg["architectures"] = ["MixtralForCausalLM"]
31
+ out_cfg["num_local_experts"] = len(config.experts)
32
+ out_cfg["num_experts_per_tok"] = config.experts_per_token
33
+
34
+ out_dtype = select_dtype(config, base_cfg)
35
+
36
+ # 2. Initialize IO
37
+ loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
38
+
39
+ # 3. Map Tensors
40
+ for weight_info in tqdm.tqdm(arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights"):
41
+ tensor_name = weight_info.name
42
+ if ".mlp." in tensor_name:
43
+ for expert_idx, expert in enumerate(config.experts):
44
+ # Map Llama's gate_proj/up_proj/down_proj to Mixtral's w1/w3/w2
45
+ expert_name = tensor_name.replace(".mlp.gate_proj", f".block_sparse_moe.experts.{expert_idx}.w1")
46
+ expert_name = expert_name.replace(".mlp.down_proj", f".block_sparse_moe.experts.{expert_idx}.w2")
47
+ expert_name = expert_name.replace(".mlp.up_proj", f".block_sparse_moe.experts.{expert_idx}.w3")
48
+
49
+ expert_loader = loaders.get(expert.source_model)
50
+ copy_tensor_out(weight_info, expert_loader, writer, expert=expert, output_name=expert_name, out_dtype=out_dtype)
51
+ else:
52
+ # Copy Attention and Norms from base model
53
+ copy_tensor_out(weight_info, base_loader, writer, out_dtype=out_dtype)
54
+
55
+ # 4. Write Router Weights
56
+ for layer_idx, weight in enumerate(router_weights):
57
+ writer.save_tensor(f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", weight.to(dtype=out_dtype))
58
+
59
+ writer.finalize()
model_tools.md CHANGED
@@ -32,8 +32,21 @@ Tools to enhance LLM quantizations and merging
32
  # [metadata_audit.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/metadata_audit.py)
33
  - Checks multiple models within subdirectories for vocab or rope mismatch (useful for large merges). Calibrated for Mistral Nemo 12B by default.
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # [eos_scanner.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner.py)
36
- - This tool scans the tokenizer jsons to detect any mismatches with EOS tokens, which cause early termination bugs. You can then use the [gen_id_patcher.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/gen_id_patcher.py) to patch missing `generation_config.json` files for EOS token. See [this post](https://huggingface.co/Naphula/Q0_Bench/discussions/1?not-for-all-audiences=true#6987717c762f0a45f672e250) as well as the [EOS Scanner ReadMe](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner_readme.md) for more info.
37
 
38
  # [weight_counter.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/weight_counter.py)
39
  - This counts the number of models in a yaml and adds up the total weight values. Useful for large della/ties merges.
 
32
  # [metadata_audit.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/metadata_audit.py)
33
  - Checks multiple models within subdirectories for vocab or rope mismatch (useful for large merges). Calibrated for Mistral Nemo 12B by default.
34
 
35
+ # llama moe
36
+ - Add support for Llama Mixture of Experts. If you want to merge custom Llama MoE you can add these scripts to your mergekit environment:
37
+ - [mergekit-main\mergekit\architecture\moe_defs.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/moe_defs.py)
38
+ - [mergekit-main\mergekit\__init__.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/__init__.py)
39
+ - [mergekit-main\mergekit\moe\llama.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/llama.py)
40
+ - Then assign the num_experts_per_tok in config.json (or the config.yaml)
41
+
42
+ # [tokensurgeon.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/tokensurgeon.py)
43
+ - Uses adaptive VRAM from Grim Jim's `measure.py` like `graph_v18` to prevent OOM. Use recommended [batch file](https://huggingface.co/spaces/Naphula/model_tools/blob/main/fix_tokenizers.bat) here or modify sh. This supposedly avoids 'cardboard town' fake patches.
44
+
45
+ # [tokeninspector.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/tokensurgeon.py)
46
+ - Audit your tokensurgeon results.
47
+
48
  # [eos_scanner.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner.py)
49
+ - Updated! This tool scans the tokenizer jsons to detect any mismatches with EOS tokens, which cause early termination bugs. You can then use the [gen_id_patcher.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/gen_id_patcher.py) to patch missing `generation_config.json` files for EOS token. See [this post](https://huggingface.co/Naphula/Q0_Bench/discussions/1?not-for-all-audiences=true#6987717c762f0a45f672e250) as well as the [EOS Scanner ReadMe](https://huggingface.co/spaces/Naphula/model_tools/blob/main/eos_scanner_readme.md) for more info.
50
 
51
  # [weight_counter.py](https://huggingface.co/spaces/Naphula/model_tools/blob/main/weight_counter.py)
52
  - This counts the number of models in a yaml and adds up the total weight values. Useful for large della/ties merges.
moe_defs.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Arcee AI
2
+ # SPDX-License-Identifier: LGPL-3.0-only
3
+
4
+ from typing import ClassVar, List, Optional
5
+
6
+ from pydantic import BaseModel
7
+ from transformers import PretrainedConfig
8
+
9
+ from mergekit.architecture.base import (
10
+ ModuleArchitecture,
11
+ WeightInfo,
12
+ )
13
+ from mergekit.architecture.json_definitions import NAME_TO_ARCH
14
+
15
+ MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0]
16
+ MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture
17
+
18
+
19
+ class MixtralModuleArchitecture(ModuleArchitecture, BaseModel):
20
+ ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
21
+ num_local_experts: int
22
+
23
+ def name(self) -> str:
24
+ return "mixtral"
25
+
26
+ @classmethod
27
+ def from_config(cls, config: PretrainedConfig):
28
+ return MixtralModuleArchitecture(num_local_experts=config.num_local_experts)
29
+
30
+ def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
31
+ return MISTRAL_MODULE_ARCH.pre_weights(config)
32
+
33
+ def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
34
+ return MISTRAL_MODULE_ARCH.post_weights(config)
35
+
36
+ def num_layers_config_key(self) -> str:
37
+ return MISTRAL_MODULE_ARCH.num_layers_config_key()
38
+
39
+ def layer_weights(
40
+ self, index: int, config: PretrainedConfig
41
+ ) -> Optional[List[WeightInfo]]:
42
+ num_experts = self.num_local_experts
43
+ prefix = f"model.layers.{index}"
44
+ tensor_names = []
45
+ for expert_idx in range(num_experts):
46
+ for param in ("w1", "w2", "w3"):
47
+ tensor_names.append(
48
+ prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
49
+ )
50
+ tensor_names.append(prefix + ".block_sparse_moe.gate.weight")
51
+ res = []
52
+ for name in tensor_names:
53
+ res.append(WeightInfo(name=name))
54
+ for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config):
55
+ if ".mlp." in weight_info.name:
56
+ continue
57
+ res.append(weight_info)
58
+ return res
59
+
60
+
61
+ QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0]
62
+ QWEN3_MODULE_ARCH = QWEN3_INFO.modules["default"].architecture
63
+
64
+
65
+ class Qwen3MoeModuleArchitecture(ModuleArchitecture, BaseModel):
66
+ ARCHITECTURE_NAME: ClassVar[str] = "Qwen3MoeForCausalLM"
67
+ num_experts: int
68
+
69
+ def name(self) -> str:
70
+ return "qwen3_moe"
71
+
72
+ @classmethod
73
+ def from_config(cls, config: PretrainedConfig):
74
+ return Qwen3MoeModuleArchitecture(num_experts=config.num_experts)
75
+
76
+ def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
77
+ return QWEN3_MODULE_ARCH.pre_weights(config)
78
+
79
+ def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
80
+ return QWEN3_MODULE_ARCH.post_weights(config)
81
+
82
+ def num_layers_config_key(self) -> str:
83
+ return QWEN3_MODULE_ARCH.num_layers_config_key()
84
+
85
+ def layer_weights(
86
+ self, index: int, config: PretrainedConfig
87
+ ) -> Optional[List[WeightInfo]]:
88
+ prefix = f"model.layers.{index}"
89
+ tensor_names = []
90
+ for expert_idx in range(self.num_experts):
91
+ for param in ("up_proj", "gate_proj", "down_proj"):
92
+ tensor_names.append(
93
+ prefix + f".mlp.experts.{expert_idx}.{param}.weight"
94
+ )
95
+ tensor_names.append(prefix + ".mlp.gate.weight")
96
+ res = []
97
+ for name in tensor_names:
98
+ res.append(WeightInfo(name=name))
99
+ for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config):
100
+ if ".mlp." in weight_info.name:
101
+ continue
102
+ res.append(weight_info)
103
+ return res
104
+
105
+
106
+ AFMOE_PARTIAL_INFO = NAME_TO_ARCH["_AfmoePartialForCausalLM"][0]
107
+ AFMOE_PARTIAL_MODULE_ARCH = AFMOE_PARTIAL_INFO.modules["default"].architecture
108
+
109
+
110
+ class AfmoeModuleArchitecture(ModuleArchitecture, BaseModel):
111
+ ARCHITECTURE_NAME: ClassVar[str] = "AfmoeForCausalLM"
112
+ num_experts: int
113
+
114
+ def name(self) -> str:
115
+ return "afmoe"
116
+
117
+ @classmethod
118
+ def from_config(cls, config: PretrainedConfig):
119
+ return AfmoeModuleArchitecture(num_experts=config.num_experts)
120
+
121
+ def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
122
+ return AFMOE_PARTIAL_MODULE_ARCH.pre_weights(config)
123
+
124
+ def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
125
+ return AFMOE_PARTIAL_MODULE_ARCH.post_weights(config)
126
+
127
+ def num_layers_config_key(self) -> str:
128
+ return AFMOE_PARTIAL_MODULE_ARCH.num_layers_config_key()
129
+
130
+ def layer_weights(
131
+ self, index: int, config: PretrainedConfig
132
+ ) -> Optional[List[WeightInfo]]:
133
+ res = AFMOE_PARTIAL_MODULE_ARCH.layer_weights(index, config) or []
134
+ prefix = f"model.layers.{index}"
135
+ for expert_idx in range(self.num_experts):
136
+ for param in ("up_proj", "gate_proj", "down_proj"):
137
+ res.append(
138
+ WeightInfo(
139
+ name=prefix + f".mlp.experts.{expert_idx}.{param}.weight",
140
+ optional=True,
141
+ )
142
+ )
143
+ return res
144
+
145
+
146
+ # Add this to moe_defs.py
147
+
148
+ # 1. Get the base Llama info from the registry
149
+ LLAMA_INFO = NAME_TO_ARCH["LlamaForCausalLM"][0]
150
+ LLAMA_MODULE_ARCH = LLAMA_INFO.modules["default"].architecture
151
+
152
+ class LlamaMoeModuleArchitecture(ModuleArchitecture, BaseModel):
153
+ # This is the name that will appear in the output config.json
154
+ ARCHITECTURE_NAME: ClassVar[str] = "LlamaMoeForCausalLM"
155
+ num_experts: int
156
+
157
+ def name(self) -> str:
158
+ return "llama_moe"
159
+
160
+ @classmethod
161
+ def from_config(cls, config: PretrainedConfig):
162
+ # This looks for the 'num_experts' key in the model's config
163
+ return LlamaMoeModuleArchitecture(num_experts=getattr(config, "num_experts", 8))
164
+
165
+ def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
166
+ # Uses standard Llama embeddings/norms
167
+ return LLAMA_MODULE_ARCH.pre_weights(config)
168
+
169
+ def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
170
+ # Uses standard Llama final norm/head
171
+ return LLAMA_MODULE_ARCH.post_weights(config)
172
+
173
+ def num_layers_config_key(self) -> str:
174
+ return LLAMA_MODULE_ARCH.num_layers_config_key()
175
+
176
+ def layer_weights(self, index: int, config: PretrainedConfig) -> Optional[List[WeightInfo]]:
177
+ prefix = f"model.layers.{index}"
178
+ res = []
179
+
180
+ # 2. Define the Expert weights
181
+ # We map the dense MLP layers into an expert array
182
+ for expert_idx in range(self.num_experts):
183
+ for param in ("gate_proj", "up_proj", "down_proj"):
184
+ res.append(
185
+ WeightInfo(name=prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight")
186
+ )
187
+
188
+ # 3. Define the Router (Gate) weight
189
+ res.append(WeightInfo(name=prefix + ".block_sparse_moe.gate.weight"))
190
+
191
+ # 4. Add the non-MLP weights (Attention layers, Input Norms)
192
+ # We skip the original .mlp. weights because we replaced them with experts
193
+ for weight_info in LLAMA_MODULE_ARCH.layer_weights(index, config):
194
+ if ".mlp." not in weight_info.name:
195
+ res.append(weight_info)
196
+
197
+ return res
tokeninspector.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python tokeninspector.py "B:\12B\models--mistralai--Mistral-Nemo-Instruct-2407" "B:\12B\models--aixonlab--Aether-12b.backup" "B:\12B\models--aixonlab--Aether-12b"
2
+
3
+ import os
4
+ import click
5
+ import torch
6
+ import transformers
7
+ from mergekit.io.lazy_tensor_loader import LazyTensorLoader
8
+
9
+ def get_embed_tensor(model_path):
10
+ """Lazily loads the embedding tensor from a model directory."""
11
+ try:
12
+ loader = LazyTensorLoader.from_disk(model_path)
13
+ for key in loader.index.tensor_paths.keys():
14
+ if "embed_tokens.weight" in key or "wte.weight" in key:
15
+ return loader.get_tensor(key)
16
+ except Exception as e:
17
+ print(f" [!] Error loading tensors from {model_path}: {e}")
18
+ return None
19
+
20
+ @click.command()
21
+ @click.argument("base_model", type=click.Path(exists=True))
22
+ @click.argument("donor_model", type=click.Path(exists=True))
23
+ @click.argument("output_model", type=click.Path(exists=True))
24
+ def main(base_model, donor_model, output_model):
25
+ print("="*60)
26
+ print("🔍 TOKEN SURGEON AUDIT TOOL")
27
+ print("="*60)
28
+
29
+ print("\n[1] Loading Tokenizers...")
30
+ tok_base = transformers.AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
31
+ tok_donor = transformers.AutoTokenizer.from_pretrained(donor_model, trust_remote_code=True)
32
+ tok_out = transformers.AutoTokenizer.from_pretrained(output_model, trust_remote_code=True)
33
+
34
+ print(f" Base: {len(tok_base)} tokens")
35
+ print(f" Donor: {len(tok_donor)} tokens")
36
+ print(f" Output: {len(tok_out)} tokens")
37
+
38
+ if len(tok_out) != len(tok_donor):
39
+ print(" ❌ FAIL: Output vocab size does not match Donor vocab size!")
40
+ else:
41
+ print(" ✅ PASS: Output vocab size matches Donor.")
42
+
43
+ print("\n[2] Loading Embedding Tensors (Lazy Load)...")
44
+ emb_base = get_embed_tensor(base_model)
45
+ emb_donor = get_embed_tensor(donor_model)
46
+ emb_out = get_embed_tensor(output_model)
47
+
48
+ print(f" Base Matrix: {emb_base.shape if emb_base is not None else 'Not found'}")
49
+ print(f" Donor Matrix: {emb_donor.shape if emb_donor is not None else 'Not found'}")
50
+ print(f" Output Matrix: {emb_out.shape if emb_out is not None else 'Not found'}")
51
+
52
+ if emb_out is not None and emb_donor is not None:
53
+ if emb_out.shape[0] >= len(tok_donor):
54
+ print(" ✅ PASS: Output embedding matrix size is sufficient for Donor vocab.")
55
+ else:
56
+ print(" ❌ FAIL: Output embedding matrix is smaller than Donor vocab!")
57
+
58
+ vocab_base = tok_base.get_vocab()
59
+ vocab_donor = tok_donor.get_vocab()
60
+
61
+ shared_tokens = set(vocab_base.keys()).intersection(set(vocab_donor.keys()))
62
+ donor_only_tokens = set(vocab_donor.keys()) - set(vocab_base.keys())
63
+
64
+ print("\n[3] Testing a Shared Token (Verifying exact transfer)...")
65
+ if shared_tokens:
66
+ # Pick a common word that is likely to exist in both
67
+ test_shared = None
68
+ for candidate in [" the", " hello", "The", "Hello", "Ġthe", "Ġhello", "the", "hello"]:
69
+ if candidate in shared_tokens:
70
+ test_shared = candidate
71
+ break
72
+ if not test_shared:
73
+ test_shared = list(shared_tokens)[len(shared_tokens)//2]
74
+
75
+ id_base = vocab_base[test_shared]
76
+ id_out = vocab_donor[test_shared] # output uses donor vocab
77
+
78
+ print(f" Token: '{test_shared}'")
79
+ print(f" ID in Base: {id_base} | ID in Output: {id_out}")
80
+
81
+ if emb_base is not None and emb_out is not None:
82
+ vec_base = emb_base[id_base].float()
83
+ vec_out = emb_out[id_out].float()
84
+
85
+ cos_sim = torch.nn.functional.cosine_similarity(vec_base, vec_out, dim=0).item()
86
+ print(f" Cosine similarity between Base and Output vectors: {cos_sim:.6f}")
87
+ if cos_sim > 0.999:
88
+ print(" ✅ PASS: Embeddings match perfectly. The vector was successfully moved to the new ID.")
89
+ else:
90
+ print(" ❌ FAIL: Embeddings for shared token do not match!")
91
+ else:
92
+ print(" ⚠️ No shared tokens found between vocabularies.")
93
+
94
+ print("\n[4] Testing a New Token (Verifying OMP approximation)...")
95
+ if donor_only_tokens:
96
+ # Try to find a special token or a distinct word
97
+ test_new = list(donor_only_tokens)[0]
98
+ for t in donor_only_tokens:
99
+ if "<" in t or "[" in t or "im_start" in t:
100
+ test_new = t
101
+ break
102
+
103
+ id_out = vocab_donor[test_new]
104
+ print(f" Token: '{test_new}' (Only exists in Donor)")
105
+ print(f" ID in Output: {id_out}")
106
+
107
+ if emb_out is not None:
108
+ vec_out = emb_out[id_out].float()
109
+ norm = vec_out.norm().item()
110
+ print(f" Vector L2 Norm: {norm:.4f}")
111
+ if norm > 0.01:
112
+ print(" ✅ PASS: Vector is non-zero. OMP successfully approximated a new embedding.")
113
+ else:
114
+ print(" ⚠️ WARN: Vector is zero or very close to zero. It may have been treated as a junk token.")
115
+ else:
116
+ print(" ⚠️ No donor-only tokens found. Vocabularies are identical.")
117
+
118
+ print("\n[5] Testing Tokenizer Encoding Behavior...")
119
+ test_text = "Hello world! This is a test of the new tokenizer. <|im_start|>system\n12345<|im_end|>"
120
+ enc_donor = tok_donor.encode(test_text)
121
+ enc_out = tok_out.encode(test_text)
122
+
123
+ if enc_donor == enc_out:
124
+ print(" ✅ PASS: Output model encodes text exactly identically to the Donor model.")
125
+ else:
126
+ print(" ❌ FAIL: Output model encoding differs from Donor model!")
127
+ print(f" Donor: {enc_donor[:10]}...")
128
+ print(f" Output: {enc_out[:10]}...")
129
+
130
+ print("\n" + "="*60)
131
+ print("Audit Complete.")
132
+ print("="*60)
133
+
134
+ if __name__ == '__main__':
135
+ main()
tokensurgeon.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Arcee AI
2
+ # SPDX-License-Identifier: LGPL-3.0-only
3
+
4
+ import enum
5
+ import logging
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import click
9
+ import torch
10
+ import torch.distributions.constraints
11
+ import tqdm
12
+ import transformers
13
+ from pydantic import BaseModel
14
+
15
+ from mergekit.architecture import (
16
+ ConfiguredModelArchitecture,
17
+ WeightInfo,
18
+ arch_info_for_config,
19
+ )
20
+ from mergekit.common import ModelReference, set_config_value
21
+ from mergekit.io.tasks import (
22
+ LoaderCache,
23
+ )
24
+ from mergekit.io.tensor_writer import TensorWriter
25
+ from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options
26
+ from mergekit.tokenizer.normalization import (
27
+ NormalizedToken,
28
+ normalized_vocabulary,
29
+ token_prefixes,
30
+ )
31
+ from mergekit.tokensurgeon import (
32
+ SubwordMethod,
33
+ WeightingScheme,
34
+ batch_mp_rope,
35
+ batch_omp,
36
+ common_interp_approximate,
37
+ compute_token_basis,
38
+ landmark_pca_approximate,
39
+ subword_approximate,
40
+ well_trained_tokens,
41
+ )
42
+ from mergekit.tokensurgeon.common_interpolation import DistanceMetric
43
+
44
+ LOG = logging.getLogger(__name__)
45
+
46
+
47
+ class TokenAssignmentStats(BaseModel):
48
+ exact_match: int = 0
49
+ byte_match: int = 0
50
+ prefix_match: int = 0
51
+ to_approximate: int = 0
52
+
53
+ def pretty_print(self):
54
+ chunks = ["Token Breakdown:"]
55
+ if self.exact_match:
56
+ chunks.append(f" Exact matches: {self.exact_match}")
57
+ if self.byte_match:
58
+ chunks.append(f" Byte matches: {self.byte_match}")
59
+ if self.prefix_match:
60
+ chunks.append(f" Prefix matches: {self.prefix_match}")
61
+ if self.to_approximate:
62
+ chunks.append(f" Tokens to approximate: {self.to_approximate}")
63
+ chunks.append(
64
+ f" Total: {self.exact_match + self.byte_match + self.prefix_match + self.to_approximate}"
65
+ )
66
+ return "\n".join(chunks)
67
+
68
+
69
+ class ApproximationMethod(enum.Enum):
70
+ COMMON_INTERPOLATION = "common_interpolation"
71
+ SUBWORD = "subword"
72
+ MEAN = "mean"
73
+ ZERO = "zero"
74
+ RANDN = "randn"
75
+ JOHN_HEWITT = "john_hewitt"
76
+ ORTHOGONAL_MATCHING_PURSUIT = "omp"
77
+ LANDMARK_PCA = "landmark_pca"
78
+ SPARSE_TOKEN_BASIS = "stb"
79
+ MATCHING_PURSUIT_ROPE = "mp_rope"
80
+
81
+
82
+ class TokenSurgeonOptions(BaseModel):
83
+ model: ModelReference
84
+ donor: ModelReference
85
+ out_path: str
86
+ method: ApproximationMethod = ApproximationMethod.COMMON_INTERPOLATION
87
+ weight_scheme: WeightingScheme = WeightingScheme.DISTANCE_PROPORTIONAL
88
+ k: int = 64
89
+ cosine_similarity: bool = False
90
+ subword_method: SubwordMethod = SubwordMethod.MEAN
91
+ batch_size: Optional[int] = None
92
+ new_vocab_noise: Optional[float] = None
93
+ new_vocab_scale: Optional[float] = None
94
+
95
+
96
+ def get_arch_info(
97
+ model: ModelReference, options: MergeOptions
98
+ ) -> ConfiguredModelArchitecture:
99
+ cfg = model.config(trust_remote_code=options.trust_remote_code)
100
+ arch_info = arch_info_for_config(cfg)
101
+ return ConfiguredModelArchitecture(info=arch_info, config=cfg)
102
+
103
+
104
+ def get_embedding_info(
105
+ arch_info: ConfiguredModelArchitecture,
106
+ ) -> Tuple[WeightInfo, WeightInfo]:
107
+ """Get WeightInfo for the input and output embeddings of a model."""
108
+
109
+ if len(arch_info.info.modules) != 1:
110
+ raise RuntimeError("Model has multiple modules - not supported by tokensurgeon")
111
+ name = next(iter(arch_info.info.modules.keys()))
112
+ module_def = arch_info.get_module(name)
113
+
114
+ embed, lm_head = None, None
115
+ for weight_info in module_def.pre_weights():
116
+ if weight_info.is_embed:
117
+ if embed is not None:
118
+ raise RuntimeError("Multiple input embeddings found")
119
+ embed = weight_info
120
+
121
+ for weight_info in module_def.post_weights():
122
+ if weight_info.is_embed:
123
+ if lm_head is not None:
124
+ raise RuntimeError("Multiple output embeddings found")
125
+ lm_head = weight_info
126
+ return embed, lm_head
127
+
128
+
129
+ def maybe_aliases(weight_info: WeightInfo, tied: bool) -> Tuple[str, ...]:
130
+ return tuple(
131
+ list(weight_info.aliases or [])
132
+ + list((weight_info.tied_names or []) if tied else [])
133
+ )
134
+
135
+
136
+ def get_stuff(
137
+ model: ModelReference,
138
+ options: MergeOptions,
139
+ arch_info: Optional[ConfiguredModelArchitecture] = None,
140
+ get_tied: bool = False,
141
+ device: str = "cpu",
142
+ ) -> Tuple[Dict[NormalizedToken, int], Optional[torch.Tensor], Optional[torch.Tensor]]:
143
+ if arch_info is None:
144
+ arch_info = get_arch_info(model, options)
145
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
146
+ model.model.path,
147
+ revision=model.model.revision,
148
+ trust_remote_code=options.trust_remote_code,
149
+ )
150
+ vocab = normalized_vocabulary(tokenizer)
151
+ embed_wi, lm_head_wi = get_embedding_info(arch_info)
152
+ loader = LoaderCache().get(model)
153
+ embed = loader.get_tensor(
154
+ embed_wi.name,
155
+ device=device,
156
+ aliases=maybe_aliases(embed_wi, get_tied),
157
+ raise_on_missing=not embed_wi.optional,
158
+ )
159
+ lm_head = loader.get_tensor(
160
+ lm_head_wi.name,
161
+ device=device,
162
+ aliases=maybe_aliases(lm_head_wi, get_tied),
163
+ raise_on_missing=not lm_head_wi.optional,
164
+ )
165
+ return vocab, embed, lm_head
166
+
167
+
168
+ def match_byte_token(
169
+ token: NormalizedToken, original_vocab: Dict[NormalizedToken, int]
170
+ ) -> Optional[int]:
171
+ if not isinstance(token, str):
172
+ return None
173
+ if len(token) == 1 and ord(token) < 256:
174
+ # check for matching byte tokens
175
+ byte_tok = f"<0x{ord(token):02X}>"
176
+ if byte_tok in original_vocab:
177
+ return original_vocab[byte_tok]
178
+ elif token.startswith("<0x") and token.endswith(">") and len(token) == 6:
179
+ # check for character tokens matching byte tokens
180
+ try:
181
+ byte = int(token[3:-1], 16)
182
+ except ValueError:
183
+ pass
184
+ else:
185
+ if chr(byte) in original_vocab:
186
+ return original_vocab[chr(byte)]
187
+ return None
188
+
189
+
190
+ def match_prefix(
191
+ token: NormalizedToken, original_vocab: Dict[NormalizedToken, int]
192
+ ) -> Optional[int]:
193
+ for prefix in token_prefixes(token):
194
+ if prefix in original_vocab:
195
+ return original_vocab[prefix]
196
+ return None
197
+
198
+
199
+ def get_out_arch_info(
200
+ model: ModelReference,
201
+ donor: ModelReference,
202
+ new_vocab_size: int,
203
+ common_options: MergeOptions,
204
+ ) -> ConfiguredModelArchitecture:
205
+ cfg_donor = donor.config(trust_remote_code=common_options.trust_remote_code)
206
+ cfg_out = model.config(trust_remote_code=common_options.trust_remote_code)
207
+ arch_info_out = arch_info_for_config(cfg_out)
208
+ set_config_value(
209
+ cfg_out, arch_info_out.vocab_size_config_key or "vocab_size", new_vocab_size
210
+ )
211
+ for key in [
212
+ "pad_token_id",
213
+ "eos_token_id",
214
+ "bos_token_id",
215
+ "unk_token_id",
216
+ "mask_token_id",
217
+ "padding_side",
218
+ ]:
219
+ if hasattr(cfg_donor, key):
220
+ set_config_value(cfg_out, key, getattr(cfg_donor, key))
221
+ return ConfiguredModelArchitecture(info=arch_info_out, config=cfg_out)
222
+
223
+
224
+ def john_hewitt_init(orig_embed: torch.Tensor, num_new_tokens: int) -> torch.Tensor:
225
+ orig_embed_f32 = orig_embed.to(torch.float32)
226
+ mean = orig_embed_f32.mean(dim=0)
227
+ centered = orig_embed_f32 - mean
228
+ covariance = centered.T @ centered / orig_embed_f32.shape[0]
229
+ is_pd = torch.distributions.constraints.positive_definite.check(covariance).all()
230
+ if not is_pd:
231
+ LOG.warning(
232
+ "Covariance matrix is not positive definite - falling back to small randn"
233
+ )
234
+ return (
235
+ torch.randn(
236
+ len(num_new_tokens),
237
+ orig_embed.shape[1],
238
+ device=orig_embed.device,
239
+ dtype=orig_embed.dtype,
240
+ )
241
+ * 0.02
242
+ )
243
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(
244
+ loc=mean,
245
+ covariance_matrix=covariance,
246
+ )
247
+ new_embeds = dist.sample((num_new_tokens,))
248
+ return new_embeds.to(orig_embed.dtype)
249
+
250
+
251
+ def compute_new_embeddings(
252
+ orig_embed: torch.Tensor,
253
+ donor_embed: torch.Tensor,
254
+ orig_vocab: Dict[NormalizedToken, int],
255
+ donor_vocab: Dict[NormalizedToken, int],
256
+ target_tokens: List[NormalizedToken],
257
+ is_lm_head: bool,
258
+ token_basis: Optional[Tuple[torch.Tensor, torch.Tensor]],
259
+ orig_tokenizer: transformers.PreTrainedTokenizerBase,
260
+ options: TokenSurgeonOptions,
261
+ shared_data: Optional[Dict] = None,
262
+ compute_device: torch.device = torch.device("cpu"),
263
+ ) -> torch.Tensor:
264
+ assert all(t in donor_vocab for t in target_tokens)
265
+ if options.method == ApproximationMethod.MEAN:
266
+ mean = orig_embed.mean(dim=0).to(compute_device)
267
+ return mean.unsqueeze(0).expand(len(target_tokens), -1)
268
+ elif options.method == ApproximationMethod.ZERO:
269
+ return torch.zeros(
270
+ len(target_tokens),
271
+ orig_embed.shape[1],
272
+ device=compute_device,
273
+ dtype=orig_embed.dtype,
274
+ )
275
+ elif options.method == ApproximationMethod.RANDN:
276
+ return torch.randn(
277
+ len(target_tokens),
278
+ orig_embed.shape[1],
279
+ device=compute_device,
280
+ dtype=orig_embed.dtype,
281
+ )
282
+ elif options.method == ApproximationMethod.JOHN_HEWITT:
283
+ return john_hewitt_init(orig_embed.to(compute_device), len(target_tokens))
284
+ elif options.method in (
285
+ ApproximationMethod.COMMON_INTERPOLATION,
286
+ ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT,
287
+ ApproximationMethod.LANDMARK_PCA,
288
+ ApproximationMethod.MATCHING_PURSUIT_ROPE,
289
+ ):
290
+ if shared_data is not None:
291
+ donor_shared_embeds = shared_data["donor_shared_embeds"].to(compute_device)
292
+ orig_shared_embeds = shared_data["orig_shared_embeds"].to(compute_device)
293
+ else:
294
+ shared_vocab = list(
295
+ sorted(
296
+ set(orig_vocab.keys()) & set(donor_vocab.keys()),
297
+ key=lambda x: donor_vocab[x],
298
+ )
299
+ )
300
+ donor_shared_embeds = donor_embed[
301
+ torch.tensor([donor_vocab[t] for t in shared_vocab])
302
+ ].to(compute_device)
303
+
304
+ orig_shared_embeds = orig_embed[
305
+ torch.tensor([orig_vocab[t] for t in shared_vocab])
306
+ ].to(compute_device)
307
+
308
+ res = None
309
+ in_donor = None
310
+ targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])].to(compute_device)
311
+
312
+ if options.method == ApproximationMethod.LANDMARK_PCA:
313
+ return landmark_pca_approximate(
314
+ targets,
315
+ donor_shared_embeds,
316
+ orig_shared_embeds,
317
+ )
318
+ elif options.method == ApproximationMethod.COMMON_INTERPOLATION:
319
+ indices, coeffs = common_interp_approximate(
320
+ targets,
321
+ donor_shared_embeds,
322
+ k=options.k,
323
+ metric=(
324
+ DistanceMetric.COSINE
325
+ if options.cosine_similarity
326
+ else DistanceMetric.EUCLIDEAN
327
+ ),
328
+ weight_scheme=options.weight_scheme,
329
+ )
330
+ elif options.method == ApproximationMethod.MATCHING_PURSUIT_ROPE:
331
+ model_config = options.model.config(trust_remote_code=False)
332
+ donor_config = options.donor.config(trust_remote_code=False)
333
+ indices, coeffs, res, in_donor = batch_mp_rope(
334
+ targets,
335
+ donor_shared_embeds,
336
+ orig_shared_embeds,
337
+ k=options.k,
338
+ num_heads_a=donor_config.num_attention_heads,
339
+ num_heads_b=model_config.num_attention_heads,
340
+ a_rope_base=donor_config.rope_theta,
341
+ b_rope_base=model_config.rope_theta,
342
+ )
343
+ else:
344
+ indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
345
+
346
+ if res is None:
347
+ res = (
348
+ torch.bmm(
349
+ coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float)
350
+ )
351
+ .squeeze(1)
352
+ .to(orig_embed.dtype)
353
+ )
354
+ return res
355
+ elif options.method == ApproximationMethod.SUBWORD:
356
+ return subword_approximate(
357
+ orig_embed.to(compute_device),
358
+ target_tokens,
359
+ is_lm_head,
360
+ orig_tokenizer,
361
+ options.subword_method,
362
+ )
363
+ elif options.method == ApproximationMethod.SPARSE_TOKEN_BASIS:
364
+ assert token_basis is not None, "Token basis must be provided for STB"
365
+ donor_basis, orig_basis = token_basis
366
+ donor_basis = donor_basis.to(compute_device).to(torch.float32)
367
+ orig_basis = orig_basis.to(compute_device).to(torch.float32)
368
+
369
+ target_donor_embeds = donor_embed[
370
+ torch.tensor([donor_vocab[t] for t in target_tokens])
371
+ ].to(compute_device).to(torch.float32) - donor_embed.mean(dim=0).to(compute_device)
372
+
373
+ coeffs = torch.linalg.lstsq(
374
+ donor_basis.T,
375
+ target_donor_embeds.T,
376
+ ).solution.T
377
+
378
+ if LOG.isEnabledFor(logging.DEBUG):
379
+ donor_rt = coeffs @ donor_basis
380
+ err = (donor_rt - target_donor_embeds).norm(dim=1)
381
+ err_rel = err / target_donor_embeds.norm(dim=1).clamp_min(1e-6)
382
+ sim = torch.nn.functional.cosine_similarity(
383
+ donor_rt, target_donor_embeds, dim=1
384
+ )
385
+ LOG.debug(f"Reconstruction error: {err.mean().item():.4f}")
386
+ LOG.debug(f"Relative reconstruction error: {err_rel.mean().item():.4f}")
387
+ LOG.debug(f"Cosine similarity: {sim.mean().item():.4f}")
388
+
389
+ return coeffs @ orig_basis + orig_embed.mean(dim=0).to(compute_device)
390
+ else:
391
+ raise ValueError(f"Unknown approximation method: {options.method}")
392
+
393
+
394
+ def build_embedding_matrix(
395
+ weight_info: WeightInfo,
396
+ orig_embed: torch.Tensor,
397
+ donor_embed: torch.Tensor,
398
+ orig_vocab: Dict[NormalizedToken, int],
399
+ donor_vocab: Dict[NormalizedToken, int],
400
+ junk_tokens: List[int],
401
+ allow_prefix: bool,
402
+ allow_byte: bool,
403
+ is_lm_head: bool,
404
+ options: TokenSurgeonOptions,
405
+ compute_device: torch.device,
406
+ ) -> torch.Tensor:
407
+ LOG.info(f"Building new tensor for {weight_info.name}")
408
+ stats = TokenAssignmentStats()
409
+ out_vocab_size = max(len(donor_vocab), max(donor_vocab.values()) + 1)
410
+
411
+ if options.method == ApproximationMethod.SPARSE_TOKEN_BASIS:
412
+ token_basis = compute_token_basis(
413
+ orig_embed,
414
+ donor_embed,
415
+ orig_vocab,
416
+ donor_vocab,
417
+ junk_tokens,
418
+ options,
419
+ )
420
+ else:
421
+ token_basis = None
422
+
423
+ res = torch.zeros(
424
+ out_vocab_size,
425
+ orig_embed.shape[1],
426
+ device=orig_embed.device,
427
+ dtype=orig_embed.dtype,
428
+ )
429
+ new_tokens = []
430
+ for token, donor_idx in donor_vocab.items():
431
+ if token in orig_vocab:
432
+ orig_idx = orig_vocab[token]
433
+ res[donor_idx] = orig_embed[orig_idx]
434
+ stats.exact_match += 1
435
+ elif (
436
+ allow_byte and (orig_idx := match_byte_token(token, orig_vocab)) is not None
437
+ ):
438
+ res[donor_idx] = orig_embed[orig_idx]
439
+ stats.byte_match += 1
440
+ elif allow_prefix and (orig_idx := match_prefix(token, orig_vocab)) is not None:
441
+ res[donor_idx] = orig_embed[orig_idx]
442
+ stats.prefix_match += 1
443
+ else:
444
+ new_tokens.append(token)
445
+ stats.to_approximate += 1
446
+
447
+ donor_tokenizer = transformers.AutoTokenizer.from_pretrained(
448
+ options.donor.model.path,
449
+ revision=options.donor.model.revision,
450
+ trust_remote_code=True,
451
+ )
452
+ orig_tokenizer = transformers.AutoTokenizer.from_pretrained(
453
+ options.model.model.path,
454
+ revision=options.model.model.revision,
455
+ trust_remote_code=True,
456
+ )
457
+
458
+ LOG.info(stats.pretty_print())
459
+ if new_tokens:
460
+ LOG.info(f"Approximating {len(new_tokens)} tokens")
461
+
462
+ # Precompute shared embeds to avoid doing it in every batch
463
+ shared_vocab = list(
464
+ sorted(
465
+ set(orig_vocab.keys()) & set(donor_vocab.keys()),
466
+ key=lambda x: donor_vocab[x],
467
+ )
468
+ )
469
+ donor_shared_embeds = donor_embed[
470
+ torch.tensor([donor_vocab[t] for t in shared_vocab])
471
+ ]
472
+ orig_shared_embeds = orig_embed[
473
+ torch.tensor([orig_vocab[t] for t in shared_vocab])
474
+ ]
475
+ shared_data = {
476
+ "donor_shared_embeds": donor_shared_embeds,
477
+ "orig_shared_embeds": orig_shared_embeds,
478
+ }
479
+
480
+ batch_size = options.batch_size
481
+ if batch_size is None or batch_size <= 0:
482
+ batch_size = 512
483
+
484
+ # Adaptive batching logic
485
+ i = 0
486
+ total_tokens = len(new_tokens)
487
+ oom_count = 0
488
+
489
+ pbar = tqdm.tqdm(total=total_tokens, desc="Approximating tokens")
490
+
491
+ while i < total_tokens:
492
+ end = min(i + batch_size, total_tokens)
493
+ current_batch = new_tokens[i:end]
494
+
495
+ try:
496
+ new_embeds = compute_new_embeddings(
497
+ orig_embed,
498
+ donor_embed,
499
+ orig_vocab,
500
+ donor_vocab,
501
+ target_tokens=current_batch,
502
+ is_lm_head=is_lm_head,
503
+ token_basis=token_basis,
504
+ orig_tokenizer=orig_tokenizer,
505
+ options=options,
506
+ shared_data=shared_data,
507
+ compute_device=compute_device,
508
+ )
509
+
510
+ if options.new_vocab_noise:
511
+ new_embeds += torch.randn_like(new_embeds) * options.new_vocab_noise
512
+ if options.new_vocab_scale:
513
+ new_embeds *= options.new_vocab_scale
514
+
515
+ for ne_idx, token in enumerate(current_batch):
516
+ res[donor_vocab[token]] = new_embeds[ne_idx].to(res.device)
517
+
518
+ # Success, move to next batch
519
+ pbar.update(end - i)
520
+ i = end
521
+ oom_count = 0
522
+
523
+ # Optional cleanup
524
+ if compute_device.type == "cuda":
525
+ torch.cuda.empty_cache()
526
+
527
+ except torch.OutOfMemoryError:
528
+ oom_count += 1
529
+ if compute_device.type == "cuda":
530
+ torch.cuda.empty_cache()
531
+ import gc
532
+ gc.collect()
533
+
534
+ old_batch = batch_size
535
+ batch_size = max(1, int(batch_size * 0.75))
536
+
537
+ if batch_size == old_batch and batch_size == 1:
538
+ LOG.error("OOM even with batch size 1. Cannot continue.")
539
+ raise
540
+
541
+ LOG.warning(f"OOM error. Reducing batch size from {old_batch} to {batch_size} (attempt {oom_count})")
542
+
543
+ if oom_count > 10:
544
+ LOG.error("Too many OOM errors, giving up.")
545
+ raise
546
+
547
+ pbar.close()
548
+
549
+ if junk_tokens:
550
+ LOG.info(f"Zero-initializing {len(junk_tokens)} junk tokens")
551
+ for token_id in junk_tokens:
552
+ res[token_id] = torch.zeros(
553
+ orig_embed.shape[1],
554
+ device=orig_embed.device,
555
+ dtype=orig_embed.dtype,
556
+ )
557
+ return res
558
+
559
+
560
+ class AllowMatch(enum.Enum):
561
+ LM_HEAD_ONLY = "lm_head"
562
+ EMBED_ONLY = "embed"
563
+ YES = "yes"
564
+ NO = "no"
565
+
566
+
567
+ @click.command("mergekit-tokensurgeon", cls=PrettyPrintHelp)
568
+ @click.argument("model", type=str)
569
+ @click.argument("donor", type=str)
570
+ @click.argument("out_path", type=str)
571
+ @click.option(
572
+ "--k",
573
+ "-k",
574
+ type=int,
575
+ default=64,
576
+ help="Number of nearest neighbours to use for embedding interpolation",
577
+ show_default=True,
578
+ )
579
+ @click.option(
580
+ "--cosine-similarity/--no-cosine-similarity",
581
+ "-c/-nc",
582
+ is_flag=True,
583
+ default=False,
584
+ help="Use cosine similarity for nearest neighbour search",
585
+ show_default=True,
586
+ )
587
+ @click.option(
588
+ "--approximation-method",
589
+ "-a",
590
+ type=click.Choice([m.value for m in ApproximationMethod]),
591
+ default=ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT.value,
592
+ help="Method for approximating missing tokens",
593
+ show_default=True,
594
+ )
595
+ @click.option(
596
+ "--weight-scheme",
597
+ "-w",
598
+ type=click.Choice([w.value for w in WeightingScheme]),
599
+ default=WeightingScheme.DISTANCE_PROPORTIONAL.value,
600
+ help="Weighting scheme for common-vocabulary interpolation",
601
+ show_default=True,
602
+ )
603
+ @click.option(
604
+ "--subword-method",
605
+ "-s",
606
+ type=click.Choice([m.value for m in SubwordMethod]),
607
+ default=SubwordMethod.MEAN.value,
608
+ help="Method for approximating embeddings with subword tokens",
609
+ show_default=True,
610
+ )
611
+ @click.option(
612
+ "--batch-size",
613
+ type=int,
614
+ default=512,
615
+ help="Number of tokens to process in each batch. -1 for no batching.",
616
+ show_default=True,
617
+ )
618
+ @click.option(
619
+ "--prefix-match",
620
+ "-pm",
621
+ type=click.Choice([m.value for m in AllowMatch]),
622
+ default=AllowMatch.NO.value,
623
+ help="Allow prefix match for tokens",
624
+ show_default=True,
625
+ )
626
+ @click.option(
627
+ "--byte-match",
628
+ "-bm",
629
+ type=click.Choice([m.value for m in AllowMatch]),
630
+ default=AllowMatch.NO.value,
631
+ help="Allow byte match for tokens",
632
+ show_default=True,
633
+ )
634
+ @click.option(
635
+ "--magikarp/--no-magikarp",
636
+ is_flag=True,
637
+ default=False,
638
+ help="Filter out poorly trained tokens",
639
+ show_default=True,
640
+ )
641
+ @click.option(
642
+ "--new-vocab-noise",
643
+ "-nvn",
644
+ type=float,
645
+ default=None,
646
+ help="Add gaussian noise to new vocab embeddings",
647
+ show_default=True,
648
+ )
649
+ @click.option(
650
+ "--new-vocab-scale",
651
+ "-nvs",
652
+ type=float,
653
+ default=None,
654
+ help="Scale computed new vocab embeddings by this factor",
655
+ show_default=True,
656
+ )
657
+ @add_merge_options
658
+ def main(
659
+ model: str,
660
+ donor: str,
661
+ out_path: str,
662
+ k: int,
663
+ cosine_similarity: bool,
664
+ approximation_method: str,
665
+ weight_scheme: str,
666
+ subword_method: str,
667
+ batch_size: Optional[int],
668
+ prefix_match: str,
669
+ byte_match: str,
670
+ magikarp: bool,
671
+ new_vocab_noise: Optional[float],
672
+ new_vocab_scale: Optional[float],
673
+ merge_options: MergeOptions,
674
+ ):
675
+ merge_options.apply_global_options()
676
+ logging.warning("This script is experimental and may produce unexpected results.")
677
+ options = TokenSurgeonOptions(
678
+ model=ModelReference.model_validate(model),
679
+ donor=ModelReference.model_validate(donor),
680
+ out_path=out_path,
681
+ k=k,
682
+ cosine_similarity=cosine_similarity,
683
+ method=ApproximationMethod(approximation_method),
684
+ weight_scheme=WeightingScheme(weight_scheme),
685
+ subword_method=SubwordMethod(subword_method),
686
+ batch_size=batch_size,
687
+ new_vocab_noise=new_vocab_noise,
688
+ new_vocab_scale=new_vocab_scale,
689
+ )
690
+ prefix_match = AllowMatch(prefix_match)
691
+ byte_match = AllowMatch(byte_match)
692
+
693
+ cache = LoaderCache()
694
+ cache.setup(options=merge_options)
695
+
696
+ compute_device = torch.device(merge_options.device if merge_options.device else "cuda" if torch.cuda.is_available() else "cpu")
697
+ storage_device = "cpu"
698
+
699
+ arch_info = get_arch_info(options.model, merge_options)
700
+ embed_wi, lm_head_wi = get_embedding_info(arch_info)
701
+ orig_vocab, orig_embed, orig_lm_head = get_stuff(
702
+ options.model, merge_options, arch_info=arch_info, device=storage_device
703
+ )
704
+ donor_vocab, donor_embed, donor_lm_head = get_stuff(
705
+ options.donor, merge_options, arch_info=None, get_tied=True, device=storage_device
706
+ )
707
+
708
+ if magikarp:
709
+ LOG.debug("Finding well-trained tokens in original model")
710
+ well_trained_orig_tokens = set(
711
+ well_trained_tokens(
712
+ orig_vocab,
713
+ orig_embed,
714
+ orig_lm_head,
715
+ )
716
+ )
717
+ LOG.debug("Finding well-trained tokens in donor model")
718
+ well_trained_donor_tokens = set(
719
+ well_trained_tokens(
720
+ donor_vocab,
721
+ donor_embed,
722
+ donor_lm_head,
723
+ )
724
+ )
725
+ common_well_trained_tokens = (
726
+ well_trained_orig_tokens & well_trained_donor_tokens
727
+ )
728
+ LOG.info(f"Found {len(common_well_trained_tokens)} common well-trained tokens")
729
+ orig_vocab = {
730
+ tok: idx
731
+ for tok, idx in orig_vocab.items()
732
+ if tok in common_well_trained_tokens
733
+ }
734
+ junk_tokens = [
735
+ idx
736
+ for tok, idx in donor_vocab.items()
737
+ if (tok not in well_trained_donor_tokens)
738
+ and (tok not in well_trained_orig_tokens)
739
+ ]
740
+ else:
741
+ junk_tokens = []
742
+
743
+ if orig_embed is not None:
744
+ if donor_embed is None:
745
+ raise RuntimeError(
746
+ f"Missing tensor {embed_wi.name} in model {options.donor}"
747
+ )
748
+ new_embed = build_embedding_matrix(
749
+ embed_wi,
750
+ orig_embed,
751
+ donor_embed,
752
+ orig_vocab=orig_vocab,
753
+ donor_vocab=donor_vocab,
754
+ junk_tokens=junk_tokens,
755
+ allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY),
756
+ allow_byte=byte_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY),
757
+ is_lm_head=False,
758
+ options=options,
759
+ compute_device=compute_device,
760
+ )
761
+ else:
762
+ if not embed_wi.optional:
763
+ raise RuntimeError(
764
+ f"Missing tensor {embed_wi.name} in model {options.model}"
765
+ )
766
+ new_embed = None
767
+
768
+ if orig_lm_head is not None:
769
+ if donor_lm_head is None:
770
+ raise RuntimeError(
771
+ f"Missing tensor {lm_head_wi.name} in model {options.donor}"
772
+ )
773
+ new_lm_head = build_embedding_matrix(
774
+ lm_head_wi,
775
+ orig_lm_head,
776
+ donor_lm_head,
777
+ orig_vocab=orig_vocab,
778
+ donor_vocab=donor_vocab,
779
+ junk_tokens=junk_tokens,
780
+ allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY),
781
+ allow_byte=byte_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY),
782
+ is_lm_head=True,
783
+ options=options,
784
+ compute_device=compute_device,
785
+ )
786
+ else:
787
+ if not lm_head_wi.optional:
788
+ raise RuntimeError(
789
+ f"Missing tensor {lm_head_wi.name} in model {options.model}"
790
+ )
791
+ new_lm_head = None
792
+
793
+ new_vocab_size = None
794
+ if new_embed is not None:
795
+ new_vocab_size = new_embed.shape[0]
796
+ elif new_lm_head is not None:
797
+ new_vocab_size = new_lm_head.shape[0]
798
+ LOG.info(f"Saving new model to {out_path}")
799
+ out_arch_info = get_out_arch_info(
800
+ options.model, options.donor, new_vocab_size, merge_options
801
+ )
802
+ writer = TensorWriter(
803
+ out_path,
804
+ max_shard_size=merge_options.out_shard_size,
805
+ safe_serialization=merge_options.safe_serialization,
806
+ use_async=merge_options.async_write,
807
+ max_write_threads=merge_options.write_threads,
808
+ )
809
+ for weight_info in tqdm.tqdm(out_arch_info.all_weights(), desc="Saving weights"):
810
+ if weight_info.name == embed_wi.name:
811
+ tensor = new_embed
812
+ elif lm_head_wi is not None and weight_info.name == lm_head_wi.name:
813
+ tensor = new_lm_head
814
+ else:
815
+ tensor = cache.get(options.model).get_tensor(
816
+ weight_info.name, aliases=weight_info.aliases, raise_on_missing=False
817
+ )
818
+ if tensor is None:
819
+ if weight_info.optional:
820
+ continue
821
+ raise RuntimeError(
822
+ f"Missing tensor {weight_info.name} in model {options.model}"
823
+ )
824
+ writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors)
825
+
826
+ # Force close lazy loader file handles so Windows allows deletion/renaming
827
+ cache.flush_all()
828
+ import gc
829
+ gc.collect()
830
+
831
+ # Delete original safetensors files to prevent FileExistsError during rename
832
+ import os
833
+ import re
834
+ temp_pattern = re.compile(r"^.*-\d+\.safetensors$")
835
+ for fname in os.listdir(out_path):
836
+ if fname.endswith(".safetensors") and not temp_pattern.match(fname):
837
+ try:
838
+ os.remove(os.path.join(out_path, fname))
839
+ except Exception as e:
840
+ LOG.warning(f"Could not remove old file {fname}: {e}")
841
+ elif fname == "model.safetensors.index.json":
842
+ try:
843
+ os.remove(os.path.join(out_path, fname))
844
+ except Exception:
845
+ pass
846
+
847
+ writer.finalize()
848
+ out_arch_info.config.save_pretrained(out_path)
849
+
850
+ tokenizer_out = transformers.AutoTokenizer.from_pretrained(
851
+ options.donor.model.path,
852
+ revision=options.donor.model.revision,
853
+ trust_remote_code=merge_options.trust_remote_code,
854
+ )
855
+ tokenizer_out.save_pretrained(out_path)
856
+
857
+ # Also copy generation_config.json if it exists in the donor
858
+ donor_gen_config = os.path.join(options.donor.model.path, "generation_config.json")
859
+ if os.path.exists(donor_gen_config):
860
+ import shutil
861
+ shutil.copy(donor_gen_config, os.path.join(out_path, "generation_config.json"))
862
+
863
+ LOG.info("Done!")
864
+
865
+
866
+ if __name__ == "__main__":
867
+ main()