Module:set utilities

From Wiktionary, the free dictionary
Jump to navigation Jump to search

--[[
-------------------------------------------------------------------------------------
-- This module includes a number of functions for dealing with Lua tables as sets. --
-- It is a meta-module, meant to be called from other Lua modules, and should      --
-- not be called directly from #invoke.                                            --
-------------------------------------------------------------------------------------
--]]

local export = {}

local libraryUtil = require("libraryUtil")
local table = table

local checkType = libraryUtil.checkType
local checkTypeMulti = libraryUtil.checkTypeMulti
local concat = table.concat
local format = string.format
local getmetatable = getmetatable
local insert = table.insert
local ipairs = ipairs
local next = next
local pairs = pairs
local rawequal = rawequal
local rawget = rawget
local setmetatable = setmetatable
local sort = table.sort
local type = type

local infinity = math.huge

local function _check(funcName, expectType)
	if type(expectType) == "string" then
		return function(argIndex, arg, nilOk)
			checkType(funcName, argIndex, arg, expectType, nilOk)
		end
	else
		return function(argIndex, arg, expectType, nilOk)
			if type(expectType) == "table" then
				if not nilOk or arg ~= nil then
					-- checkTypeMulti() doesn't accept a fifth `nilOk` argument, unlike the other check functions.
					checkTypeMulti(funcName, argIndex, arg, expectType)
				end
			else
				checkType(funcName, argIndex, arg, expectType, nilOk)
			end
		end
	end
end

--[==[
Convert `list` (a table with a list of values) into a set (a table where those values are keys instead). This is a useful
way to create a fast lookup table, since looking up a table key is much, much faster than iterating over the whole list
to see if it contains a given value.

By default, each item is given the value true. If the optional parameter `value` is specified, then it is used as the
fixed value for every item.
]==]
function export.list_to_set(list, value, ...)
	checkType("list_to_set", 1, list, "table")
	local set, i = {}, 0
	if value == nil then
		value = true
		-- Comment the following out for now so we don't have to define is_callable().
	-- elseif is_callable(value) then
	-- 	-- Separate loop avoids an "is callable" lookup each iteration.
	-- 	while true do
	-- 		i = i + 1
	-- 		local item = list[i]
	-- 		if item == nil then
	-- 			return set
	-- 		end
	-- 		set[item] = value(item, ...)
	-- 	end
	end
	while true do
		i = i + 1
		local item = list[i]
		if item == nil then
			return set
		end
		set[item] = value
	end
end


--[==[
General set intersection.
]==]
function export.intersection(set, ...)
	local result = {}
	checkType("intersection", 1, set, "table")
	for key, _ in pairs(set) do
		result[key] = true
	end
	for i = 1, select("#", ...) do
		local this_set = select(i, ...)
		checkType("intersection", i + 1, this_set, "table")
		for key, _ in pairs(result) do
			if not this_set[key] then
				-- See https://stackoverflow.com/questions/6167555/how-can-i-safely-iterate-a-lua-table-while-keys-are-being-removed
				-- It is safe to modify or remove a key while iterating over the table.
				result[key] = nil
			end
		end
	end
	return result
end

function export.union(...)
	local union = {}
	for i = 1, select("#", ...) do
		local this_set = select(i, ...)
		checkType("union", i, this_set, "table")
		for key, _ in pairs(this_set) do
			union[key] = true
		end
	end
	return union
end

function export.difference(set1, set2)
	checkType("difference", 1, set1, "table")
	checkType("difference", 2, set2, "table")
	local diff = {}
	for key, _ in pairs(set1) do
		if not set2[key] then
			diff[key] = true
		end
	end
	return diff
end