prompting.validators.reward.dpo#

Module Contents#

Classes#

DirectPreferenceRewardModel

class prompting.validators.reward.dpo.DirectPreferenceRewardModel(device)#

Bases: prompting.validators.reward.reward.BaseRewardModel

Parameters:

device (str) –

property name: str#
Return type:

str

reward_model_name: str = 'cerebras/btlm-3b-8k-base'#
reward_single(prompt, completion, name, with_penalty=True)#

Calculates a direct preference optimization (DPO) style reward for a completion, which is a reference model’s average log-probability for completion tokens given a prompt. Uses guidance from eric-mitchell/direct-preference-optimization.

Parameters:
  • prompt (str) –

  • completion (str) –

  • name (str) –

Return type:

prompting.validators.reward.reward.BaseRewardEvent

get_rewards(prompt, completions, name)#
Parameters:
  • prompt (str) –

  • completions (List[str]) –

  • name (str) –

Return type:

List[prompting.validators.reward.reward.BaseRewardEvent]