Skip to content

Commit de12f8a

Browse files
committed
convert : begin handling pre-quantized models
1 parent 2be60cb commit de12f8a

File tree

1 file changed

+174
-65
lines changed

1 file changed

+174
-65
lines changed

convert_hf_to_gguf.py

Lines changed: 174 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,8 @@ class ModelBase:
6464
endianess: gguf.GGUFEndian
6565
use_temp_file: bool
6666
lazy: bool
67-
part_names: list[str]
68-
is_safetensors: bool
6967
hparams: dict[str, Any]
70-
tensor_names: set[str] | None
68+
model_tensors: dict[str, Callable[[], Tensor]]
7169
gguf_writer: gguf.GGUFWriter
7270
model_name: str | None
7371
metadata_override: Path | None
@@ -99,24 +97,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
9997
self.use_temp_file = use_temp_file
10098
self.lazy = not eager or (remote_hf_model_id is not None)
10199
self.remote_hf_model_id = remote_hf_model_id
102-
if remote_hf_model_id is not None:
103-
self.is_safetensors = True
104-
105-
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
106-
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
107-
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
108-
self.tensor_names = set(name for name in remote_tensors.keys())
109-
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
110-
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
111-
112-
self.get_tensors = get_remote_tensors
113-
else:
114-
self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
115-
self.is_safetensors = len(self.part_names) > 0
116-
if not self.is_safetensors:
117-
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
118100
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
119-
self.tensor_names = None
101+
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
120102
self.metadata_override = metadata_override
121103
self.model_name = model_name
122104
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
@@ -132,6 +114,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
132114
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
133115
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
134116

117+
self.dequant_model()
118+
135119
# Configure GGUF Writer
136120
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
137121
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@@ -150,63 +134,209 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
150134
return None
151135
raise KeyError(f"could not find any of: {keys}")
152136

153-
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
154-
tensor_names_from_parts: set[str] = set()
137+
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
138+
tensors: dict[str, Callable[[], Tensor]] = {}
139+
140+
if remote_hf_model_id is not None:
141+
is_safetensors = True
155142

156-
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
143+
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
144+
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
145+
for name, remote_tensor in remote_tensors.items():
146+
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
147+
148+
return tensors
149+
150+
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
151+
is_safetensors: bool = len(part_names) > 0
152+
if not is_safetensors:
153+
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
154+
155+
tensor_names_from_index: set[str] = set()
156+
157+
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
157158
index_name += ".index.json"
158159
index_file = self.dir_model / index_name
159160

160161
if index_file.is_file():
161-
self.tensor_names = set()
162162
logger.info(f"gguf: loading model weight map from '{index_name}'")
163163
with open(index_file, "r", encoding="utf-8") as f:
164164
index: dict[str, Any] = json.load(f)
165165
weight_map = index.get("weight_map")
166166
if weight_map is None or not isinstance(weight_map, dict):
167167
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
168-
self.tensor_names.update(weight_map.keys())
168+
tensor_names_from_index.update(weight_map.keys())
169169
else:
170-
self.tensor_names = tensor_names_from_parts
171170
weight_map = {}
172171

173-
for part_name in self.part_names:
174-
logger.info(f"gguf: loading model part '{part_name}'")
172+
for part_name in part_names:
173+
logger.info(f"gguf: indexing model part '{part_name}'")
175174
ctx: ContextManager[Any]
176-
if self.is_safetensors:
175+
if is_safetensors:
177176
from safetensors import safe_open
178177
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
179178
else:
180179
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
181180

182181
with ctx as model_part:
183-
tensor_names_from_parts.update(model_part.keys())
182+
assert model_part is not None
184183

185184
for name in model_part.keys():
186-
if self.is_safetensors:
185+
if is_safetensors:
187186
if self.lazy:
188187
data = model_part.get_slice(name)
189-
data = LazyTorchTensor.from_safetensors_slice(data)
188+
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
190189
else:
191190
data = model_part.get_tensor(name)
191+
data_gen = lambda data=data: data # noqa: E731
192192
else:
193193
data = model_part[name]
194194
if self.lazy:
195-
data = LazyTorchTensor.from_eager(data)
196-
yield name, data
195+
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
196+
else:
197+
data_gen = lambda data=data: data # noqa: E731
198+
tensors[name] = data_gen
197199

