开发者

Go手写数据库ZiyiDB的实现

目录
  • 项目结构
  • 原理介绍
  • 具体实现
    • 模块一: 词法分析器 (Lexer)
      • ①定义标记类型token.go
      • ② 实现词法分析器lexer.go
    • 模块二:抽象语法树 (AST)
      • 模块三:语法分析器 (Parser)
        • 模块四:存储引擎 (Storage)
          • 模块五:REPL 交互界面
          • 整体测试
            • 测试脚本
              • 运行效果

              ZiyiDB是一个简单的内存数据库实现,支持基本的SQL操作,包含create、insert、delete、select、update、drop。目前一期暂支持int类型以及字符类型数据,后续会支持更多数据结构以及能力。本项目基于https://github.com/eatonphil/gosql灵感进行开发。

              • 项目Github地址:https://github.com/ziyifast/ZiyiDB

              项目结构

              // 项目创建
              mkdir ZiyiDB
              cd ZiyiDB/
              go mod init ziyi.db.com
              
              ZiyiDB/
              ├── cmd/
              │   └── main.go           # 主程序入口
              ├── internal/
              │   ├── ast/
              │   │   └── ast.go        # 抽象语法树定义
              │   ├── lexer/
              │   │   ├── lexer.go      # 词法分析器实现
              │   │   └── token.go      # 词法单元定义
              │   ├── parser/
              │   │   └── parser.go     # 语法分析器实现
              │   └── storage/
              │       └── memory.go     # 内存存储引擎实现
              ├── go.mod                # Go模块定义
              └── go.sum                # 依赖版本锁定
              

              原理介绍

              流程图:

              Go手写数据库ZiyiDB的实现

              主要包含几大模块:

              • cmd/main.go:
              • 程序入口点
              • 实现交互式命令行界面
              • 处理用户输入
              • 显示执行结果
              • internal/ast/ast.go:
              • 定义抽象语法树节点
              • 定义 SQL 语句结构
              • 定义表达式结构
              • internal/lexer/token.go:
              • 定义词法单元类型
              • 定义 SQL 关键字
              • 定义运算符和分隔符
              • internal/lexer/lexer.go:
              • 实现词法分析器
              • 将输入文本转换为标记序列
              • 处理标识符和字面量
              • internal/parser/parser.go:
              • 实现语法分析器
              • 将标记序列转换为抽象语法树
              • 处理各种 SQL 语句
              • internal/storage/memory.go:
              • 实现内存存储引擎
              • 处理数据存储和检索
              • 实现索引和约束

              具体实现

              模块一: 词法分析器 (Lexer)

              词法分析器 (Lexer):SQL转token序列

              ①定义标记类型token.go

              思路

              新建ziyi-db/internal/lexer/token.go文件,完成词法分析器(Lexer)中的标记(Token)定义部分,用于将 SQL 语句分解成基本的语法单元。

              定义词法单元以及关键字:

              • 包含常见的SQL关键字,如:select、update等
              • 包含符号关键字:=、>、<
              • 包含字段类型:INT、字符型(TEXT)
              • 包含标识符:INDENT,解析出来的SQL列名、表名
              type TokenType string
              
              const (
                  SELECT  TokenType = "SELECT"
                  FROM    TokenType = "FROM"
                  IDENT   TokenType = "IDENT"  // 标识符(如列名、表名)
                  INT_LIT TokenType = "INT"    // 整数字面量
                  STRING  TokenType = "STRING" // 字符串字面量
                  EQ TokenType = "=" // 等于
                  GT TokenType = ">" // 大于
                  LT TokenType = "<" // 小于
                  ....
              )
              
              // Token 词法单元
              // Type:标记的类型(如 SELECT、IDENT 等)
              // Literal:标记的实际值(如具体的列名、数字等)
              type Token struct {
                  Type    TokenType // 标记类型
                  Literal string    // 标记的实际值
              }
              

              示例:

              SELECT id, name FROM users WHERE age > 18;
              
              该SQL 语句会被下面的词法分析器lexer.go分解成以下标记序列:
              {Type: SELECT, Literal: "SELECT"}
              {Type: IDENT, Literal: "id"}
              {Type: COMMA, Literal: ","}
              {Type: IDENT, Literal: "name"}
              {Type: FROM, Literal: "FROM"}
              {Type: IDENT, Literal: "users"}
              {Type: WHERE, Literal: "WHERE"}
              {Type: IDENT, Literal: "age"}
              {Type: GT, Literal: ">"}
              {Type: INT_LIT, Literal: "18"}
              {Type: SEMI, Literal: ";"}
              解析后的标记随后会被传递给语法分析器(Parser)进行进一步处理,构建抽象语法树(AST)。
              

              全部代码

              // internal/lexer/token.go
              package lexer
              
              // TokenType 表示词法单元类型
              type TokenType string
              
              const (
                  // 特殊标记
                  EOF   TokenType = "EOF"   // 文件结束标记
                  ERROR TokenType = "ERROR" // 错误标记
              
                  // 关键字
                  SELECT  TokenType = "SELECT"
                  FROM    TokenType = "FROM"
                  WHERE   TokenType = "WHERE"
                  CREATE  TokenType = "CREATE"
                  TABLE   TokenType = "TABLE"
                  INSERT  TokenType = "INSERT"
                  INTO    TokenType = "INTO"
                  VALUES  TokenType = "VALUES"
                  UPDATE  TokenType = "UPDATE"
                  SET     TokenType = "SET"
                  DELETE  TokenType = "DELETE"
                  DROP    TokenType = "DROP"
                  PRIMARY TokenType = "PRIMARY"
                  KEY     TokenType = "KEY"
                  INT     TokenType = "INT"
                  TEXT    TokenType = "TEXT"
                  LIKE    TokenType = "LIKE"
              
                  // 标识符和字面量
                  IDENT   TokenType = "IDENT"  // 标识符(如列名、表名)
                  INT_LIT TokenType = "INT"    // 整数字面量
                  STRING  TokenType = "STRING" // 字符串字面量
              
                  // 运算符
                  EQ TokenType = "="
                  GT TokenType = ">"
                  LT TokenType = "<"
              
                  // 标识符
                  COMMA    TokenType = ","
                  SEMI     TokenType = ";"
                  LPAREN   TokenType = "("
                  RPAREN   TokenType = ")"
                  ASTERISK TokenType = "*"
              )
              
              // Token 词法单元
              // Type:标记的类型(如 SELECT、IDENT 等)
              // Literal:标记的实际值(如具体的列名、数字等)
              type Token struct {
                  Type    TokenType // 标记类型
                  Literal string    // 标记的实际值
              }
              

              ② 实现词法分析器lexer.go

              思路

              新建ziyi-db/internal/lexer/lexer.go文件,这是词法分析器(Lexer)的核心实现,负责将输入的 SQL 语句分解成标记(Token)序列。

              词法分析器lexer.go:读取SQL到内存中并进行解析,将字符转换为对应关键字

              示例:

              SELECT id, name FROM users WHERE age > 18;
              
              处理过程:
              跳过空白字符
              读取 "SELECT" 并识别为关键字
              读取 "id" 并识别为标识符
              读取 "," 并识别为分隔符
              读取 "name" 并识别为标识符
              读取 "FROM" 并识别为关键字
              读取 "users" 并识别为标识符
              读取 "WHERE" 并识别为关键字
              读取 "age" 并识别为标识符
              读取 ">" 并识别为运算符
              读取 "18" 并识别为数字
              读取 ";" 并识别为分隔符
              这个词法分析器是 SQL 解析器的第一步,它将输入的 SQL 语句分解成标记序列,为后续的语法分析提供基础
              
              
              
              该SQL 语句会被词法分析器分解成以下标记序列:
              {Type: SELECT, Literal: "SELECT"}
              {Type: IDENT, Literal: "id"}
              {Type: COMMA, Literal: ","}
              {Type: IDENT, Literal: "name"}
              {Type: FROM, Literal: "FROM"}
              {Type: IDENT, Literal: "users"}
              {Type: WHERE, Literal: "WHERE"}
              {Type: IDENT, Literal: "age"}
              {Type: GT, Literal: ">"}
              {Type: INT_LIT, Literal: "18"}
              {Type: SEMI, Literal: ";"}
              解析后的标记随后会被传递给语法分析器(Parser)进行进一步处理,构建抽象语法树(AST)。
              

              全部代码

              // internal/lexer/lexer.go
              package lexer
              
              import (
                  "bufio"
                  "bytes"
                  "io"
                  "strings"
                  "unicode"
              )
              
              // Lexer 词法分析器
              // reader:使用 bufio.Reader 进行高效的字符读取
              // ch:存储当前正在处理的字符
              type Lexer struct {
                  reader *bufio.Reader // 用于读取输入
                  ch     rune          // 当前字符
              }
              
              // NewLexer 创建一个新的 词法分析器
              // 初始化 reader 并读取第一个字符
              func NewLexer(r io.Reader) *Lexer {
                  l := &Lexer{
                     reader: bufio.NewReader(r),
                  }
                  l.readChar()
                  return l
              }
              
              // 读取字符
              func (l *Lexer) readChar() {
                  ch, _, err := l.reader.ReadRune()
                  if err != nil {
                     l.ch = 0 // 遇到错误或EOF时设置为0
                  } else {
                     l.ch = ch
                  }
              }
              
              // NextToken 获取下一个词法单元
              // 识别并返回下一个标记
              // 处理各种类型的标记:运算符、分隔符、标识符、数字、字符串等
              func (l *Lexer) NextToken() Token {
                  var tok Token
                  // 跳过空白字符
                  l.skipWhitespace()
              
                  switch l.ch {
                  case '=':
                     tok = Token{Type: EQ, Literal: "="}
                  case '>':
                     tok = Token{Type: GT, Literal: ">"}
                  case '<':
                     tok = Token{Type: LT, Literal: "<"}
                  case ',':
                     tok = Token{Type: COMMA, Literal: ","}
                  case ';':
                     tok = Token{Type: SEMI, Literal: ";"}
                  case '(':
                     tok = Token{Type: LPAREN, Literal: "("}
                  case ')':
                     tok = Token{Type: RPAREN, Literal: ")"}
                  case '*':
                     tok = Token{Type: ASTERISK, Literal: "*"}
                  case '\'':
                     tok.Type = STRING
                     // 读取字符串字面量
                     tok.Literal = l.readString()
                     return tok
                  case 0:
                     tok = Token{Type: EOF, Literal: ""}
                  default:
                     if isLetter(l.ch) {
                        // 读取标识符(表名、列名等)
                        tok.Literal = l.readIdentifier()
                        // 将读取到的标识符转换为对应的标记类型(转换为对应tokenType)
                        tok.Type = l.lookupIdentifier(tok.Literal)
                        return tok
                     } else if isDigit(l.ch) {
                        tok.Type = INT_LIT
                        // 读取数字
                        tok.Literal = l.readNumber()
                        return tok
                     } else {
                        tok = Token{Type: ERROR, Literal: string(l.ch)}
                     }
                  }
              
                  l.readChar()
                  return tok
              }
              
              func (l *Lexer) skipWhitespace() {
                  for unicode.IsSpace(l.ch) {
                     l.readChar()
                  }
              }
              
              // 读取标识符,如:列名、表名
              func (l *Lexer) readIdentifier() string {
                  var ident bytes.Buffer
                  for isLetter(l.ch) || isDigit(l.ch) {
                     ident.WriteRune(l.ch)
                     l.readChar()
                  }
                  return ident.String()
              }
              
              func (l *Lexer) readNumber() string {
                  var num bytes.Buffer
                  for isDigit(l.ch) {
                     num.WriteRune(l.ch)
                     l.readChar()
                  }
                  return num.String()
              }
              
              // 读取字符串字面量
              func (l *Lexer) readString() string {
                  var str bytes.Buffer
                  l.readChar() // 跳过开始的引号
                  for l.ch != '\'' && l.ch != 0 {
                     str.WriteRune(l.ch)
                     l.readChar()
                  }
                  l.readChar() // 跳过结束的引号
                  return str.String()
              }
              
              func (l *Lexer) peekChar() rune {
                  ch, _, err := l.reader.ReadRune()
                  if err != nil {
                     return 0
                  }
                  l.reader.UnreadRune()
                  return ch
              }
              
              // lookupIdentifier 查找标识符类型
              // 将标识符转换为对应的标记类型
              // 识别 SQL 关键字
              func (l *Lexer) lookupIdentifier(ident string) TokenType {
                  switch strings.ToUpper(ident) {
                  case "SELECT":
                     return SELECT
                  case "FROM":
                     return FROM
                  case "WHERE":
                     return WHERE
                  case "CREATE":
                     return CREATE
                  case "TABLE":
                     return TABLE
                  case "INSERT":
                     return INSERT
                  case "INTO":
                     return INTO
                  case "VALUES":
                     return VALUES
                  case "UPDATE":
                     return UPDATE
                  case "SET":
                     return SET
                  case "DELETE":
                     return DELETE
                  case "DROP":
                     return DROP
                  case "PRIMARY":
                     return PRIMARY
                  case "KEY":
                     return KEY
                  case "INT":
                     return INT
                  case "TEXT":
                     return TEXT
                  case "LIKE":
                     return LIKE
                  default:
                     return IDENT
                  }
              }
              
              // 判断字符是否为字母或下划线
              func isLetter(ch rune) bool {
                  return unicode.IsLetter(ch) || ch == '_'
              }
              
              // 判断字符是否为数字
              func isDigit(ch rune) bool {
                  return unicode.IsDigit(ch)
              }
              

              模块二:抽象语法树 (AST)

              思路

              抽象语法树用于表示 SQL 语句的语法结构。我们需要为每种 SQL 语句定义相应的节点类型。

              我们新建internal/ast/ast.go。

              ast.go构建不同SQL语句的结构,以及查询结果等。

              这个 AST 定义文件是 SQL 解析器的核心部分,它:

              • 定义了所有 SQL 语句的语法结构
              • 提供了类型安全的方式来表示 SQL 语句
              • 支持复杂的表达式和条件
              • 便于后续的语义分析和执行

                通过这个 AST,我们可以:

              • 验证 SQL 语句的语法正确性
              • 进行语义分析
              • 生成执行计划
              • 执行 SQL 语句

              示例:

              SELECT id, name FROM users WHERE age > 18;
              
              交给语法分析器parser解析后的AST结构为:
              
              SelectStatement
              ├── Fields
              │   ├── Identifier{Value: "id"}
              │   └── Identifier{Value: "name"}
              ├── TableName: "users"
              └── Where
                  └── BinaryExpression
                      ├── Left: Identifier{Value: "age"}
                      ├── Operator: ">"
                      └── Right: IntegerLiteral{Value: "18"}
              

              全部代码

              package ast
              
              import (
                  "cursor-db/internal/lexer"
                  "fmt"
              )
              
              // Node 表示AST中的节点
              type Node interface {
                  TokenLiteral() string
              }
              
              // Statement 表示SQL语句
              type Statement interface {
                  Node
                  statementNode()
              }
              
              // Expression 表示表达式
              type Expression interface {
                  Node
                  expressionNode()
              }
              
              // Program 表示整个SQL程序
              type Program struct {
                  Statements []Statement
              }
              
              // SelectStatement 表示SELECT语句
              type SelectStatement struct {
                  Token     lexer.Token
                  Fields    []Expression
                  TableName string
                  Where     Expression
              }
              
              func (ss *SelectStatement) statementNode()       {}
              func (ss *SelectStatement) TokenLiteral() string { return ss.Token.Literal }
              
              // CreateTableStatement 表示CREATE TABLE语句
              type CreateTableStatement struct {
                  Token     lexer.Token
                  TableName string
                  Columns   []ColumnDefinition
              }
              
              func (cts *CreateTableStatement) statementNode()       {}
              func (cts *CreateTableStatement) TokenLiteral() string { return cts.Token.Literal }
              
              // InsertStatement 表示INSERT语句
              type InsertStatement struct {
                  Token     lexer.Token
                  TableName string
                  Values    []Expression
              }
              
              func (is *InsertStatement) statementNode()       {}
              func (is *InsertStatement) TokenLiteral() string { return is.Token.Literal }
              
              // ColumnDefinition 表示列定义
              type ColumnDefinition struct {
                  Name     string
                  Type     string
                  Primary  bool
                  Nullable bool
              }
              
              // Cell 表示数据单元格
              type Cell struct {
                  Type      CellType
                  IntValue  int32
                  TextValue string
              }
              
              // CellType 表示单元格类型
              type CellType int
              
              const (
                  CellTypeInt CellType = iota
                  CellTypeText
              )
              
              // AsText 返回单元格的文本值
              func (c *Cell) AsText() string {
                  switch c.Type {
                  case CellTypeInt:
                     s := fmt.Sprintf("%d", c.IntValue)
                     return s
                  case CellTypeText:
                     return c.TextValue
                  default:
                     return "NULL"
                  }
              }
              
              // AsInt 返回单元格的整数值
              func (c *Cell) AsInt() int32 {
                  if c.Type == CellTypeInt {
                     return c.IntValue
                  }
                  return 0
              }
              
              // String 返回单元格的字符串表示
              func (c Cell) String() string {
                  switch c.Type {
                  case CellTypeInt:
                     return fmt.Sprintf("%d", c.IntValue)
                  case CellTypeText:
                     return c.TextValue
                  default:
                     return "NULL"
                  }
              }
              
              // Results 表示查询结果
              type Results struct {
                  Columns []ResultColumn
                  Rows    [][]Cell
              }
              
              // ResultColumn 表示结果列
              type ResultColumn struct {
                  Name string
                  Type string
              }
              
              // StarEwww.devze.comxpression 表示星号表达式,如:select * from users;
              type StarExpression struct{}
              
              func (se *StarExpression) expressionNode()      {}
              func (se *StarExpression) TokenLiteral() string { return "*" }
              
              // LikeExpression 表示LIKE表达式, 如 LIKE '%b'
              type LikeExpression struct {
                  Token   lexer.Token
                  Left    Expression
                  Pattern string
              }
              
              func (le *LikeExpression) expressionNode()      {}
              func (le *LikeExpression) TokenLiteral() string { return le.Token.Literal }
              
              // BinaryExpression 表示二元表达式,如比较运算,大于小于比较等
              type BinaryExpression struct {
                  Token    lexer.Token
                  Left     Expression
                  Operator string
                  Right    Expression
              }
              
              func (be *BinaryExpression) expressionNode()      {}
              func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal }
              
              // IntegerLiteral 表示整数字面量
              type IntegerLiteral struct {
                  Token lexer.Token
                  Value string
              }
              
              func (il *IntegerLiteral) expressionNode()      {}
              func (il *IntegerLiteral) TokenLiteral() string { return il.Token.Literal }
              
              // StringLiteral 表示字符串字面量
              type StringLiteral struct {
                  Token lexer.Token
                  Value string
              }
              
              func (sl *StringLiteral) expressionNode()      {}
              func (sl *StringLiteral) TokenLiteral() string { return sl.Token.Literal }
              
              // Identifier 表示标识符(如列名)
              type Identifier struct {
                  Token lexer.Token
                  Value string
              }
              
              func (i *Identifier) expressionNode()      {}
              func (i *Identifier) TokenLiteral() string { return i.Token.Literal }
              
              // UpdateStatement 表示UPDATE语句
              type UpdateStatement struct {
                  Token     lexer.Token
                  TableName string
                  Set       []SetClause
                  Where     Expression
              }
              
              func (us *UpdateStatement) statementNode()       {}
              func (us *UpdateStatement) TokenLiteral() string { return us.Token.Literal }
              
              // SetClause 表示SET子句
              type SetClause struct {
                  Column string
                  Value  Expression
              }
              
              // DeleteStatement 表示DELETE语句
              type DeleteStatement struct {
                  Token     lexer.Token
                  TableName string
                  Where     Expression
              }
              
              func (ds *DeleteStatement) statementNode()       {}
              func (ds *DeleteStatement) TokenLiteral() string { return ds.Token.Literal }
              
              // DropTableStatement 表示DROP TABLE语句
              type DropTableStatement struct {
                  Token     lexer.Token
                  TableName string
              }
              
              func (ds *DropTableStatement) statementNode()       {}
              func (ds *DropTableStatement) TokenLiteral() string { return ds.Token.Literal }
              

              模块三:语法分析器 (Parser)

              思路

              语法分析器负责将词法分析器生成的标记序列转换为抽象语法树。将token序列构建成ast。

              SQL 解析器(Parser)的实现,负责将词法分析器(Lexer)产生的标记(Token)序列转换为抽象语法树(AST)。

              语法分析器SQL 数据库系统的关键组件,负责:

              • 验证 SQL 语句的语法正确性
              • 构建抽象语法树
              • 为后续的语义分析和执行提供基础

              我们新建internal/parser/parser.go。

              示例:

              CREATE TABLE users (
                  id INT PRIMARY KEY,
                  name TEXT
              );
              
              
              解析过程:
              1. 识别 CREATE 关键字
              2. 解析 TABLE 关键字
              3. 解析表名 "users"
              4. 解析列定义:
                  列名 "id",类型 INT,主键
                  列名 "name",类型 TEXT
              5. 生成 CREATE TABLE 语句的 AST
              编程客栈

              全部代码

              package parser
              
              import (
                  "fmt"
                  "ziyi.db.com/internal/ast"
                  "ziyi.db.com/internal/lexer"
              )
              
              // Parser 表示语法分析器
              // 维护当前和下一个标记,实现向前查看(lookahead)
              // 记录解析过程中的错误
              type Parser struct {
                  l         *lexer.Lexer // 词法分析器
                  curToken  lexer.Token  // 当前标记
                  peekToken lexer.Token  // 下一个标记
                  errors    []string     // 错误信息
              }
              
              // Newparser 创建新的语法分析器
              // 初始化解析器
              // 预读两个标记
              func NewParser(l *lexer.Lexer) *Parser {
                  p := &Parser{
                     l:      l,
                     errors: []string{},
                  }
              
                  // 读取两个token,设置curToken和peekToken
                  p.nextToken()
                  p.nextToken()
              
                  return p
              }
              
              // nextToken 移动到下一个词法单元
              func (p *Parser) nextToken() {
                  p.curToken = p.peekToken
                  p.peekToken = p.l.NextToken()
              }
              
              // ParseProgram 解析整个程序
              // 解析整个 SQL 程序
              // 循环解析每个语句直到结束
              func (p *Parser) ParseProgram() (*ast.Program, error) {
                  program := &ast.Program{
                     Statements: []ast.Statement{},
                  }
              
                  for p.curToken.Type != lexer.EOF {
                     stmt, err := p.parseStatement()
                     if err != nil {
                        return nil, err
                     }
                     if stmt != nil {
                        program.Statements = append(program.Statements, stmt)
                     }
                     p.nextToken()
                  }
              
                  return program, nil
              }
              
              // parseStatement 解析语句
              // 根据当前标记类型选择相应的解析方法
              func (p *Parser) parseStatement() (ast.Statement, error) {
                  switch p.curToken.Type {
                  case lexer.CREATE:
                     return p.parseCreateTableStatement()
                  case lexer.INSERT:
                     return p.parseInsertStatement()
                  case lexer.SELECT:
                     return p.parseSelectStatement()
                  case lexer.UPDATE:
                     return p.parseUpdateStatement()
                  case lexer.DELETE:
                     return p.parseDeleteStatement()
                  case lexer.DROP:
                     return p.parseDropTableStatement()
                  case lexer.SEMI:
                     return nil, nil
                  default:
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Type)
                  }
              }
              
              // parseCreateTableStatement 解析CREATE TABLE语句
              // 解析表名
              // 解析列定义
              // 处理主键约束
              func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) {
                  stmt := &ast.CreateTableStatement{Token: p.curToken}
              
                  if !p.expectPeek(lexer.TABLE) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  if !p.expectPeek(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
                  stmt.TableName = p.curToken.Literal
              
                  if !p.expectPeek(lexer.LPAREN) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  // 解析列定义
                  for !p.peekTokenIs(lexer.RPAREN) {
                     p.nextToken()
              
                     if !p.curTokenIs(lexer.IDENT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                     }
              
                     col := ast.ColumnDefinition{
                        Name: p.curToken.Literal,
                     }
              
                     if !p.expectPeek(lexer.INT) && !p.expectPeek(lexer.TEXT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                     }
              
                     col.Type = string(p.curToken.Type)
              
                   uNqNS  if p.peekTokenIs(lexer.PRIMARY) {
                        p.nextToken()
                        if !p.expectPeek(lexer.KEY) {
                           return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                        }
                        col.Primary = true
                     }
              
                     stmt.Columns = append(stmt.Columns, col)
              
                     if p.peekTokenIs(lexer.COMMA) {
                        p.nextToken()
                     }
                  }
              
                  if !p.expectPeek(lexer.RPAREN) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  return stmt, nil
              }
              
              // parseInsertStatement 解析INSERT语句
              // 解析表名
              // 解析 VALUES 子句
              // 解析插入的值
              func (p *Parser) parseInsertStatement() (*ast.InsertStatement, error) {
                  stmt := &ast.InsertStatement{Token: p.curToken}
              
                  if !p.expectPeek(lexer.INTO) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  if !p.expectPeek(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
                  stmt.TableName = p.curToken.Literal
              
                  if !p.expectPeek(lexer.VALUES) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  if !p.expectPeek(lexer.LPAREN) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  // 解析值列表
                  for !p.peekTokenIs(lexer.RPAREN) {
                     p.nextToken()
              
                     expr, err := p.parseExpression()
                     if err != nil {
                        return nil, err
                     }
              
                     stmt.Values = append(stmt.Values, expr)
              
                     if p.peekTokenIs(lexer.COMMA) {
                        p.nextToken()
                     }
                  }
              
                  if !p.expectPeek(lexer.RPAREN) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  return stmt, nil
              }
              
              // parseSelectStatement 解析SELECT语句
              // 解析选择列表
              // 解析 FROM 子句
              // 解析 WHERE 子句
              func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
                  stmt := &ast.SelectStatement{Token: p.curToken}
              
                  // 解析选择列表
                  for !p.peekTokenIs(lexer.FROM) {
                     p.nextToken()
              
                     if p.curToken.Type == lexer.ASTERISK {
                        stmt.Fields = append(stmt.Fields, &ast.StarExpression{})
                        break
                     }
              
                     expr, err := p.parseExpression()
                     if err != nil {
                        return nil, err
                     }
              
                     stmt.Fields = append(stmt.Fields, expr)
              
                     if p.peekTokenIs(lexer.COMMA) {
                        p.nextToken()
                     }
                  }
              
                  if !p.expectPeek(lexer.FROM) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  if !p.expectPeek(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
                  stmt.TableName = p.curToken.Literal
              
                  // 解析WHERE子句
                  if p.peekTokenIs(lexer.WHERE) {
                     p.nextToken()
                     p.nextToken()
              
                     // 解析左操作数(列名)
                     if !p.curTokenIs(lexer.IDENT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                     }
                     left := &ast.Identifier{
                        Token: p.curToken,
                        Value: p.curToken.Literal,
                     }
              
                     // 解析操作符
                     p.nextToken()
                     operator := p.curToken
              
                     // 处理LIKE操作符
                     if p.curTokenIs(lexer.LIKE) {
                        p.nextToken()
                        if !p.curTokenIs(lexer.STRING) {
                           return nil, fmt.Errorf("You have an error javascriptin your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                        }
                        // 移除字符串字面量的引号
                        pattern := p.curToken.Literal
                        if len(pattern) >= 2 && (pattern[0] == '\'' || pattern[0] == '"') {
                           pattern = pattern[1 : len(pattern)-1]
                        }
                        stmt.Where = &ast.LikeExpression{
                           Token:   operator,
                           Left:    left,
                           Pattern: pattern,
                        }
                        return stmt, nil
                     }
              
                     // 处理其他操作符
                     if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexe编程r.GT) && !p.curTokenIs(lexer.LT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
                     }
              
                     // 解析右操作数
                     p.nextToken()
                     right, err := p.parseExpression()
                     if err != nil {
                        return nil, err
                     }
              
                     stmt.Where = &ast.BinaryExpression{
                        Token:    operator,
                        Left:     left,
                        Operator: operator.Literal,
                        Right:    right,
                     }
                  }
              
                  return stmt, nil
              }
              
              // parseUpdateStatement 解析UPDATE语句
              // 解析表名
              // 解析 SET 子句
              // 解析 WHERE 子句
              func (p *Parser) parseUpdateStatement() (*ast.UpdateStatement, error) {
                  stmt := &ast.UpdateStatement{Token: p.curToken}
              
                  if !p.expectPeek(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
                  stmt.TableName = p.curToken.Literal
              
                  if !p.expectPeek(lexer.SET) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  // 解析SET子句
                  for {
                     p.nextToken()
                     if !p.curTokenIs(lexer.IDENT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                     }
                     column := p.curToken.Literal
              
                     if !p.expectPeek(lexer.EQ) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                     }
              
                     p.nextToken()
                     value, err := p.parseExpression()
                     if err != nil {
                        return nil, err
                     }
              
                     stmt.Set = append(stmt.Set, ast.SetClause{
                        Column: column,
                        Value:  value,
                     })
              
                     if !p.peekTokenIs(lexer.COMMA) {
                        break
                     }
                     p.nextToken()
                  }
              
                  // 解析WHERE子句
                  if p.peekTokenIs(lexer.WHERE) {
                     p.nextToken()
                     p.nextToken()
              
                     // 解析左操作数(列名)
                     if !p.curTokenIs(lexer.IDENT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                     }
                     left := &ast.Identifier{
                        Token: p.curToken,
                        Value: p.curToken.Literal,
                     }
              
                     // 解析操作符
                     p.nextToken()
                     operator := p.curToken
                     if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
                     }
              
                     // 解析右操作数
                     p.nextToken()
                     right, err := p.parseExpression()
                     if err != nil {
                        return nil, err
                     }
              
                     stmt.Where = &ast.BinaryExpression{
                        Token:    operator,
                        Left:     left,
                        Operator: operator.Literal,
                        Right:    right,
                     }
                  }
              
                  return stmt, nil
              }
              
              // parseDeleteStatement 解析DELETE语句
              // 解析表名
              // 解析 WHERE 子句
              func (p *Parser) parseDeleteStatement() (*ast.DeleteStatement, error) {
                  stmt := &ast.DeleteStatement{Token: p.curToken}
              
                  if !p.expectPeek(lexer.FROM) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  if !p.expectPeek(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
                  stmt.TableName = p.curToken.Literal
              
                  // 解析WHERE子句
                  if p.peekTokenIs(lexer.WHERE) {
                     p.nextToken()
                     p.nextToken()
              
                     // 解析左操作数(列名)
                     if !p.curTokenIs(lexer.IDENT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                     }
                     left := &ast.Identifier{
                        Token: p.curToken,
                        Value: p.curToken.Literal,
                     }
              
                     // 解析操作符
                     p.nextToken()
                     operator := p.curToken
                     if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
                     }
              
                     // 解析右操作数
                     p.nextToken()
                     right, err := p.parseExpression()
                     if err != nil {
                        return nil, err
                     }
              
                     stmt.Where = &ast.BinaryExpression{
                        Token:    operator,
                        Left:     left,
                        Operator: operator.Literal,
                        Right:    right,
                     }
                  }
              
                  return stmt, nil
              }
              
              // parseDropTableStatement 解析DROP TABLE语句
              func (p *Parser) parseDropTableStatement() (*ast.DropTableStatement, error) {
                  stmt := &ast.DropTableStatement{Token: p.curToken}
              
                  if !p.expectPeek(lexer.TABLE) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
              
                  if !p.expectPeek(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
                  }
                  stmt.TableName = p.curToken.Literal
              
                  return stmt, nil
              }
              
              // parseExpression 解析表达式(字面量int、string类型,标识符列名、表名等)
              // 解析各种类型的表达式
              // 支持字面量、标识符等
              func (p *Parser) parseExpression() (ast.Expression, error) {
                  switch p.curToken.Type {
                  case lexer.INT_LIT:
                     return &ast.IntegerLiteral{
                        Token: p.curToken,
                        Value: p.curToken.Literal,
                     }, nil
                  case lexer.STRING:
                     return &ast.StringLiteral{
                        Token: p.curToken,
                        Value: p.curToken.Literal,
                     }, nil
                  case lexer.IDENT:
                     return &ast.Identifier{
                        Token: p.curToken,
                        Value: p.curToken.Literal,
                     }, nil
                  default:
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Type)
                  }
              }
              
              // curTokenIs 检查当前token是否为指定类型
              func (p *Parser) curTokenIs(t lexer.TokenType) bool {
                  return p.curToken.Type == t
              }
              
              // peekTokenIs 检查下一个token是否为指定类型
              func (p *Parser) peekTokenIs(t lexer.TokenType) bool {
                  return p.peekToken.Type == t
              }
              
              // expectPeek 检查下一个词法单元是否为预期类型
              func (p *Parser) expectPeek(t lexer.TokenType) bool {
                  if p.peekTokenIs(t) {
                     p.nextToken()
                     return true
                  }
                  return false
              }
              
              // parseWhereClause 解析WHERE子句
              func (p *Parser) parseWhereClause() (ast.Expression, error) {
                  p.nextToken()
              
                  // 解析左操作数(列名)
                  if !p.curTokenIs(lexer.IDENT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                  }
                  left := &ast.Identifier{
                     Token: p.curToken,
                     Value: p.curToken.Literal,
                  }
              
                  // 解析操作符
                  p.nextToken()
                  operator := p.curToken
              
                  // 处理LIKE操作符
                  if p.curTokenIs(lexer.LIKE) {
                     p.nextToken()
                     if !p.curTokenIs(lexer.STRING) {
                        return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
                     }
                     // 移除字符串字面量的引号
                     pattern := p.curToken.Literal
                     if len(pattern) >= 2 && (pattern[0] == '\'' || pattern[0] == '"') {
                        pattern = pattern[1 : len(pattern)-1]
                     }
                     return &ast.LikeExpression{
                        Token:   operator,
                        Left:    left,
                        Pattern: pattern,
                     }, nil
                  }
              
                  // 处理其他操作符
                  if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
                     return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
                  }
              
                  // 解析右操作数
                  p.nextToken()
                  right, err := p.parseExpression()
                  if err != nil {
                     return nil, err
                  }
              
                  return &ast.BinaryExpression{
                     Token:    operator,
                     Left:     left,
                     Operator: operator.Literal,
                     Right:    right,
                  }, nil
              }
              

              模块四:存储引擎 (Storage)

              思路

              存储引擎负责实际的数据存储和检索操作,执行引擎中的数据操作CURD。

              我们需要新建internal/storage/memory.go文件。

              这是内存存储引擎的实现,负责处理 SQL 语句的实际执行和数据存储。

              本期存储引擎实现了:

              • 完整的数据操作(CRUD)
              • 主键约束
              • 索引支持
              • 类型检查
              • 条件评估
              • 模式匹配

              它是 SQL 数据库系统的核心组件,负责:

              • 数据存储和管理
              • 查询执行
              • 数据完整性维护
              • 性能优化(通过索引)

              原理解析:

              -- 创建表
              CREATE TABLE users (
                  id INT PRIMARY KEY,
                  name TEXT
              );
              
              -- 插入数据
              INSERT INTO users VALUES (1, 'Alice');
              
              -- 查询数据
              SELECT * FROM users WHERE name LIKE 'A%';
              
              -- 更新数据
              UPDATE users SET name = 'Bob' WHERE id = 1;
              
              -- 删除数据
              DELETE FROM users WHERE id = 1;
              
              
              存储引擎会根据解析后的语法分析器,创建出对应的数据结构(如:在内存中),以及对外暴露对该数据的操作(CRUD)
              

              全部代码

              // internal/storage/memory.go
              package storage
              
              import (
                  "fmt"
                  "regexp"
                  "strconv"
                  "strings"
                  "ziyi.db.com/internal/ast"
              )
              
              // MemoryBackend 内存存储引擎,管理所有表
              type MemoryBackend struct {
                  tables map[string]*Table
              }
              
              // Table 数据表,包含列定义、数据行和索引
              type Table struct {
                  Name    string
                  Columns []ast.ColumnDefinition
                  Rows    [][]ast.Cell
                  Indexes map[string]*Index // 值到行索引的映射
              }
              
              // Index 索引,用于加速查询
              type Index struct {
                  Column string
                  Values map[string][]int // 值到行索引的映射
              }
              
              // NewMemoryBackend 创建新的内存存储引擎
              func NewMemoryBackend() *MemoryBackend {
                  return &MemoryBackend{
                     tables: make(map[string]*Table),
                  }
              }
              
              // CreateTable 创建表
              // 验证表名唯一性
              // 创建表结构
              // 为主键列创建索引
              func (b *MemoryBackend) CreateTable(stmt *ast.CreateTableStatement) error {
                  if _, exists := b.tables[stmt.TableName]; exists {
                     return fmt.Errorf("Table '%s' already exists", stmt.TableName)
                  }
              
                  table := &Table{
                     Name:    stmt.TableName,
                     Columns: stmt.Columns,
                     Rows:    make([][]ast.Cell, 0),
                     Indexes: make(map[string]*Index),
                  }
              
                  // 为主键创建索引
                  for _, col := range stmt.Columns {
                     if col.Primary {
                        table.Indexes[col.Name] = &Index{
                           Column: col.Name,
                           Values: make(map[string][]int),
                        }
                     }
                  }
              
                  b.tables[stmt.TableName] = table
                  return nil
              }
              
              // Insert 插入数据
              // 验证表存在性
              // 检查数据完整性
              // 处理主键约束
              // 维护索引
              func (b *MemoryBackend) Insert(stmt *ast.InsertStatement) error {
                  table, exists := b.tables[stmt.TableName]
                  if !exists {
                     return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
                  }
              
                  if len(stmt.Values) != len(table.Columns) {
                     return fmt.Errorf("Column count doesn't match value count at row 1")
                  }
              
                  // 转换值
                  row := make([]ast.Cell, len(stmt.Values))
                  for i, expr := range stmt.Values {
                     value, err := evaLuateExpression(expr)
                     if err != nil {
                        return err
                     }
              
                     switch v := value.(type) {
                     case string:
                        if table.Columns[i].Type == "INT" {
                           // 尝试将字符串转换为整数
                           intVal, err := strconv.ParseInt(v, 10, 32)
                           if err != nil {
                              return fmt.Errorf("Incorrect integer value: '%s' for column '%s'", v, table.Columns[i].Name)
                           }
                           row[i] = ast.Cell{Type: ast.CellTypeInt, IntValue: int32(intVal)}
                        } else {
                           row[i] = ast.Cell{Type: ast.CellTypeText, TextValue: v}
                        }
                     case int32:
                        row[i] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}
                     default:
                        return fmt.Errorf("Unsupported value type: %T for column '%s'", value, table.Columns[i].Name)
                     }
                  }
              
                  // 检查主键约束
                  for i, col := range table.Columns {
                     if col.Primary {
                        key := row[i].String()
                        if _, exists := table.Indexes[col.Name].Values[key]; exists {
                           return fmt.Errorf("Duplicate entry '%s' for key '%s'", key, col.Name)
                        }
                     }
                  }
              
                  // 插入数据
                  rowIndex := len(table.Rows)
                  table.Rows = append(table.Rows, row)
              
                  // 更新索引
                  for i, col := range table.Columns {
                     if col.Primary {
                        key := row[i].String()
                        table.Indexes[col.Name].Values[key] = append(table.Indexes[col.Name].Values[key], rowIndex)
                     }
                  }
              
                  return nil
              }
              
              // Select 查询数据
              // 支持 SELECT * 和指定列
              // 处理 WHERE 条件
              // 返回查询结果
              func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*ast.Results, error) {
                  table, exists := b.tables[stmt.TableName]
                  if !exists {
                     return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
                  }
              
                  results := &ast.Results{
                     Columns: make([]ast.ResultColumn, 0),
                     Rows:    make([][]ast.Cell, 0),
                  }
              
                  // 处理选择列表
                  if len(stmt.Fields) == 1 && stmt.Fields[0].(*ast.StarExpression) != nil {
                     // SELECT *
                     for _, col := range table.Columns {
                        results.Columns = append(results.Columns, ast.ResultColumn{
                           Name: col.Name,
                           Type: col.Type,
                        })
                     }
                  } else {
                     // 处理指定的列
                     for _, expr := range stmt.Fields {
                        switch e := expr.(type) {
                        case *ast.Identifier:
                           // 查找列
                           found := false
                           for _, col := range table.Columns {
                              if col.Name == e.Value {
                                 results.Columns = append(results.Columns, ast.ResultColumn{
                                    Name: col.Name,
                                    Type: col.Type,
                                 })
                                 found = true
                                 break
                              }
                           }
                           if !found {
                              return nil, fmt.Errorf("Unknown column '%s' in 'field list'", e.Value)
                           }
                        default:
                           return nil, fmt.Errorf("Unsupported select expression type")
                        }
                     }
                  }
              
                  // 处理WHERE子句
                  for _, row := range table.Rows {
                     if stmt.Where != nil {
                        match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)
                        if err != nil {
                           return nil, err
                        }
                        if !match {
                           continue
                        }
                     }
              
                     // 构建结果行
                     resultRow := make([]ast.Cell, len(results.Columns))
                     for j, col := range results.Columns {
                        // 查找列在原始行中的位置
                        for k, tableCol := range table.Columns {
                           if tableCol.Name == col.Name {
                              resultRow[j] = row[k]
                              break
                           }
                        }
                     }
                     results.Rows = append(results.Rows, resultRow)
                  }
              
                  return results, nil
              }
              
              // Update 执行UPDATE操作
              // 验证表和列存在性
              // 处理 WHERE 条件
              // 更新符合条件的行
              func (mb *MemoryBackend) Update(stmt *ast.UpdateStatement) error {
                  table, ok := mb.tables[stmt.TableName]
                  if !ok {
                     return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
                  }
              
                  // 获取列索引
                  columnIndices := make(map[string]int)
                  for i, col := range table.Columns {
                     columnIndices[col.Name] = i
                  }
              
                  // 验证所有要更新的列是否存在
                  for _, set := range stmt.Set {
                     if _, ok := columnIndices[set.Column]; !ok {
                        return fmt.Errorf("Unknown column '%s' in 'field list'", set.Column)
                     }
                  }
              
                  // 更新符合条件的行
                  for i := range table.Rows {
                     if stmt.Where != nil {
                        // 评估WHERE条件
                        result, err := evaluateWhereCondition(stmt.Where, table.Rows[i], table.Columns)
                        if err != nil {
                           return err
                        }
                        if !result {
                           continue
                        }
                     }
              
                     // 更新行
                     for _, set := range stmt.Set {
                        colIndex := columnIndices[set.Column]
                        value, err := evaluateExpression(set.Value)
                        if err != nil {
                           return err
                        }
              
                        switch v := value.(type) {
                        case int32:
                           table.Rows[i][colIndex] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}
                        case string:
                           table.Rows[i][colIndex] = ast.Cell{Type: ast.CellTypeText, TextValue: v}
                        default:
                           return fmt.Errorf("Unsupported value type: %T for column '%s'", value, set.Column)
                        }
                     }
                  }
              
                  return nil
              }
              
              // Delete 执行DELETE操作
              // 验证表存在性
              // 处理 WHERE 条件
              // 删除符合条件的行
              func (mb *MemoryBackend) Delete(stmt *ast.DeleteStatement) error {
                  table, ok := mb.tables[stmt.TableName]
                  if !ok {
                     return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
                  }
              
                  // 找出要删除的行
                  rowsToDelete := make([]int, 0)
                  for i := range table.Rows {
                     if stmt.Where != nil {
                        // 评估WHERE条件
                        result, err := evaluateWhereCondition(stmt.Where, table.Rows[i], table.Columns)
                        if err != nil {
                           return err
                        }
                        if !result {
                           continue
                        }
                     }
                     rowsToDelete = append(rowsToDelete, i)
                  }
              
                  // 从后向前删除行,以避免索引变化
                  for i := len(rowsToDelete) - 1; i >= 0; i-- {
                     rowIndex := rowsToDelete[i]
                     table.Rows = append(table.Rows[:rowIndex], table.Rows[rowIndex+1:]...)
                  }
              
                  return nil
              }
              
              // DropTable 删除表
              // 验证表是否存在
              // 从存储引擎中删除表
              func (mb *MemoryBackend) DropTable(stmt *ast.DropTableStatement) error {
                  if _, exists := mb.tables[stmt.TableName]; !exists {
                     return fmt.Errorf("Unknown table '%s'", stmt.TableName)
                  }
              
                  delete(mb.tables, stmt.TableName)
                  return nil
              }
              
              // evaluateExpression 评估表达式的值
              // 计算表达式的值
              // 处理不同类型的数据
              func evaluateExpression(expr ast.Expression) (interface{}, error) {
                  switch e := expr.(type) {
                  case *ast.IntegerLiteral:
                     val, err := strconv.ParseInt(e.Value, 10, 32)
                     if err != nil {
                        return nil, fmt.Errorf("Incorrect integer value: '%s'", e.Value)
                     }
                     return int32(val), nil
                  case *ast.StringLiteral:
                     return e.Value, nil
                  case *ast.Identifier:
                     return nil, fmt.Errorf("Cannot evaluate identifier: '%s'", e.Value)
                  default:
                     return nil, fmt.Errorf("Unknown expression type: %T", expr)
                  }
              }
              
              // matchLikePattern 检查字符串是否匹配LIKE模式
              func matchLikePattern(str, pattern string) bool {
                  // 将SQL LIKE模式转换为正则表达式
                  regexPattern := "^"
                  for i := 0; i < len(pattern); i++ {
                     switch pattern[i] {
                     case '%':
                        regexPattern += ".*"
                     case '_':
                        regexPattern += "."
                     case '\\':
                        if i+1 < len(pattern) {
                           regexPattern += "\\" + string(pattern[i+1])
                           i++
                        }
                     default:
                        // 转义正则表达式特殊字符
                        if strings.ContainsAny(string(pattern[i]), ".+*?^$()[]{}|") {
                           regexPattern += "\\" + string(pattern[i])
                        } else {
                           regexPattern += string(pattern[i])
                        }
                     }
                  }
                  regexPattern += "$"
              
                  // 编译正则表达式
                  re, err := regexp.Compile(regexPattern)
                  if err != nil {
                     return false
                  }
              
                  // 执行匹配
                  return re.MatchString(str)
              }
              
              // evaluateWhereCondition 评估WHERE条件
              // 评估 WHERE 条件
              // 支持比较运算符和 LIKE 操作符
              func evaluateWhereCondition(expr ast.Expression, row []ast.Cell, columns []ast.ColumnDefinition) (bool, error) {
                  switch e := expr.(type) {
                  case *ast.BinaryExpression:
                     // 获取左操作数的值
                     leftValue, err := getColumnValue(e.Left, row, columns)
                     if err != nil {
                        return false, err
                     }
              
                     // 获取右操作数的值
                     rightValue, err := getColumnValue(e.Right, row, columns)
                     if err != nil {
                        return false, err
                     }
              
                     // 根据操作符比较值
                     switch e.Operator {
                     case "=":
                        return compareValues(leftValue, rightValue, "=")
                     case ">":
                        return compareValues(leftValue, rightValue, ">")
                     case "<":
                        return compareValues(leftValue, rightValue, "<")
                     default:
                        return false, fmt.Errorf("Unknown operator: '%s'", e.Operator)
                     }
                  case *ast.LikeExpression:
                     // 获取左操作数的值
                     leftValue, err := getColumnValue(e.Left, row, columns)
                     if err != nil {
                        return false, err
                     }
              
                     // 确保左操作数是字符串类型
                     strValue, ok := leftValue.(string)
                     if !ok {
                        return false, fmt.Errorf("LIKE operator requires string operand")
                     }
              
                     // 执行LIKE匹配
                     return matchLikePattern(strValue, e.Pattern), nil
                  default:
                     return false, fmt.Errorf("Unknown expression type: %T", expr)
                  }
              }
              
              // compareValues 比较两个值
              func compareValues(left, right interface{}, operator string) (bool, error) {
                  switch l := left.(type) {
                  case int32:
                     if r, ok := right.(int32); ok {
                        switch operator {
                        case "=":
                           return l == r, nil
                        case ">":
                           return l > r, nil
                        case "<":
                           return l < r, nil
                        }
                     }
                  case string:
                     if r, ok := right.(string); ok {
                        switch operator {
                        case "=":
                           return l == r, nil
                        case ">":
                           return l > r, nil
                        case "<":
                           return l < r, nil
                        }
                     }
                  }
                  return false, fmt.Errorf("Cannot compare values of different types: %T and %T", left, right)
              }
              
              // getColumnValue 获取列的值
              func getColumnValue(expr ast.Expression, row []ast.Cell, columns []ast.ColumnDefinition) (interface{}, error) {
                  switch e := expr.(type) {
                  case *ast.Identifier:
                     // 查找列索引
                     for i, col := range columns {
                        if col.Name == e.Value {
                           switch row[i].Type {
                           case ast.CellTypeInt:
                              return row[i].IntValue, nil
                           case ast.CellTypeText:
                              return row[i].TextValue, nil
                           default:
                              return nil, fmt.Errorf("Unknown cell type: %v", row[i].Type)
                           }
                        }
                     }
                     return nil, fmt.Errorf("Unknown column '%s' in 'where clause'", e.Value)
                  case *ast.IntegerLiteral:
                     val, err := strconv.ParseInt(e.Value, 10, 32)
                     if err != nil {
                        return nil, fmt.Errorf("Incorrect integer value: '%s'", e.Value)
                     }
                     return int32(val), nil
                  case *ast.StringLiteral:
                     return e.Value, nil
                  default:
                     return nil, fmt.Errorf("Unknown expression type: %T", expr)
                  }
              }
              
              //后续拓展新的存储引擎,如落地到文件...
              

              模块五:REPL 交互界面

              思路

              最后,我们需要实现一个交互式的命令行界面,让用户可以输入 SQL 命令并查看结果。

              这是 ZiyiDB 的主程序,实现了一个交互式的 SQL 命令行界面。

              为了实现客户端可以上下翻找之前执行的命令以及cli客户端的美观,我们这里使用"github.com/c-BATa/go-prompt"库

              // 安装依赖
              go get "github.com/c-bata/go-prompt"
              

              我们需要新建cmd/main.go文件。

              主要实现:

              • 交互式命令行界面
              • SQL 命令解析和执行
              • 命令历史记录
              • 查询结果格式化
              • 错误处理和提示

              全部代码

              package main
              
              import (
                  "fmt"
                  "github.com/c-bata/go-prompt"
                  "os"
                  "strings"
                  "ziyi.db.com/internal/ast"
                  "ziyi.db.com/internal/lexer"
                  "ziyi.db.com/internal/parser"
                  "ziyi.db.com/internal/storage"
              )
              
              var history []string               // 存储命令历史
              var backend *storage.MemoryBackend // 存储引擎实例
              var historyIndex int               // 当前历史记录索引
              
              // 处理用户输入的命令
              func executor(t string) {
                  t = strings.TrimSpace(t)
                  if t == "" {
                     return
                  }
              
                  // 添加到历史记录
                  history = append(history, t)
                  historyIndex = len(history) // 重置历史记录索引
              
                  // 处理退出命令
                  if strings.ToLower(t) == "exit" {
                     fmt.Println("Bye!")
                     os.Exit(0)
                  }
              
                  // 创建词法分析器
                  l := lexer.NewLexer(strings.NewReader(t))
              
                  // 创建语法分析器
                  p := parser.NewParser(l)
              
                  // 解析SQL语句
                  stmt, err := p.ParseProgram()
                  if err != nil {
                     fmt.Printf("Parse error: %v\n", err)
                     return
                  }
              
                  // 执行SQL语句
                  for _, statement := range stmt.Statements {
                     switch s := statement.(type) {
                     case *ast.CreateTableStatement:
                        if err := backend.CreateTable(s); err != nil {
                           fmt.Printf("Error: %v\n", err)
                        } else {
                           fmt.Println("Table created successfully")
                        }
                     case *ast.InsertStatement:
                        if err := backend.Insert(s); err != nil {
                           fmt.Printf("Error: %v\n", err)
                        } else {
                           fmt.Println("1 row inserted")
                        }
                     case *ast.SelectStatement:
                        results, err := backend.Select(s)
                        if err != nil {
                           fmt.Printf("Error: %v\n", err)
                        } else {
                           // 计算每列的最大宽度
                           colWidths := make([]int, len(results.Columns))
                           for i, col := range results.Columns {
                              colWidths[i] = len(col.Name)
                           }
                           for _, row := range results.Rows {
                              for i, cell := range row {
                                 cellLen := len(cell.String())
                                 if cellLen > colWidths[i] {
                                    colWidths[i] = cellLen
                                 }
                              }
                           }
              
                           // 打印表头
                           fmt.Print("+")
                           for _, width := range colWidths {
                              fmt.Print(strings.Repeat("-", width+2))
                              fmt.Print("+")
                           }
                           fmt.Println()
              
                           // 打印列名
                           fmt.Print("|")
                           for i, col := range results.Columns {
                              fmt.Printf(" %-*s |", colWidths[i], col.Name)
                           }
                           fmt.Println()
              
                           // 打印分隔线
                           fmt.Print("+")
                           for _, width := range colWidths {
                              fmt.Print(strings.Repeat("-", width+2))
                              fmt.Print("+")
                           }
                           fmt.Println()
              
                           // 打印数据行
                           for _, row := range results.Rows {
                              fmt.Print("|")
                              for i, cell := range row {
                                 fmt.Printf(" %-*s |", colWidths[i], cell.String())
                              }
                              fmt.Println()
                           }
              
                           // 打印底部边框
                           fmt.Print("+")
                           for _, width := range colWidths {
                              fmt.Print(strings.Repeat("-", width+2))
                              fmt.Print("+")
                           }
                           fmt.Println()
              
                           // 打印行数统计
                           fmt.Printf("%d rows in set\n", len(results.Rows))
                        }
                     case *ast.UpdateStatement:
                        if err := backend.Update(s); err != nil {
                           fmt.Printf("Error: %v\n", err)
                        } else {
                           fmt.Println("Query OK, 1 row affected")
                        }
                     case *ast.DeleteStatement:
                        if err := backend.Delete(s); err != nil {
                           fmt.Printf("Error: %v\n", err)
                        } else {
                           fmt.Println("Query OK, 1 row affected")
                        }
                     case *ast.DropTableStatement:
                        if err := backend.DropTable(s); err != nil {
                           fmt.Printf("Error: %v\n", err)
                        } else {
                           fmt.Println("Table dropped successfully")
                        }
                     default:
                        fmt.Printf("Unsupported statement type: %T\n", s)
                     }
                  }
              }
              
              // 提供命令补全功能
              func completer(d prompt.Document) []prompt.Suggest {
                  s := []prompt.Suggest{}
                  return prompt.FilterHASPrefix(s, d.GetWordBeforeCursor(), true)
              }
              
              func main() {
                  // 初始化存储引擎
                  backend = storage.NewMemoryBackend()
                  historyIndex = 0
              
                  fmt.Println("Welcome to ZiyiDB!")
                  fmt.Println("Type your SQL commands (type 'exit' to quit)")
              
                  p := prompt.New(
                     executor,
                     completer,
                     prompt.OptionTitle("ZiyiDB: A Simple SQL Database"),
                     prompt.OptionPrefix("ziyidb> "),
                     prompt.OptionHistory(history),
                     prompt.OptionLivePrefix(func() (string, bool) {
                        return "ziyidb> ", true
                     }),
                     //实现方向键上下翻阅历史命令
                     // 上键绑定
                     prompt.OptionAddKeyBind(prompt.KeyBind{
                        Key: prompt.Up,
                        Fn: func(buf *prompt.Buffer) {
                           if historyIndex > 0 {
                              historyIndex--
                              buf.DeleteBeforeCursor(len(buf.Text()))
                              buf.InsertText(history[historyIndex], false, true)
                           }
                        },
                     }),
                     // 下键绑定
                     prompt.OptionAddKeyBind(prompt.KeyBind{
                        Key: prompt.Down,
                        Fn: func(buf *prompt.Buffer) {
                           if historyIndex < len(history)-1 {
                              historyIndex++
                              buf.DeleteBeforeCursor(len(buf.Text()))
                              buf.InsertText(history[historyIndex], false, true)
                           } else if historyIndex == len(history)-1 {
                              historyIndex++
                              buf.DeleteBeforeCursor(len(buf.Text()))
                           }
                        },
                     }),
                  )
                  p.Run()
              }
              

              整体测试

              编写完第一版后,现在我们来整体测试一下。

              测试脚本

              test_cast.sql:

              -- 1. 创建表
              CREATE TABLE users (id INT PRIMARY KEY,name TEXT ,age INT);
              
              
              -- 2. 插入用户数据
              INSERT INTO users VALUES (1, 'Alice', 20);
              INSERT INTO users VALUES (2, 'Bob', 25);
              INSERT INTO users VALUES (3, 'Charlie', 30);
              INSERT INTO users VALUES (4, 'David', 35);
              INSERT INTO users VALUES (5, 'Eve', 40);
              
              -- 3. 测试主键冲突
              INSERT INTO users VALUES (1, 'Tomas', 21);
              
              
              -- 4. 基本查询测试
              -- 4.1 查询所有数据
              SELECT * FROM users;
              
              
              -- 4.2 查询特定列
              SELECT id, name FROM users;
              
              -- 5. WHERE 子句测试
              SELECT * FROM users WHERE age > 25;
              SELECT * FROM users WHERE age < 30;
              
              -- 6. LIKE 操作符测试
              -- 6.1 基本模式匹配
              SELECT * FROM users WHERE name LIKE 'A%';  -- 以 A 开头
              SELECT * FROM users WHERE name LIKE '%e';  -- 以 e 结尾
              
              -- 6.2 转义字符测试
              INSERT INTO users VALUES (6, 'Bob%Smith', 45);
              SELECT * FROM users WHERE name LIKE 'Bob\%Smith';
              
              -- 7. 更新操作测试
              -- 7.1 更新单个字段
              UPDATE users SET age = 21 WHERE name = 'Alice';
              
              -- 7.2 更新多个字段
              UPDATE users SET name = 'Robert', age = 8 WHERE id = 2;
              
              
              -- 8. 删除操作测试
              DELETE FROM users WHERE age > 30;
              
              -- 9. 清理测试数据
              DROP TABLE users;
              
              -- 10. 验证表已删除
              SELECT * FROM users;    -- 应该失败
              
              
              todo::
                  1. 实现!= >= <=等运算符
                  2. 支持更多数据类型
                  3. 支持更多函数
                  4. 优化查询结果展示
                  5. 支持更多索引类型
                  6. 支持null值等
                  7. 支持数据落地本地文件
                  8. 支持事务操作等
              

              运行效果

              cd ZiyiDB
              go run cmd/main.go
              

              Go手写数据库ZiyiDB的实现

              参考文章:https://notes.eatonphil.com/database-basics.html

              到此这篇关于Go手写数据库ZiyiDB的实现的文章就介绍到这了,更多相关Go手写ZiyiDB内容请搜索编程客栈(www.devze.com)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程客栈(www.devze.com)!

              0

              上一篇:

              下一篇:

              精彩评论

              暂无评论...
              验证码 换一张
              取 消

              最新开发

              开发排行榜