Skip to content

Implement overloading #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 74 additions & 50 deletions flax/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from random import randrange
import operator as ops

from flax.common import mp, mpc, inf, mpf
from flax.common import mp, mpc, ilist, inf, mpf

__all__ = [
"base",
Expand Down Expand Up @@ -115,23 +115,16 @@ def base_i(w, x):

def binary(x):
"""binary: converts x to binary"""
return [-i if x < 0 else i for i in map(int, bin(x)[3 if x < 0 else 2 :])]
return base_i(2, x)


def binary_i(x):
"""binary_i: convert x from binary"""
x = iterable(x, digits_=True)
sign = -1 if sum(x) < 0 else 1
num = 0
i = 0
for b in x[::-1]:
num += abs(b) * 2**i
i += 1
return num * sign
return base(2, x)


def boolify(f):
"""boolify: wrapper around boolean functions to only return 1/0"""
"""boolify: [helper] wrapper around boolean functions to only return 1/0"""
return lambda *args: int(f(*args))


Expand All @@ -155,7 +148,7 @@ def convolve(w, x):

def depth(x):
"""depth: how deeply x is nested"""
if type(x) != list:
if type2str(x) != "lst":
return 0
else:
if not x:
Expand Down Expand Up @@ -192,27 +185,39 @@ def diagonals(x, antidiagonals=False):

def digits(x):
"""digits: turn x into a list of digits"""
return [
-int(i) if x < 0 else int(i) for i in str(x)[1 if x < 0 else 0 :] if i != "."
]
if type2strn(x) == "mpc":
return [
mpc(i[0], i[1])
for i in transpose([digits(x.real)[::-1], digits(x.imag)[::-1]], filler=0)[
::-1
]
][:-1]
else:
return [
-int(i) if x < 0 else int(i)
for i in str(x)[1 if x < 0 else 0 :]
if i != "."
]


def digits_i(x):
"""digits_i: convert x from digits"""
x = iterable(x, range_=True)
sign = -1 if sum(x) < 0 else 1
num = 0
i = 0
for b in x[::-1]:
num += abs(b) * 10**i
i += 1
return num * sign
if type2strn(x[0]) == "mpc":
return mpc(digits_i([i.real for i in x]), digits_i([i.imag for i in x]))
else:
sign = -1 if sum(x) < 0 else 1
num = 0
i = 0
for b in x[::-1]:
num += abs(b) * 10**i
i += 1
return num * sign


def enumerate_md(x, upper_level=[]):
"""enumerate_md: enumerate multidimensionally"""
for i, e in enumerate(x):
if type(e) != list:
if type2str(e) != "lst":
yield [upper_level + [i], e]
else:
yield from enumerate_md(e, upper_level + [i])
Expand Down Expand Up @@ -328,10 +333,10 @@ def group_indicies(x, md=False):
def index_into(w, x):
"""index_into: index into w with x"""
w = iterable(w, digits_=True)
x = int(x) if type(x) != mpc and type(x) != list and int(x) == x else x
if type(x) == int:
x = int(x) if type2strn(x) == "int" else x
if type2strn(x) == "int":
return w[x % len(w)]
elif type(x) == mpc:
elif type2strn(x) == "mpc":
return index_into(index_into(w, x.real), x.imag)
else:
return [index_into(w, mp.floor(x)), index_into(w, mp.ceil(x))]
Expand All @@ -347,34 +352,38 @@ def index_into_md(w, x):

def iota(x):
"""iota: APL's ⍳ and BQN's ↕"""
if type(x) != list:
if type2strn(x) in ["int", "dec"]:
return list(range(int(x)))

res = cartesian_product(*(list(range(int(a))) for a in x))
for e in x:
res = split(int(e), res)
return res[0]
elif type2strn(x) == "mpc":
return [[mpc(j[0], j[1]) for j in i] for i in iota([x.real, x.imag])]
else:
res = cartesian_product(*(iota(i) for i in x))
for i in x[::-1]:
res = split(int(abs(i)) if type2strn(i) != "lst" else len(i), res)
return res[0]


def iota1(x):
"""iota1: iota but 1 based"""
if type(x) != list:
return [i + 1 for i in range(int(x))]

res = cartesian_product(*([i + 1 for i in range(int(a))] for a in x))
for e in x:
res = split(int(e), res)
return res[0]
if type2strn(x) in ["int", "dec"]:
return list(range(1, int(x) + 1))
elif type2strn(x) == "mpc":
return [[mpc(j[0], j[1]) for j in i] for i in iota1([x.real, x.imag])]
else:
res = cartesian_product(*(iota1(i) for i in x))
for i in x[::-1]:
res = split(int(abs(i)) if type2strn(i) != "lst" else len(i), res)
return res[0]


def iterable(x, digits_=False, range_=False, copy_=False):
"""iterable: make sure x is a list"""
if type(x) != list:
if type(x) == str:
if type2str(x) != "lst":
if type2str(x) == "str":
return list(x)
else:
if range_:
return list(range(int(x)))
return iota(x)
elif digits_:
return digits(x)
else:
Expand All @@ -392,13 +401,14 @@ def join(w, x):

def json_decode(x):
"""json_decode: convert jsoned x to flax arrays"""
if type(x) == list or type(x) == tuple:
t = type(x)
if t == list or t == tuple:
return [json_decode(i) for i in x]
elif type(x) == str:
elif t == str:
return x
elif type(x) == dict:
elif t == dict:
return [json_decode(i) for i in x.items()]
elif type(x) == bool:
elif t == bool:
return int(x)
elif x is None:
return inf
Expand Down Expand Up @@ -439,7 +449,7 @@ def maximal_indicies_md(x, m=None, upper_level=[]):
m = max(flatten(x) or [0])
res = []
for i, e in enumerate(x):
if type(e) != list:
if type2str(e) != "lst":
if e == m:
res.append(upper_level + [i])
else:
Expand All @@ -450,7 +460,7 @@ def maximal_indicies_md(x, m=None, upper_level=[]):
def mold(w, x):
"""mold: mold x to the shape w"""
for i in range(len(w)):
if type(w[i]) == list:
if type2str(w[i]) == "lst":
mold(x, w[i])
else:
item = x.pop(0)
Expand Down Expand Up @@ -696,6 +706,20 @@ def type2str(x):
return "num"


def type2strn(x):
"""type2strn: [helper] converts a number type to string"""
t = type2str(x)
if t == "num":
if type(x) == mpc:
return "mpc"
elif int(x) == x:
return "int"
else:
return "dec"
else:
return t


def transpose(x, filler=None):
"""transpose: transpose x"""
return list(
Expand All @@ -720,7 +744,7 @@ def unrepeat(x):
def where(x, upper_level=[]):
"""where: ngn/k's &:"""
x = iterable(x)
if type(x[0]) != list:
if type2str(x[0]) != "lst":
return flatten([(upper_level + [i]) * e for i, e in enumerate(x)])
else:
return [where(e, upper_level + [i]) for i, e in enumerate(x)]