Skip to content

Commit d2bc4c8

Browse files
committed
Merge branch 'main' of github.com:cram2/semantic_world into world-copy
2 parents b5f31e2 + 0f4be1b commit d2bc4c8

25 files changed

Lines changed: 3397 additions & 690 deletions

examples/persistence_of_annotated_worlds.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ world = URDFParser.from_file(table).parse()
5454
Next, we create a semantic annotation that describes the table.
5555

5656
```{code-cell} ipython3
57-
table_semantic_annotation = Table([b for b in world.bodies if "top" in str(b.name)][0])
57+
table_semantic_annotation = Table(body=[b for b in world.bodies if "top" in str(b.name)][0])
5858
with world.modify_world():
5959
world.add_semantic_annotation(table_semantic_annotation)
6060
print(table_semantic_annotation)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
-- create_postgres_database_and_user_if_not_exists.sql (PostgreSQL, idempotent, psql-compatible)
2+
-- Usage (as postgres superuser):
3+
-- sudo -u postgres psql -f create_postgres_database_and_user_if_not_exists.sql \
4+
-- -v db_name="semantic_digital_twin" \
5+
-- -v user_name="semantic_digital_twin" \
6+
-- -v user_password="a_strong_password_here"
7+
8+
-- 1) Create database if it does not exist (PostgreSQL has no CREATE DATABASE IF NOT EXISTS)
9+
SELECT 'CREATE DATABASE ' || quote_ident(:'db_name')
10+
WHERE NOT EXISTS (
11+
SELECT FROM pg_database WHERE datname = :'db_name'
12+
)\gexec
13+
14+
-- 2) Create role (user) if it does not exist
15+
SELECT format('CREATE ROLE %I LOGIN PASSWORD %L', :'user_name', :'user_password')
16+
WHERE NOT EXISTS (
17+
SELECT FROM pg_roles WHERE rolname = :'user_name'
18+
)\gexec
19+
20+
-- 3) Make that role the owner of the database
21+
ALTER DATABASE :"db_name" OWNER TO :"user_name";
22+
23+
-- 4) Connect to the target database as current superuser (not the new user),
24+
-- so we can set permissions and default privileges inside that database
25+
\c :"db_name"
26+
27+
-- 5) Grant privileges on the public schema to the application role
28+
GRANT ALL ON SCHEMA public TO :"user_name";
29+
30+
-- 6) Ensure future objects get privileges automatically
31+
-- Use FOR ROLE to set defaults owned by the application role
32+
ALTER DEFAULT PRIVILEGES FOR ROLE :"user_name" IN SCHEMA public GRANT ALL ON TABLES TO :"user_name";
33+
ALTER DEFAULT PRIVILEGES FOR ROLE :"user_name" IN SCHEMA public GRANT ALL ON SEQUENCES TO :"user_name";
34+
ALTER DEFAULT PRIVILEGES FOR ROLE :"user_name" IN SCHEMA public GRANT ALL ON FUNCTIONS TO :"user_name";

scripts/generate_orm.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,41 @@
1010
import os
1111
from dataclasses import is_dataclass
1212

13-
import krrood.entity_query_language.orm.model
1413
import trimesh
1514
from krrood.class_diagrams import ClassDiagram
16-
from krrood.entity_query_language.predicate import HasTypes, HasType, Symbol
17-
from krrood.entity_query_language.symbol_graph import SymbolGraph
1815
from krrood.ormatic.dao import AlternativeMapping
1916
from krrood.ormatic.ormatic import ORMatic
2017
from krrood.ormatic.utils import classes_of_module
2118
from krrood.utils import recursive_subclasses
2219

2320
import semantic_digital_twin.orm.model
21+
import semantic_digital_twin.reasoning.predicates
2422
import semantic_digital_twin.robots.abstract_robot
2523
import semantic_digital_twin.semantic_annotations.semantic_annotations
2624
import semantic_digital_twin.world # ensure the module attribute exists on the package
25+
import semantic_digital_twin.adapters.procthor.procthor_semantic_annotations
2726
import semantic_digital_twin.world_description.degree_of_freedom
2827
import semantic_digital_twin.world_description.geometry
2928
import semantic_digital_twin.world_description.shape_collection
3029
import semantic_digital_twin.world_description.world_entity
3130
from semantic_digital_twin.datastructures.prefixed_name import PrefixedName
31+
from semantic_digital_twin.reasoning.predicates import ContainsType
32+
from semantic_digital_twin.semantic_annotations.mixins import HasBody
3233
from semantic_digital_twin.spatial_computations.forward_kinematics import (
3334
ForwardKinematicsManager,
3435
)
3536
from semantic_digital_twin.world import (
3637
ResetStateContextManager,
3738
WorldModelUpdateContextManager,
3839
)
40+
from semantic_digital_twin.world import WorldModelManager
3941
from semantic_digital_twin.world_description.connections import (
4042
FixedConnection,
4143
HasUpdateState,
4244
)
45+
from semantic_digital_twin.orm.model import * # type: ignore
4346

44-
45-
# collect all semantic digital twin classes that should be mapped
46-
all_classes = set(classes_of_module(semantic_digital_twin.orm.model))
47-
all_classes |= set(
47+
all_classes = set(
4848
classes_of_module(semantic_digital_twin.world_description.world_entity)
4949
)
5050
all_classes |= set(classes_of_module(semantic_digital_twin.world_description.geometry))
@@ -66,6 +66,15 @@
6666
classes_of_module(semantic_digital_twin.world_description.degree_of_freedom)
6767
)
6868
all_classes |= set(classes_of_module(semantic_digital_twin.robots.abstract_robot))
69+
# classes |= set(recursive_subclasses(ViewFactory))
70+
all_classes |= set([HasBody] + recursive_subclasses(HasBody))
71+
all_classes |= set(classes_of_module(semantic_digital_twin.reasoning.predicates))
72+
all_classes |= set(classes_of_module(semantic_digital_twin.semantic_annotations.mixins))
73+
all_classes |= set(
74+
classes_of_module(
75+
semantic_digital_twin.adapters.procthor.procthor_semantic_annotations
76+
)
77+
)
6978

7079

7180
# remove classes that should not be mapped
@@ -74,29 +83,22 @@
7483
WorldModelUpdateContextManager,
7584
HasUpdateState,
7685
ForwardKinematicsManager,
86+
WorldModelManager,
87+
semantic_digital_twin.adapters.procthor.procthor_semantic_annotations.ProcthorResolver,
88+
ContainsType,
7789
}
78-
79-
# build the symbol graph
80-
symbol_graph = SymbolGraph()
81-
82-
# collect all KRROOD classes
83-
all_classes |= {c.clazz for c in symbol_graph.class_diagram.wrapped_classes}
84-
all_classes |= {am.original_class() for am in recursive_subclasses(AlternativeMapping)}
85-
all_classes |= set(classes_of_module(krrood.entity_query_language.symbol_graph))
86-
all_classes |= {Symbol}
87-
88-
# remove classes that don't need persistence
89-
all_classes -= {HasType, HasTypes}
90-
91-
9290
# keep only dataclasses that are NOT AlternativeMapping subclasses
9391
all_classes = {
9492
c for c in all_classes if is_dataclass(c) and not issubclass(c, AlternativeMapping)
9593
}
96-
97-
# ensure we have the original classes of the mappings (ORMatic uses these)
9894
all_classes |= {am.original_class() for am in recursive_subclasses(AlternativeMapping)}
9995

96+
alternative_mappings = [
97+
am
98+
for am in recursive_subclasses(AlternativeMapping)
99+
if am.original_class() in all_classes
100+
]
101+
100102

101103
def generate_orm():
102104
"""
@@ -109,7 +111,7 @@ def generate_orm():
109111
instance = ORMatic(
110112
class_dependency_graph=class_diagram,
111113
type_mappings={trimesh.Trimesh: semantic_digital_twin.orm.model.TrimeshType},
112-
alternative_mappings=recursive_subclasses(AlternativeMapping),
114+
alternative_mappings=alternative_mappings,
113115
)
114116

115117
instance.make_all_tables()

scripts/parse_procthor_files_and_save_to_database.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,36 @@
66
from typing import List
77

88
import tqdm
9+
from krrood.entity_query_language.symbol_graph import SymbolGraph
910
from krrood.ormatic.dao import to_dao
1011
from krrood.ormatic.utils import drop_database
12+
from krrood.utils import recursive_subclasses
1113
from sqlalchemy import create_engine
1214
from sqlalchemy.orm import Session
1315
from typing_extensions import TYPE_CHECKING
1416

17+
18+
from semantic_digital_twin.world import World
19+
20+
sg = SymbolGraph()
21+
22+
1523
from semantic_digital_twin.adapters.fbx import FBXParser
1624
from semantic_digital_twin.adapters.procthor.procthor_pipelines import (
1725
dresser_factory_from_body,
1826
)
1927
from semantic_digital_twin.orm.ormatic_interface import *
28+
from semantic_digital_twin.adapters.procthor.procthor_semantic_annotations import (
29+
ProcthorResolver,
30+
)
31+
from semantic_digital_twin.semantic_annotations.mixins import HasBody
2032
from semantic_digital_twin.pipeline.pipeline import (
2133
Pipeline,
2234
BodyFilter,
2335
BodyFactoryReplace,
2436
CenterLocalGeometryAndPreserveWorldPose,
2537
)
2638

27-
if TYPE_CHECKING:
28-
from semantic_digital_twin.world import World
29-
3039

3140
def remove_root_and_move_children_into_new_worlds(world: World) -> List[World]:
3241
"""
@@ -48,8 +57,10 @@ def remove_root_and_move_children_into_new_worlds(world: World) -> List[World]:
4857

4958
with world.modify_world():
5059

51-
worlds = [world.copy_subgraph_to_new_world(child) for child in root_children]
60+
worlds = [world.move_branch_to_new_world(child) for child in root_children]
5261
for world in worlds:
62+
if world.root is None:
63+
...
5364
world.name = world.root.name.name
5465

5566
return worlds
@@ -102,6 +113,12 @@ def parse_fbx_file_to_world_mapping_daos(fbx_file_path: str) -> List[WorldMappin
102113
worlds = remove_root_and_move_children_into_new_worlds(world)
103114

104115
worlds = replace_dresser_meshes_with_factories(worlds, dresser_pattern)
116+
resolver = ProcthorResolver(*[recursive_subclasses(HasBody)])
117+
for world in worlds:
118+
resolved = resolver.resolve(world.name)
119+
if resolved:
120+
with world.modify_world():
121+
world.add_semantic_annotation(resolved(body=world.root))
105122

106123
return [to_dao(world) for world in worlds]
107124

@@ -116,11 +133,11 @@ def parse_procthor_files_and_save_to_database(
116133
TODO: Ensure all relevant files, even those not inside a grp, are parsed.
117134
"""
118135
semantic_digital_twin_database_uri = os.environ.get(
119-
"semantic_digital_twin_DATABASE_URI"
136+
"SEMANTIC_DIGITAL_TWIN_DATABASE_URI"
120137
)
121138
assert (
122139
semantic_digital_twin_database_uri is not None
123-
), "Please set the semantic_digital_twin_DATABASE_URI environment variable."
140+
), "Please set the SEMANTIC_DIGITAL_TWIN_DATABASE_URI environment variable."
124141

125142
procthor_root = os.path.join(os.path.expanduser("~"), "ai2thor")
126143
# procthor_root = os.path.join(os.path.expanduser("~"), "work", "ai2thor")
@@ -147,7 +164,7 @@ def parse_procthor_files_and_save_to_database(
147164
if not any([e in f for e in excluded_words]) and fbx_file_pattern.fullmatch(f)
148165
]
149166
# Create database engine and session
150-
engine = create_engine(f"mysql+pymysql://{semantic_digital_twin_database_uri}")
167+
engine = create_engine(semantic_digital_twin_database_uri)
151168
session = Session(engine)
152169

153170
if drop_existing_database:
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import itertools
2+
import logging
3+
import os
4+
5+
import prior
6+
import tqdm
7+
from krrood.entity_query_language.symbol_graph import SymbolGraph
8+
from krrood.ormatic.dao import to_dao, ToDAOState
9+
from krrood.ormatic.utils import classes_of_module, drop_database
10+
from krrood.utils import recursive_subclasses
11+
from sqlalchemy import create_engine
12+
from sqlalchemy.orm import Session
13+
14+
import semantic_digital_twin.adapters.procthor.procthor_semantic_annotations
15+
from semantic_digital_twin.adapters.procthor.procthor_parser import ProcTHORParser
16+
from semantic_digital_twin.adapters.procthor.procthor_semantic_annotations import (
17+
ProcthorResolver,
18+
)
19+
from semantic_digital_twin.semantic_annotations.mixins import HasBody
20+
from semantic_digital_twin.reasoning.predicates import InsideOf
21+
from semantic_digital_twin.orm.ormatic_interface import *
22+
from semantic_digital_twin.world_description.world_entity import SemanticAnnotation
23+
24+
25+
def parse_procthor_worlds_and_calculate_containment_ratio():
26+
semantic_world_database_uri = os.environ.get("SEMANTIC_DIGITAL_TWIN_DATABASE_URI")
27+
semantic_world_engine = create_engine(semantic_world_database_uri, echo=False)
28+
semantic_world_session = Session(semantic_world_engine)
29+
30+
procthor_experiments_database_uri = os.environ.get(
31+
"PROCTHOR_EXPERIMENTS_DATABASE_URI"
32+
)
33+
procthor_experiments_engine = create_engine(
34+
procthor_experiments_database_uri, echo=False
35+
)
36+
# drop_database(procthor_experiments_engine)
37+
# Base.metadata.create_all(procthor_experiments_engine)
38+
procthor_experiments_session = Session(procthor_experiments_engine)
39+
40+
dataset = prior.load_dataset("procthor-10k")
41+
42+
# Iterate through all JSON files in the directory
43+
for index, house in enumerate(
44+
tqdm.tqdm(dataset["train"], desc="Parsing Procthor worlds")
45+
):
46+
if index < 5058:
47+
continue
48+
try:
49+
parser = ProcTHORParser(f"house_{index}", house, semantic_world_session)
50+
world = parser.parse()
51+
except Exception as e:
52+
logging.error(f"Error parsing house {index}: {e}")
53+
continue
54+
# resolve views
55+
resolver = ProcthorResolver(
56+
[
57+
cls
58+
for cls in classes_of_module(
59+
semantic_digital_twin.adapters.procthor.procthor_semantic_annotations
60+
)
61+
if issubclass(cls, SemanticAnnotation)
62+
]
63+
)
64+
for body in world.bodies:
65+
resolved = resolver.resolve(body.name.name)
66+
if resolved:
67+
with world.modify_world():
68+
world.add_semantic_annotation(
69+
resolved(body=body), skip_duplicates=True
70+
)
71+
72+
state = ToDAOState()
73+
daos = []
74+
75+
world_dao = to_dao(world, state=state)
76+
procthor_experiments_session.add(world_dao)
77+
78+
for kse, other in itertools.product(
79+
world.kinematic_structure_entities, world.kinematic_structure_entities
80+
):
81+
if kse != other:
82+
is_inside = InsideOf(kse, other)
83+
if is_inside() > 0.0:
84+
dao = to_dao(is_inside, state=state)
85+
daos.append(dao)
86+
87+
procthor_experiments_session.add_all(daos)
88+
procthor_experiments_session.commit()
89+
procthor_experiments_session.expunge_all()
90+
semantic_world_session.expunge_all()
91+
SymbolGraph().clear()
92+
93+
94+
if __name__ == "__main__":
95+
parse_procthor_worlds_and_calculate_containment_ratio()

0 commit comments

Comments
 (0)