diff --git a/CHANGELOG.md b/CHANGELOG.md index 271e877..ed373fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- added: Added named params support - added: Custom type extensions. See: `Exqlite.TypeExtensions`. - changed: Update sqlite to `3.50.1`. diff --git a/c_src/sqlite3_nif.c b/c_src/sqlite3_nif.c index 1984e1d..812de16 100644 --- a/c_src/sqlite3_nif.c +++ b/c_src/sqlite3_nif.c @@ -561,6 +561,30 @@ exqlite_bind_parameter_count(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[] return enif_make_int(env, bind_parameter_count); } +/// +/// Get the bind parameter index +/// +ERL_NIF_TERM +exqlite_bind_parameter_index(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + statement_t* statement; + if (!enif_get_resource(env, argv[0], statement_type, (void**)&statement)) { + return raise_badarg(env, argv[0]); + } + + ERL_NIF_TERM eos = enif_make_int(env, 0); + ErlNifBinary name; + + if (!enif_inspect_iolist_as_binary(env, enif_make_list2(env, argv[1], eos), &name)) { + return raise_badarg(env, argv[1]); + } + + statement_acquire_lock(statement); + int index = sqlite3_bind_parameter_index(statement->statement, (const char*)name.data); + statement_release_lock(statement); + return enif_make_int(env, index); +} + /// /// Binds a text parameter /// @@ -1423,6 +1447,7 @@ static ErlNifFunc nif_funcs[] = { {"prepare", 2, exqlite_prepare, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"reset", 1, exqlite_reset, ERL_NIF_DIRTY_JOB_CPU_BOUND}, {"bind_parameter_count", 1, exqlite_bind_parameter_count}, + {"bind_parameter_index", 2, exqlite_bind_parameter_index}, {"bind_text", 3, exqlite_bind_text}, {"bind_blob", 3, exqlite_bind_blob}, {"bind_integer", 3, exqlite_bind_integer}, diff --git a/lib/exqlite/sqlite3.ex b/lib/exqlite/sqlite3.ex index 906b6c4..f5a061d 100644 --- a/lib/exqlite/sqlite3.ex +++ b/lib/exqlite/sqlite3.ex @@ -158,6 +158,12 @@ defmodule Exqlite.Sqlite3 do iex> Sqlite3.step(conn, stmt) {:row, [42, 3.14, "Alice", <<0, 0, 0>>, nil]} + iex> {:ok, conn} = Sqlite3.open(":memory:", [:readonly]) + iex> {:ok, stmt} = Sqlite3.prepare(conn, "SELECT :42, @pi, $name, @blob, :null") + iex> Sqlite3.bind(stmt, %{":42" => 42, "@pi" => 3.14, "$name" => "Alice", :"@blob" => {:blob, <<0, 0, 0>>}, ~c":null" => nil}) + iex> Sqlite3.step(conn, stmt) + {:row, [42, 3.14, "Alice", <<0, 0, 0>>, nil]} + iex> {:ok, conn} = Sqlite3.open(":memory:", [:readonly]) iex> {:ok, stmt} = Sqlite3.prepare(conn, "SELECT ?") iex> Sqlite3.bind(stmt, [42, 3.14, "Alice"]) @@ -174,10 +180,13 @@ defmodule Exqlite.Sqlite3 do ** (ArgumentError) unsupported type: #PID<0.0.0> """ - @spec bind(statement, [bind_value] | nil) :: :ok + @spec bind( + statement, + [bind_value] | %{optional(String.t()) => bind_value} | nil + ) :: :ok def bind(stmt, nil), do: bind(stmt, []) - def bind(stmt, args) do + def bind(stmt, args) when is_list(args) do params_count = bind_parameter_count(stmt) args_count = length(args) @@ -188,8 +197,40 @@ defmodule Exqlite.Sqlite3 do end end - # credo:disable-for-next-line Credo.Check.Refactor.CyclomaticComplexity + def bind(stmt, args) when is_map(args) do + params_count = bind_parameter_count(stmt) + args_count = map_size(args) + + if args_count == params_count do + bind_all_named(Map.to_list(args), stmt) + else + raise ArgumentError, + "expected #{params_count} named arguments, got #{args_count}: #{inspect(Map.keys(args))}" + end + end + defp bind_all([param | params], stmt, idx) do + do_bind(stmt, idx, param) + bind_all(params, stmt, idx + 1) + end + + defp bind_all([], _stmt, _idx), do: :ok + + defp bind_all_named([{name, param} | named_params], stmt) do + idx = Sqlite3NIF.bind_parameter_index(stmt, to_string(name)) + + if idx == 0 do + raise ArgumentError, "unknown named parameter: #{inspect(name)}" + end + + do_bind(stmt, idx, param) + bind_all_named(named_params, stmt) + end + + defp bind_all_named([], _stmt), do: :ok + + # credo:disable-for-next-line Credo.Check.Refactor.CyclomaticComplexity + defp do_bind(stmt, idx, param) do case convert(param) do i when is_integer(i) -> bind_integer(stmt, idx, i) f when is_float(f) -> bind_float(stmt, idx, f) @@ -202,12 +243,8 @@ defmodule Exqlite.Sqlite3 do {:blob, b} when is_list(b) -> bind_blob(stmt, idx, IO.iodata_to_binary(b)) _other -> raise ArgumentError, "unsupported type: #{inspect(param)}" end - - bind_all(params, stmt, idx + 1) end - defp bind_all([], _stmt, _idx), do: :ok - @spec columns(db(), statement()) :: {:ok, [binary()]} | {:error, reason()} def columns(conn, statement), do: Sqlite3NIF.columns(conn, statement) diff --git a/lib/exqlite/sqlite3_nif.ex b/lib/exqlite/sqlite3_nif.ex index 79874fe..d9416a2 100644 --- a/lib/exqlite/sqlite3_nif.ex +++ b/lib/exqlite/sqlite3_nif.ex @@ -72,6 +72,9 @@ defmodule Exqlite.Sqlite3NIF do @spec bind_parameter_count(statement) :: integer def bind_parameter_count(_stmt), do: :erlang.nif_error(:not_loaded) + @spec bind_parameter_index(statement, String.t()) :: integer + def bind_parameter_index(_stmt, _name), do: :erlang.nif_error(:not_loaded) + @spec bind_text(statement, non_neg_integer, String.t()) :: integer() def bind_text(_stmt, _index, _text), do: :erlang.nif_error(:not_loaded) diff --git a/test/exqlite/query_test.exs b/test/exqlite/query_test.exs index 34a0613..cc0e0c8 100644 --- a/test/exqlite/query_test.exs +++ b/test/exqlite/query_test.exs @@ -31,6 +31,16 @@ defmodule Exqlite.QueryTest do assert Enum.to_list(columns["y"]) == ["a", "b", "c"] end + test "named params", %{conn: conn} do + assert Exqlite.query!(conn, "select :a, @b, $c", %{":a" => 1, "@b" => 2, "$c" => 3}) == + %Exqlite.Result{ + command: :execute, + columns: [":a", "@b", "$c"], + rows: [[1, 2, 3]], + num_rows: 1 + } + end + defp create_conn!(_) do opts = [database: "#{Temp.path!()}.db"] diff --git a/test/exqlite/sqlite3_test.exs b/test/exqlite/sqlite3_test.exs index 22f4b98..7cd7e07 100644 --- a/test/exqlite/sqlite3_test.exs +++ b/test/exqlite/sqlite3_test.exs @@ -298,6 +298,55 @@ defmodule Exqlite.Sqlite3Test do Sqlite3.bind(statement, [other_tz]) end end + + test "binds named parameters" do + {:ok, conn} = Sqlite3.open(":memory:") + + {:ok, statement} = + Sqlite3.prepare(conn, "select :42, @pi, :name, $👋, :blob, :null") + + :ok = + Sqlite3.bind(statement, %{ + ":42" => 42, + "@pi" => 3.14, + :":name" => "Alice", + "$👋" => "👋", + ":blob" => {:blob, <<0, 1, 2>>}, + ~c":null" => nil + }) + + assert {:row, [42, 3.14, "Alice", "👋", <<0, 1, 2>>, nil]} = + Sqlite3.step(conn, statement) + end + + test "handles repeating named parameters" do + {:ok, conn} = Sqlite3.open(":memory:") + + {:ok, statement} = + Sqlite3.prepare(conn, "select :name, :name, :name") + + :ok = + Sqlite3.bind(statement, %{ + ":name" => "Alice" + }) + + assert {:row, ["Alice", "Alice", "Alice"]} = Sqlite3.step(conn, statement) + end + + test "raises an error when too few or too many named parameters" do + {:ok, conn} = Sqlite3.open(":memory:") + + {:ok, statement} = + Sqlite3.prepare(conn, "select :name, :age") + + assert_raise ArgumentError, ~r"expected 2 named arguments, got 1", fn -> + Sqlite3.bind(statement, %{":name" => "Alice"}) + end + + assert_raise ArgumentError, ~r"expected 2 named arguments, got 3", fn -> + Sqlite3.bind(statement, %{":name" => "Alice", ":age" => 30, ":extra" => "value"}) + end + end end describe ".bind_text/3" do