@@ -64,10 +64,8 @@ class ModelBase:
6464 endianess : gguf .GGUFEndian
6565 use_temp_file : bool
6666 lazy : bool
67- part_names : list [str ]
68- is_safetensors : bool
6967 hparams : dict [str , Any ]
70- tensor_names : set [str ] | None
68+ model_tensors : dict [str , Callable [[], Tensor ]]
7169 gguf_writer : gguf .GGUFWriter
7270 model_name : str | None
7371 metadata_override : Path | None
@@ -99,24 +97,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
9997 self .use_temp_file = use_temp_file
10098 self .lazy = not eager or (remote_hf_model_id is not None )
10199 self .remote_hf_model_id = remote_hf_model_id
102- if remote_hf_model_id is not None :
103- self .is_safetensors = True
104-
105- def get_remote_tensors () -> Iterator [tuple [str , Tensor ]]:
106- logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id } " )
107- remote_tensors = gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
108- self .tensor_names = set (name for name in remote_tensors .keys ())
109- for name , remote_tensor in gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id ).items ():
110- yield (name , LazyTorchTensor .from_remote_tensor (remote_tensor ))
111-
112- self .get_tensors = get_remote_tensors
113- else :
114- self .part_names = ModelBase .get_model_part_names (self .dir_model , "model" , ".safetensors" )
115- self .is_safetensors = len (self .part_names ) > 0
116- if not self .is_safetensors :
117- self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
118100 self .hparams = ModelBase .load_hparams (self .dir_model ) if hparams is None else hparams
119- self .tensor_names = None
101+ self .model_tensors = self . index_tensors ( remote_hf_model_id = remote_hf_model_id )
120102 self .metadata_override = metadata_override
121103 self .model_name = model_name
122104 self .dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
@@ -132,6 +114,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
132114 logger .info (f"choosing --outtype bf16 from first tensor type ({ first_tensor .dtype } )" )
133115 self .ftype = gguf .LlamaFileType .MOSTLY_BF16
134116
117+ self .dequant_model ()
118+
135119 # Configure GGUF Writer
136120 self .gguf_writer = gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file ,
137121 split_max_tensors = split_max_tensors , split_max_size = split_max_size , dry_run = dry_run , small_first_shard = small_first_shard )
@@ -150,63 +134,209 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
150134 return None
151135 raise KeyError (f"could not find any of: { keys } " )
152136
153- def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
154- tensor_names_from_parts : set [str ] = set ()
137+ def index_tensors (self , remote_hf_model_id : str | None = None ) -> dict [str , Callable [[], Tensor ]]:
138+ tensors : dict [str , Callable [[], Tensor ]] = {}
139+
140+ if remote_hf_model_id is not None :
141+ is_safetensors = True
155142
156- index_name = "model.safetensors" if self .is_safetensors else "pytorch_model.bin"
143+ logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id } " )
144+ remote_tensors = gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
145+ for name , remote_tensor in remote_tensors .items ():
146+ tensors [name ] = lambda r = remote_tensor : LazyTorchTensor .from_remote_tensor (r )
147+
148+ return tensors
149+
150+ part_names : list [str ] = ModelBase .get_model_part_names (self .dir_model , "model" , ".safetensors" )
151+ is_safetensors : bool = len (part_names ) > 0
152+ if not is_safetensors :
153+ part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
154+
155+ tensor_names_from_index : set [str ] = set ()
156+
157+ index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
157158 index_name += ".index.json"
158159 index_file = self .dir_model / index_name
159160
160161 if index_file .is_file ():
161- self .tensor_names = set ()
162162 logger .info (f"gguf: loading model weight map from '{ index_name } '" )
163163 with open (index_file , "r" , encoding = "utf-8" ) as f :
164164 index : dict [str , Any ] = json .load (f )
165165 weight_map = index .get ("weight_map" )
166166 if weight_map is None or not isinstance (weight_map , dict ):
167167 raise ValueError (f"Can't load 'weight_map' from { index_name !r} " )
168- self . tensor_names .update (weight_map .keys ())
168+ tensor_names_from_index .update (weight_map .keys ())
169169 else :
170- self .tensor_names = tensor_names_from_parts
171170 weight_map = {}
172171
173- for part_name in self . part_names :
174- logger .info (f"gguf: loading model part '{ part_name } '" )
172+ for part_name in part_names :
173+ logger .info (f"gguf: indexing model part '{ part_name } '" )
175174 ctx : ContextManager [Any ]
176- if self . is_safetensors :
175+ if is_safetensors :
177176 from safetensors import safe_open
178177 ctx = cast (ContextManager [Any ], safe_open (self .dir_model / part_name , framework = "pt" , device = "cpu" ))
179178 else :
180179 ctx = contextlib .nullcontext (torch .load (str (self .dir_model / part_name ), map_location = "cpu" , mmap = True , weights_only = True ))
181180
182181 with ctx as model_part :
183- tensor_names_from_parts . update ( model_part . keys ())
182+ assert model_part is not None
184183
185184 for name in model_part .keys ():
186- if self . is_safetensors :
185+ if is_safetensors :
187186 if self .lazy :
188187 data = model_part .get_slice (name )
189- data = LazyTorchTensor .from_safetensors_slice (data )
188+ data_gen = lambda data = data : LazyTorchTensor .from_safetensors_slice (data ) # noqa: E731
190189 else :
191190 data = model_part .get_tensor (name )
191+ data_gen = lambda data = data : data # noqa: E731
192192 else :
193193 data = model_part [name ]
194194 if self .lazy :
195- data = LazyTorchTensor .from_eager (data )
196- yield name , data
195+ data_gen = lambda data = data : LazyTorchTensor .from_eager (data ) # noqa: E731
196+ else :
197+ data_gen = lambda data = data : data # noqa: E731
198+ tensors [name ] = data_gen
197199
198200 # verify tensor name presence and identify potentially missing files
199- if len (tensor_names_from_parts .symmetric_difference (self .tensor_names )) > 0 :
200- missing = sorted (self .tensor_names .difference (tensor_names_from_parts ))
201- extra = sorted (tensor_names_from_parts .difference (self .tensor_names ))
202- missing_files = sorted (set (weight_map [n ] for n in missing if n in weight_map ))
203- if len (extra ) == 0 and len (missing_files ) > 0 :
204- raise ValueError (f"Missing or incomplete model files: { missing_files } \n "
205- f"Missing tensors: { missing } " )
201+ if len (tensor_names_from_index ) > 0 :
202+ tensor_names_from_parts = set (tensors .keys ())
203+ if len (tensor_names_from_parts .symmetric_difference (tensor_names_from_index )) > 0 :
204+ missing = sorted (tensor_names_from_index .difference (tensor_names_from_parts ))
205+ extra = sorted (tensor_names_from_parts .difference (tensor_names_from_index ))
206+ missing_files = sorted (set (weight_map [n ] for n in missing if n in weight_map ))
207+ if len (extra ) == 0 and len (missing_files ) > 0 :
208+ raise ValueError (f"Missing or incomplete model files: { missing_files } \n "
209+ f"Missing tensors: { missing } " )
210+ else :
211+ raise ValueError ("Mismatch between weight map and model parts for tensor names:\n "
212+ f"Missing tensors: { missing } \n "
213+ f"Extra tensors: { extra } " )
214+
215+ return tensors
216+
217+ def dequant_model (self ):
218+ tensors_to_remove : list [str ] = []
219+ new_tensors : dict [str , Callable [[], Tensor ]] = {}
220+
221+ if (quant_config := self .hparams .get ("quantization_config" )) and isinstance (quant_config , dict ):
222+ quant_method = quant_config .get ("quant_method" )
223+
224+ def dequant_bitnet (weight : Tensor , scale : Tensor ) -> Tensor :
225+ weight = weight .view (torch .uint8 )
226+ orig_shape = weight .shape
227+
228+ shift = torch .tensor ([0 , 2 , 4 , 6 ], dtype = torch .uint8 ).reshape ((4 , * (1 for _ in range (len (orig_shape )))))
229+ data = weight .unsqueeze (0 ).expand ((4 , * orig_shape )) >> shift
230+ data = data & 3
231+ data = (data .float () - 1 ).reshape ((orig_shape [0 ] * 4 , * orig_shape [1 :]))
232+
233+ # The scale is inverted
234+ return data / scale .float ()
235+
236+ def dequant_simple (weight : Tensor , scale : Tensor ) -> Tensor :
237+ scale = scale .float ()
238+
239+ if (weight_block_size := quant_config .get ("weight_block_size" )):
240+ # TODO: make sure it's a list of integers
241+ for i , size in enumerate (weight_block_size ):
242+ scale = scale .repeat_interleave (size , i )
243+
244+ return weight .float () * scale
245+
246+ # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
247+ def dequant_gptq (g_idx : Tensor , qweight : Tensor , qzeros : Tensor , scales : Tensor ) -> Tensor :
248+ bits = quant_config ["bits" ]
249+ assert bits in (2 , 3 , 4 , 8 )
250+ assert qweight .dtype == qzeros .dtype
251+ maxq = (2 ** bits ) - 1
252+ weight = None
253+ zeros = None
254+ pack_dtype_bits = qweight .dtype .itemsize * 8
255+
256+ if bits in [2 , 4 , 8 ]:
257+ pack_factor = pack_dtype_bits // bits
258+ wf = torch .tensor (list (range (0 , pack_dtype_bits , bits )), dtype = torch .int32 ).unsqueeze (0 )
259+ if self .lazy :
260+ wf = LazyTorchTensor .from_eager (wf )
261+
262+ zeros = torch .bitwise_right_shift (
263+ qzeros .unsqueeze (2 ).expand (- 1 , - 1 , pack_factor ),
264+ wf .unsqueeze (0 )
265+ ).to (torch .int16 if bits == 8 else torch .int8 )
266+ zeros = torch .bitwise_and (zeros , maxq ).reshape (scales .shape )
267+
268+ weight = torch .bitwise_and (
269+ torch .bitwise_right_shift (
270+ qweight .unsqueeze (1 ).expand (- 1 , pack_factor , - 1 ),
271+ wf .unsqueeze (- 1 )
272+ ).to (torch .int16 if bits == 8 else torch .int8 ),
273+ maxq
274+ )
275+ elif bits == 3 :
276+ raise NotImplementedError ("3-bit gptq dequantization is not yet implemented" )
277+
278+ assert weight is not None
279+ assert zeros is not None
280+
281+ weight = weight .reshape (weight .shape [0 ] * weight .shape [1 ], weight .shape [2 ])
282+
283+ # gptq_v2 doesn't need to offset zeros
284+ if quant_config .get ("checkpoint_format" , "gptq" ) == "gptq" :
285+ zeros += 1
286+
287+ return (scales [g_idx ].float () * (weight - zeros [g_idx ]).float ()).T
288+
289+ if quant_method == "bitnet" :
290+ for name in self .model_tensors .keys ():
291+ if name .endswith (".weight_scale" ):
292+ weight_name = name .removesuffix ("_scale" )
293+ w = self .model_tensors [weight_name ]
294+ s = self .model_tensors [name ]
295+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_bitnet (w (), s ())
296+ tensors_to_remove .append (name )
297+ elif quant_method == "fp8" :
298+ for name in self .model_tensors .keys ():
299+ if name .endswith (".weight_scale_inv" ):
300+ weight_name = name .removesuffix ("_scale_inv" )
301+ w = self .model_tensors [weight_name ]
302+ s = self .model_tensors [name ]
303+ self .model_tensors [weight_name ] = lambda w = w , s = s : dequant_simple (w (), s ())
304+ tensors_to_remove .append (name )
305+ elif quant_method == "gptq" :
306+ for name in self .model_tensors .keys ():
307+ if name .endswith (".qweight" ):
308+ base_name = name .removesuffix (".qweight" )
309+ g_idx = self .model_tensors [base_name + ".g_idx" ]
310+ qweight = self .model_tensors [base_name + ".qweight" ]
311+ qzeros = self .model_tensors [base_name + ".qzeros" ]
312+ scales = self .model_tensors [base_name + ".scales" ]
313+ new_tensors [base_name + ".weight" ] = (
314+ lambda g = g_idx , z = qzeros , w = qweight , s = scales : dequant_gptq (
315+ g (), w (), z (), s ()
316+ )
317+ )
318+ tensors_to_remove += [
319+ base_name + n
320+ for n in (
321+ ".g_idx" ,
322+ ".qzeros" ,
323+ ".qweight" ,
324+ ".scales" ,
325+ )
326+ ]
206327 else :
207- raise ValueError ("Mismatch between weight map and model parts for tensor names:\n "
208- f"Missing tensors: { missing } \n "
209- f"Extra tensors: { extra } " )
328+ raise NotImplementedError (f"Quant method is not yet supported: { quant_method !r} " )
329+
330+ for name in tensors_to_remove :
331+ if name in self .model_tensors :
332+ del self .model_tensors [name ]
333+
334+ for name , value in new_tensors .items ():
335+ self .model_tensors [name ] = value
336+
337+ def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
338+ for name , gen in self .model_tensors .items ():
339+ yield name , gen ()
210340
211341 def format_tensor_name (self , key : gguf .MODEL_TENSOR , bid : int | None = None , suffix : str = ".weight" ) -> str :
212342 if key not in gguf .MODEL_TENSORS [self .model_arch ]:
@@ -3860,27 +3990,6 @@ def set_gguf_parameters(self):
38603990 self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
38613991 self .gguf_writer .add_rope_scaling_factor (1.0 )
38623992
3863- _has_tok_embd = False
3864-
3865- def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3866- del bid # unused
3867-
3868- output_name = self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT )
3869- tok_embd_name = self .format_tensor_name (gguf .MODEL_TENSOR .TOKEN_EMBD )
3870-
3871- new_name = self .map_tensor_name (name )
3872-
3873- # assuming token_embd.weight is seen before output.weight
3874- if not self ._has_tok_embd and new_name == self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT ):
3875- # even though the tensor file(s) does not contain the word embeddings they are still in the weight map
3876- if self .tensor_names and "transformer.wte.weight" in self .tensor_names :
3877- logger .debug (f"{ tok_embd_name } not found before { output_name } , assuming they are tied" )
3878- self .tensor_names .remove ("transformer.wte.weight" )
3879- elif new_name == tok_embd_name :
3880- self ._has_tok_embd = True
3881-
3882- return [(new_name , data_torch )]
3883-
38843993
38853994@ModelBase .register ("InternLM2ForCausalLM" )
38863995class InternLM2Model (TextModel ):
0 commit comments