1010from typing import IO , TYPE_CHECKING , Any , Callable , Dict , List , Optional , Type , Union
1111
1212import click
13- from jinja2 import Environment , FileSystemLoader , StrictUndefined
13+ from jinja2 import Environment , FileSystemLoader , StrictUndefined , Template
1414from paramiko import SSHException
1515from typeguard import TypeCheckError , check_type
1616
2626BLOCKSIZE = 65536
2727
2828# Caches
29- TEMPLATES : Dict [Any , Any ] = {}
29+ TEMPLATES : Dict [str , Template ] = {}
3030FILE_SHAS : Dict [Any , Any ] = {}
3131
3232PYINFRA_INSTALL_DIR = path .normpath (path .join (path .dirname (__file__ ), ".." ))
@@ -139,7 +139,9 @@ def get_operation_order_from_stack(state: "State"):
139139 return line_numbers
140140
141141
142- def get_template (filename_or_io : str | IO , jinja_env_kwargs : dict [str , Any ] | None = None ):
142+ def get_template (
143+ filename_or_io : str | IO , jinja_env_kwargs : dict [str , Any ] | None = None
144+ ) -> Template :
143145 """
144146 Gets a jinja2 ``Template`` object for the input filename or string, with caching
145147 based on the filename of the template, or the SHA1 of the input string.
@@ -155,10 +157,11 @@ def get_template(filename_or_io: str | IO, jinja_env_kwargs: dict[str, Any] | No
155157 with file_data as file_io :
156158 template_string = file_io .read ()
157159
160+ default_loader = FileSystemLoader (getcwd ())
158161 template = Environment (
159162 undefined = StrictUndefined ,
160163 keep_trailing_newline = True ,
161- loader = FileSystemLoader ( getcwd () ),
164+ loader = jinja_env_kwargs . pop ( "loader" , default_loader ),
162165 ** jinja_env_kwargs ,
163166 ).from_string (template_string )
164167
0 commit comments