notex.nvim/lua/notex/query/builder.lua

370 lines
10 KiB
Lua
Raw Permalink Normal View History

2025-10-05 20:16:33 -04:00
-- SQL query builder module
local M = {}
local utils = require('notex.utils')
-- Build SQL query from parsed query object
function M.build_sql(parsed_query, options)
options = options or {}
local query = {
select = "",
from = "",
where = "",
group_by = "",
order_by = "",
limit = "",
params = {}
}
-- Build SELECT clause
query.select = M.build_select_clause(parsed_query, options)
-- Build FROM clause
query.from = M.build_from_clause(parsed_query, options)
-- Build WHERE clause
query.where = M.build_where_clause(parsed_query, options)
-- Build GROUP BY clause
query.group_by = M.build_group_by_clause(parsed_query, options)
-- Build ORDER BY clause
query.order_by = M.build_order_by_clause(parsed_query, options)
-- Build LIMIT clause
query.limit = M.build_limit_clause(parsed_query, options)
-- Combine all clauses
local sql = M.combine_clauses(query)
return sql, query.params
end
-- Build SELECT clause
function M.build_select_clause(parsed_query, options)
local select_fields = {"d.*"}
-- Add property aggregates if needed
if parsed_query.group_by then
table.insert(select_fields, "COUNT(p.id) as document_count")
table.insert(select_fields, "GROUP_CONCAT(p.value) as aggregated_values")
else
-- Add property values for filtering
for field, _ in pairs(parsed_query.filters) do
table.insert(select_fields, string.format("(SELECT p.value FROM properties p WHERE p.document_id = d.id AND p.key = '%s') as %s", field, field))
end
end
if options.count_only then
select_fields = {"COUNT(DISTINCT d.id) as total_count"}
end
return "SELECT " .. table.concat(select_fields, ", ")
end
-- Build FROM clause
function M.build_from_clause(parsed_query, options)
local from_parts = {"documents d"}
-- Add properties join if we have filters or conditions
if next(parsed_query.filters) ~= nil or parsed_query.conditions then
table.insert(from_parts, "LEFT JOIN properties p ON d.id = p.document_id")
end
return "FROM " .. table.concat(from_parts, " ")
end
-- Build WHERE clause
function M.build_where_clause(parsed_query, options)
local where_conditions = {}
local params = {}
-- Add filter conditions
for field, value in pairs(parsed_query.filters) do
local condition, param = M.build_filter_condition(field, value)
table.insert(where_conditions, condition)
if param then
for key, val in pairs(param) do
params[key] = val
end
end
end
-- Add parsed conditions
if parsed_query.conditions then
local condition, param = M.build_conditions(parsed_query.conditions)
if condition then
table.insert(where_conditions, condition)
if param then
for key, val in pairs(param) do
params[key] = val
end
end
end
end
if #where_conditions == 0 then
return "", {}
end
return "WHERE " .. table.concat(where_conditions, " AND "), params
end
-- Build filter condition
function M.build_filter_condition(field, value)
local param_name = field:gsub("[^%w]", "_") .. "_filter"
if type(value) == "table" then
-- Handle array values
local placeholders = {}
for i = 1, #value do
local item_param = param_name .. "_" .. i
table.insert(placeholders, ":" .. item_param)
end
return string.format("(p.key = '%s' AND p.value IN (%s))", field, table.concat(placeholders, ", "))
else
-- Handle single value
return string.format("(p.key = '%s' AND p.value = :%s)", field, param_name)
end
end
-- Build conditions from parsed condition tree
function M.build_conditions(conditions)
if conditions.type == "comparison" then
return M.build_comparison_condition(conditions)
elseif conditions.type == "existence" then
return M.build_existence_condition(conditions)
elseif conditions.clauses then
return M.build_logical_condition(conditions)
end
return nil, nil
end
-- Build comparison condition
function M.build_comparison_condition(condition)
local field = condition.field
local operator = condition.operator
local value = condition.value
local negated = condition.negated
local param_name = field:gsub("[^%w]", "_") .. "_comp"
local sql_condition
local params = {}
-- Handle special operators
if operator == "CONTAINS" then
sql_condition = string.format("p.key = '%s' AND p.value LIKE :%s", field, param_name)
params[param_name] = "%" .. value .. "%"
elseif operator == "STARTS_WITH" then
sql_condition = string.format("p.key = '%s' AND p.value LIKE :%s", field, param_name)
params[param_name] = value .. "%"
elseif operator == "ENDS_WITH" then
sql_condition = string.format("p.key = '%s' AND p.value LIKE :%s", field, param_name)
params[param_name] = "%" .. value
elseif operator == "INCLUDES" then
sql_condition = string.format("p.key = '%s' AND p.value LIKE :%s", field, param_name)
params[param_name] "%" .. value .. "%"
elseif operator == "BEFORE" then
sql_condition = string.format("p.key = '%s' AND p.value < :%s", field, param_name)
params[param_name] = value
elseif operator == "AFTER" then
sql_condition = string.format("p.key = '%s' AND p.value > :%s", field, param_name)
params[param_name] = value
elseif operator == "WITHIN" then
-- Handle relative time
if type(value) == "table" and value.type == "relative_time" then
local time_value = M.calculate_relative_time(value)
sql_condition = string.format("p.key = '%s' AND p.value >= :%s", field, param_name)
params[param_name] = time_value
else
sql_condition = string.format("p.key = '%s' AND p.value >= :%s", field, param_name)
params[param_name] = value
end
else
-- Handle standard comparison operators
local op_map = {
["="] = "=",
["!="] = "!=",
[">"] = ">",
["<"] = "<",
[">="] = ">=",
["<="] = "<="
}
local sql_op = op_map[operator] or "="
sql_condition = string.format("p.key = '%s' AND p.value %s :%s", field, sql_op, param_name)
params[param_name] = value
end
if negated then
sql_condition = "NOT (" .. sql_condition .. ")"
end
return sql_condition, params
end
-- Build existence condition
function M.build_existence_condition(condition)
local field = condition.field
local negated = condition.negated
local sql_condition = string.format("EXISTS (SELECT 1 FROM properties p2 WHERE p2.document_id = d.id AND p2.key = '%s')", field)
if negated then
sql_condition = "NOT " .. sql_condition
end
return sql_condition, {}
end
-- Build logical condition (AND/OR)
function M.build_logical_condition(conditions)
local clause_parts = {}
local all_params = {}
for _, clause in ipairs(conditions.clauses) do
local clause_sql, clause_params = M.build_conditions(clause)
if clause_sql then
table.insert(clause_parts, clause_sql)
if clause_params then
for key, value in pairs(clause_params) do
all_params[key] = value
end
end
end
end
if #clause_parts == 0 then
return nil, nil
end
local logical_op = conditions.type:upper()
local sql_condition = "(" .. table.concat(clause_parts, " " .. logical_op .. " ") .. ")"
return sql_condition, all_params
end
-- Build GROUP BY clause
function M.build_group_by_clause(parsed_query, options)
if not parsed_query.group_by then
return ""
end
local group_fields = {}
if parsed_query.group_by == "property_key" then
table.insert(group_fields, "p.key")
else
table.insert(group_fields, "p." .. parsed_query.group_by)
end
return "GROUP BY " .. table.concat(group_fields, ", ")
end
-- Build ORDER BY clause
function M.build_order_by_clause(parsed_query, options)
if not parsed_query.order_by then
return "ORDER BY d.updated_at DESC"
end
local field = parsed_query.order_by.field
local direction = parsed_query.order_by.direction or "ASC"
-- Map field names to columns
local field_map = {
created_at = "d.created_at",
updated_at = "d.updated_at",
file_path = "d.file_path",
title = "CASE WHEN p.key = 'title' THEN p.value END"
}
local column = field_map[field] or "p." .. field
return string.format("ORDER BY %s %s", column, direction)
end
-- Build LIMIT clause
function M.build_limit_clause(parsed_query, options)
if not parsed_query.limit then
return ""
end
return "LIMIT " .. parsed_query.limit
end
-- Combine all SQL clauses
function M.combine_clauses(query)
local parts = {}
table.insert(parts, query.select)
table.insert(parts, query.from)
if query.where ~= "" then
table.insert(parts, query.where)
end
if query.group_by ~= "" then
table.insert(parts, query.group_by)
end
if query.order_by ~= "" then
table.insert(parts, query.order_by)
end
if query.limit ~= "" then
table.insert(parts, query.limit)
end
return table.concat(parts, "\n")
end
-- Calculate relative time
function M.calculate_relative_time(relative_time)
local current_time = os.time()
local amount = relative_time.amount
local unit = relative_time.unit
local seconds = 0
if unit == "s" then
seconds = amount
elseif unit == "m" then
seconds = amount * 60
elseif unit == "h" then
seconds = amount * 3600
elseif unit == "d" then
seconds = amount * 86400
elseif unit == "w" then
seconds = amount * 604800
elseif unit == "m" then -- month (approximate)
seconds = amount * 2592000
elseif unit == "y" then -- year (approximate)
seconds = amount * 31536000
end
return os.date("%Y-%m-%d", current_time - seconds)
end
-- Build count query
function M.build_count_query(parsed_query)
local options = { count_only = true }
local sql, params = M.build_sql(parsed_query, options)
return sql, params
end
-- Validate built SQL
function M.validate_sql(sql)
if not sql or sql == "" then
return false, "Empty SQL query"
end
-- Basic SQL injection prevention
if sql:match(";") or sql:match("DROP") or sql:match("DELETE") or sql:match("UPDATE") or sql:match("INSERT") then
return false, "Potentially unsafe SQL detected"
end
return true, "SQL query is valid"
end
return M