1+ # Removed all C dependencies and library loading.
2+ # Introducing GRPO class that uses PyTorch's autograd and vectorized operations.
3+
4+ import torch
5+ import torch .nn as nn
6+ import torch .optim as optim
7+ from torch import Tensor
8+ from typing import Any
9+
10+ class GRPO :
11+ """
12+ A GRPO model that leverages PyTorch for automatic gradient computation and vectorized operations.
13+ This replaces the previous C-based implementation.
14+ """
15+
16+ def __init__ (self , model : nn .Module , optimizer : optim .Optimizer , epsilon : float = 1e-8 ) -> None :
17+ """
18+ Initializes the GRPO instance.
19+
20+ Args:
21+ model (nn.Module): The PyTorch model to be trained.
22+ optimizer (optim.Optimizer): The optimizer for training the model.
23+ epsilon (float): A small value added for numerical stability.
24+ """
25+ self .model = model
26+ self .optimizer = optimizer
27+ self .epsilon = epsilon
28+
29+ def forward (self , x : Tensor ) -> Tensor :
30+ """
31+ Forward pass through the model.
32+
33+ Args:
34+ x (Tensor): Input tensor.
35+
36+ Returns:
37+ Tensor: The model's output.
38+ """
39+ return self .model (x )
40+
41+ def compute_loss (self , predictions : Tensor , targets : Tensor ) -> Tensor :
42+ """
43+ Compute the loss between predictions and targets.
44+ Uses mean squared error loss with an epsilon offset for numerical stability.
45+
46+ Args:
47+ predictions (Tensor): The predicted outputs.
48+ targets (Tensor): The ground-truth values.
49+
50+ Returns:
51+ Tensor: The computed loss.
52+ """
53+ loss = torch .mean ((predictions - targets ) ** 2 + self .epsilon )
54+ return loss
55+
56+ def training_step (self , x : Tensor , targets : Tensor ) -> float :
57+ """
58+ Executes one training step: forward pass, loss computation, backpropagation, and optimizer update.
59+
60+ Args:
61+ x (Tensor): Input batch.
62+ targets (Tensor): Target values for the batch.
63+
64+ Returns:
65+ float: The scalar loss value for this step.
66+ """
67+ self .optimizer .zero_grad ()
68+ predictions = self .forward (x )
69+ loss = self .compute_loss (predictions , targets )
70+ loss .backward ()
71+ self .optimizer .step ()
72+ return loss .item ()
0 commit comments