Skip to content

Commit 90846f6

Browse files
authored
Fix DiscreteSACExperimentBuilder not exposing with_actor_factory_default (#1250)
- [X] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) - [X] I have provided a description of the changes in this Pull Request - [x] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md - [x] If applicable, I have added tests to cover my changes. - [x] I have reformatted the code using `poe format` - [ ] I have checked style and types with `poe lint` and `poe type-check` - [ ] (Optional) I ran tests locally with `poe test` (or a subset of them with `poe test-reduced`) ,and they pass - [ ] (Optional) I have tested that documentation builds correctly with `poe doc-build` Fixes #1248 Additional Changes: * Add information on Windows-specific developer configuration * Fix an incorrect docstring
1 parent c4ae7cd commit 90846f6

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Changelog
22

3-
## Release 1.2.0
3+
## Unreleased
44

55
### Changes/Improvements
66

77
- trainer:
88
- Custom scoring now supported for selecting the best model. #1202
9+
- highlevel:
10+
- `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250
911

1012
### Breaking Changes
1113

docs/04_contributing/04_contributing.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@ to install all relevant requirements in editable mode you can simply call
1313
$ poetry install --with dev
1414
1515
16+
Platform-Specific Configuration
17+
-------------------------------
18+
19+
**Windows**:
20+
Since the repository contains symbolic links, make sure this is supported:
21+
22+
* Enable Windows Developer Mode to allow symbolic links to be created: Search Start Menu for "Developer Settings" and enable "Developer Mode"
23+
* Enable symbolic links for this repository: ``git config core.symlinks true``
24+
* Re-checkout the current git state: ``git checkout .``
25+
26+
1627
PEP8 Code Style Check and Formatting
1728
----------------------------------------
1829

tianshou/highlevel/experiment.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,28 @@ def with_actor_factory_default(
801801
return super()._with_actor_factory_default(hidden_sizes, hidden_activation)
802802

803803

804+
class _BuilderMixinActorFactory_DiscreteOnly(_BuilderMixinActorFactory):
805+
"""Specialization of the actor mixin where only environments with discrete action spaces are supported."""
806+
807+
def __init__(self) -> None:
808+
super().__init__(ContinuousActorType.UNSUPPORTED)
809+
810+
def with_actor_factory_default(
811+
self,
812+
hidden_sizes: Sequence[int],
813+
hidden_activation: ModuleType = torch.nn.ReLU,
814+
) -> Self:
815+
"""Defines use of the default actor factory, allowing its parameters it to be customized.
816+
817+
The default actor factory uses an MLP-style architecture.
818+
819+
:param hidden_sizes: dimensions of hidden layers used by the network
820+
:param hidden_activation: the activation function to use for hidden layers
821+
:return: the builder
822+
"""
823+
return super()._with_actor_factory_default(hidden_sizes, hidden_activation)
824+
825+
804826
class _BuilderMixinCriticsFactory:
805827
def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol):
806828
self._actor_future_provider = actor_future_provider
@@ -959,7 +981,7 @@ def with_critic2_factory_default(
959981
return self
960982

961983
def with_critic2_factory_use_actor(self) -> Self:
962-
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
984+
"""Makes the second critic reuse the actor's preprocessing network (parameter sharing)."""
963985
return self._with_critic_factory_use_actor(1)
964986

965987

@@ -1333,7 +1355,7 @@ def _create_agent_factory(self) -> AgentFactory:
13331355

13341356
class DiscreteSACExperimentBuilder(
13351357
ExperimentBuilder,
1336-
_BuilderMixinActorFactory,
1358+
_BuilderMixinActorFactory_DiscreteOnly,
13371359
_BuilderMixinDualCriticFactory,
13381360
):
13391361
def __init__(
@@ -1343,7 +1365,7 @@ def __init__(
13431365
sampling_config: SamplingConfig | None = None,
13441366
):
13451367
super().__init__(env_factory, experiment_config, sampling_config)
1346-
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
1368+
_BuilderMixinActorFactory_DiscreteOnly.__init__(self)
13471369
_BuilderMixinDualCriticFactory.__init__(self, self)
13481370
self._params: DiscreteSACParams = DiscreteSACParams()
13491371

0 commit comments

Comments
 (0)