198200
# verify tensor name presence and identify potentially missing files
199-
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
200-
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
201-
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
202-
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
203-
if len(extra) == 0 and len(missing_files) > 0:
204-
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
205-
f"Missing tensors: {missing}")
201+
if len(tensor_names_from_index) > 0:
202+
tensor_names_from_parts = set(tensors.keys())
203+
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
204+
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
205+
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
206+
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
207+
if len(extra) == 0 and len(missing_files) > 0:
208+
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
209+
f"Missing tensors: {missing}")
210+
else:
211+
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
212+
f"Missing tensors: {missing}\n"
213+
f"Extra tensors: {extra}")
214+
215+
return tensors
216+
217+
def dequant_model(self):
218+
tensors_to_remove: list[str] = []
219+
new_tensors: dict[str, Callable[[], Tensor]] = {}
220+
221+
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
222+
quant_method = quant_config.get("quant_method")
223+
224+
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
225+
weight = weight.view(torch.uint8)
226+
orig_shape = weight.shape
227+
228+
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
229+
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
230+
data = data & 3
231+
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
232+
233+
# The scale is inverted
234+
return data / scale.float()
235+
236+
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
237+
scale = scale.float()
238+
239+
if (weight_block_size := quant_config.get("weight_block_size")):
240+
# TODO: make sure it's a list of integers
241+
for i, size in enumerate(weight_block_size):
242+
scale = scale.repeat_interleave(size, i)
243+
244+
return weight.float() * scale
245+
246+
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
247+
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
248+
bits = quant_config["bits"]
249+
assert bits in (2, 3, 4, 8)
250+
assert qweight.dtype == qzeros.dtype
251+
maxq = (2 ** bits) - 1
252+
weight = None
253+
zeros = None
254+
pack_dtype_bits = qweight.dtype.itemsize * 8
255+
256+
if bits in [2, 4, 8]:
257+
pack_factor = pack_dtype_bits // bits
258+
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
259+
if self.lazy:
260+
wf = LazyTorchTensor.from_eager(wf)
261+
262+
zeros = torch.bitwise_right_shift(
263+
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
264+
wf.unsqueeze(0)
265+
).to(torch.int16 if bits == 8 else torch.int8)
266+
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
267+
268+
weight = torch.bitwise_and(
269+
torch.bitwise_right_shift(
270+
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
271+
wf.unsqueeze(-1)
272+
).to(torch.int16 if bits == 8 else torch.int8),
273+
maxq
274+
)
275+
elif bits == 3:
276+
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
277+
278+
assert weight is not None
279+
assert zeros is not None
280+
281+
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
282+
283+
# gptq_v2 doesn't need to offset zeros
284+
if quant_config.get("checkpoint_format", "gptq") == "gptq":
285+
zeros += 1
286+
287+
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
288+
289+
if quant_method == "bitnet":
290+
for name in self.model_tensors.keys():
291+
if name.endswith(".weight_scale"):
292+
weight_name = name.removesuffix("_scale")
293+
w = self.model_tensors[weight_name]
294+
s = self.model_tensors[name]
295+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
296+
tensors_to_remove.append(name)
297+
elif quant_method == "fp8":
298+
for name in self.model_tensors.keys():
299+
if name.endswith(".weight_scale_inv"):
300+
weight_name = name.removesuffix("_scale_inv")
301+
w = self.model_tensors[weight_name]
302+
s = self.model_tensors[name]
303+
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
304+
tensors_to_remove.append(name)
305+
elif quant_method == "gptq":
306+
for name in self.model_tensors.keys():
307+
if name.endswith(".qweight"):
308+
base_name = name.removesuffix(".qweight")
309+
g_idx = self.model_tensors[base_name + ".g_idx"]
310+
qweight = self.model_tensors[base_name + ".qweight"]
311+
qzeros = self.model_tensors[base_name + ".qzeros"]
312+
scales = self.model_tensors[base_name + ".scales"]
313+
new_tensors[base_name + ".weight"] = (
314+
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
315+
g(), w(), z(), s()
316+
)
317+
)
318+
tensors_to_remove += [
319+
base_name + n
320+
for n in (
321+
".g_idx",
322+
".qzeros",
323+
".qweight",
324+
".scales",
325+
)
326+
]
206327
else:
207-
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
208-
f"Missing tensors: {missing}\n"
209-
f"Extra tensors: {extra}")
328+
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
329+
330+
for name in tensors_to_remove:
331+
if name in self.model_tensors:
332+
del self.model_tensors[name]
333+
334+
for name, value in new_tensors.items():
335+
self.model_tensors[name] = value
336+
337+
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
338+
for name, gen in self.model_tensors.items():
339+
yield name, gen()
210340

211341
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
212342
if key not in gguf.MODEL_TENSORS[self.model_arch]:
@@ -3860,27 +3990,6 @@ def set_gguf_parameters(self):
38603990
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
38613991
self.gguf_writer.add_rope_scaling_factor(1.0)
38623992

3863-
_has_tok_embd = False
3864-
3865-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3866-
del bid # unused
3867-
3868-
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
3869-
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
3870-
3871-
new_name = self.map_tensor_name(name)
3872-
3873-
# assuming token_embd.weight is seen before output.weight
3874-
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
3875-
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
3876-
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
3877-
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
3878-
self.tensor_names.remove("transformer.wte.weight")
3879-
elif new_name == tok_embd_name:
3880-
self._has_tok_embd = True
3881-
3882-
return [(new_name, data_torch)]
3883-
38843993

38853994
@ModelBase.register("InternLM2ForCausalLM")
38863995
class InternLM2Model(TextModel):

0 commit comments

Comments
 (0)