MEMIT

execute_memit() -> Dict[str, Tuple[torch.Tensor]]

Execution function: Execute the MEMIT update algorithm for the specified update at the specified layers

def execute_memit(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: List[Dict],
    hparams: MEMITHyperParams,
) -> Dict[str, Tuple[torch.Tensor]]:
  • Paramters

    • model(PreTrainedModel): model to be edited

    • tok(PreTrainedTokenizer): tokenizer for inputs

    • requests(List[Dict]): The edit descriptors and targets.

    • hparams(Hyperparams): hyperparameters for editing method

  • Return Type

    • delta(Dict[str, Tuple[torch.Tensor]]): new delta weights

apply_memit_to_model()-> PreTrainedModel

Main function: Given the request, it applies MEMIT to your model. Return the changed weights of the model.

def apply_memit_to_model(
    self,
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: MEMITHyperParams,
    copy=False,
    return_orig_weights=False,
    keep_original_weight=False,
    **kwargs
):
  • Paramters

    • model(PreTrainedModel): model to be edited

    • tok(PreTrainedTokenizer): tokenizer for inputs

    • requests(List[Dict]): The edit descriptors and targets.

    • hparams(Hyperparams): hyperparameters for editing method

    • copy(bool): whether to copy original model

    • return_orig_weights(bool): whether to return the weights of original model

    • keep_original_weight(bool): whether to edit sequentially

      • False: edit sequentially(because the original weight is not maintained after each edit)

      • True: not edit sequentially

  • Return Type

    • edited_model(PreTrainedModel): model weights after editing

Example

// ...
hparams = MEMITHyperaParams.from_hparams("llama-7b.yaml")
editor = BaseEditor.from_hparams(hparams)
prompts = ['What university did Watts Humphrey attend?',
    'Which family does Ramalinaceae belong to',
    'What role does Denny Herzig play in football?'
]
target_new = ['University of Michigan',
    'Lamiinae',
    'winger'
]
metrics, edited_model, _ = editor.edit(
    prompts=prompts,
    target_new=target_new,
    keep_original_weight=True
)

Last updated