MEND
Last updated
Last updated
MendRewriteExecutor
is the class for apply MEND to your model, it employs a hyper network to learn the necessarydelta
for editing the language model.
MEND requires a pre-trained specific model structure and weights (use Trainer)
def init_model(self, model, tok, params: MENDHyperParams)
Paramters
model(PreTrainedModel): model to be edited
tok(PreTrainedTokenizer): tokenizer for inputs
params(Hyperparams): hyperparameters for editing method
Return Type
Main function: Given the request, it applies mend to your model. Return the changed weights of the model.
def apply_to_model(
self,
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: MENDHyperParams,
copy=False,
return_orig_weights=False,
keep_original_weight=False,
**kwargs
):
Paramters
model(PreTrainedModel): model to be edited
tok(PreTrainedTokenizer): tokenizer for inputs
requests(List[Dict]): The edit descriptors and targets.
hparams(Hyperparams): hyperparameters for editing method
copy(bool): whether to copy original model
return_orig_weights(bool): whether to return the weights of original model
keep_original_weight(bool): whether to edit sequentially
False
: edit sequentially(because the original weight is not maintained after each edit)
True
: not edit sequentially
Return Type
edited_model(PreTrainedModel): model weights after editing
// ...
hparams = MENDHyperaParams.from_hparams("llama-7b.yaml")
editor = BaseEditor.from_hparams(hparams)
prompts = ['What university did Watts Humphrey attend?',
'Which family does Ramalinaceae belong to',
'What role does Denny Herzig play in football?'
]
target_new = ['University of Michigan',
'Lamiinae',
'winger'
]
metrics, edited_model, _ = editor.edit(
prompts=prompts,
target_new=target_new,
keep_original_weight=True
)