Skip to content

Incorrect SAT witness in disjunction query #873

@Zinoex

Description

@Zinoex

The following MWE with a simple ReLU-activated FCNN (I can send the .onnx file) results in 'sat' with the output values (vals[outputVars[0]], vals[outputVars[1]]) = (-0.9808001445708098, -0.3497275237786019). The query is a disjunction, i.e. the output should be outside the upper/lower bounds along some dimension. However, the returned example is clearly within the bounds, and thus is not a witness of 'sat'; that is, the assertion on the last line fails.

onnx_path = 'simple_nn.onnx'
network = Marabou.read_onnx(onnx_path)  # Updated to read ONNX models

outputVars = network.outputVars[0].flatten()
inputVars = network.inputVars[0].flatten()
options = Marabou.createOptions(verbosity=1)

sample = np.array([-0.45, -0.99])
delta = 0.01
ub = np.array([-0.9600000381469727, -0.3095250427722931])
lb = np.array([-1.0199999809265137, -0.3695250451564789])

# Set the input variables to the sampled point
for i, inputVar in enumerate(inputVars):
    network.setLowerBound(inputVar, sample[i] - delta)
    network.setUpperBound(inputVar, sample[i] + delta)

# Create disjunctive constraint for all output dimensions
disjuncts = []

for i, outputVar in enumerate(outputVars):
    # nn_output >= ub
    equation_GE = MarabouUtils.Equation(MarabouCore.Equation.GE)
    equation_GE.addAddend(1, outputVar)
    equation_GE.setScalar(ub[i])

    # nn_output <= lb
    equation_LE = MarabouUtils.Equation(MarabouCore.Equation.LE)
    equation_LE.addAddend(1, outputVar)
    equation_LE.setScalar(lb[i])

    # For this dimension, either GE or LE must be true
    disjuncts.extend([[equation_GE], [equation_LE]])

network.addDisjunctionConstraint(disjuncts)

# Solve
res, vals, _ = network.solve(verbose=True, options=options)
if res == 'sat':
    for i, inputVar in enumerate(inputVars):
        assert vals[inputVar] >= sample[i] - delta
        assert vals[inputVar] <= sample[i] + delta
    example_found = False
    for i, outputVar in enumerate(outputVars):
        if vals[outputVar] >= ub[i] or vals[outputVar] <= lb[i]:
            example_found = True
            break
    assert example_found

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions