更新 actor 的位置在 verl/trainer/ppo/ray_trainer.py

if self.config.trainer.critic_warmup <= self.global_steps:
    # update actor
    with _timer("update_actor", timing_raw):
        real_batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
        actor_output = self.actor_rollout_wg.update_actor(real_batch) # 这里进行更新
    actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
    metrics.update(actor_output_metrics)

这里 self.actor_rollout_wg 的类型为 ActorRolloutRefWorker ,对应函数 updat_actorverl/workers/fsdp_workers.py

def update_actor(self, data: DataProto):
    # Support all hardwares
    data = data.to(torch.cuda.current_device())
    assert self._is_actor
 
    if self._is_offload_param:
        load_fsdp_model_to_gpu(self.actor_module_fsdp)
    if self._is_offload_optimizer:
        load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
 
    with self.ulysses_sharding_manager:
        data = self.ulysses_sharding_manager.preprocess_data(data=data)
        # perform training
        with Timer(name="update_policy", logger=None) as timer:
            metrics = self.actor.update_policy(data=data) # 这里进行更新
        delta_time = timer.last
        global_num_tokens = data.meta_info["global_token_num"]
        estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
        metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
        metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
        metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
        metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
 
        self.actor_lr_scheduler.step()
        lr = self.actor_lr_scheduler.get_last_lr()[0]
        metrics["actor/lr"] = lr
 
        # TODO: here, we should return all metrics
        output = DataProto(meta_info={"metrics": metrics})
        output = self.ulysses_sharding_manager.postprocess_data(data=output)
        output = output.to("cpu")
 
    if self._is_offload_param:
        offload_fsdp_model_to_cpu(self.actor_module_fsdp)
        log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
    if self._is_offload_optimizer:
        offload_fsdp_optimizer(optimizer=self.actor_optimizer)
        log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
 
    return output

这里的 self.actor 的定义为:

self.actor = DataParallelPPOActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer)

类型为 DataParallelPPOActor ,在 verl/workers/actor/dp_actor.py 中:

