Skip to content

Commit f13f4ce

Browse files
committed
Add support of absolute paths for enums on parse level
1 parent bd7b4b4 commit f13f4ce

14 files changed

Lines changed: 123 additions & 45 deletions

File tree

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package io.kaitai.struct.exprlang
2+
3+
import org.scalatest.funspec.AnyFunSpec
4+
import org.scalatest.matchers.should.Matchers._
5+
6+
class EnumRefSpec extends AnyFunSpec {
7+
describe("Expressions.parseEnumRef") {
8+
describe("parses local enum refs") {
9+
it("some_enum") {
10+
Expressions.parseEnumRef("some_enum") should be(Ast.EnumRef(
11+
false, Seq(), "some_enum"
12+
))
13+
}
14+
it("with spaces: ' some_enum '") {
15+
Expressions.parseEnumRef(" some_enum ") should be(Ast.EnumRef(
16+
false, Seq(), "some_enum"
17+
))
18+
}
19+
20+
it("::some_enum") {
21+
Expressions.parseEnumRef("::some_enum") should be(Ast.EnumRef(
22+
true, Seq(), "some_enum"
23+
))
24+
}
25+
it("with spaces: ' :: some_enum '") {
26+
Expressions.parseEnumRef(" :: some_enum ") should be(Ast.EnumRef(
27+
true, Seq(), "some_enum"
28+
))
29+
}
30+
}
31+
32+
describe("parses path enum refs") {
33+
it("some::enum") {
34+
Expressions.parseEnumRef("some::enum") should be(Ast.EnumRef(
35+
false, Seq("some"), "enum"
36+
))
37+
}
38+
it("with spaces: ' some :: enum '") {
39+
Expressions.parseEnumRef(" some :: enum ") should be(Ast.EnumRef(
40+
false, Seq("some"), "enum"
41+
))
42+
}
43+
44+
it("::some::enum") {
45+
Expressions.parseEnumRef("::some::enum") should be(Ast.EnumRef(
46+
true, Seq("some"), "enum"
47+
))
48+
}
49+
it("with spaces: ' :: some :: enum '") {
50+
Expressions.parseEnumRef(" :: some :: enum ") should be(Ast.EnumRef(
51+
true, Seq("some"), "enum"
52+
))
53+
}
54+
}
55+
}
56+
}

shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
486486
): LanguageCompiler = ???
487487

488488
def type2class(name: List[String]) = name.last
489-
def type2display(name: List[String]) = name.map(Utils.upperCamelCase).mkString("::")
489+
def type2display(name: Seq[String]) = name.map(Utils.upperCamelCase).mkString("::")
490490

