2222 Set ,
2323)
2424
25+ from _pytest .skipping import xfailed_key
26+
2527from .constants import PYTEST_NODE_SEP
2628from .data import (
2729 Snapshot ,
@@ -70,6 +72,7 @@ class SnapshotReport:
7072 used : "SnapshotCollections" = field (default_factory = SnapshotCollections )
7173 _provided_test_paths : Dict [str , List [str ]] = field (default_factory = dict )
7274 _keyword_expressions : Set ["Expression" ] = field (default_factory = set )
75+ _num_xfails : int = field (default = 0 )
7376
7477 @property
7578 def update_snapshots (self ) -> bool :
@@ -89,6 +92,14 @@ def _collected_items_by_nodeid(self) -> Dict[str, "pytest.Item"]:
8992 getattr (item , "nodeid" ): item for item in self .collected_items # noqa: B009
9093 }
9194
95+ def _has_xfail (self , item : "pytest.Item" ) -> bool :
96+ # xfailed_key is 'private'. I'm open to a better way to do this:
97+ if xfailed_key in item .stash :
98+ result = item .stash [xfailed_key ]
99+ if result :
100+ return result .run
101+ return False
102+
92103 def __post_init__ (self ) -> None :
93104 self .__parse_invocation_args ()
94105
@@ -113,13 +124,17 @@ def __post_init__(self) -> None:
113124 Snapshot (name = result .snapshot_name , data = result .final_data )
114125 )
115126 self .used .update (snapshot_collection )
127+
116128 if result .created :
117129 self .created .update (snapshot_collection )
118130 elif result .updated :
119131 self .updated .update (snapshot_collection )
120132 elif result .success :
121133 self .matched .update (snapshot_collection )
122134 else :
135+ has_xfail = self ._has_xfail (item = result .test_location .item )
136+ if has_xfail :
137+ self ._num_xfails += 1
123138 self .failed .update (snapshot_collection )
124139
125140 def __parse_invocation_args (self ) -> None :
@@ -161,7 +176,7 @@ def __parse_invocation_args(self) -> None:
161176 def num_created (self ) -> int :
162177 return self ._count_snapshots (self .created )
163178
164- @property
179+ @cached_property
165180 def num_failed (self ) -> int :
166181 return self ._count_snapshots (self .failed )
167182
@@ -256,14 +271,22 @@ def lines(self) -> Iterator[str]:
256271 ```
257272 """
258273 summary_lines : List [str ] = []
259- if self .num_failed :
274+ if self .num_failed and self . _num_xfails < self . num_failed :
260275 summary_lines .append (
261276 ngettext (
262277 "{} snapshot failed." ,
263278 "{} snapshots failed." ,
264- self .num_failed ,
265- ).format (error_style (self .num_failed ))
279+ self .num_failed - self . _num_xfails ,
280+ ).format (error_style (self .num_failed - self . _num_xfails )),
266281 )
282+ if self ._num_xfails :
283+ summary_lines .append (
284+ ngettext (
285+ "{} snapshot xfailed." ,
286+ "{} snapshots xfailed." ,
287+ self ._num_xfails ,
288+ ).format (warning_style (self ._num_xfails )),
289+ )
267290 if self .num_matched :
268291 summary_lines .append (
269292 ngettext (
0 commit comments