OpenRLHFѧϰ

NaiveExperienceMaker

���������õ�micro_rollout_batch_sizeԭ����һ����������������

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def generate_samples(self, all_prompts: List[str], **generate_kwargs) -> List[Samples]:
"""
Generate samples and return in batches.
"""
assert not getattr(self, "packing_samples", False)
args = self.strategy.args
self.actor.eval()
# sample multiple response
all_prompts = sum([[prompt] * args.n_samples_per_prompt for prompt in all_prompts], [])
samples_list = []
for i in range(0, len(all_prompts), args.micro_rollout_batch_size):
prompts = all_prompts[i : i + args.micro_rollout_batch_size]
inputs = self.tokenize_fn(prompts, self.prompt_max_len, device="cuda")
sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs)
samples = Samples(
sequences=sequences,
attention_mask=attention_mask,
action_mask=action_mask,
num_actions=action_mask.size(1),
packed_seq_lens=None,
response_length=action_mask.float().sum(dim=-1),
total_length=attention_mask.float().sum(dim=-1),
)
samples_list.append(samples)
return samples_list

kl �ǰ� samples.sequences �ֱ����� actor model �� ref model���õ� log_prob ����
r �ǰ� sequence ι�� reward model

forward �� generate ������