Skip to content

Commit 387da58

Browse files
committed
feat: update subscription allocation during session update
1 parent 2cb0e2e commit 387da58

File tree

7 files changed

+115
-33
lines changed

7 files changed

+115
-33
lines changed

x/session/expected/keeper.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package expected
22

33
import (
4+
sdkmath "cosmossdk.io/math"
45
sdk "github.com/cosmos/cosmos-sdk/types"
56
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
67

@@ -27,5 +28,6 @@ type NodeKeeper interface {
2728

2829
type SubscriptionKeeper interface {
2930
SessionInactivePreHook(ctx sdk.Context, id uint64) error
31+
SessionUpdatePreHook(ctx sdk.Context, id uint64, currBytes sdkmath.Int) error
3032
UpdateSessionMaxValues(ctx sdk.Context, session v3.Session) error
3133
}

x/session/keeper/alias.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package keeper
22

33
import (
4+
sdkmath "cosmossdk.io/math"
45
sdk "github.com/cosmos/cosmos-sdk/types"
56

67
"github.com/sentinel-official/hub/v12/x/session/types/v3"
@@ -33,6 +34,14 @@ func (k *Keeper) SessionInactivePreHook(ctx sdk.Context, id uint64) error {
3334
return nil
3435
}
3536

37+
func (k *Keeper) SessionUpdatePreHook(ctx sdk.Context, id uint64, currBytes sdkmath.Int) error {
38+
if err := k.subscription.SessionUpdatePreHook(ctx, id, currBytes); err != nil {
39+
return err
40+
}
41+
42+
return nil
43+
}
44+
3645
func (k *Keeper) UpdateMaxValues(ctx sdk.Context, session v3.Session) error {
3746
if err := k.node.UpdateSessionMaxValues(ctx, session); err != nil {
3847
return err

x/session/keeper/msg_handler.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,17 @@ func (k *Keeper) HandleMsgUpdateSession(ctx sdk.Context, msg *v3.MsgUpdateSessio
6666
return nil, types.NewErrorInvalidSessionStatus(session.GetID(), session.GetStatus())
6767
}
6868

69-
if k.ProofVerificationEnabled(ctx) {
69+
if msg.DownloadBytes.LT(session.GetDownloadBytes()) {
70+
return nil, types.NewErrorInvalidDownloadBytes(msg.DownloadBytes)
71+
}
72+
if msg.UploadBytes.LT(session.GetUploadBytes()) {
73+
return nil, types.NewErrorInvalidUploadBytes(msg.UploadBytes)
74+
}
75+
if msg.Duration < session.GetDuration() {
76+
return nil, types.NewErrorInvalidDuration(msg.Duration)
77+
}
78+
79+
if ok := k.ProofVerificationEnabled(ctx); ok {
7080
accAddr, err := sdk.AccAddressFromBech32(session.GetAccAddress())
7181
if err != nil {
7282
return nil, err
@@ -77,6 +87,10 @@ func (k *Keeper) HandleMsgUpdateSession(ctx sdk.Context, msg *v3.MsgUpdateSessio
7787
}
7888
}
7989

90+
if err := k.SessionUpdatePreHook(ctx, session.GetID(), msg.Bytes()); err != nil {
91+
return nil, err
92+
}
93+
8094
if session.GetStatus().Equal(v1base.StatusActive) {
8195
k.DeleteSessionForInactiveAt(ctx, session.GetInactiveAt(), session.GetID())
8296
}

x/session/types/errors.go

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
package types
22

33
import (
4+
"time"
5+
46
sdkerrors "cosmossdk.io/errors"
7+
sdkmath "cosmossdk.io/math"
58

69
v1base "github.com/sentinel-official/hub/v12/types/v1"
710
)
811

912
var (
1013
ErrInvalidMessage = sdkerrors.Register(ModuleName, 101, "invalid message")
1114

12-
ErrInvalidSessionStatus = sdkerrors.Register(ModuleName, 201, "invalid session status")
13-
ErrInvalidSignature = sdkerrors.Register(ModuleName, 202, "invalid signature")
14-
ErrSessionNotFound = sdkerrors.Register(ModuleName, 203, "session not found")
15-
ErrUnauthorized = sdkerrors.Register(ModuleName, 204, "unauthorized")
15+
ErrInvalidDownloadBytes = sdkerrors.Register(ModuleName, 201, "invalid download bytes")
16+
ErrInvalidDuration = sdkerrors.Register(ModuleName, 202, "invalid duration")
17+
ErrInvalidSessionStatus = sdkerrors.Register(ModuleName, 203, "invalid session status")
18+
ErrInvalidSignature = sdkerrors.Register(ModuleName, 204, "invalid signature")
19+
ErrInvalidUploadBytes = sdkerrors.Register(ModuleName, 205, "invalid upload bytes")
20+
ErrSessionNotFound = sdkerrors.Register(ModuleName, 206, "session not found")
21+
ErrUnauthorized = sdkerrors.Register(ModuleName, 207, "unauthorized")
1622
)
1723

24+
// NewErrorInvalidDownloadBytes returns an error indicating that the download bytes are invalid.
25+
func NewErrorInvalidDownloadBytes(bytes sdkmath.Int) error {
26+
return sdkerrors.Wrapf(ErrInvalidDownloadBytes, "invalid download bytes %s", bytes)
27+
}
28+
29+
// NewErrorInvalidDuration returns an error indicating that the specified duration is invalid.
30+
func NewErrorInvalidDuration(duration time.Duration) error {
31+
return sdkerrors.Wrapf(ErrInvalidDuration, "invalid duration %d", duration)
32+
}
33+
1834
// NewErrorInvalidSessionStatus returns an error indicating that the provided status is invalid for the session.
1935
func NewErrorInvalidSessionStatus(id uint64, status v1base.Status) error {
2036
return sdkerrors.Wrapf(ErrInvalidSessionStatus, "invalid status %s for session %d", status, id)
@@ -25,6 +41,11 @@ func NewErrorInvalidSignature(signature []byte) error {
2541
return sdkerrors.Wrapf(ErrInvalidSignature, "invalid signature %X", signature)
2642
}
2743

44+
// NewErrorInvalidUploadBytes returns an error indicating that the upload bytes are invalid.
45+
func NewErrorInvalidUploadBytes(bytes sdkmath.Int) error {
46+
return sdkerrors.Wrapf(ErrInvalidUploadBytes, "invalid upload bytes %s", bytes)
47+
}
48+
2849
// NewErrorSessionNotFound returns an error indicating that the specified session does not exist.
2950
func NewErrorSessionNotFound(id uint64) error {
3051
return sdkerrors.Wrapf(ErrSessionNotFound, "session %d does not exist", id)

x/session/types/v3/msg.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ func NewMsgUpdateSessionRequest(from base.NodeAddress, id uint64, downloadBytes,
5959
}
6060
}
6161

62+
func (m *MsgUpdateSessionRequest) Bytes() sdkmath.Int {
63+
return m.DownloadBytes.Add(m.UploadBytes)
64+
}
65+
6266
func (m *MsgUpdateSessionRequest) Proof() *Proof {
6367
return &Proof{
6468
ID: m.ID,

x/session/types/v3/session.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
type Session interface {
1313
proto.Message
1414

15+
Bytes() sdkmath.Int
16+
1517
GetID() uint64
1618
GetAccAddress() string
1719
GetNodeAddress() string
@@ -39,6 +41,10 @@ type Session interface {
3941
SetStatusAt(v time.Time)
4042
}
4143

44+
func (m *BaseSession) Bytes() sdkmath.Int {
45+
return m.GetDownloadBytes().Add(m.GetUploadBytes())
46+
}
47+
4248
func (m *BaseSession) GetID() uint64 { return m.ID }
4349
func (m *BaseSession) GetAccAddress() string { return m.AccAddress }
4450
func (m *BaseSession) GetNodeAddress() string { return m.NodeAddress }

x/subscription/keeper/hooks.go

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package keeper
22

33
import (
4+
sdkmath "cosmossdk.io/math"
45
sdk "github.com/cosmos/cosmos-sdk/types"
56

67
base "github.com/sentinel-official/hub/v12/types"
@@ -9,58 +10,94 @@ import (
910
"github.com/sentinel-official/hub/v12/x/subscription/types/v3"
1011
)
1112

12-
// SessionInactivePreHook handles the necessary operations when a session becomes inactive.
13+
// SessionInactivePreHook performs cleanup operations when a session transitions to an inactive state.
1314
func (k *Keeper) SessionInactivePreHook(ctx sdk.Context, id uint64) error {
1415
k.Logger(ctx).Info("Running session inactive pre-hook", "id", id)
1516

16-
// Retrieve the session by ID; return an error if not found.
17+
// Retrieve the session by ID; return an error if it doesn't exist.
1718
item, found := k.GetSession(ctx, id)
1819
if !found {
1920
return types.NewErrorSessionNotFound(id)
2021
}
2122

22-
// Assert the retrieved session to the v3.Session type; return nil if the assertion fails.
23+
// Ensure the session is of type v3.Session; do nothing if it's not.
2324
session, ok := item.(*v3.Session)
2425
if !ok {
2526
return nil
2627
}
2728

28-
// Ensure the session status is "InactivePending"; return an error if it has a different status.
29+
// Verify that the session's status is "InactivePending"; otherwise, return an error.
2930
if !session.Status.Equal(v1base.StatusInactivePending) {
3031
return types.NewErrorInvalidSessionStatus(session.ID, session.Status)
3132
}
3233

33-
// Retrieve the subscription associated with the session; return an error if not found.
34+
// Fetch the subscription associated with the session; return an error if it doesn't exist.
3435
subscription, found := k.GetSubscription(ctx, session.SubscriptionID)
3536
if !found {
3637
return types.NewErrorSubscriptionNotFound(session.SubscriptionID)
3738
}
3839

39-
// Convert the session's account address from Bech32 format.
40+
// Decode the session's account address from Bech32 format.
4041
accAddr, err := sdk.AccAddressFromBech32(session.AccAddress)
4142
if err != nil {
4243
return err
4344
}
4445

45-
// Retrieve the allocation for the subscription and account; return an error if not found.
46-
alloc, found := k.GetAllocation(ctx, subscription.ID, accAddr)
46+
// Decode the session's node address from Bech32 format.
47+
nodeAddr, err := base.NodeAddressFromBech32(session.NodeAddress)
48+
if err != nil {
49+
return err
50+
}
51+
52+
// Remove session references for allocation, node, plan, and subscription.
53+
k.DeleteSessionForAllocation(ctx, subscription.ID, accAddr, session.ID)
54+
k.DeleteSessionForPlanByNode(ctx, subscription.PlanID, nodeAddr, session.ID)
55+
k.DeleteSessionForSubscription(ctx, subscription.ID, session.ID)
56+
57+
return nil
58+
}
59+
60+
// SessionUpdatePreHook updates session and allocation details during a session update.
61+
func (k *Keeper) SessionUpdatePreHook(ctx sdk.Context, id uint64, currBytes sdkmath.Int) error {
62+
k.Logger(ctx).Info("Running session update pre-hook", "id", id)
63+
64+
// Retrieve the session by ID; return an error if it doesn't exist.
65+
item, found := k.GetSession(ctx, id)
4766
if !found {
48-
return types.NewErrorAllocationNotFound(subscription.ID, accAddr)
67+
return types.NewErrorSessionNotFound(id)
68+
}
69+
70+
// Ensure the session is of type v3.Session; do nothing if it's not.
71+
session, ok := item.(*v3.Session)
72+
if !ok {
73+
return nil
4974
}
5075

51-
// Calculate the total utilised bytes as the sum of download and upload bytes.
52-
utilisedBytes := session.DownloadBytes.Add(session.UploadBytes)
76+
// Ensure the session is not in the "Inactive" state; return an error if it is.
77+
if session.Status.Equal(v1base.StatusInactive) {
78+
return types.NewErrorInvalidSessionStatus(session.ID, session.Status)
79+
}
80+
81+
// Decode the session's account address from Bech32 format.
82+
accAddr, err := sdk.AccAddressFromBech32(session.AccAddress)
83+
if err != nil {
84+
return err
85+
}
5386

54-
// Update the utilised bytes in the allocation; cap it at the granted bytes if it exceeds the limit.
55-
alloc.UtilisedBytes = alloc.UtilisedBytes.Add(utilisedBytes)
56-
if alloc.UtilisedBytes.GT(alloc.GrantedBytes) {
57-
alloc.UtilisedBytes = alloc.GrantedBytes
87+
// Fetch the allocation for the subscription and account; return an error if it doesn't exist.
88+
alloc, found := k.GetAllocation(ctx, session.SubscriptionID, accAddr)
89+
if !found {
90+
return types.NewErrorAllocationNotFound(session.SubscriptionID, accAddr)
5891
}
5992

60-
// Save the updated allocation in the store.
93+
// Update allocation's utilised bytes based on the difference between current and previous session bytes.
94+
diffBytes := currBytes.Sub(session.Bytes())
95+
alloc.UtilisedBytes = alloc.UtilisedBytes.Add(diffBytes)
96+
97+
// Store the updated allocation in the keeper.
6198
k.SetAllocation(ctx, alloc)
6299

63-
// Emit an event to log the allocation update.
100+
// Emit an event logging the updated allocation details.
64101
ctx.EventManager().EmitTypedEvent(
65102
&v3.EventAllocate{
66103
ID: alloc.ID,
@@ -70,16 +107,5 @@ func (k *Keeper) SessionInactivePreHook(ctx sdk.Context, id uint64) error {
70107
},
71108
)
72109

73-
// Convert the session's node address from Bech32 format.
74-
nodeAddr, err := base.NodeAddressFromBech32(session.NodeAddress)
75-
if err != nil {
76-
return err
77-
}
78-
79-
// Delete the session records associated with allocation, node, plan, and subscription from the store.
80-
k.DeleteSessionForAllocation(ctx, subscription.ID, accAddr, session.ID)
81-
k.DeleteSessionForPlanByNode(ctx, subscription.PlanID, nodeAddr, session.ID)
82-
k.DeleteSessionForSubscription(ctx, subscription.ID, session.ID)
83-
84110
return nil
85111
}

0 commit comments

Comments
 (0)