造轮子上瘾ing

果然只有自己从0撸算法才会理解的比较透彻

ID3决策树是基于信息增益和贪心法来递归建树

对于每一个节点选取信息增益最大的属性进行划分,当某个子数据集中仅存在一种分类时,则将其作为一个叶子节点

信息增益 gain(X,A) = entropy(X) – sum(len(X)/len(XAi) * entropy([I for I in X if I->Ai is [Ai for Ai in A]]))

即对于数据集X的A属性的信息增益为 X的熵减去 X中A属性每一种的子集的熵乘上A属性子集所占比例的和

熵 entropy(X) = -sum(Pi * log2(Pi))

Pi即每一种分类在整个集中所占比

树节点的数据结构:

class TreeNode:
    child = {}

    def __init__(self):
        self.child = {}
        self.index = -1  #
        self.is_leaf = False
        self.decision = None

    def add_node(self, key, value):
        self.child.setdefault(key, value)

    def find_child(self, key):
        for i in self.child.keys():
            if i is key:
                return self.child[i]

        return None

其中is_leaf表示该节点是否为叶子,child中存储了该层的所有可能分支,index表明该层的分类使用哪个属性

决策树(ID3)的实现:


class DecisionTree:
    def __init__(self, dataset=None):
        self.dataset = dataset
        self.root = None
        self.attr_num = 0

    # 训练数据的纯度 熵
    def entropy(self, dataset, index):
        type = {}
        sum = len(dataset)
        for i in dataset:
            if i[index] not in type.keys():
                type.setdefault(i[index], len([item for item in dataset if item[index] is i[index]]))

        num = 0
        for key in type.keys():
            num += - (float(type[key]) / sum) * log2(float(type[key]) / sum)

        return num

    # 计算信息增益
    def gain(self, dataset, index):
        type = {}
        sum = len(dataset)
        for i in dataset:
            if i[index] not in type.keys():
                type.setdefault(i[index], (len([item for item in dataset if item[index] is i[index]]),
                                           self.entropy([item for item in dataset if item[index] is i[index]],
                                                        len(dataset[0]) - 1)))
        num = self.entropy(dataset, len(dataset[0]) - 1)
        for key in type:
            num -= float(type[key][0]) / sum * float(type[key][1])

        return num

    @staticmethod
    def find_most_pop_ans(dataset):
        type = {}
        sum = len(dataset)
        for i in dataset:
            if i[len(i) - 1] not in type.keys():
                type.setdefault(i[len(i) - 1], len([item for item in dataset if item[len(item) - 1] is i[len(i) - 1]]))

        max_key = None
        max_value = 0
        for key in type:
            if type[key] > max_value:
                max_value = type[key]
                max_key = key

        return max_key

    def build_tree(self, dataset, used_index=[]):
        type_num = len(dataset[0]) - 1
        self.attr_num = type_num
        if len(used_index) is type_num:
            leaf_node = TreeNode()
            leaf_node.is_leaf = True
            leaf_node.decision = self.find_most_pop_ans(dataset)
            return leaf_node
        if len(set([item[len(item) - 1] for item in dataset])) is 1:
            leaf_node = TreeNode()
            leaf_node.is_leaf = True
            leaf_node.decision = dataset[0][len(dataset[0]) - 1]
            return leaf_node
        max_gain = 0
        max_index = -1
        for i in range(type_num):
            if i not in used_index:
                gain_tmp = self.gain(dataset, i)
                if gain_tmp > max_gain:
                    max_gain = gain_tmp
                    max_index = i

        node = TreeNode()
        node.index = max_index
        used_index.append(max_index)
        type_set = set([item[max_index] for item in dataset])
        for typex in type_set:
            node.child.setdefault(typex,
                                  self.build_tree([item for item in dataset if item[max_index] is typex], used_index))

        self.root = node
        return node

    def check(self, dataset):
        if len(dataset) is not self.attr_num:
            print('Err data !')

        root = self.root
        while root.is_leaf is not True:
            root = root.find_child(dataset[root.index])
            if root is None:
                print('Unindex Attr !')

        return root.decision

build_tree为建树过程,check为验证

使用数据集为:

test_data = [
    ['sunny', 85.0, 85.0, False, 'no'],
    ['rainy', 75.0, 80.0, False, 'yes'],
    ['sunny', 75.0, 70.0, True, 'yes'],
    ['overcast', 72.0, 90.0, True, 'yes'],
    ['overcast', 81.0, 75.0, False, 'yes'],
    ['rainy', 71.0, 91.0, True, 'no'],
    ['sunny', 80.0, 90.0, True, 'no'],
    ['overcast', 83.0, 86.0, False, 'yes'],
    ['rainy', 70.0, 96.0, False, 'yes'],
    ['rainy', 68.0, 80.0, False, 'yes'],
    ['rainy', 65.0, 70.0, True, 'no'],
    ['overcast', 64.0, 65.0, True, 'yes'],
    ['sunny', 72.0, 95.0, False, 'no'],
    ['sunny', 69.0, 70.0, False, 'yes'],
]

因为ID3基础的决策树不能处理连续数据,所以将【1,2】两列数据以平均值为分割线分成两类

测试代码:

d = DecisionTree()
avg = sum([float(item[1]) for item in test_data]) / float(len(test_data))
for i in range(len(test_data)):
    if test_data[i][1] >= avg:
        test_data[i][1] = 1
    else:
        test_data[i][1] = 0
avg = sum([float(item[2]) for item in test_data]) / float(len(test_data))
for i in range(len(test_data)):
    if test_data[i][2] >= avg:
        test_data[i][2] = 1
    else:
        test_data[i][2] = 0

t = d.build_tree(test_data, [])
print(d.check(['rainy', 1, 0, True]))