Skip to content

Commit c4267b1

Browse files
domm99Filocava99
andauthored
Vmas integration (#31)
* feat: Setup module for the VMAS integration * feat: completed integration with vmas * feat: add rendering and cuda support * feat: add wandb logging and generalize logger for CTDESystem, Learner and DTDEAgent * feat: refactor classes names and deprecate VmasCTDEAgent class * chore: delete unnecessary class * feat: add CohesionAndCollision scenario without LiDAR * feat: boost performances of CohesionAndCollisionNoLidar scenario * feat: boost performances of CohesionAndCollisionNoLidar scenario * feat: add snapshot for tests * feat: add DSL for defining reward functions --------- Co-authored-by: Filippo Cavallari <[email protected]>
1 parent e11d10c commit c4267b1

27 files changed

+1461
-12
lines changed

138-2024-01-16-18-24-46-agent-0

24 KB
Binary file not shown.

scarlib-core/src/main/scala/it/unibo/scarlib/core/model/AutodiffDevice.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ object AutodiffDevice {
1818
def apply() =
1919
deepLearningLib()
2020
.device(if (deepLearningLib().cuda.is_available().as[Boolean]) "cuda" else "cpu")
21+
// .device("cpu")
2122
}

scarlib-core/src/main/scala/it/unibo/scarlib/core/model/DeepQLearner.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
package it.unibo.scarlib.core.model
1111

1212
import it.unibo.scarlib.core.neuralnetwork.{NeuralNetworkEncoding, SimpleSequentialDQN, TorchSupport}
13-
import it.unibo.scarlib.core.util.TorchLiveLogger
13+
import it.unibo.scarlib.core.util.{Logger, TorchLiveLogger}
1414
import me.shadaj.scalapy.py
1515
import me.shadaj.scalapy.py.{PyQuote, SeqConverters}
1616

1717
import java.text.SimpleDateFormat
1818
import java.util.Date
19+
import scala.reflect.io.{File, Path}
1920
import scala.util.Random
2021

2122
/** The DQN learning algorithm
@@ -27,7 +28,8 @@ import scala.util.Random
2728
class DeepQLearner(
2829
memory: ReplayBuffer[State, Action],
2930
actionSpace: Seq[Action],
30-
learningConfiguration: LearningConfiguration
31+
learningConfiguration: LearningConfiguration,
32+
logger: Logger
3133
)(implicit encoding: NeuralNetworkEncoding[State]) extends Learner {
3234

3335
private val random = learningConfiguration.random
@@ -71,7 +73,7 @@ class DeepQLearner(
7173
val expectedValue = (nextStateValues * gamma) + rewards
7274
val criterion = TorchSupport.neuralNetworkModule().SmoothL1Loss()
7375
val loss = criterion(stateActionValue, expectedValue.unsqueeze(1))
74-
TorchLiveLogger.logScalar("Loss", loss.item().as[Double], updates)
76+
logger.logScalar("Loss", loss.item().as[Double], updates)
7577
optimizer.zero_grad()
7678
loss.backward()
7779
it.unibo.scarlib.core.neuralnetwork.TorchSupport
@@ -93,9 +95,14 @@ class DeepQLearner(
9395
.deepLearningLib()
9496
.save(
9597
targetNetwork.state_dict(),
96-
s"${learningConfiguration.snapshotPath}-$episode-$timeMark-agent-$agentId"
98+
s"${learningConfiguration.snapshotPath}${File.separator}$episode-$timeMark-agent-$agentId"
9799
)
98100
}
101+
102+
override def loadSnapshot(path: String): Unit = {
103+
targetNetwork.load_state_dict(TorchSupport.deepLearningLib().load(path, map_location = AutodiffDevice()))
104+
policyNetwork.load_state_dict(TorchSupport.deepLearningLib().load(path, map_location = AutodiffDevice()))
105+
}
99106
}
100107

101108
object DeepQLearner {

scarlib-core/src/main/scala/it/unibo/scarlib/core/model/Learner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ trait Learner {
2323
/** Takes a snapshot of the current policy */
2424
def snapshot(episode: Int, agentId: Int): Unit
2525

26+
def loadSnapshot(path: String): Unit
2627
}

scarlib-core/src/main/scala/it/unibo/scarlib/core/neuralnetwork/DeepLearningSupport.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@ trait DeepLearningSupport[M]{
1818

1919
def logger(): M
2020

21+
def arrayModule: M
22+
2123
}
2224

scarlib-core/src/main/scala/it/unibo/scarlib/core/neuralnetwork/TorchSupport.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ object TorchSupport extends DeepLearningSupport[py.Module] {
2020
override def optimizerModule(): py.Module = py.module("torch.optim")
2121

2222
override def logger(): py.Module = py.module("torch.utils.tensorboard")
23+
24+
override def arrayModule: py.Module = py.module("numpy")
2325
}

scarlib-core/src/main/scala/it/unibo/scarlib/core/system/CTDESystem.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
package it.unibo.scarlib.core.system
1111

