Skip to content

change class decorator to function decorator #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rltk/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rltk.io.reader import Reader
from rltk.io.adapter import KeyValueAdapter, MemoryKeyValueAdapter
from rltk.record import Record, generate_record_property_cache, get_property_names
from rltk.record import Record, generate_record_property_cache, get_property_names, get_cached_property_names
from rltk.parallel_processor import ParallelProcessor

import pandas as pd
Expand Down Expand Up @@ -72,9 +72,10 @@ def add_records(self, reader: Reader, size: int = None,
"""

def generate(_raw_object):
cached_property_names = get_cached_property_names(self._record_class)
if not self._sampling_function or self._sampling_function(_raw_object):
record_instance = self._record_class(_raw_object)
generate_record_property_cache(record_instance)
generate_record_property_cache(record_instance, cached_property_names)
self._adapter.set(record_instance.id, record_instance)

if not self._record_class:
Expand Down
67 changes: 41 additions & 26 deletions rltk/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,46 @@ def __eq__(self, other):
return self.id == other.id


class cached_property(property):
def get_decorators(cls):
import ast
import inspect

target = cls
decorators = {}

def visit_FunctionDef(node):
for n in node.decorator_list:
if isinstance(n, ast.Call):
name = n.func.attr if isinstance(n.func, ast.Attribute) else n.func.id
else:
name = n.attr if isinstance(n, ast.Attribute) else n.id

if name not in decorators:
decorators[name] = []
decorators[name].append(node.name)

node_iter = ast.NodeVisitor()
node_iter.visit_FunctionDef = visit_FunctionDef
node_iter.visit(ast.parse(inspect.getsource(target)))

return decorators


def get_cached_property_names(cls):
return get_decorators(cls).get('cached_property')


def cached_property(method):
"""
Decorator.
If a Record property is decorated, the final value of it will be pre-calculated.
"""
def __init__(self, func):
self.func = func

def __get__(self, obj, cls):
"""
Args:
obj (object): Record instance
cls (class): Record class
Returns:
object: cached value
"""
if obj is None:
return self

# create property if it's not there
cached_name = self.func.__name__
if cached_name not in obj.__dict__:
obj.__dict__[cached_name] = self.func(obj)

value = obj.__dict__.get(cached_name)
return value
def wrapper(record_instance):
cached_name = method.__name__
if cached_name not in record_instance.__dict__:
record_instance.__dict__[cached_name] = method(record_instance)
return record_instance.__dict__[cached_name]
return property(wrapper)


def remove_raw_object(cls):
Expand All @@ -73,15 +87,16 @@ def remove_raw_object(cls):
return cls


def generate_record_property_cache(obj):
def generate_record_property_cache(obj, cached_property_names=None):
"""
Generate final value on all cached_property decorated methods.

Args:
obj (Record): Record instance.
cached_property_names (list): Properties that need to cache.
"""
for prop_name, prop_type in obj.__class__.__dict__.items():
if isinstance(prop_type, cached_property):
if cached_property_names:
for prop_name in cached_property_names:
getattr(obj, prop_name)

validate_record(obj)
Expand Down Expand Up @@ -118,7 +133,7 @@ def get_property_names(cls: type):
"""
keys = []
for prop_name, prop_type in cls.__dict__.items():
if not isinstance(prop_type, property) and not isinstance(prop_type, cached_property):
if not isinstance(prop_type, property):
continue
keys.append(prop_name)
return keys
Expand Down