12 KiB
Z3 约束求解器
Z3 是一个由微软开发的可满足性摸理论(Satisfiability Modulo Theories,SMT)的约束求解器。所谓约束求解器就是用户使用某种特定的语言描述对象(变量)的约束条件,求解器将试图求解出能够满足所有约束条件的每个变量的值。Z3 可以用来检查满足一个或多个理论的公式的可满足性,也就是说,它可以自动化地通过内置理论对一阶逻辑多种排列进行可满足性校验。目前其支持的理论有:
- equality over free 函数和谓词符号
- 实数和整形运算(有限支持非线性运算)
- 位向量
- 阵列
- 元组/记录/枚举类型和代数(递归)数据类型
- ...
因其强大的功能,Z3 已经被用于许多领域中,在安全领域,主要见于符号执行、Fuzzing、二进制逆向、密码学等。另外 Z3 提供了多种语言的接口,这里我们使用 Python。
安装
在 Linux 环境下,执行下面的命令:
$ git clone https://github.com/Z3Prover/z3.git
$ cd z3
$ python scripts/mk_make.py --python
$ cd build
$ make
$ sudo make install
另外还可以使用 pip 来安装 Python 接口(py2和py3均可),这是二进制分析框架 angr 里内置的修改版:
$ sudo pip install z3-solver
Z3 理论基础
Op | Mnmonics | Description |
---|---|---|
0 | true | 恒真 |
1 | flase | 恒假 |
2 | = | 相等 |
3 | distinct | 不同 |
4 | ite | if-then-else |
5 | and | n元 合取 |
6 | or | n元 析取 |
7 | iff | implication |
8 | xor | 异或 |
9 | not | 否定 |
10 | implies | Bi-implications |
使用 Z3
先来看一个简单的例子:
>>> from z3 import *
>>> x = Int('x')
>>> y = Int('y')
>>> solve(x > 2, y < 10, x + 2*y == 7)
[y = 0, x = 7]
首先定义了两个常量 x 和 y,类型是 Z3 内置的整数类型 Int
,solve()
函数会创造一个 solver,然后对括号中的约束条件进行求解,注意在 Z3 默认情况下只会找到满足条件的一组解。
>>> simplify(x + y + 2*x + 3)
3 + 3*x + y
>>> simplify(x < y + x + 2)
Not(y <= -2)
>>> simplify(And(x + 1 >= 3, x**2 + x**2 + y**2 + 2 >= 5))
And(x >= 2, 2*x**2 + y**2 >= 3)
>>>
>>> simplify((x + 1)*(y + 1))
(1 + x)*(1 + y)
>>> simplify((x + 1)*(y + 1), som=True) # sum-of-monomials:单项式的和
1 + x + y + x*y
>>> t = simplify((x + y)**3, som=True)
>>> t
x*x*x + 3*x*x*y + 3*x*y*y + y*y*y
>>> simplify(t, mul_to_power=True) # mul_to_power 将乘法转换成乘方
x**3 + 2*y*x**2 + x**2*y + 3*x*y**2 + y**3
simplify()
函数用于对表达式进行化简,同时可以设置一些选项来满足不同的要求。更多选项使用 help_simplify()
获得。
同时,Z3 提供了一些函数可以解析表达式:
>>> n = x + y >= 3
>>> "num args: ", n.num_args()
('num args: ', 2)
>>> "children: ", n.children()
('children: ', [x + y, 3])
>>> "1st child:", n.arg(0)
('1st child:', x + y)
>>> "2nd child:", n.arg(1)
('2nd child:', 3)
>>> "operator: ", n.decl()
('operator: ', >=)
>>> "op name: ", n.decl().name()
('op name: ', '>=')
set_param()
函数用于对 Z3 的全局变量进行配置,如运算精度,输出格式等等:
>>> x = Real('x')
>>> y = Real('y')
>>> solve(x**2 + y**2 == 3, x**3 == 2)
[x = 1.2599210498?, y = -1.1885280594?]
>>>
>>> set_param(precision=30)
>>> solve(x**2 + y**2 == 3, x**3 == 2)
[x = 1.259921049894873164767210607278?,
y = -1.188528059421316533710369365015?]
逻辑运算有 And
、Or
、Not
、Implies
、If
,另外 ==
表示 Bi-implications。
>>> p = Bool('p')
>>> q = Bool('q')
>>> r = Bool('r')
>>> solve(Implies(p, q), r == Not(q), Or(Not(p), r))
[q = False, p = False, r = True]
>>>
>>> x = Real('x')
>>> solve(Or(x < 5, x > 10), Or(p, x**2 == 2), Not(p))
[x = -1.4142135623?, p = False]
Z3 提供了多种 Solver,即 Solver
类,其中实现了很多 SMT 2.0 的命令,如 push
, pop
, check
等等。
>>> x = Int('x')
>>> y = Int('y')
>>> s = Solver() # 创造一个通用 solver
>>> type(s) # Solver 类
<class 'z3.z3.Solver'>
>>> s
[]
>>> s.add(x > 10, y == x + 2) # 添加约束到 solver 中
>>> s
[x > 10, y == x + 2]
>>> s.check() # 检查 solver 中的约束是否满足
sat # satisfiable/满足
>>> s.push() # 创建一个回溯点,即将当前栈的大小保存下来
>>> s.add(y < 11)
>>> s
[x > 10, y == x + 2, y < 11]
>>> s.check()
unsat # unsatisfiable/不满足
>>> s.pop(num=1) # 回溯 num 个点
>>> s
[x > 10, y == x + 2]
>>> s.check()
sat
>>> for c in s.assertions(): # assertions() 返回一个包含所有约束的AstVector
... print(c)
...
x > 10
y == x + 2
>>> s.statistics() # statistics() 返回最后一个 check() 的统计信息
(:max-memory 6.26
:memory 4.37
:mk-bool-var 1
:num-allocs 331960806
:rlimit-count 7016)
>>> m = s.model() # model() 返回最后一个 check() 的 model
>>> type(m) # ModelRef 类
<class 'z3.z3.ModelRef'>
>>> m
[x = 11, y = 13]
>>> for d in m.decls(): # decls() 返回 model 包含了所有符号的列表
... print("%s = %s" % (d.name(), m[d]))
...
x = 11
y = 13
为了将 Z3 中的数和 Python 区分开,应该使用 IntVal()
、RealVal()
和 RatVal()
分别返回 Z3 整数、实数和有理数值。
>>> 1/3
0.3333333333333333
>>> RealVal(1)/3
1/3
>>> Q(1, 3) # Q(a, b) 返回有理数 a/b
1/3
>>>
>>> x = Real('x')
>>> x + 1/3
x + 3333333333333333/10000000000000000
>>> x + Q(1, 3)
x + 1/3
>>> x + "1/3"
x + 1/3
>>> x + 0.25
x + 1/4
>>> solve(3*x == 1)
[x = 1/3]
>>> set_param(rational_to_decimal=True) # 以十进制形式表示有理数
>>> solve(3*x == 1)
[x = 0.3333333333?]
在混合使用实数和整数变量时,Z3Py 会自动添加强制类型转换将整数表达式转换成实数表达式。
>>> x = Real('x')
>>> y = Int('y')
>>> a, b, c = Reals('a b c') # 返回一个实数常量元组
>>> s, r = Ints('s r') # 返回一个整数常量元组
>>> x + y + 1 + a + s
x + ToReal(y) + 1 + a + ToReal(s) # ToReal() 将整数表达式转换成实数表达式
>>> ToReal(y) + c
ToReal(y) + c
现代的CPU使用固定大小的位向量进行算术运算,在 Z3 中,使用函数 BitVec()
创建位向量常量,BitVecVal()
返回给定位数的位向量值。
>>> x = BitVec('x', 16) # 16 位,命名为 x
>>> y = BitVec('x', 16)
>>> x + 2
x + 2
>>> (x + 2).sexpr() # .sexpr() 返回内部表现形式
'(bvadd x #x0002)'
>>> simplify(x + y - 1) # 16 位整数的 -1 等于 65535
65535 + 2*x
>>> a = BitVecVal(-1, 16) # 16 位,值为 -1
>>> a
65535
>>> b = BitVecVal(65535, 16)
>>> b
65535
>>> simplify(a == b)
True
Z3 在 CTF 中的运用
re PicoCTF2013 Harder_Serial
题目如下,是一段 Python 代码,要求输入一段 20 个数字构成的序列号,然后程序会对序列号的每一位进行验证,以满足各种要求。题目难度不大,但完全手工验证是一件麻烦的事,而使用 Z3 的话,只要定义好这些条件,就可以得出满足条件的值。
import sys
print ("Please enter a valid serial number from your RoboCorpIntergalactic purchase")
if len(sys.argv) < 2:
print ("Usage: %s [serial number]"%sys.argv[0])
exit()
print ("#>" + sys.argv[1] + "<#")
def check_serial(serial):
if (not set(serial).issubset(set(map(str,range(10))))):
print ("only numbers allowed")
return False
if len(serial) != 20:
return False
if int(serial[15]) + int(serial[4]) != 10:
return False
if int(serial[1]) * int(serial[18]) != 2:
return False
if int(serial[15]) / int(serial[9]) != 1:
return False
if int(serial[17]) - int(serial[0]) != 4:
return False
if int(serial[5]) - int(serial[17]) != -1:
return False
if int(serial[15]) - int(serial[1]) != 5:
return False
if int(serial[1]) * int(serial[10]) != 18:
return False
if int(serial[8]) + int(serial[13]) != 14:
return False
if int(serial[18]) * int(serial[8]) != 5:
return False
if int(serial[4]) * int(serial[11]) != 0:
return False
if int(serial[8]) + int(serial[9]) != 12:
return False
if int(serial[12]) - int(serial[19]) != 1:
return False
if int(serial[9]) % int(serial[17]) != 7:
return False
if int(serial[14]) * int(serial[16]) != 40:
return False
if int(serial[7]) - int(serial[4]) != 1:
return False
if int(serial[6]) + int(serial[0]) != 6:
return False
if int(serial[2]) - int(serial[16]) != 0:
return False
if int(serial[4]) - int(serial[6]) != 1:
return False
if int(serial[0]) % int(serial[5]) != 4:
return False
if int(serial[5]) * int(serial[11]) != 0:
return False
if int(serial[10]) % int(serial[15]) != 2:
return False
if int(serial[11]) / int(serial[3]) != 0:
return False
if int(serial[14]) - int(serial[13]) != -4:
return False
if int(serial[18]) + int(serial[19]) != 3:
return False
return True
if check_serial(sys.argv[1]):
print ("Thank you! Your product has been verified!")
else:
print ("I'm sorry that is incorrect. Please use a valid RoboCorpIntergalactic serial number")
首先创建一个求解器实例,然后将序列的每个数字定义为常量:
serial = [Int("serial[%d]" % i) for i in range(20)]
接着定义约束条件,注意,除了题目代码里的条件外,还有一些隐藏的条件,比如这一句:
solver.add(serial[11] / serial[3] == 0)
因为被除数不能为 0,所以 serial[3]
不能为 0。另外,每个序列号数字都是大于等于 0,小于 9 的。最后求解得到结果。
完整的 exp 如下,其他文件在 github 相应文件夹中。
from z3 import *
solver = Solver()
serial = [Int("serial[%d]" % i) for i in range(20)]
solver.add(serial[15] + serial[4] == 10)
solver.add(serial[1] * serial[18] == 2 )
solver.add(serial[15] / serial[9] == 1)
solver.add(serial[17] - serial[0] == 4)
solver.add(serial[5] - serial[17] == -1)
solver.add(serial[15] - serial[1] == 5)
solver.add(serial[1] * serial[10] == 18)
solver.add(serial[8] + serial[13] == 14)
solver.add(serial[18] * serial[8] == 5)
solver.add(serial[4] * serial[11] == 0)
solver.add(serial[8] + serial[9] == 12)
solver.add(serial[12] - serial[19] == 1)
solver.add(serial[9] % serial[17] == 7)
solver.add(serial[14] * serial[16] == 40)
solver.add(serial[7] - serial[4] == 1)
solver.add(serial[6] + serial[0] == 6)
solver.add(serial[2] - serial[16] == 0)
solver.add(serial[4] - serial[6] == 1)
solver.add(serial[0] % serial[5] == 4)
solver.add(serial[5] * serial[11] == 0)
solver.add(serial[10] % serial[15] == 2)
solver.add(serial[11] / serial[3] == 0) # serial[3] can't be 0
solver.add(serial[14] - serial[13] == -4)
solver.add(serial[18] + serial[19] == 3)
for i in range(20):
solver.add(serial[i] >= 0, serial[i] < 10)
solver.add(serial[3] != 0)
if solver.check() == sat:
m = solver.model()
for d in m.decls():
print("%s = %s" % (d.name(), m[d]))
print("".join([str(m.eval(serial[i])) for i in range(20)]))
Bingo!!!
$ python exp.py
serial[2] = 8
serial[11] = 0
serial[3] = 9
serial[4] = 3
serial[1] = 2
serial[0] = 4
serial[19] = 2
serial[14] = 5
serial[17] = 8
serial[16] = 8
serial[10] = 9
serial[8] = 5
serial[6] = 2
serial[9] = 7
serial[5] = 7
serial[13] = 9
serial[7] = 4
serial[18] = 1
serial[15] = 7
serial[12] = 3
42893724579039578812
$ python harder_serial.py 42893724579039578812
Please enter a valid serial number from your RoboCorpIntergalactic purchase
#>42893724579039578812<#
Thank you! Your product has been verified!
这一题简直是为 Z3 量身定做的,方法也很简单,但 Z3 远比这个强大,后面我们还会讲到它更高级的应用。