Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,35 +123,23 @@ public Object agg(Object accumulator, Object inputField) {
InternalArray acc = (InternalArray) accumulator;
InternalArray input = (InternalArray) inputField;

if (acc.size() >= countLimit) {
return accumulator;
}

int remainCount = countLimit - acc.size();

List<InternalRow> rows = new ArrayList<>(acc.size() + input.size());
addNonNullRows(acc, rows);
addNonNullRows(input, rows, remainCount);

if (keyProjection != null) {
Map<BinaryRow, InternalRow> map = new HashMap<>();
for (InternalRow row : rows) {
BinaryRow key = keyProjection.apply(row).copy();
if (hasSequenceField) {
// When sequence field is configured, only update if the new sequence is greater
InternalRow existing = map.get(key);
if (existing == null || compareSequence(row, existing) >= 0) {
map.put(key, row);
}
} else {
map.put(key, row);
}
if (keyProjection == null) {
if (acc.size() >= countLimit) {
return accumulator;
}

rows = new ArrayList<>(map.values());
int remainCount = countLimit - acc.size();

List<InternalRow> rows = new ArrayList<>(acc.size() + input.size());
addNonNullRows(acc, rows);
addNonNullRows(input, rows, remainCount);
return new GenericArray(rows.toArray());
}

return new GenericArray(rows.toArray());
Map<BinaryRow, InternalRow> map = new HashMap<>();
addNestedRows(acc, map, false);
addNestedRows(input, map, true);
return new GenericArray(new ArrayList<>(map.values()).toArray());
}

@Override
Expand Down Expand Up @@ -235,4 +223,26 @@ private void addNonNullRows(InternalArray array, List<InternalRow> rows, int rem
count++;
}
}

private void addNestedRows(
InternalArray array, Map<BinaryRow, InternalRow> rows, boolean limitNewKeys) {
checkNotNull(keyProjection);

for (int i = 0; i < array.size(); i++) {
if (array.isNullAt(i)) {
continue;
}

InternalRow row = array.getRow(i, nestedFields);
BinaryRow key = keyProjection.apply(row).copy();
InternalRow existing = rows.get(key);
if (existing != null) {
if (!hasSequenceField || compareSequence(row, existing) >= 0) {
rows.put(key, row);
}
} else if (!limitNewKeys || rows.size() < countLimit) {
rows.put(key, row);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,39 @@ public void testFieldNestedAppendAggWithCountLimit() {
.containsExactlyInAnyOrderElementsOf(Arrays.asList(row(0, 1, "B"), row(0, 1, "b")));
}

@Test
public void testFieldNestedUpdateAggWithCountLimitUpdatesExistingKeyAtLimitWithoutSequence() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()));

FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Arrays.asList("k0", "k1"),
2);

InternalArray accumulator = null;
InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B")));
accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(1, 2, "C")));

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B_updated")));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated"), row(1, 2, "C")));

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(2, 3, "D")));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated"), row(1, 2, "C")));
}

@Test
public void testFieldNestedUpdateAggWithSequenceField() {
DataType elementRowType =
Expand Down Expand Up @@ -1076,6 +1109,42 @@ public void testFieldNestedUpdateAggWithCountLimitWithSequenceField() {
Arrays.asList(row(0, 1, "B_updated", 2), row(1, 2, "C", 3)));
}

@Test
public void testFieldNestedUpdateAggWithCountLimitUpdatesExistingKeyAtLimit() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()),
DataTypes.FIELD(3, "seq", DataTypes.INT()));

FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Arrays.asList("k0", "k1"),
Collections.singletonList("seq"),
2);

InternalArray accumulator = null;
InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B", 1)));
accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(1, 2, "C", 3)));

accumulator =
(InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B_updated", 4)));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated", 4), row(1, 2, "C", 3)));

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(2, 3, "D", 5)));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated", 4), row(1, 2, "C", 3)));
}

private List<Object> unnest(InternalArray array, InternalArray.ElementGetter elementGetter) {
return IntStream.range(0, array.size())
.mapToObj(i -> elementGetter.getElementOrNull(array, i))
Expand Down