Skip to content
31 changes: 19 additions & 12 deletions src/abstract-interpretation/data-frame/dataframe-domain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ import { SetRangeDomain } from '../domains/set-range-domain';

/** The type of the abstract product representing the shape of data frames */
export type AbstractDataFrameShape = {
colnames: SetRangeDomain<string>;
cols: PosIntervalDomain;
rows: PosIntervalDomain;
readonly colnames: SetRangeDomain<string>;
readonly cols: PosIntervalDomain;
readonly rows: PosIntervalDomain;
};

/**
* The data frame abstract domain as product domain of a column names domain, column count domain, and row count domain.
*/
export class DataFrameDomain extends ProductDomain<AbstractDataFrameShape> {
public create(value: AbstractDataFrameShape): this;
public create(value: AbstractDataFrameShape): DataFrameDomain {
return new DataFrameDomain(value);
public create(value: AbstractDataFrameShape): this {
return new DataFrameDomain(value) as this;
}

/**
Expand Down Expand Up @@ -57,21 +56,29 @@ export class DataFrameDomain extends ProductDomain<AbstractDataFrameShape> {

protected reduce(value: AbstractDataFrameShape): AbstractDataFrameShape {
if(value.colnames.isValue() && value.cols.isValue()) {
if(value.colnames.value.min.size >= value.cols.value[1]) {
const minColNames = value.colnames.must.size;
const maxColNames = value.colnames.isFinite() ? value.colnames.must.size + value.colnames.may.size : Infinity;

if(minColNames >= value.cols.upper) {
value = {
...value,
colnames: value.colnames.create({ must: value.colnames.must, may: new Set() })
};
} else if(value.colnames.isFinite() && value.colnames.may.size > 0 && maxColNames <= value.cols.lower) {
value = {
...value,
colnames: value.colnames.meet({ min: new Set(), range: value.colnames.value.min })
colnames: value.colnames.create({ must: value.colnames.upper(), may: new Set() })
};
}
}
if(value.colnames.isValue() && value.cols.isValue()) {
const minColNames = value.colnames.value.min.size;
const maxColNames = value.colnames.isFinite() ? value.colnames.value.min.size + value.colnames.value.range.size : Infinity;
const minColNames = value.colnames.must.size;
const maxColNames = value.colnames.isFinite() ? value.colnames.must.size + value.colnames.may.size : Infinity;

if(minColNames > value.cols.value[0] || maxColNames < value.cols.value[1]) {
if((minColNames > value.cols.lower || maxColNames < value.cols.upper) && Math.max(minColNames, value.cols.lower) <= Math.min(maxColNames, value.cols.upper)) {
value = {
...value,
cols: value.cols.meet([minColNames, maxColNames])
cols: value.cols.create([Math.max(minColNames, value.cols.lower), Math.min(maxColNames, value.cols.upper)])
};
}
}
Expand Down
30 changes: 14 additions & 16 deletions src/abstract-interpretation/data-frame/semantics.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { assertUnreachable, isNotUndefined } from '../../util/assert';
import { Bottom, Top } from '../domains/lattice';
import { PosIntervalDomain, PosIntervalTop } from '../domains/positive-interval-domain';
import type { ArrayRangeValue } from '../domains/set-range-domain';
import type { SetRangeValue } from '../domains/set-range-domain';
import type { DataFrameDomain } from './dataframe-domain';

/**
Expand Down Expand Up @@ -212,7 +212,7 @@ function applySetColNamesSemantics(
rows: value.rows
});
}
const allColNames = colnames?.every(isNotUndefined) && value.cols.value !== Bottom && colnames.length >= value.cols.value[1];
const allColNames = colnames?.every(isNotUndefined) && value.cols.isValue() && colnames.length >= value.cols.upper;

return value.create({
colnames: allColNames ? value.colnames.create(setRange(colnames)) : value.colnames.create(setRange(colnames)).widenUp(),
Expand All @@ -236,7 +236,7 @@ function applyAddRowsSemantics(
value: DataFrameDomain,
{ rows }: { rows: number | undefined }
): DataFrameDomain {
if(value.cols.value !== Bottom && value.cols.value[0] === 0) {
if(value.cols.isValue() && value.cols.lower === 0) {
return value.create({
colnames: value.colnames.top(),
cols: rows !== undefined ? value.cols.add([1, 1]) : value.cols.top(),
Expand Down Expand Up @@ -303,7 +303,7 @@ function applyConcatRowsSemantics(
value: DataFrameDomain,
{ other }: { other: DataFrameDomain }
): DataFrameDomain {
if(value.cols.value !== Bottom && value.cols.value[0] === 0) {
if(value.cols.value !== Bottom && value.cols.lower === 0) {
return value.create({
colnames: value.colnames.join(other.colnames),
cols: value.cols.join(other.cols),
Expand Down Expand Up @@ -417,19 +417,17 @@ function applyJoinSemantics(
): DataFrameDomain {
// Merge two intervals by creating the maximum of the lower bounds and adding the upper bounds
const mergeInterval = (interval1: PosIntervalDomain, interval2: PosIntervalDomain): PosIntervalDomain => {
if(interval1.value === Bottom || interval2.value === Bottom) {
return interval1.bottom();
} else {
return new PosIntervalDomain([Math.max(interval1.value[0], interval2.value[0]), interval1.value[1] + interval2.value[1]]);
if(interval1.isValue() && interval2.isValue()) {
return new PosIntervalDomain([Math.max(interval1.lower, interval2.lower), interval1.upper + interval2.upper]);
}
return interval1.bottom();
};
// Creating the Cartesian product of two intervals by keeping the lower bound and multiplying the upper bounds
const productInterval = (lower: PosIntervalDomain, interval1: PosIntervalDomain, interval2: PosIntervalDomain): PosIntervalDomain => {
if(lower.value === Bottom || interval1.value === Bottom || interval2.value === Bottom) {
return lower.bottom();
} else {
return new PosIntervalDomain([lower.value[0], interval1.value[1] * interval2.value[1]]);
if(lower.isValue() && interval1.isValue() && interval2.isValue()) {
return new PosIntervalDomain([lower.lower, interval1.upper * interval2.upper]);
}
return lower.bottom();
};
let duplicateCols: string[] | undefined; // columns that may be renamed due to occurring in both data frames
let productRows: boolean; // whether the resulting rows may be a Cartesian product of the rows of the data frames
Expand Down Expand Up @@ -461,10 +459,10 @@ function applyJoinSemantics(
rows = value.rows.max(other.rows).widenDown();
break;
case 'left':
rows = value.rows.max(other.rows.isValue() ? [0, other.rows.value[1]] : Bottom);
rows = value.rows.max(other.rows.isValue() ? [0, other.rows.upper] : Bottom);
break;
case 'right':
rows = other.rows.max(value.rows.isValue() ? [0, value.rows.value[1]] : Bottom);
rows = other.rows.max(value.rows.isValue() ? [0, value.rows.upper] : Bottom);
break;
case 'full':
rows = mergeInterval(value.rows, other.rows);
Expand Down Expand Up @@ -495,8 +493,8 @@ function applyUnknownSemantics(
return value.top();
}

function setRange(colnames: (string | undefined)[] | undefined): ArrayRangeValue<string> {
function setRange(colnames: (string | undefined)[] | undefined): SetRangeValue<string> {
const names = colnames?.filter(isNotUndefined) ?? [];

return { min: names, range: names.length === colnames?.length ? [] : Top };
return { must: new Set(names), may: names.length === colnames?.length ? new Set() : Top };
}
Loading
Loading