@@ -124,35 +124,24 @@ Exec_DefinitionRegistry::GetComputationDefinition(
124124 return nullptr ;
125125 }
126126
127- // Iterate over all ancestor types of the provider's schema type, from
128- // derived to base, starting with the schema type itself. Look for a
129- // matching plugin prim computation for the derived-most schema type that
130- // defines it, or null, if no matching computation can be found.
131- //
132- // TODO: Repeatedly traversing the schema type hierarchy like this is
133- // wasteful and we plan to cache results appropriately. But we still need to
134- // add support for applied schemas, and we will also need to add the ability
135- // to compose computation definitions, so for now we are leaving this
136- // inefficiency in place until more of that functionality lands. The current
137- // thinking is that exec will cache composed prim definitions that will be
138- // keyed off of a tuple of typed and applied schemas and will cache the
139- // resulting set of composed computation definitions. That will enable this
140- // code to construct a key and do a single lookup into the cache, rather
141- // than searching to find the computation definition.
142-
143- TfType foundType;
144- std::vector<TfType> schemaAncestorTypes;
145- schemaType.GetAllAncestorTypes (&schemaAncestorTypes);
146-
147- for (const TfType type : schemaAncestorTypes) {
148- if (const auto pluginIt = _pluginPrimComputationDefinitions.find (
149- {type, computationName});
150- pluginIt != _pluginPrimComputationDefinitions.end ()) {
151- return &pluginIt->second ;
152- }
127+ // Get the composed prim definition, creating it if necesseary, and use it
128+ // to look up the computation, or to determine that the requested
129+ // computation isn't defined for this prim.
130+ auto composedDefIt = _composedPrimDefinitions.find (schemaType);
131+ if (composedDefIt == _composedPrimDefinitions.end ()) {
132+ // Note that we allow concurrent callers to race to compose prim
133+ // definitions, since it is safe to do so and we don't expect it to
134+ // happen in the common case.
135+ _ComposedPrimDefinition primDef =
136+ _ComposePrimDefinition (schemaType);
137+
138+ composedDefIt = _composedPrimDefinitions.emplace (
139+ schemaType, std::move (primDef)).first ;
153140 }
154141
155- return nullptr ;
142+ const auto &compDefs = composedDefIt->second .primComputationDefinitions ;
143+ const auto it = compDefs.find (computationName);
144+ return it == compDefs.end () ? nullptr : it->second ;
156145}
157146
158147const Exec_ComputationDefinition *
@@ -176,6 +165,52 @@ Exec_DefinitionRegistry::GetComputationDefinition(
176165 return nullptr ;
177166}
178167
168+ Exec_DefinitionRegistry::_ComposedPrimDefinition
169+ Exec_DefinitionRegistry::_ComposePrimDefinition (
170+ const TfType schemaType) const
171+ {
172+ TRACE_FUNCTION ();
173+
174+ // Iterate over all ancestor types of the provider's schema type, from
175+ // derived to base, starting with the schema type itself. Ensure that plugin
176+ // computations have been loaded for each schema type for which they are
177+ // registered. Add all plugin computations registered for each type to the
178+ // composed prim definition.
179+ //
180+ // TODO: Add support for computations that are registered for applied
181+ // schemas. To do that, instead of keying off the schema type we will need
182+ // to use a "configuration key" that combines the typed schema with applied
183+ // schemas. We will also need to search through all applied schemas, in
184+ // strength order, in addition to searching up the typed schema type
185+ // hierarchy.
186+
187+ std::vector<TfType> schemaAncestorTypes;
188+ schemaType.GetAllAncestorTypes (&schemaAncestorTypes);
189+
190+ // Build up the composed prim definition.
191+ _ComposedPrimDefinition primDef;
192+
193+ for (const TfType type : schemaAncestorTypes) {
194+
195+ // TODO: For all but the first type, it makes sense to look in
196+ // _composedPrimDefinitions to see if we have already composed the base
197+ // type, and then to merge, rather than keep searching up the type
198+ // hierarchy.
199+
200+ if (const auto pluginIt = _pluginPrimComputationDefinitions.find (type);
201+ pluginIt != _pluginPrimComputationDefinitions.end ()) {
202+ for (const Exec_PluginComputationDefinition &computationDef :
203+ pluginIt->second ) {
204+ primDef.primComputationDefinitions .emplace (
205+ computationDef.GetComputationName (),
206+ &computationDef);
207+ }
208+ }
209+ }
210+
211+ return primDef;
212+ }
213+
179214void
180215Exec_DefinitionRegistry::_RegisterPrimComputation (
181216 TfType schemaType,
@@ -203,14 +238,11 @@ Exec_DefinitionRegistry::_RegisterPrimComputation(
203238 }
204239
205240 const bool emplaced =
206- _pluginPrimComputationDefinitions.emplace (
207- std::piecewise_construct,
208- std::forward_as_tuple (schemaType, computationName),
209- std::forward_as_tuple (
210- resultType,
211- computationName,
212- std::move (callback),
213- std::move (inputKeys))).second ;
241+ _pluginPrimComputationDefinitions[schemaType].emplace (
242+ resultType,
243+ computationName,
244+ std::move (callback),
245+ std::move (inputKeys)).second ;
214246
215247 if (!emplaced) {
216248 TF_CODING_ERROR (
0 commit comments