def update_policy(self, data: DataProto): # 从调用链来看,这里的 data 是 repeat 之后 rollout,计算了 advantage 的批次,大小为 train_batch_size * rollout_n
    # make sure we are in training mode
    self.actor_module.train()
 
    temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error
    multi_turn = data.meta_info.get("multi_turn", False)
 
    select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
    if multi_turn:
        select_keys.append("loss_mask")
    if self.config.use_kl_loss:
        select_keys.append("ref_log_prob")
    batch = data.select(batch_keys=select_keys).batch # 类型为 TensorDict
    has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
 
    # Split to make minibatch iterator for updating the actor
    # See PPO paper for details. https://arxiv.org/abs/1707.06347
    if has_multi_modal_inputs:
        num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
        non_tensor_select_keys = ["multi_modal_inputs"]
        dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
    else:
        dataloader = batch.split(self.config.ppo_mini_batch_size) # 按照 mini batch size 进行划分,每次更新只使用 mini batch size 数量的样本
                                                                  # 类型为 List[TensorDictBase]
 
    metrics = {}
    for epoch in range(self.config.ppo_epochs): # 一般只更新一个 epoch,可以在 actor_rollout_ref.actor.ppo_epochs 中设置
        for batch_idx, data in enumerate(dataloader): # 这里的 batch_idx 没有用到,可以打印一下看看是什么
            # split batch into micro_batches
            mini_batch = data
            # 设置 micro batch
            if has_multi_modal_inputs:
                self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
                micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
            elif self.config.use_dynamic_bsz: # 可以设置动态的 micro batch size,见下面的说明
                max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
                micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
            else:
                self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
                # split batch into micro_batches
                micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
 
            self.actor_optimizer.zero_grad()
 
            for data in micro_batches: # 这里进行梯度累加,和 mini_batch 不同
                # Support all hardwares
                if isinstance(data, DataProto):
                    data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch}
                else:
                    data = data.to(torch.cuda.current_device())  # actor device is cpu when using offload
                # 直到 entropy 那行,都是在准备数据进行计算
                responses = data["responses"]
                response_length = responses.size(1)
                attention_mask = data["attention_mask"]
                if multi_turn:
                    response_mask = data["loss_mask"][:, -response_length:]
                else:
                    response_mask = attention_mask[:, -response_length:]
 
                old_log_prob = data["old_log_probs"]
                advantages = data["advantages"]
 
                clip_ratio = self.config.clip_ratio
                clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
                clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
                clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
                entropy_coeff = self.config.entropy_coeff
                loss_agg_mode = self.config.loss_agg_mode
 
                # all return: (bsz, response_length)
                calculate_entropy = False
                if entropy_coeff != 0:
                    calculate_entropy = True
                entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy)
 
                # 计算 policy gradient loss,同时返回 pg_clipfrac, ppo_kl, pg_clipfrac_lower 三个指标
                pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
                    old_log_prob=old_log_prob,
                    log_prob=log_prob,
                    advantages=advantages,
                    response_mask=response_mask,
                    cliprange=clip_ratio,
                    cliprange_low=clip_ratio_low,
                    cliprange_high=clip_ratio_high,
                    clip_ratio_c=clip_ratio_c,
                    loss_agg_mode=loss_agg_mode,
                )
 
                # 如果启用了 entropy_coeff,那么需要对 pg_loss 减去相应的项
                if entropy_coeff != 0:
                    entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
 
                    # compute policy loss
                    policy_loss = pg_loss - entropy_loss * entropy_coeff
                else:
                    policy_loss = pg_loss
 
                if self.config.use_kl_loss:
                    ref_log_prob = data["ref_log_prob"]
                    # compute kl loss
                    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
                    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
 
                    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
                    metrics["actor/kl_loss"] = kl_loss.detach().item()
                    metrics["actor/kl_coef"] = self.config.kl_loss_coef
 
                # 计算最终 loss 并进行梯度累积。这里为了保证在所有样本上求平均,需要除以相应的系数
                if self.config.use_dynamic_bsz:
                    # relative to the dynamic bsz
                    loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
                else:
                    loss = policy_loss / self.gradient_accumulation
                loss.backward()
 
                data = {
                    "actor/pg_loss": pg_loss.detach().item(),
                    "actor/pg_clipfrac": pg_clipfrac.detach().item(),
                    "actor/ppo_kl": ppo_kl.detach().item(),
                    "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
                }
                append_to_dict(metrics, data)
 
            # 一个 mini batch 内的梯度累积完成,进行一次更新
            grad_norm = self._optimizer_step()
            data = {"actor/grad_norm": grad_norm.detach().item()}
        append_to_dict(metrics, data)
    self.actor_optimizer.zero_grad()
    return metrics

其中 use_dynamic_bsz 在官方文档中是这样介绍的( Performance Tuning Guide — verl documentation ):

Dynamic batch size is a technique that allows the model to process similar number of tokens in a single forward pass (with different actual batch sizes). This can significantly improve the training efficiency and reduce the memory usage.

更具体的调整方式见 veRL 的官方文档。

然后渗入到 compute_policy_loss 函数,在 verl/trainer/ppo/core_algos.py

def compute_policy_loss(
    old_log_prob,
    log_prob,
    advantages,
    response_mask,
    cliprange=None,
    cliprange_low=None,
    cliprange_high=None,
    clip_ratio_c=3.0,
    loss_agg_mode="token-mean",
):
    """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
    Args:
        old_log_prob: \`(torch.Tensor)\`
            shape: (bs, response_length)
        log_prob: \`(torch.Tensor)\`
            shape: (bs, response_length)
        advantages: \`(torch.Tensor)\`
            shape: (bs, response_length)
        response_mask: \`(torch.Tensor)\`
            shape: (bs, response_length)
        cliprange: (float)
            The clip range used in PPO. See https://arxiv.org/abs/1707.06347
        cliprange_low: (float)
            The lower clip range used in PPO.
        cliprange_high: (float)
            The higher clip range used in PPO.
        clip_ratio_c: (float) default: 3.0
            The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
        loss_agg_mode: (str) choices: "token-mean" /
                                      "seq-mean-token-sum" /
                                      "seq-mean-token-mean" /
                                      "seq-mean-token-sum-norm" /
            "token-mean" is the default behavior
 
    Returns:
        pg_loss: \`a scalar torch.Tensor\`
            policy gradient loss computed via PPO
        pg_clipfrac: (float)
            the fraction of policy gradient loss being clipped
        ppo_kl: (float)
            the estimated KL divergence between the latest updating policy and the old sampling policy
        pg_clipfrac_lower: (float)
            the fraction of policy gradient loss being clipped when the advantage is negative
    """
    assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."
 
    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
 
    pg_losses1 = -advantages * ratio
    if cliprange_low is None:
        cliprange_low = cliprange
    if cliprange_high is None:
        cliprange_high = cliprange
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)  # - clip(ratio, 1-cliprange, 1+cliprange) * A
    clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)  # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
    pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
 
    pg_losses3 = -advantages * clip_ratio_c
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)
 
    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
 
    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower

其中 verl_F.masked_mean 辅助函数如下(在 verl/utils/torch_functional.py ):

def masked_mean(values, mask, axis=None):
    """Compute mean of tensor with a masked values."""
    return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)

agg_lossverl/trainer/ppo/core_algos.py ):

def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
    """
    Aggregate the loss matrix into a scalar.
    Args:
        loss_mat: \`(torch.Tensor)\`
            shape: (bs, response_length)
        loss_mask: \`(torch.Tensor)\`
            shape: (bs, response_length)
        loss_agg_mode: (str) choices: "token-mean" /
                                      "seq-mean-token-sum" /
                                      "seq-mean-token-mean" /
                                      "seq-mean-token-sum-norm" /
            "token-mean" is the default behavior
    Returns:
        loss: \`a scalar torch.Tensor\`
            aggregated loss
    """
    if loss_agg_mode == "token-mean":
        loss = verl_F.masked_mean(loss_mat, loss_mask)
    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)  # token-sum
        loss = torch.mean(seq_losses)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)  # token-mean
        loss = torch.mean(seq_losses)  # seq-mean
    elif loss_agg_mode == "seq-mean-token-sum-norm":
        seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
        loss = torch.sum(seq_losses) / loss_mask.shape[-1]  # The divisor
        # (loss_mask.shape[-1]) should ideally be constant
        # throughout training to well-replicate the DrGRPO paper.
        # TODO: Perhaps add user-defined normalizer argument to
        # agg_loss to ensure divisor stays constant throughout.
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
 
    return loss

默认 loss_agg_mode 的值是 token-mean

Entropy 的计算方式:

def entropy_from_logits(logits: torch.Tensor):
    """Calculate entropy from logits."""
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
    return entropy

KL loss 的计算方式:

def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
    """Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other
    kl penalty compute method for unbiased KL gradient estimation.
    See more description in http://joschu.net/blog/kl-approx.html
 
    Args:
        logprob:
        ref_logprob:
 
    Returns:
        kl_estimate
    """
    forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty)
    if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"):
        return forward_score
 
    """
    The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3
    estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator,
    so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+.
    """
    backward_score = 0.5 * (logprob - ref_logprob).square()
 
    return backward_score - backward_score.detach() + forward_score.detach()
 
def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
    """Compute KL divergence given logprob and ref_logprob.
    Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
    See more description in http://joschu.net/blog/kl-approx.html
 
    Args:
        logprob:
        ref_logprob:
 
    Returns:
        kl_estimate
    """
    if kl_penalty in ("kl", "k1"):
        return logprob - ref_logprob
 
    if kl_penalty == "abs":
        return (logprob - ref_logprob).abs()
 
    if kl_penalty in ("mse", "k2"):
        return 0.5 * (logprob - ref_logprob).square()
 
    # J. Schulman. Approximating kl divergence, 2020.
    # # URL http://joschu.net/blog/kl-approx.html.
    if kl_penalty in ("low_var_kl", "k3"):
        kl = ref_logprob - logprob
        # For numerical stability
        kl = torch.clamp(kl, min=-20, max=20)
        ratio = torch.exp(kl)
        kld = (ratio - kl - 1).contiguous()
        return torch.clamp(kld, min=-10, max=10)
 
    if kl_penalty == "full":
        # so, here logprob and ref_logprob should contain the logits for every token in vocabulary
        raise NotImplementedError
 
    raise NotImplementedError