491491
def dataTypeName(dataType: DataType, valid: Option[ValidationSpec]): String = {
492492
dataType match {
@@ -509,8 +509,8 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
509509
val bytesStr = dataTypeName(basedOn, valid)
510510
val comma = if (bytesStr.isEmpty) "" else ", "
511511
s"str($bytesStr$comma$encoding)"
512-
case EnumType(name, basedOn) =>
513-
s"${dataTypeName(basedOn, valid)}${type2display(name)}"
512+
case EnumType(ref, basedOn) =>
513+
s"${dataTypeName(basedOn, valid)}${type2display(ref.fullName)}"
514514
case BitsType(width, bitEndian) => s"b$width${bitEndian.toSuffix}"
515515
case BitsType1(bitEndian) => s"b1${bitEndian.toSuffix}→bool"
516516
case _ => dataType.toString

shared/src/main/scala/io/kaitai/struct/datatype/DataType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ object DataType {
259259
def isOwning = false
260260
}
261261

262-
case class EnumType(name: List[String], basedOn: IntType) extends DataType {
262+
case class EnumType(ref: Ast.EnumRef, basedOn: IntType) extends DataType {
263263
var enumSpec: Option[EnumSpec] = None
264264

265265
/**
@@ -459,7 +459,7 @@ object DataType {
459459
enumRef match {
460460
case Some(enumName) =>
461461
r match {
462-
case numType: IntType => EnumType(classNameToList(enumName), numType)
462+
case numType: IntType => EnumType(Expressions.parseEnumRef(enumName), numType)
463463
case _ =>
464464
throw KSYParseError(s"tried to resolve non-integer $r to enum", path).toException
465465
}

shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,18 @@ object Ast {
143143
case object GtE extends cmpop
144144
}
145145

146+
/**
147+
* Reference to an enum in scope. Scope is defined by the `absolute` flag and
148+
* a path to a type (which can be empty) in which enum is defined.
149+
*/
150+
case class EnumRef(absolute: Boolean, typePath: Seq[String], name: String) {
151+
/** @return Type path and name of enum in one list. */
152+
def fullName: Seq[String] = typePath :+ name
153+
/**
154+
* @return Enum designation name as human-readable string, to be used in compiler
155+
* error messages.
156+
*/
157+
def asStr: String = fullName.mkString(if (absolute) "::" else "", "::", "")
158+
}
146159
case class TypeWithArguments(typeName: typeId, arguments: expr.List)
147160
}

shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ object Expressions {
195195
case (path, Some(args)) => Ast.TypeWithArguments(path, args)
196196
}
197197

198+
def enumRef[$: P]: P[Ast.EnumRef] = P(Start ~ "::".!.? ~ NAME.rep(1, "::") ~ End).map {
199+
case (absolute, names) =>
200+
// List have at least one element, so we always can split it into head and the last element
201+
val typePath :+ enumName = names
202+
Ast.EnumRef(absolute.nonEmpty, typePath.map(i => i.name), enumName.name)
203+
}
204+
198205
class ParseException(val src: String, val failure: Parsed.Failure)
199206
extends RuntimeException(failure.msg)
200207

@@ -211,6 +218,14 @@ object Expressions {
211218
*/
212219
def parseTypeRef(src: String): Ast.TypeWithArguments = realParse(src, typeRef(_))
213220

221+
/**
222+
* Parse string with reference to enumeration definition, optionally in full path format.
223+
*
224+
* @param src Enum reference as string, like `::path::to::enum`
225+
* @return Object that represents path to enum
226+
*/
227+
def parseEnumRef(src: String): Ast.EnumRef = realParse(src, enumRef(_))
228+
214229
private def realParse[T](src: String, parser: P[_] => P[T]): T = {
215230
val r = fastparse.parse(src.trim, parser)
216231
r match {

shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ object CSharpCompiler extends LanguageCompilerStatic
708708
case KaitaiStreamType | OwnedKaitaiStreamType => kstreamName
709709

710710
case t: UserType => types2class(t.name)
711-
case EnumType(name, _) => types2class(name)
711+
case EnumType(ref, _) => types2class(ref.fullName)
712712

713713
case at: ArrayType => {
714714
importList.add("System.Collections.Generic")

shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ object CppCompiler extends LanguageCompilerStatic
11681168
types2class(if (absolute) {
11691169
t.enumSpec.get.name
11701170
} else {
1171-
t.name
1171+
t.ref.fullName
11721172
})
11731173

11741174
case at: ArrayType => {
@@ -1229,7 +1229,7 @@ object CppCompiler extends LanguageCompilerStatic
12291229
)
12301230
}
12311231

1232-
def types2class(components: List[String]) =
1232+
def types2class(components: Seq[String]) =
12331233
components.map(type2class).mkString("::")
12341234

12351235
def type2class(name: String) = Utils.lowerUnderscoreCase(name) + "_t"

shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,7 @@ object JavaCompiler extends LanguageCompilerStatic
13051305
case KaitaiStructType | CalcKaitaiStructType(_) => kstructNameFull(config)
13061306

13071307
case t: UserType => types2class(t.name)
1308-
case EnumType(name, _) => types2class(name)
1308+
case EnumType(ref, _) => types2class(ref.fullName)
13091309

13101310
case _: ArrayType => kaitaiType2JavaTypeBoxed(attrType, importList, config)
13111311

@@ -1349,7 +1349,7 @@ object JavaCompiler extends LanguageCompilerStatic
13491349
case KaitaiStructType | CalcKaitaiStructType(_) => kstructNameFull(config)
13501350

13511351
case t: UserType => types2class(t.name)
1352-
case EnumType(name, _) => types2class(name)
1352+
case EnumType(ref, _) => types2class(ref.fullName)
13531353

13541354
case at: ArrayType => {
13551355
importList.add("java.util.List")
@@ -1360,7 +1360,7 @@ object JavaCompiler extends LanguageCompilerStatic
13601360
}
13611361
}
13621362

1363-
def types2class(names: List[String]) = names.map(x => type2class(x)).mkString(".")
1363+
def types2class(names: Seq[String]) = names.map(x => type2class(x)).mkString(".")
13641364

13651365
override def kstreamName: String = "KaitaiStream"
13661366
override def kstructName: String = "KaitaiStruct"

shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ object RustCompiler
13021302
def classTypeName(c: ClassSpec): String =
13031303
s"${types2class(c.name)}"
13041304

1305-
def types2class(names: List[String]): String =
1305+
def types2class(names: Seq[String]): String =
13061306
// TODO: Use `mod` to scope types instead of weird names
13071307
names.map(x => type2class(x)).mkString("_")
13081308

@@ -1329,7 +1329,7 @@ object RustCompiler
13291329
case t: EnumType =>
13301330
val baseName = t.enumSpec match {
13311331
case Some(spec) => s"${types2class(spec.name)}"
1332-
case None => s"${types2class(t.name)}"
1332+
case None => s"${types2class(t.ref.fullName)}"
13331333
}
13341334
baseName
13351335

shared/src/main/scala/io/kaitai/struct/languages/ZigCompiler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ object ZigCompiler extends LanguageCompilerStatic
813813
if (isExternal) {
814814
externalTypeDeclaration(ExternalEnum(et.enumSpec.get), importList)
815815
}
816-
types2class(et.name, isExternal)
816+
types2class(et.ref.fullName, isExternal)
817817
}
818818

819819
case at: ArrayType => s"*_imp_std.ArrayList(${kaitaiType2NativeType(at.elType, importList, curClass)})"
@@ -854,7 +854,7 @@ object ZigCompiler extends LanguageCompilerStatic
854854
case ut: UserType => ut.name.last
855855
// NOTE: at the time of writing, this is unreachable because the `enum` key is not compatible
856856
// with type switching
857-
case et: EnumType => et.name.last
857+
case et: EnumType => et.ref.name
858858
}
859859

860860
def switchTaggedUnionName(id: Identifier): String = {

0 commit comments

Comments
 (0)