前缀树___2023-08-05

前缀树

目录


定义

前缀树(Trie Tree) 用以存放一个字符串数组
其效果是能快速反应出具有某个前缀字符串的字符串的数量

例如, 在["a", "ab", "abc"]中, 有2个以"ab"为前缀的字符串

实现方式类似于多叉树, 只是字符的信息被记录在了节点与节点的连线上(通过索引表示字符)
节点的信息用于更丰富的功能实现

代码实现

节点类定义

前缀树节点(TrieNode)包含三个元素:

  • int pass 表示在该前缀树中该节点被经过了几次(换言之, 有多少字符串共用该节点)
  • int end 表示该节点是几个字符串的结尾节点
  • vector<TrieNode *> nexts 用来存放其下所有后续节点

节点的索引对应字母
例如myNode.nexts[0]对应字母'a', myNode.nexts[25]对应字母'z'
查看一个节点下的路径上是否有字母'x', 只要去对应索引23查询那个TrieNode *是否是nullptr即可

class TrieNode
{
public:
    int pass;
    int end;
    vector<TrieNode *> nexts;

    TrieNode()
    {
        this->pass = 0;
        this->end = 0;
        this->nexts = vector<TrieNode *>(26);
    }
};

这里依旧, 使用较为稳定安全的指针

如果要查询是否有单词存在, 只需通过end值来判断即可(同时还能获得有几个同样的字符串) 同样, 查询"有几个字符串具有该前缀"时, 只需要查询pass值即可

创建前缀树

由于是一个树状结构, 只需要hold住一个头结点即可

class Trie{
public:
    TrieNode* root;
    void createTrie(){
        root = new TrieNode();
    }
};

插入字符串

class Trie{
public:
    // ...

    void insert(string word)
    {
        if (word.empty())
        {
            return;
        }
        // 将字符串转化为字符数组
        vector<char> chars(word.begin(), word.end());
        TrieNode *cur = root;
        cur->pass++;
        for (int i = 0; i < chars.size(); i++)
        {
            int index = chars[i] - 'a';
            if (cur->nexts[index] == nullptr)
            {
                cur->nexts[index] = new TrieNode();
            }
            cur = cur->nexts[index];
            cur->pass++;
        }
        cur->end++;
    }  
};

删除字符串

class Trie{
public:
    // ...

    void delet(string word)
    {
        if (wordCount(word) < 1)
        {
            return;
        }
        vector<char> chars(word.begin(), word.end());
        TrieNode *cur = root;
        cur->pass--;
        for (int i = 0; i < chars.size(); i++)
        {
            int index = chars[i] - 'a';
            if (--cur->nexts[index]->pass < 1)
            {
                // 析构后续节点
                cur->nexts[index] = nullptr;
                return;
            }
            cur = cur->nexts[index];
        }
        cur->end--;
    }
};

在Java中, 由于有垃圾回收机制, 因此不需要手动释放内存, 只需要将cur->nexts[index]置为null即可
但是, C++ 中, “析构后续节点"处需要手动释放内存
具体析构的实现笔者还没有完全理解, 以下方案仅供参考

void delet(string word)
{
    if (wordCount(word) < 1)
    {
        return;
    }
    vector<char> chars(word.begin(), word.end());
    TrieNode *cur = root;
    cur->pass--;

    TrieNode *deletStartNode = nullptr;
    int deletStartIndex = -1;
    set<TrieNode *> deleteSet;
    for (int i = 0; i < chars.size(); i++)
    {
        int index = chars[i] - 'a';
        if (--cur->nexts[index]->pass < 1)
        {
            deletStartNode = deletStartNode == nullptr ? cur : deletStartNode;
            deletStartIndex = deletStartIndex == -1 ? i : deletStartIndex;
            deleteSet.insert(cur->nexts[index]);
        }
        cur = cur->nexts[index];
    }
    cur->end--;
    if (deletStartIndex == -1)
    {
        deletStartNode->nexts[deletStartIndex] = nullptr;
        // 遍历set, 析构
        for (auto it = deleteSet.begin(); it != deleteSet.end(); it++)
        {
            delete *it;
        }
    }
}