12-
import it.unibo.scarlib.core.model._
12+
import it.unibo.scarlib.core.model.{Action, Decay, DeepQLearner, Environment, LearningConfiguration, ReplayBuffer, State}
13+
import it.unibo.scarlib.core.util.{Logger, TorchLiveLogger}
1314
import it.unibo.scarlib.core.neuralnetwork.{NeuralNetworkEncoding, NeuralNetworkSnapshot}
1415

1516
import scala.annotation.tailrec
@@ -30,12 +31,16 @@ class CTDESystem(
3031
environment: Environment,
3132
dataset: ReplayBuffer[State, Action],
3233
actionSpace: Seq[Action],
33-
learningConfiguration: LearningConfiguration
34+
learningConfiguration: LearningConfiguration,
35+
logger: Logger = TorchLiveLogger
3436
)(implicit context: ExecutionContext, encoding: NeuralNetworkEncoding[State]) {
3537

3638
private val epsilon: Decay[Double] = learningConfiguration.epsilon
37-
private val learner =
38-
new DeepQLearner(dataset, actionSpace, learningConfiguration)
39+
40+
private val learner = new DeepQLearner(dataset, actionSpace, learningConfiguration, logger)
41+
42+
43+
3944

4045
/** Starts the learning process
4146
*
@@ -47,6 +52,7 @@ class CTDESystem(
4752
@tailrec
4853
def singleEpisode(time: Int): Unit =
4954
if (time > 0) {
55+
println("Time: " + time)
5056
agents.foreach(_.notifyNewPolicy(learner.behavioural))
5157
Await.ready(Future.sequence(agents.map(_.step())), scala.concurrent.duration.Duration.Inf)
5258
environment.log()
@@ -65,6 +71,11 @@ class CTDESystem(
6571

6672
}
6773

74+
final def learn(episodes: Int, episodeLength: Int, snapshot: String): Unit = {
75+
learner.loadSnapshot(snapshot)
76+
learn(episodes, episodeLength)
77+
}
78+
6879
/** Starts the testing process
6980
*
7081
* @param episodeLength the length of the episode
@@ -79,6 +90,7 @@ class CTDESystem(
7990

8091
@tailrec
8192
def episode(time: Int): Unit = {
93+
println(time)
8294
if (time > 0) {
8395
Await.ready(Future.sequence(agents.map(_.step())), scala.concurrent.duration.Duration.Inf)
8496
episode(time - 1)

scarlib-core/src/main/scala/it/unibo/scarlib/core/system/DTDEAgent.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
package it.unibo.scarlib.core.system
1111

12-
import it.unibo.scarlib.core.model._
12+
import it.unibo.scarlib.core.model.{Action, Agent, AgentMode, Decay, DeepQLearner, Environment, Experience, LearningConfiguration, ReplayBuffer, State}
1313
import it.unibo.scarlib.core.neuralnetwork.{NeuralNetworkEncoding, NeuralNetworkSnapshot}
14+
import it.unibo.scarlib.core.util.{Logger, TorchLiveLogger}
1415

1516
import scala.reflect.io.File
1617
import scala.concurrent.ExecutionContext.Implicits.global
@@ -31,12 +32,13 @@ class DTDEAgent(
3132
actionSpace: Seq[Action],
3233
datasetSize: Int,
3334
agentMode: AgentMode = AgentMode.Training,
34-
learningConfiguration: LearningConfiguration
35+
learningConfiguration: LearningConfiguration,
36+
logger: Logger = TorchLiveLogger
3537
)(implicit encoding: NeuralNetworkEncoding[State]) extends Agent {
3638

3739
private val dataset: ReplayBuffer[State, Action] = ReplayBuffer[State, Action](datasetSize)
3840
private val epsilon: Decay[Double] = learningConfiguration.epsilon
39-
private val learner = new DeepQLearner(dataset, actionSpace, learningConfiguration)
41+
private val learner = new DeepQLearner(dataset, actionSpace, learningConfiguration, logger)
4042
private var testPolicy: State => Action = _
4143

4244
/** A single interaction of the agent with the environment */
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package it.unibo.scarlib.core.util
2+
3+
import me.shadaj.scalapy.py
4+
5+
trait Logger {
6+
def logScalar(tag: String, value: Double, tick: Int): Unit
7+
8+
def logAny(tag: String, value: py.Dynamic, tick: Int): Unit
9+
}

scarlib-core/src/main/scala/it/unibo/scarlib/core/util/TorchLiveLogger.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ package it.unibo.scarlib.core.util
1212
import it.unibo.scarlib.core.neuralnetwork.TorchSupport
1313
import me.shadaj.scalapy.py
1414

15-
object TorchLiveLogger {
15+
object TorchLiveLogger extends Logger{
1616
private val writer = TorchSupport.logger().SummaryWriter()
1717

1818
def logScalar(tag: String, value: Double, tick: Int): Unit = writer.add_scalar(tag, value, tick)
1919

2020
def logAny(tag: String, value: py.Dynamic, tick: Int): Unit = writer.add_scalar(tag, value, tick)
2121
}
2222

23+

0 commit comments

Comments
 (0)