Evaluate

Contain evaluation utilities for pytorch-based rewriting methods. To use, simply call compute_edit_quality with the appropriate arguments, which returns a dictionary containing them.

compute_rewrite_or_rephrase_quality() -> Dict

evaluation method for Reliability and Generalization

  • Conduct a computation to identify the token that possesses the highest probability at each respective token position.

  • Subsequently, compare this calculated result with the established ground truth to ascertain the mean accuracy of the probabilistic model.

def compute_rewrite_or_rephrase_quality(
    model,
    model_name,
    hparams: HyperParams,
    tok: AutoTokenizer,
    prompt: str,
    target_new: str,
    device,
    test_rephrase: bool = False
) -> typing.Dict:
  • Paramters

    • model(PreTrainedModel): model to be edited

    • modle_name(Str): model_name_or_path

    • hparams(Hyperparams): hyperparameters for editing method

    • tok(PreTrainedTokenizer): tokenizer for inputs

    • prompt(Str): the edit descriptor

    • target_new(Str): the edit target

    • test_rephrase(bool): whether to evalute the rephrase prompt(For Generalization)

  • Return Type

    • metrics(Dict): model weights after editing

compute_locality_quality() -> Dict

the input-ouput format is same as compute_rewrite_or_rephrase_quality

  • Conduct a computation to identify the token that possesses the highest probability at each respective token position.

  • Compare whether the output tokens before and after editing are the same, and calculate the average accuracy rate

Example

  • metrics

{
    "post": {
        "rewrite_acc": ,
        "rephrase_acc": ,
        "locality": {
            "YOUR_LOCALITY_KEY": ,
            //...
        },
        "portablility": {
            "YOUR_PORTABILITY_KEY": ,
            //...
        },
    },
    "pre": {
        "rewrite_acc": ,
        "rephrase_acc": ,
        "portablility": {
            "YOUR_PORTABILITY_KEY": ,
            //...
        },
    }
}

Last updated