@@ -11165,16 +11165,38 @@ def replace(
1116511165 raise TypeError (msg )
1116611166 new = list (old .values ())
1116711167 old = list (old .keys ())
11168+
11169+ old_is_seq = isinstance (old , Sequence ) and not isinstance (old , (str , pl .Series ))
11170+ new_is_seq = isinstance (new , Sequence ) and not isinstance (new , (str , pl .Series ))
11171+ has_expr_in_old = old_is_seq and any (isinstance (v , Expr ) for v in old ) # type: ignore[union-attr]
11172+ has_expr_in_new = new_is_seq and any (isinstance (v , Expr ) for v in new ) # type: ignore[union-attr]
11173+
11174+ if has_expr_in_old or has_expr_in_new :
11175+ old_list = list (old ) if old_is_seq else [old ] # type: ignore[arg-type, misc]
11176+ new_list = list (new ) if new_is_seq else [new ] # type: ignore[arg-type]
11177+
11178+ if len (new_list ) == 1 and len (old_list ) > 1 :
11179+ new_list = new_list * len (old_list )
11180+
11181+ if len (old_list ) != len (new_list ):
11182+ msg = f"lengths of `old` ({ len (old_list )} ) and `new` ({ len (new_list )} ) must match"
11183+ raise ValueError (msg )
11184+
11185+ # when(False).then(self) preserves the column name
11186+ result : Expr = F .when (F .lit (False )).then (self )
11187+ for old_val , new_val in zip (old_list , new_list , strict = True ):
11188+ result = result .when (self == old_val ).then (new_val ) # type: ignore[attr-defined]
11189+ result = result .otherwise (self ) # type: ignore[attr-defined]
1116811190 else :
11169- if isinstance ( old , Sequence ) and not isinstance ( old , ( str , pl . Series )) :
11191+ if old_is_seq :
1117011192 old = pl .Series (old )
11171- if isinstance ( new , Sequence ) and not isinstance ( new , ( str , pl . Series )) :
11193+ if new_is_seq :
1117211194 new = pl .Series (new )
1117311195
11174- old_pyexpr = parse_into_expression (old , str_as_lit = True ) # type: ignore[arg-type]
11175- new_pyexpr = parse_into_expression (new , str_as_lit = True )
11196+ old_pyexpr = parse_into_expression (old , str_as_lit = True ) # type: ignore[arg-type]
11197+ new_pyexpr = parse_into_expression (new , str_as_lit = True ) # type: ignore[arg-type]
1117611198
11177- result = wrap_expr (self ._pyexpr .replace (old_pyexpr , new_pyexpr ))
11199+ result = wrap_expr (self ._pyexpr .replace (old_pyexpr , new_pyexpr ))
1117811200
1117911201 if return_dtype is not None :
1118011202 result = result .cast (return_dtype )
@@ -11362,6 +11384,37 @@ def replace_strict(
1136211384 new = list (old .values ())
1136311385 old = list (old .keys ())
1136411386
11387+ old_is_seq = isinstance (old , Sequence ) and not isinstance (old , (str , pl .Series ))
11388+ new_is_seq = isinstance (new , Sequence ) and not isinstance (new , (str , pl .Series ))
11389+ has_expr_in_old = old_is_seq and any (isinstance (v , Expr ) for v in old ) # type: ignore[union-attr]
11390+ has_expr_in_new = new_is_seq and any (isinstance (v , Expr ) for v in new ) # type: ignore[union-attr]
11391+
11392+ if has_expr_in_old or has_expr_in_new :
11393+ old_list = list (old ) if old_is_seq else [old ] # type: ignore[arg-type, misc]
11394+ new_list = list (new ) if new_is_seq else [new ] # type: ignore[arg-type]
11395+
11396+ if len (new_list ) == 1 and len (old_list ) > 1 :
11397+ new_list = new_list * len (old_list )
11398+
11399+ if len (old_list ) != len (new_list ):
11400+ msg = f"lengths of `old` ({ len (old_list )} ) and `new` ({ len (new_list )} ) must match"
11401+ raise ValueError (msg )
11402+
11403+ # when(False).then(self) preserves the column name
11404+ result : Expr = F .when (F .lit (False )).then (self )
11405+ for old_val , new_val in zip (old_list , new_list , strict = True ):
11406+ result = result .when (self == old_val ).then (new_val ) # type: ignore[attr-defined]
11407+
11408+ if default is no_default :
11409+ result = result .otherwise (None ) # type: ignore[attr-defined]
11410+ else :
11411+ result = result .otherwise (default ) # type: ignore[attr-defined]
11412+
11413+ if return_dtype is not None :
11414+ result = result .cast (return_dtype )
11415+
11416+ return result
11417+
1136511418 old_pyexpr = parse_into_expression (old , str_as_lit = True ) # type: ignore[arg-type]
1136611419 new_pyexpr = parse_into_expression (new , str_as_lit = True ) # type: ignore[arg-type]
1136711420
0 commit comments