SmashConfig
SmashConfig
is an essential tool in pruna for configuring parameters to optimize your models. This manual explains how to define and use a SmashConfig
.
Defining a simple SmashConfig
Define a SmashConfig
using the following snippet:
from pruna import SmashConfig
smash_config = SmashConfig()
After creating an empty SmashConfig
, you can set activate a algorithm by adding it to the SmashConfig
:
smash_config['quantizer'] = 'hqq'
Additionally, you can overwrite the defaults of the algorithm you have added by setting the hyperparameters in the SmashConfig
:
smash_config['hqq_weight_bits'] = 4
You’re done! You created your SmashConfig
and can now pass it to the smash function.
Adding a Dataset, Tokenizer or Processor
Some algorithms require a dataset, tokenizer or processor to be passed to the SmashConfig
.
For example, the gptq
quantizer requires a dataset and a tokenizer. We can pass them to the SmashConfig
e.g. as follows:
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("facebook/opt-125m")
smash_config.add_data("WikiText")
As you can see in this example, we can add a dataset simply by passing the name of the dataset. However, the add_data()
function also supports other input formats. For more information, see the dataset documentation.
We can now activate the gptq
quantizer by adding it to the SmashConfig
:
smash_config['quantizers'] = 'gptq'
Similarly, we can add a processor to the SmashConfig
if required, like for example by the c_whisper
compiler:
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config.add_processor("openai/whisper-large-v3")
smash_config['compiler'] = 'c_whisper'
If you try to activate a algorithm that requires a dataset, tokenizer or processor and haven’t added them to the SmashConfig
, you will receive an error. Make sure to add them before activating the algorithm! If you want to know which algorithms require a dataset, tokenizer or processor, you can look at the compression algorithm overview.
SmashConfig
Documentation
- class pruna.config.smash_config.SmashConfig(max_batch_size=1, device='cuda', cache_dir_prefix='/home/docs/.cache/pruna', configuration=None)
Wrapper class to hold a ConfigSpace Configuration object as a Smash configuration.
- Parameters:
max_batch_size (int, optional) – The maximum number of batches to process at once. Default is 1.
device (str, optional) – The device to be used for smashing, e.g., ‘cuda’ or ‘cpu’. Default is ‘cuda’.
cache_dir_prefix (str, optional) – The prefix for the cache directory. If None, a default cache directory will be created.
configuration (Configuration, optional) – The configuration to be used for smashing. If None, a default configuration will be created.
- add_data(arg)
- add_data(dataset_name, *args, **kwargs)
- add_data(datasets, collate_fn, *args, **kwargs)
- add_data(datasets, collate_fn, *args, **kwargs)
- add_data(datamodule)
Add data to the SmashConfig.
- Parameters:
arg (Any) – The argument to be used.
- add_processor(processor)
Add a processor to the SmashConfig.
- Parameters:
processor (str | transformers.AutoProcessor) – The processor to be added to the SmashConfig.
- Return type:
None
- add_tokenizer(tokenizer)
Add a tokenizer to the SmashConfig.
- Parameters:
tokenizer (str | transformers.AutoTokenizer) – The tokenizer to be added to the SmashConfig.
- Return type:
None
- flush_configuration()
Remove all algorithm hyperparameters from the SmashConfig.
Examples
>>> config = SmashConfig() >>> config['cacher'] = 'deepcache' >>> config.flush_configuration() >>> config SmashConfig()
- Return type:
None
- load_dict(config_dict)
Load a dictionary of hyperparameters into the SmashConfig.
- Parameters:
config_dict (dict) – The dictionary to load into the SmashConfig.
- Return type:
None
Examples
>>> config = SmashConfig() >>> config.load_dict({'cacher': 'deepcache', 'deepcache_interval': 4}) >>> config SmashConfig( 'cacher': 'deepcache', 'deepcache_interval': 4, )