-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmodule_list.h
More file actions
62 lines (50 loc) · 1.54 KB
/
module_list.h
File metadata and controls
62 lines (50 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// module_list.h
// List of arbitrary type modules
#ifndef TINYTENSOR_NN_MODULE_LIST_H_
#define TINYTENSOR_NN_MODULE_LIST_H_
#include <tt/device.h>
#include <tt/export.h>
#include <tt/nn/module.h>
#include <tt/tensor.h>
#include <concepts>
#include <memory>
#include <string>
#include <type_traits>
namespace tinytensor::nn {
class TINYTENSOR_EXPORT ModuleList : public Module {
private:
CheckedVec<std::shared_ptr<nn::Module>> modules;
using Iterator = decltype(modules)::Iterator;
using ConstIterator = decltype(modules)::ConstIterator;
public:
template <typename M>
requires(std::derived_from<M, nn::Module> && !std::is_lvalue_reference_v<M>)
void push_back(M &&module) {
using T = std::remove_reference_t<M>;
modules.push_back(std::make_shared<T>(std::forward<M>(module)));
register_module(*modules[-1]);
}
auto operator[](int idx) -> nn::Module & {
return *modules[idx];
}
auto operator[](int idx) const -> const nn::Module & {
return *modules[idx];
}
[[nodiscard]] auto name() const -> std::string override {
return "ModuleList";
}
[[nodiscard]] auto begin() -> Iterator {
return modules.begin();
}
[[nodiscard]] auto begin() const -> ConstIterator {
return modules.begin();
}
[[nodiscard]] auto end() -> Iterator {
return modules.end();
}
[[nodiscard]] auto end() const -> ConstIterator {
return modules.end();
}
};
} // namespace tinytensor::nn
#endif // TINYTENSOR_NN_MODULE_LIST_H_