diff --git a/yaql/__init__.py b/yaql/__init__.py index 5284ebf..eddd54f 100644 --- a/yaql/__init__.py +++ b/yaql/__init__.py @@ -86,7 +86,8 @@ def create_context(data=utils.NO_VALUE, context=None, system=True, math=True, collections=True, queries=True, regex=True, branching=True, no_sets=False, finalizer=None, delegates=False, - convention=None, datetime=True, yaqlized=True): + convention=None, datetime=True, yaqlized=True, + group_by_agg_fallback=True): context = _setup_context(data, context, finalizer, convention) if system: @@ -104,7 +105,7 @@ def create_context(data=utils.NO_VALUE, context=None, system=True, if collections: std_collections.register(context, no_sets) if queries: - std_queries.register(context) + std_queries.register(context, group_by_agg_fallback) if regex: std_regex.register(context) if branching: diff --git a/yaql/standard_library/queries.py b/yaql/standard_library/queries.py index 49bd10d..4df2966 100644 --- a/yaql/standard_library/queries.py +++ b/yaql/standard_library/queries.py @@ -15,10 +15,17 @@ Queries module. """ +# Get python standard library collections module instead of +# yaql.standard_library.collections +from __future__ import absolute_import + +import collections import itertools +import sys import six +from yaql.language import exceptions from yaql.language import specs from yaql.language import utils from yaql.language import yaqltypes @@ -844,53 +851,111 @@ def then_by_descending(collection, selector, context): return collection -@specs.parameter('collection', yaqltypes.Iterable()) -@specs.parameter('key_selector', yaqltypes.Lambda()) -@specs.parameter('value_selector', yaqltypes.Lambda()) -@specs.parameter('aggregator', yaqltypes.Lambda()) -@specs.method -def group_by(engine, collection, key_selector, value_selector=None, - aggregator=None): - """:yaql:groupBy +class GroupAggregator(object): + """A function to aggregate the members of a group found by group_by(). - Returns a collection grouped by keySelector with applied valueSelector as - values. Returns a list of pairs where the first value is a result value - of keySelector and the second is a list of values which have common - keySelector return value. + The user-specified function is provided at creation. It is assumed to + accept the group value list as an argument and return an aggregated value. - :signature: collection.groupBy(keySelector, valueSelector => null, - aggregator => null) - :receiverArg collection: input collection - :argType collection: iterable - :arg keySelector: function to be applied to every collection element. - Values are grouped by return value of this function - :argType keySelector: lambda - :arg valueSelector: function to be applied to every collection element to - put it under appropriate group. null by default, which means return - element itself - :argType valueSelector: lambda - :arg aggregator: function to aggregate value within each group. null by - default, which means no function to be evaluated on groups - :argType aggregator: lambda - :returnType: list - - .. code:: - - yaql> [["a", 1], ["b", 2], ["c", 1], ["d", 2]].groupBy($[1], $[0]) - [[1, ["a", "c"]], [2, ["b", "d"]]] - yaql> [["a", 1], ["b", 2], ["c", 1]].groupBy($[1], $[0], $.sum()) - [[1, "ac"], [2, "b"]] + However, on error we will (optionally) fall back to the pre-1.1.1 behaviour + of assuming that the function expects a tuple containing both the key and + the value list, and similarly returns a tuple of the key and value. This + can still give the wrong results if the first group(s) to be aggregated + have value lists of length exactly 2, but for the most part is backwards + compatible to 1.1.1. """ - groups = {} - if aggregator is None: - new_aggregator = lambda x: x - else: - new_aggregator = lambda x: (x[0], aggregator(x[1])) - for t in collection: - value = t if value_selector is None else value_selector(t) - groups.setdefault(key_selector(t), []).append(value) - utils.limit_memory_usage(engine, (1, groups)) - return select(six.iteritems(groups), new_aggregator) + + def __init__(self, aggregator_func=None, allow_fallback=True): + self.aggregator = aggregator_func + self.allow_fallback = allow_fallback + self._failure_info = None + + def __call__(self, group_item): + if self.aggregator is None: + return group_item + + if self._failure_info is None: + key, value_list = group_item + try: + result = self.aggregator(value_list) + except (exceptions.NoMatchingMethodException, + exceptions.NoMatchingFunctionException, + IndexError): + self._failure_info = sys.exc_info() + else: + if not (len(value_list) == 2 and + isinstance(result, collections.Sequence) and + not isinstance(result, six.string_types) and + len(result) == 2 and + result[0] == value_list[0]): + # We are not dealing with (correct) version 1.1.1 syntax, + # so don't bother trying to fall back if there's an error + # with a later group. + self.allow_fallback = False + + return key, result + + if self.allow_fallback: + # Fall back to assuming version 1.1.1 syntax. + try: + result = self.aggregator(group_item) + if len(result) == 2: + return result + except Exception: + pass + + # If we are unable to successfully fall back, re-raise the first + # exception encountered to help the user debug in the new style. + six.reraise(*self._failure_info) + + +def group_by_function(allow_aggregator_fallback): + @specs.parameter('collection', yaqltypes.Iterable()) + @specs.parameter('key_selector', yaqltypes.Lambda()) + @specs.parameter('value_selector', yaqltypes.Lambda()) + @specs.parameter('aggregator', yaqltypes.Lambda()) + @specs.method + def group_by(engine, collection, key_selector, value_selector=None, + aggregator=None): + """:yaql:groupBy + + Returns a collection grouped by keySelector with applied valueSelector + as values. Returns a list of pairs where the first value is a result + value of keySelector and the second is a list of values which have + common keySelector return value. + + :signature: collection.groupBy(keySelector, valueSelector => null, + aggregator => null) + :receiverArg collection: input collection + :argType collection: iterable + :arg keySelector: function to be applied to every collection element. + Values are grouped by return value of this function + :argType keySelector: lambda + :arg valueSelector: function to be applied to every collection element + to put it under appropriate group. null by default, which means + return element itself + :argType valueSelector: lambda + :arg aggregator: function to aggregate value within each group. null by + default, which means no function to be evaluated on groups + :argType aggregator: lambda + :returnType: list + + .. code:: + + yaql> [["a", 1], ["b", 2], ["c", 1], ["d", 2]].groupBy($[1], $[0]) + [[1, ["a", "c"]], [2, ["b", "d"]]] + yaql> [["a", 1], ["b", 2], ["c", 1]].groupBy($[1], $[0], $.sum()) + [[1, "ac"], [2, "b"]] + """ + groups = {} + new_aggregator = GroupAggregator(aggregator, allow_aggregator_fallback) + for t in collection: + value = t if value_selector is None else value_selector(t) + groups.setdefault(key_selector(t), []).append(value) + utils.limit_memory_usage(engine, (1, groups)) + return select(six.iteritems(groups), new_aggregator) + + return group_by @specs.method @@ -1680,7 +1745,7 @@ def default_if_empty(engine, collection, default): return default -def register(context): +def register(context, allow_group_by_agg_fallback=True): context.register_function(where) context.register_function(where, name='filter') context.register_function(select) @@ -1711,7 +1776,7 @@ def register(context): context.register_function(order_by_descending) context.register_function(then_by) context.register_function(then_by_descending) - context.register_function(group_by) + context.register_function(group_by_function(allow_group_by_agg_fallback)) context.register_function(join) context.register_function(zip_) context.register_function(zip_longest) diff --git a/yaql/tests/test_queries.py b/yaql/tests/test_queries.py index da65ec3..bc2cbab 100644 --- a/yaql/tests/test_queries.py +++ b/yaql/tests/test_queries.py @@ -226,6 +226,30 @@ class TestQueries(yaql.tests.TestCase): 'groupBy($[1], aggregator => $.sum())', data=data)) + def test_group_by_old_syntax(self): + # Test the syntax used in 1.1.1 and earlier, where the aggregator + # function was passed the key as well as the value list, and returned + # the key along with the aggregated value. This ensures backward + # compatibility with existing expressions. + data = {'a': 1, 'b': 2, 'c': 1, 'd': 3, 'e': 2} + + self.assertItemsEqual( + [[1, 'ac'], [2, 'be'], [3, 'd']], + self.eval('$.items().orderBy($[0]).' + 'groupBy($[1], $[0], [$[0], $[1].sum()])', data=data)) + + self.assertItemsEqual( + [[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]], + self.eval('$.items().orderBy($[0]).' + 'groupBy($[1],, [$[0], $[1].sum()])', + data=data)) + + self.assertItemsEqual( + [[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]], + self.eval('$.items().orderBy($[0]).' + 'groupBy($[1], aggregator => [$[0], $[1].sum()])', + data=data)) + def test_join(self): self.assertEqual( [[2, 1], [3, 1], [3, 2], [4, 1], [4, 2], [4, 3]],