Skip to content

Commit c803b25

Browse files
committed
ScaIR wow
1 parent a4c52c2 commit c803b25

File tree

1 file changed

+29
-1
lines changed
  • src/main/scala/uk/ac/ed/dal/structtensor

1 file changed

+29
-1
lines changed

src/main/scala/uk/ac/ed/dal/structtensor/Main.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import java.io.File
1111
import scopt.OParser
1212

1313
object Main {
14-
def main(args: Array[String]) = {
14+
def main(args: Array[String]): Unit = {
1515
import Optimizer._
1616
import Utils._
1717

@@ -112,37 +112,49 @@ object Main {
112112
.toSeq
113113
.filter(_.nonEmpty)
114114
.filterNot(_.startsWith("#"))
115+
116+
// Specify names of symbols
115117
val (symbols_lines, symbols_index) =
116118
lineSeqInit.zipWithIndex.filter(_._1.startsWith("symbols:")).unzip
119+
120+
// Specify names of tensors expected to be output
117121
val (outputs_lines, outputs_index) =
118122
lineSeqInit.zipWithIndex.filter(_._1.startsWith("outputs:")).unzip
123+
124+
// Specify names of iterators
119125
val (iters_lines, iters_index) =
120126
lineSeqInit.zipWithIndex.filter(_._1.startsWith("iters:")).unzip
127+
// parsed symbols names
121128
val symbols = symbols_lines
122129
.map(e => e.slice(8, e.length))
123130
.flatMap(_.split(",").map(_.trim).toSeq)
124131
.map(Variable(_))
132+
// parsed outputs names
125133
val outputs_names = outputs_lines
126134
.map(e => e.slice(8, e.length))
127135
.flatMap(_.split(",").map(_.trim).toSeq)
136+
// parsed iterators names and vars
128137
val iters_map = iters_lines
129138
.map(e => e.slice(6, e.length))
130139
.flatMap(_.split(";").map(_.trim).toSeq)
131140
.map(iter_str =>
132141
fastparse.parse(iter_str, Parser.iterators(_)).get.value
133142
)
134143
.toMap
144+
// Remaining input lines AKA the program + compression hatches
135145
val lineSeq = lineSeqInit.zipWithIndex
136146
.filterNot(x =>
137147
symbols_index.contains(x._2) ||
138148
outputs_index.contains(x._2) ||
139149
iters_index.contains(x._2)
140150
)
141151
.map(_._1)
152+
// Manual arbitrary access computation danger zone start
142153
val preprocess_start_index = lineSeq.indexOf("@preprocess_start")
143154
val preprocess_end_index = lineSeq.indexOf("@preprocess_end")
144155
val preprocess_lines =
145156
lineSeq.slice(preprocess_start_index + 1, preprocess_end_index)
157+
// Remaining input lines AKA JUST the program
146158
val computation_lines = lineSeq.slice(
147159
0,
148160
preprocess_start_index
@@ -154,19 +166,29 @@ object Main {
154166
res.head
155167
})
156168
.toSeq
169+
// Danger zone end?
170+
171+
// Program parsed as bunch of rules
157172
val parsedComputation = computation_lines
158173
.map(line => {
159174
val Parsed.Success(res, _) = parse(line, parser(_))
160175
res.head
161176
})
162177
.toSeq
178+
179+
// cOOKED COMPRESSION RULES
180+
// Ignore if ignoring manual compression hatch
163181
val (
164182
all_tensors_preprocess,
165183
tensorComputations_preprocess,
166184
dimInfo_preprocess,
167185
uniqueSets_preprocess,
168186
redundancyMaps_preprocess
169187
) = convertRules(parsedPreprocess)
188+
189+
// Tensor information extracted from the computation rules, before
190+
// any inference.
191+
// Probably what we want to ScaIR out!
170192
val (
171193
all_tensors_computation,
172194
tensorComputations_computation,
@@ -184,7 +206,13 @@ object Main {
184206
symbols,
185207
outputs_names
186208
)
209+
val ScaIR = false
187210

211+
if (ScaIR) {
212+
// wow
213+
return ()
214+
}
215+
188216
val (newUS, newRM, newCC, ccRuleSeq, rcRuleSeq) =
189217
tensorComputations_computation.foldLeft(
190218
(

0 commit comments

Comments
 (0)