Make changes to groupBy() backward-compatible

The behaviour of the groupBy() function was fixed in 1.1.2 by
3fb91784018de335440b01b3b069fe45dc53e025 to pass only the list of values
to be aggregated to the aggregator function, instead of passing both the
key and the list of values (and expecting both the key and the
aggregated value to be returned). This fix was incompatible with
existing expressions that used the aggregator argument to groupBy().

In the event of an error, fall back trying the previous syntax and see
if something plausible gets returned. If not, re-raise the original error.
This should mean that pre-existing expressions will continue to work,
while most outright bogus expressions should fail in a manner consistent
with the correct definition of the function.

Change-Id: Ic6c54be4ed99003fe56cf1a5329f3f1d84fd43c8
Closes-Bug: #1750032
This commit is contained in:
Zane Bitter 2018-02-19 17:00:22 -05:00
parent dc21a823bd
commit 74cb81b280
3 changed files with 138 additions and 48 deletions

View File

@ -86,7 +86,8 @@ def create_context(data=utils.NO_VALUE, context=None, system=True,
math=True, collections=True, queries=True, math=True, collections=True, queries=True,
regex=True, branching=True, regex=True, branching=True,
no_sets=False, finalizer=None, delegates=False, no_sets=False, finalizer=None, delegates=False,
convention=None, datetime=True, yaqlized=True): convention=None, datetime=True, yaqlized=True,
group_by_agg_fallback=True):
context = _setup_context(data, context, finalizer, convention) context = _setup_context(data, context, finalizer, convention)
if system: if system:
@ -104,7 +105,7 @@ def create_context(data=utils.NO_VALUE, context=None, system=True,
if collections: if collections:
std_collections.register(context, no_sets) std_collections.register(context, no_sets)
if queries: if queries:
std_queries.register(context) std_queries.register(context, group_by_agg_fallback)
if regex: if regex:
std_regex.register(context) std_regex.register(context)
if branching: if branching:

View File

@ -15,10 +15,17 @@
Queries module. Queries module.
""" """
# Get python standard library collections module instead of
# yaql.standard_library.collections
from __future__ import absolute_import
import collections
import itertools import itertools
import sys
import six import six
from yaql.language import exceptions
from yaql.language import specs from yaql.language import specs
from yaql.language import utils from yaql.language import utils
from yaql.language import yaqltypes from yaql.language import yaqltypes
@ -844,19 +851,78 @@ def then_by_descending(collection, selector, context):
return collection return collection
@specs.parameter('collection', yaqltypes.Iterable()) class GroupAggregator(object):
@specs.parameter('key_selector', yaqltypes.Lambda()) """A function to aggregate the members of a group found by group_by().
@specs.parameter('value_selector', yaqltypes.Lambda())
@specs.parameter('aggregator', yaqltypes.Lambda()) The user-specified function is provided at creation. It is assumed to
@specs.method accept the group value list as an argument and return an aggregated value.
def group_by(engine, collection, key_selector, value_selector=None,
However, on error we will (optionally) fall back to the pre-1.1.1 behaviour
of assuming that the function expects a tuple containing both the key and
the value list, and similarly returns a tuple of the key and value. This
can still give the wrong results if the first group(s) to be aggregated
have value lists of length exactly 2, but for the most part is backwards
compatible to 1.1.1.
"""
def __init__(self, aggregator_func=None, allow_fallback=True):
self.aggregator = aggregator_func
self.allow_fallback = allow_fallback
self._failure_info = None
def __call__(self, group_item):
if self.aggregator is None:
return group_item
if self._failure_info is None:
key, value_list = group_item
try:
result = self.aggregator(value_list)
except (exceptions.NoMatchingMethodException,
exceptions.NoMatchingFunctionException,
IndexError):
self._failure_info = sys.exc_info()
else:
if not (len(value_list) == 2 and
isinstance(result, collections.Sequence) and
not isinstance(result, six.string_types) and
len(result) == 2 and
result[0] == value_list[0]):
# We are not dealing with (correct) version 1.1.1 syntax,
# so don't bother trying to fall back if there's an error
# with a later group.
self.allow_fallback = False
return key, result
if self.allow_fallback:
# Fall back to assuming version 1.1.1 syntax.
try:
result = self.aggregator(group_item)
if len(result) == 2:
return result
except Exception:
pass
# If we are unable to successfully fall back, re-raise the first
# exception encountered to help the user debug in the new style.
six.reraise(*self._failure_info)
def group_by_function(allow_aggregator_fallback):
@specs.parameter('collection', yaqltypes.Iterable())
@specs.parameter('key_selector', yaqltypes.Lambda())
@specs.parameter('value_selector', yaqltypes.Lambda())
@specs.parameter('aggregator', yaqltypes.Lambda())
@specs.method
def group_by(engine, collection, key_selector, value_selector=None,
aggregator=None): aggregator=None):
""":yaql:groupBy """:yaql:groupBy
Returns a collection grouped by keySelector with applied valueSelector as Returns a collection grouped by keySelector with applied valueSelector
values. Returns a list of pairs where the first value is a result value as values. Returns a list of pairs where the first value is a result
of keySelector and the second is a list of values which have common value of keySelector and the second is a list of values which have
keySelector return value. common keySelector return value.
:signature: collection.groupBy(keySelector, valueSelector => null, :signature: collection.groupBy(keySelector, valueSelector => null,
aggregator => null) aggregator => null)
@ -865,9 +931,9 @@ def group_by(engine, collection, key_selector, value_selector=None,
:arg keySelector: function to be applied to every collection element. :arg keySelector: function to be applied to every collection element.
Values are grouped by return value of this function Values are grouped by return value of this function
:argType keySelector: lambda :argType keySelector: lambda
:arg valueSelector: function to be applied to every collection element to :arg valueSelector: function to be applied to every collection element
put it under appropriate group. null by default, which means return to put it under appropriate group. null by default, which means
element itself return element itself
:argType valueSelector: lambda :argType valueSelector: lambda
:arg aggregator: function to aggregate value within each group. null by :arg aggregator: function to aggregate value within each group. null by
default, which means no function to be evaluated on groups default, which means no function to be evaluated on groups
@ -882,16 +948,15 @@ def group_by(engine, collection, key_selector, value_selector=None,
[[1, "ac"], [2, "b"]] [[1, "ac"], [2, "b"]]
""" """
groups = {} groups = {}
if aggregator is None: new_aggregator = GroupAggregator(aggregator, allow_aggregator_fallback)
new_aggregator = lambda x: x
else:
new_aggregator = lambda x: (x[0], aggregator(x[1]))
for t in collection: for t in collection:
value = t if value_selector is None else value_selector(t) value = t if value_selector is None else value_selector(t)
groups.setdefault(key_selector(t), []).append(value) groups.setdefault(key_selector(t), []).append(value)
utils.limit_memory_usage(engine, (1, groups)) utils.limit_memory_usage(engine, (1, groups))
return select(six.iteritems(groups), new_aggregator) return select(six.iteritems(groups), new_aggregator)
return group_by
@specs.method @specs.method
@specs.parameter('collections', yaqltypes.Iterable()) @specs.parameter('collections', yaqltypes.Iterable())
@ -1680,7 +1745,7 @@ def default_if_empty(engine, collection, default):
return default return default
def register(context): def register(context, allow_group_by_agg_fallback=True):
context.register_function(where) context.register_function(where)
context.register_function(where, name='filter') context.register_function(where, name='filter')
context.register_function(select) context.register_function(select)
@ -1711,7 +1776,7 @@ def register(context):
context.register_function(order_by_descending) context.register_function(order_by_descending)
context.register_function(then_by) context.register_function(then_by)
context.register_function(then_by_descending) context.register_function(then_by_descending)
context.register_function(group_by) context.register_function(group_by_function(allow_group_by_agg_fallback))
context.register_function(join) context.register_function(join)
context.register_function(zip_) context.register_function(zip_)
context.register_function(zip_longest) context.register_function(zip_longest)

View File

@ -226,6 +226,30 @@ class TestQueries(yaql.tests.TestCase):
'groupBy($[1], aggregator => $.sum())', 'groupBy($[1], aggregator => $.sum())',
data=data)) data=data))
def test_group_by_old_syntax(self):
# Test the syntax used in 1.1.1 and earlier, where the aggregator
# function was passed the key as well as the value list, and returned
# the key along with the aggregated value. This ensures backward
# compatibility with existing expressions.
data = {'a': 1, 'b': 2, 'c': 1, 'd': 3, 'e': 2}
self.assertItemsEqual(
[[1, 'ac'], [2, 'be'], [3, 'd']],
self.eval('$.items().orderBy($[0]).'
'groupBy($[1], $[0], [$[0], $[1].sum()])', data=data))
self.assertItemsEqual(
[[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]],
self.eval('$.items().orderBy($[0]).'
'groupBy($[1],, [$[0], $[1].sum()])',
data=data))
self.assertItemsEqual(
[[1, ['a', 1, 'c', 1]], [2, ['b', 2, 'e', 2]], [3, ['d', 3]]],
self.eval('$.items().orderBy($[0]).'
'groupBy($[1], aggregator => [$[0], $[1].sum()])',
data=data))
def test_join(self): def test_join(self):
self.assertEqual( self.assertEqual(
[[2, 1], [3, 1], [3, 2], [4, 1], [4, 2], [4, 3]], [[2, 1], [3, 1], [3, 2], [4, 1], [4, 2], [4, 3]],