IKE

encode_ike_facts

In order to retrieve the nearest neighbors as in-context demonstrations, samples in the dev set need to be encoded

def encode_ike_facts(
    sentence_model: SentenceTransformer,
    ds: Dataset,
    hparams: IKEHyperParams
):

Paramters

  • sentence_model(SentenceTransformer): the model to encode demonstrations

  • ds(Dataset): the dev set, as a corpus for retrieving demonstrations

  • hparams(Hyperparams): hyperparameters for editing method

store dense embeddings in the form of pickle

apply_ike_to_model()-> PreTrainedModel

Main function: Given the request, it applies IKE to your model. Utilizing the preceding prompt to modify the behavior of the model

def apply_ike_to_model(
    self,
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: IKEHyperParams,
    copy=False,
    return_orig_weights=False,
    keep_original_weight=False,
    **kwargs
):
  • Paramters

    • model(PreTrainedModel): model to be edited

    • tok(PreTrainedTokenizer): tokenizer for inputs

    • requests(List[Dict]): The edit descriptors and targets.

    • hparams(Hyperparams): hyperparameters for editing method

    • copy(bool): whether to copy original model

    • return_orig_weights(bool): whether to return the weights of original model

    • keep_original_weight(bool): whether to edit sequentially

      • False: edit sequentially(because the original weight is not maintained after each edit)

      • True: not edit sequentially

  • Return Type

    • edited_model(PreTrainedModel): model weights after editing

Example

// ...
hparams = IKEHyperaParams.from_hparams("llama-7b.yaml")
editor = BaseEditor.from_hparams(hparams)
prompts = ['What university did Watts Humphrey attend?',
    'Which family does Ramalinaceae belong to',
    'What role does Denny Herzig play in football?'
]
target_new = ['University of Michigan',
    'Lamiinae',
    'winger'
]
metrics, edited_model, _ = editor.edit(
    prompts=prompts,
    target_new=target_new,
    keep_original_weight=True
)

Last updated