Commit 627c9dc

mo khan <mo@mokhan.ca>
2025-09-24 18:28:04
feat: build simple sql formatter in oneshot
1 parent c68375f
src/formatter.rs
@@ -0,0 +1,634 @@
+use sqlparser::ast::*;
+
+pub fn format_statement(statement: &Statement, indent_level: usize) -> String {
+    match statement {
+        Statement::Query(query) => format_query(query, indent_level),
+        Statement::Insert {
+            table_name,
+            columns,
+            source,
+            ..
+        } => {
+            if let Some(source) = source {
+                format_insert(table_name, columns, source, indent_level)
+            } else {
+                format!("INSERT INTO {}", table_name)
+            }
+        }
+        Statement::Update {
+            table,
+            assignments,
+            selection,
+            ..
+        } => format_update(table, assignments, selection, indent_level),
+        Statement::Delete {
+            from, selection, ..
+        } => format_delete(from, selection, indent_level),
+        Statement::CreateTable { name, columns, .. } => {
+            format_create_table(name, columns, indent_level)
+        }
+        _ => statement.to_string(),
+    }
+}
+
+fn format_query(query: &Query, indent_level: usize) -> String {
+    let mut result = String::new();
+
+    if let Some(with) = &query.with {
+        result.push_str(&format_with(with, indent_level));
+        result.push('\n');
+    }
+
+    result.push_str(&format_set_expr(&query.body, indent_level));
+
+    if !query.order_by.is_empty() {
+        result.push('\n');
+        result.push_str(&format!("{}ORDER BY ", " ".repeat(indent_level)));
+        let order_items: Vec<String> = query
+            .order_by
+            .iter()
+            .map(|order| format_order_by_expr(order))
+            .collect();
+        result.push_str(&order_items.join(", "));
+    }
+
+    if let Some(limit) = &query.limit {
+        result.push('\n');
+        result.push_str(&format!(
+            "{}LIMIT {}",
+            " ".repeat(indent_level),
+            format_expr(limit)
+        ));
+    }
+
+    result
+}
+
+fn format_with(with: &With, indent_level: usize) -> String {
+    let mut result = format!("{}WITH", " ".repeat(indent_level));
+    if with.recursive {
+        result.push_str(" RECURSIVE");
+    }
+
+    for (i, cte) in with.cte_tables.iter().enumerate() {
+        if i == 0 {
+            result.push('\n');
+            result.push_str(&format!("  {}", format_cte(cte, indent_level + 2)));
+        } else {
+            result.push_str(",\n");
+            result.push_str(&format!("  {}", format_cte(cte, indent_level + 2)));
+        }
+    }
+
+    result
+}
+
+fn format_cte(cte: &Cte, indent_level: usize) -> String {
+    let mut result = cte.alias.name.value.clone();
+
+    if !cte.alias.columns.is_empty() {
+        result.push('(');
+        let columns: Vec<String> = cte
+            .alias
+            .columns
+            .iter()
+            .map(|col| col.value.clone())
+            .collect();
+        result.push_str(&columns.join(", "));
+        result.push(')');
+    }
+
+    result.push_str(" AS (\n");
+    result.push_str(&format_query(&cte.query, indent_level + 2));
+    result.push('\n');
+    result.push_str(&format!("{})", " ".repeat(indent_level)));
+
+    result
+}
+
+fn format_set_expr(set_expr: &SetExpr, indent_level: usize) -> String {
+    match set_expr {
+        SetExpr::Select(select) => format_select(select, indent_level),
+        SetExpr::Query(query) => {
+            let mut result = String::new();
+            result.push('(');
+            result.push('\n');
+            result.push_str(&format_query(query, indent_level + 2));
+            result.push('\n');
+            result.push_str(&format!("{})", " ".repeat(indent_level)));
+            result
+        }
+        SetExpr::SetOperation {
+            op,
+            set_quantifier,
+            left,
+            right,
+        } => {
+            let mut result = format_set_expr(left, indent_level);
+            result.push('\n');
+            result.push_str(&format!("{}{}", " ".repeat(indent_level), op));
+            result.push(' ');
+            result.push_str(&set_quantifier.to_string());
+            result.push('\n');
+            result.push_str(&format_set_expr(right, indent_level));
+            result
+        }
+        _ => set_expr.to_string(),
+    }
+}
+
+fn format_select(select: &Select, indent_level: usize) -> String {
+    let mut result = String::new();
+    let base_indent = " ".repeat(indent_level);
+    let item_indent = " ".repeat(indent_level + 2);
+
+    result.push_str(&format!("{}SELECT", base_indent));
+
+    if select.distinct.is_some() {
+        result.push_str(" DISTINCT");
+    }
+
+    for (i, item) in select.projection.iter().enumerate() {
+        if i == 0 {
+            result.push('\n');
+        } else {
+            result.push_str(",\n");
+        }
+        result.push_str(&format!("{}{}", item_indent, format_select_item(item)));
+    }
+
+    if !select.from.is_empty() {
+        result.push('\n');
+        result.push_str(&format!("{}FROM", base_indent));
+
+        for (i, table) in select.from.iter().enumerate() {
+            if i == 0 {
+                result.push('\n');
+                result.push_str(&format!(
+                    "{}{}",
+                    item_indent,
+                    format_table_with_joins(table, indent_level + 2)
+                ));
+            } else {
+                result.push_str(",\n");
+                result.push_str(&format!(
+                    "{}{}",
+                    item_indent,
+                    format_table_with_joins(table, indent_level + 2)
+                ));
+            }
+        }
+    }
+
+    if let Some(selection) = &select.selection {
+        result.push('\n');
+        result.push_str(&format!("{}WHERE {}", base_indent, format_expr(selection)));
+    }
+
+    match &select.group_by {
+        GroupByExpr::Expressions(exprs) if !exprs.is_empty() => {
+            result.push('\n');
+            result.push_str(&format!("{}GROUP BY ", base_indent));
+            let group_items: Vec<String> = exprs.iter().map(|expr| format_expr(expr)).collect();
+            result.push_str(&group_items.join(", "));
+        }
+        _ => {}
+    }
+
+    if let Some(having) = &select.having {
+        result.push('\n');
+        result.push_str(&format!("{}HAVING {}", base_indent, format_expr(having)));
+    }
+
+    result
+}
+
+fn format_select_item(item: &SelectItem) -> String {
+    match item {
+        SelectItem::UnnamedExpr(expr) => format_expr(expr),
+        SelectItem::ExprWithAlias { expr, alias } => {
+            format!("{} AS {}", format_expr(expr), alias.value)
+        }
+        SelectItem::QualifiedWildcard(object_name, _) => {
+            format!("{}.*", object_name)
+        }
+        SelectItem::Wildcard(_) => "*".to_string(),
+    }
+}
+
+fn format_table_with_joins(table: &TableWithJoins, indent_level: usize) -> String {
+    let mut result = format_table_factor(&table.relation);
+
+    for join in &table.joins {
+        result.push('\n');
+        result.push_str(&format!(
+            "{}{}",
+            " ".repeat(indent_level),
+            format_join(join)
+        ));
+    }
+
+    result
+}
+
+fn format_table_factor(table: &TableFactor) -> String {
+    match table {
+        TableFactor::Table { name, alias, .. } => {
+            let mut result = name.to_string();
+            if let Some(alias) = alias {
+                result.push_str(&format!(" AS {}", alias.name.value));
+            }
+            result
+        }
+        TableFactor::Derived {
+            subquery, alias, ..
+        } => {
+            let mut result = String::new();
+            result.push('(');
+            result.push('\n');
+            result.push_str(&format_query(subquery, 4));
+            result.push('\n');
+            result.push(')');
+            if let Some(alias) = alias {
+                result.push_str(&format!(" AS {}", alias.name.value));
+            }
+            result
+        }
+        _ => table.to_string(),
+    }
+}
+
+fn format_join(join: &Join) -> String {
+    let mut result = String::new();
+
+    match &join.join_operator {
+        JoinOperator::Inner(constraint) => {
+            result.push_str("INNER JOIN ");
+            result.push_str(&format_table_factor(&join.relation));
+            if let JoinConstraint::On(expr) = constraint {
+                result.push_str(&format!(" ON {}", format_expr(expr)));
+            }
+        }
+        JoinOperator::LeftOuter(constraint) => {
+            result.push_str("LEFT OUTER JOIN ");
+            result.push_str(&format_table_factor(&join.relation));
+            if let JoinConstraint::On(expr) = constraint {
+                result.push_str(&format!(" ON {}", format_expr(expr)));
+            }
+        }
+        JoinOperator::RightOuter(constraint) => {
+            result.push_str("RIGHT OUTER JOIN ");
+            result.push_str(&format_table_factor(&join.relation));
+            if let JoinConstraint::On(expr) = constraint {
+                result.push_str(&format!(" ON {}", format_expr(expr)));
+            }
+        }
+        JoinOperator::FullOuter(constraint) => {
+            result.push_str("FULL OUTER JOIN ");
+            result.push_str(&format_table_factor(&join.relation));
+            if let JoinConstraint::On(expr) = constraint {
+                result.push_str(&format!(" ON {}", format_expr(expr)));
+            }
+        }
+        _ => {
+            result.push_str("JOIN ");
+            result.push_str(&format_table_factor(&join.relation));
+        }
+    }
+
+    result
+}
+
+fn format_expr(expr: &Expr) -> String {
+    match expr {
+        Expr::Identifier(ident) => ident.value.clone(),
+        Expr::CompoundIdentifier(idents) => idents
+            .iter()
+            .map(|i| i.value.clone())
+            .collect::<Vec<_>>()
+            .join("."),
+        Expr::Value(value) => format_value(value),
+        Expr::BinaryOp { left, op, right } => {
+            format!("{} {} {}", format_expr(left), op, format_expr(right))
+        }
+        Expr::UnaryOp { op, expr } => {
+            format!("{} {}", op, format_expr(expr))
+        }
+        Expr::Cast {
+            expr, data_type, ..
+        } => {
+            format!("CAST({} AS {})", format_expr(expr), data_type)
+        }
+        Expr::Case {
+            operand,
+            conditions,
+            results,
+            else_result,
+        } => format_case_expr(operand, conditions, results, else_result),
+        Expr::Function(function) => format_function(function),
+        Expr::Subquery(query) => {
+            let mut result = String::new();
+            result.push('(');
+            result.push('\n');
+            result.push_str(&format_query(query, 4));
+            result.push('\n');
+            result.push(')');
+            result
+        }
+        Expr::InSubquery {
+            expr,
+            subquery,
+            negated,
+        } => {
+            let mut result = format_expr(expr);
+            if *negated {
+                result.push_str(" NOT");
+            }
+            result.push_str(" IN (");
+            result.push('\n');
+            result.push_str(&format_query(subquery, 4));
+            result.push('\n');
+            result.push(')');
+            result
+        }
+        Expr::Between {
+            expr,
+            negated,
+            low,
+            high,
+        } => {
+            let mut result = format_expr(expr);
+            if *negated {
+                result.push_str(" NOT");
+            }
+            result.push_str(&format!(
+                " BETWEEN {} AND {}",
+                format_expr(low),
+                format_expr(high)
+            ));
+            result
+        }
+        Expr::Like {
+            expr,
+            negated,
+            pattern,
+            ..
+        } => {
+            let mut result = format_expr(expr);
+            if *negated {
+                result.push_str(" NOT");
+            }
+            result.push_str(&format!(" LIKE {}", format_expr(pattern)));
+            result
+        }
+        Expr::IsNull(expr) => format!("{} IS NULL", format_expr(expr)),
+        Expr::IsNotNull(expr) => format!("{} IS NOT NULL", format_expr(expr)),
+        Expr::Nested(expr) => format!("({})", format_expr(expr)),
+        _ => expr.to_string(),
+    }
+}
+
+fn format_case_expr(
+    operand: &Option<Box<Expr>>,
+    conditions: &[Expr],
+    results: &[Expr],
+    else_result: &Option<Box<Expr>>,
+) -> String {
+    let mut result = String::new();
+
+    result.push_str("CASE");
+    if let Some(operand) = operand {
+        result.push_str(&format!(" {}", format_expr(operand)));
+    }
+
+    for (condition, case_result) in conditions.iter().zip(results.iter()) {
+        result.push_str(&format!(
+            "\n  WHEN {} THEN {}",
+            format_expr(condition),
+            format_expr(case_result)
+        ));
+    }
+
+    if let Some(else_result) = else_result {
+        result.push_str(&format!("\n  ELSE {}", format_expr(else_result)));
+    }
+
+    result.push_str("\nEND");
+    result
+}
+
+fn format_function(function: &Function) -> String {
+    let mut result = function.name.to_string().to_uppercase();
+    result.push('(');
+
+    if function.distinct {
+        result.push_str("DISTINCT ");
+    }
+
+    let args: Vec<String> = function
+        .args
+        .iter()
+        .map(|arg| match arg {
+            FunctionArg::Named { name, arg, .. } => {
+                format!("{} => {}", name.value, format_function_arg_expr(arg))
+            }
+            FunctionArg::Unnamed(arg) => format_function_arg_expr(arg),
+        })
+        .collect();
+
+    result.push_str(&args.join(", "));
+    result.push(')');
+
+    result
+}
+
+fn format_function_arg_expr(arg: &FunctionArgExpr) -> String {
+    match arg {
+        FunctionArgExpr::Expr(expr) => format_expr(expr),
+        FunctionArgExpr::QualifiedWildcard(name) => format!("{}.*", name),
+        FunctionArgExpr::Wildcard => "*".to_string(),
+    }
+}
+
+fn format_value(value: &Value) -> String {
+    match value {
+        Value::Number(n, _) => n.clone(),
+        Value::SingleQuotedString(s) => format!("'{}'", s),
+        Value::DoubleQuotedString(s) => format!("\"{}\"", s),
+        Value::Boolean(b) => b.to_string().to_uppercase(),
+        Value::Null => "NULL".to_string(),
+        _ => value.to_string(),
+    }
+}
+
+fn format_order_by_expr(order: &OrderByExpr) -> String {
+    let mut result = format_expr(&order.expr);
+
+    if let Some(asc) = order.asc {
+        if asc {
+            result.push_str(" ASC");
+        } else {
+            result.push_str(" DESC");
+        }
+    }
+
+    if let Some(nulls_first) = order.nulls_first {
+        if nulls_first {
+            result.push_str(" NULLS FIRST");
+        } else {
+            result.push_str(" NULLS LAST");
+        }
+    }
+
+    result
+}
+
+fn format_insert(
+    table_name: &ObjectName,
+    columns: &[Ident],
+    source: &Query,
+    indent_level: usize,
+) -> String {
+    let mut result = format!("{}INSERT INTO {}", " ".repeat(indent_level), table_name);
+
+    if !columns.is_empty() {
+        result.push(' ');
+        result.push('(');
+        let column_names: Vec<String> = columns.iter().map(|col| col.value.clone()).collect();
+        result.push_str(&column_names.join(", "));
+        result.push(')');
+    }
+
+    result.push('\n');
+    result.push_str(&format_query(source, indent_level));
+
+    result
+}
+
+fn format_update(
+    table: &TableWithJoins,
+    assignments: &[Assignment],
+    selection: &Option<Expr>,
+    indent_level: usize,
+) -> String {
+    let mut result = format!(
+        "{}UPDATE {}",
+        " ".repeat(indent_level),
+        format_table_with_joins(table, indent_level)
+    );
+
+    result.push('\n');
+    result.push_str(&format!("{}SET", " ".repeat(indent_level)));
+
+    for (i, assignment) in assignments.iter().enumerate() {
+        if i == 0 {
+            result.push('\n');
+        } else {
+            result.push_str(",\n");
+        }
+        result.push_str(&format!(
+            "{}  {} = {}",
+            " ".repeat(indent_level),
+            assignment
+                .id
+                .iter()
+                .map(|i| i.value.clone())
+                .collect::<Vec<_>>()
+                .join("."),
+            format_expr(&assignment.value)
+        ));
+    }
+
+    if let Some(selection) = selection {
+        result.push('\n');
+        result.push_str(&format!(
+            "{}WHERE {}",
+            " ".repeat(indent_level),
+            format_expr(selection)
+        ));
+    }
+
+    result
+}
+
+fn format_delete(from: &FromTable, selection: &Option<Expr>, indent_level: usize) -> String {
+    let mut result = format!("{}DELETE FROM", " ".repeat(indent_level));
+
+    match from {
+        FromTable::WithFromKeyword(tables) => {
+            for (i, table) in tables.iter().enumerate() {
+                if i == 0 {
+                    result.push('\n');
+                } else {
+                    result.push_str(",\n");
+                }
+                result.push_str(&format!(
+                    "{}  {}",
+                    " ".repeat(indent_level),
+                    format_table_with_joins(table, indent_level + 2)
+                ));
+            }
+        }
+        FromTable::WithoutKeyword(tables) => {
+            for (i, table) in tables.iter().enumerate() {
+                if i == 0 {
+                    result.push(' ');
+                } else {
+                    result.push_str(", ");
+                }
+                result.push_str(&format_table_with_joins(table, indent_level));
+            }
+        }
+    }
+
+    if let Some(selection) = selection {
+        result.push('\n');
+        result.push_str(&format!(
+            "{}WHERE {}",
+            " ".repeat(indent_level),
+            format_expr(selection)
+        ));
+    }
+
+    result
+}
+
+fn format_create_table(name: &ObjectName, columns: &[ColumnDef], indent_level: usize) -> String {
+    let mut result = format!("{}CREATE TABLE {} (", " ".repeat(indent_level), name);
+
+    for (i, column) in columns.iter().enumerate() {
+        if i == 0 {
+            result.push('\n');
+        } else {
+            result.push_str(",\n");
+        }
+        result.push_str(&format!(
+            "{}  {} {}",
+            " ".repeat(indent_level),
+            column.name.value,
+            column.data_type
+        ));
+
+        for option in &column.options {
+            match option.option {
+                ColumnOption::NotNull => result.push_str(" NOT NULL"),
+                ColumnOption::Null => result.push_str(" NULL"),
+                ColumnOption::Default(ref expr) => {
+                    result.push_str(&format!(" DEFAULT {}", format_expr(expr)))
+                }
+                ColumnOption::Unique { is_primary, .. } => {
+                    if is_primary {
+                        result.push_str(" PRIMARY KEY");
+                    } else {
+                        result.push_str(" UNIQUE");
+                    }
+                }
+                _ => {}
+            }
+        }
+    }
+
+    result.push('\n');
+    result.push_str(&format!("{})", " ".repeat(indent_level)));
+    result
+}
src/main.rs
@@ -1,3 +1,65 @@
+use anyhow::{Context, Result};
+use sqlparser::dialect::GenericDialect;
+use sqlparser::parser::Parser;
+use std::io::{self, Read};
+
+mod formatter;
+
+use formatter::format_statement;
+
 fn main() {
-    println!("Hello, world!");
+    match run() {
+        Ok(()) => std::process::exit(0),
+        Err(_) => std::process::exit(1),
+    }
+}
+
+fn run() -> Result<()> {
+    let mut input = String::new();
+    io::stdin()
+        .read_to_string(&mut input)
+        .context("Failed to read from stdin")?;
+
+    let input = input.trim();
+    if input.is_empty() {
+        return Ok(());
+    }
+
+    let dialect = GenericDialect {};
+    match Parser::parse_sql(&dialect, input) {
+        Ok(statements) => {
+            let mut output = Vec::new();
+            for (i, statement) in statements.iter().enumerate() {
+                if i > 0 {
+                    output.push("\n".to_string());
+                }
+                output.push(format_statement(statement, 0));
+            }
+            let formatted = output.join("");
+            let formatted = ensure_semicolon_and_cleanup(&formatted);
+            println!("{}", formatted);
+            Ok(())
+        }
+        Err(_) => {
+            print!("{}", input);
+            if !input.ends_with('\n') {
+                println!();
+            }
+            Err(anyhow::anyhow!("Parse error"))
+        }
+    }
+}
+
+fn ensure_semicolon_and_cleanup(sql: &str) -> String {
+    let mut result = sql.trim_end().to_string();
+
+    if !result.ends_with(';') {
+        result.push(';');
+    }
+
+    result
+        .lines()
+        .map(|line| line.trim_end())
+        .collect::<Vec<_>>()
+        .join("\n")
 }
Cargo.lock
@@ -0,0 +1,32 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 4
+
+[[package]]
+name = "anyhow"
+version = "1.0.100"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
+
+[[package]]
+name = "log"
+version = "0.4.28"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432"
+
+[[package]]
+name = "sqlfmt"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "sqlparser",
+]
+
+[[package]]
+name = "sqlparser"
+version = "0.45.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0"
+dependencies = [
+ "log",
+]
Cargo.toml
@@ -4,3 +4,5 @@ version = "0.1.0"
 edition = "2024"
 
 [dependencies]
+sqlparser = "0.45"
+anyhow = "1.0"