造轮子上瘾ing
果然只有自己从0撸算法才会理解的比较透彻
ID3决策树是基于信息增益和贪心法来递归建树
对于每一个节点选取信息增益最大的属性进行划分,当某个子数据集中仅存在一种分类时,则将其作为一个叶子节点
信息增益 gain(X,A) = entropy(X) – sum(len(X)/len(XAi) * entropy([I for I in X if I->Ai is [Ai for Ai in A]]))
即对于数据集X的A属性的信息增益为 X的熵减去 X中A属性每一种的子集的熵乘上A属性子集所占比例的和
熵 entropy(X) = -sum(Pi * log2(Pi))
Pi即每一种分类在整个集中所占比
树节点的数据结构:
class TreeNode:
child = {}
def __init__(self):
self.child = {}
self.index = -1 #
self.is_leaf = False
self.decision = None
def add_node(self, key, value):
self.child.setdefault(key, value)
def find_child(self, key):
for i in self.child.keys():
if i is key:
return self.child[i]
return None
其中is_leaf表示该节点是否为叶子,child中存储了该层的所有可能分支,index表明该层的分类使用哪个属性
决策树(ID3)的实现:
class DecisionTree:
def __init__(self, dataset=None):
self.dataset = dataset
self.root = None
self.attr_num = 0
# 训练数据的纯度 熵
def entropy(self, dataset, index):
type = {}
sum = len(dataset)
for i in dataset:
if i[index] not in type.keys():
type.setdefault(i[index], len([item for item in dataset if item[index] is i[index]]))
num = 0
for key in type.keys():
num += - (float(type[key]) / sum) * log2(float(type[key]) / sum)
return num
# 计算信息增益
def gain(self, dataset, index):
type = {}
sum = len(dataset)
for i in dataset:
if i[index] not in type.keys():
type.setdefault(i[index], (len([item for item in dataset if item[index] is i[index]]),
self.entropy([item for item in dataset if item[index] is i[index]],
len(dataset[0]) - 1)))
num = self.entropy(dataset, len(dataset[0]) - 1)
for key in type:
num -= float(type[key][0]) / sum * float(type[key][1])
return num
@staticmethod
def find_most_pop_ans(dataset):
type = {}
sum = len(dataset)
for i in dataset:
if i[len(i) - 1] not in type.keys():
type.setdefault(i[len(i) - 1], len([item for item in dataset if item[len(item) - 1] is i[len(i) - 1]]))
max_key = None
max_value = 0
for key in type:
if type[key] > max_value:
max_value = type[key]
max_key = key
return max_key
def build_tree(self, dataset, used_index=[]):
type_num = len(dataset[0]) - 1
self.attr_num = type_num
if len(used_index) is type_num:
leaf_node = TreeNode()
leaf_node.is_leaf = True
leaf_node.decision = self.find_most_pop_ans(dataset)
return leaf_node
if len(set([item[len(item) - 1] for item in dataset])) is 1:
leaf_node = TreeNode()
leaf_node.is_leaf = True
leaf_node.decision = dataset[0][len(dataset[0]) - 1]
return leaf_node
max_gain = 0
max_index = -1
for i in range(type_num):
if i not in used_index:
gain_tmp = self.gain(dataset, i)
if gain_tmp > max_gain:
max_gain = gain_tmp
max_index = i
node = TreeNode()
node.index = max_index
used_index.append(max_index)
type_set = set([item[max_index] for item in dataset])
for typex in type_set:
node.child.setdefault(typex,
self.build_tree([item for item in dataset if item[max_index] is typex], used_index))
self.root = node
return node
def check(self, dataset):
if len(dataset) is not self.attr_num:
print('Err data !')
root = self.root
while root.is_leaf is not True:
root = root.find_child(dataset[root.index])
if root is None:
print('Unindex Attr !')
return root.decision
build_tree为建树过程,check为验证
使用数据集为:
test_data = [
['sunny', 85.0, 85.0, False, 'no'],
['rainy', 75.0, 80.0, False, 'yes'],
['sunny', 75.0, 70.0, True, 'yes'],
['overcast', 72.0, 90.0, True, 'yes'],
['overcast', 81.0, 75.0, False, 'yes'],
['rainy', 71.0, 91.0, True, 'no'],
['sunny', 80.0, 90.0, True, 'no'],
['overcast', 83.0, 86.0, False, 'yes'],
['rainy', 70.0, 96.0, False, 'yes'],
['rainy', 68.0, 80.0, False, 'yes'],
['rainy', 65.0, 70.0, True, 'no'],
['overcast', 64.0, 65.0, True, 'yes'],
['sunny', 72.0, 95.0, False, 'no'],
['sunny', 69.0, 70.0, False, 'yes'],
]
因为ID3基础的决策树不能处理连续数据,所以将【1,2】两列数据以平均值为分割线分成两类
测试代码:
d = DecisionTree()
avg = sum([float(item[1]) for item in test_data]) / float(len(test_data))
for i in range(len(test_data)):
if test_data[i][1] >= avg:
test_data[i][1] = 1
else:
test_data[i][1] = 0
avg = sum([float(item[2]) for item in test_data]) / float(len(test_data))
for i in range(len(test_data)):
if test_data[i][2] >= avg:
test_data[i][2] = 1
else:
test_data[i][2] = 0
t = d.build_tree(test_data, [])
print(d.check(['rainy', 1, 0, True]))