Skip to content

Commit 9b6dc6b

Browse files
committed
fix join types in TreeManager
1 parent 2ee01d1 commit 9b6dc6b

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

backend/oasst_backend/tree_manager.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def query_prompts_need_review(self) -> list[Message]:
636636
qry = (
637637
self.db.query(Message)
638638
.select_from(MessageTreeState)
639-
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
639+
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
640640
.filter(
641641
MessageTreeState.active,
642642
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
@@ -661,7 +661,7 @@ def query_replies_need_review(self) -> list[Message]:
661661
qry = (
662662
self.db.query(Message)
663663
.select_from(MessageTreeState)
664-
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
664+
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
665665
.filter(
666666
MessageTreeState.active,
667667
MessageTreeState.state == message_tree_state.State.GROWING,
@@ -682,7 +682,7 @@ def query_replies_need_review(self) -> list[Message]:
682682
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
683683
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
684684
FROM message_tree_state mts
685-
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
685+
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
686686
WHERE mts.active -- only consider active trees
687687
AND mts.state = :ranking_state -- message tree must be in ranking state
688688
AND m.review_result -- must be reviewed
@@ -708,7 +708,7 @@ def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
708708
-- find all extendible parent nodes
709709
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
710710
FROM message_tree_state mts
711-
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
711+
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
712712
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
713713
WHERE mts.active -- only consider active trees
714714
AND mts.state = :growing_state -- message tree must be growing
@@ -738,8 +738,8 @@ def query_extendible_parents(self) -> list[ExtendibleParentRow]:
738738
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
739739
FROM (
740740
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
741-
) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
742-
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
741+
) trees INNER JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
742+
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
743743
WHERE NOT m.deleted
744744
AND (
745745
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
@@ -787,7 +787,7 @@ def query_misssing_tree_states(self) -> list[UUID]:
787787
"""Find all initial prompt messages that have no associated message tree state"""
788788
qry_missing_tree_states = (
789789
self.db.query(Message.id)
790-
.join(MessageTreeState, isouter=True)
790+
.outerjoin(MessageTreeState, Message.message_tree_id == MessageTreeState.message_tree_id)
791791
.filter(
792792
Message.parent_id.is_(None),
793793
Message.message_tree_id == Message.id,
@@ -804,7 +804,7 @@ def query_misssing_tree_states(self) -> list[UUID]:
804804
-- find parents with > 1 children
805805
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
806806
FROM message_tree_state mts
807-
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
807+
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
808808
WHERE m.review_result -- must be reviewed
809809
AND NOT m.deleted -- not deleted
810810
AND m.parent_id IS NOT NULL -- ignore initial prompts
@@ -813,8 +813,8 @@ def query_misssing_tree_states(self) -> list[UUID]:
813813
GROUP BY m.parent_id, m.message_tree_id
814814
HAVING COUNT(m.id) > 1
815815
) as p
816-
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
817-
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
816+
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
817+
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
818818
"""
819819

820820
def query_tree_ranking_results(
@@ -925,7 +925,7 @@ def _insert_default_state(
925925
# print("query_num_active_trees", tm.query_num_active_trees())
926926
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
927927
print("query_replies_need_review", tm.query_replies_need_review())
928-
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
928+
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
929929
# print("query_extendible_trees", tm.query_extendible_trees())
930930
# print("query_extendible_parents", tm.query_extendible_parents())
931931
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))

0 commit comments

Comments
 (0)