查询字符串和前缀数量

class Trie{
public:
    // ...

    int wordCount(string word)
    {
        if (word.empty())
        {
            return 0;
        }
        vector<char> chars(word.begin(), word.end());
        TrieNode *cur = root;
        for (int i = 0; i < chars.size(); i++)
        {
            int index = chars[i] - 'a';
            if (cur->nexts[index] == nullptr)
            {
                return 0;
            }
            cur = cur->nexts[index];
        }
        return cur->end;
    }

    int prefixCount(string word)
    {
        if (word.empty())
        {
            return 0;
        }
        vector<char> chars(word.begin(), word.end());
        TrieNode *cur = root;
        for (int i = 0; i < chars.size(); i++)
        {
            int index = chars[i] - 'a';
            if (cur->nexts[index] == nullptr)
            {
                return 0;
            }
            cur = cur->nexts[index];
        }
        return cur->pass;
    }
};

例子

图书管理员

题目: P3955 [NOIP2017 普及组] 图书管理员

代码:

#include <bits/stdc++.h>
using namespace std;

struct TrieNode
{
    int min;
    vector<TrieNode *> nexts;

    TrieNode()
    {
        min = INT_MAX;
        nexts = vector<TrieNode *>(10, nullptr);
    }
};

class Solution
{

public:
    TrieNode *root;
    void buildTrieTree()
    {
        root = new TrieNode();
    }
    void insert(int number)
    {
        TrieNode *cur = root;
        int x = number;
        while (x != 0)
        {
            int index = x % 10;
            x /= 10;
            if (cur->nexts[index] == nullptr)
            {
                cur->nexts[index] = new TrieNode();
            }
            cur->min = min(cur->min, number);
            cur = cur->nexts[index];
        }
        cur->min = number;
    }
    void process(int number, int len)
    {
        TrieNode *cur = root;
        // number 一定是正整数
        while (number != 0)
        {
            int index = number % 10;
            number /= 10;
            if (cur->nexts[index] == nullptr)
            {
                cout << -1 << endl;
                return;
            }
            cur = cur->nexts[index];
        }
        cout << cur->min << endl;
    }
};

int main()
{
    Solution s;
    s.buildTrieTree();
    int n, q;
    cin >> n >> q;
    while (n--)
    {
        int number;
        cin >> number;
        s.insert(number);
    }
    while (q--)
    {
        int len, number;
        cin >> len >> number;
        s.process(number, len);
    }
}

本题应该称作 “后缀树”, 但其实本质上是一样的, 只需要通过取模整除的方式将数字个位数取出即可
实现思路是:

  • 套用模板, 建立TrieNode
  • 在节点上记录min值, 用于记录该节点下的最小值
  • passend值都不需要
  • 插入时, 从个位数开始, 依次插入, 同时更新沿途所有节点上min
  • 查询时, 从个位数开始, 依次查询, 如果有节点为nullptr, 则说明没有该数字, 返回-1, 否则直接返回该节点的min值即可

这也就是不需要end值的原因, 因为有min值的存在, 既说明了有该数字, 又说明了该数字是该后缀的最小值

解题过程中还运用了对数器, 用于验证算法的正确性 这里是笔者的暴力解法

class ShabbySolution
{
public:
    vector<int> books;
    void insert(int number)
    {
        books.push_back(number);
    }
    void process(int number, int len)
    {
        int minIndex = -1;
        int min = INT_MAX;
        for (int i = 0; i < books.size(); i++)
        {
            if (books[i] % (int)round(pow(10, len)) == number)
            {
                if (books[i] < min)
                {
                    min = books[i];
                    minIndex = i;
                }
            }
        }
        if (minIndex == -1)
        {
            cout << -1 << endl;
        }
        else
        {
            cout << min << endl;
        }
    }
};