Medial Code Documentation
Loading...
Searching...
No Matches
lua.h
Go to the documentation of this file.
1
29#ifndef DMLC_LUA_H_
30#define DMLC_LUA_H_
31
32extern "C" {
33#include <lua.h>
34#include <luaT.h>
35#include <lualib.h>
36}
37
38#include <string>
39#include <stdexcept>
40#include <tuple>
41#include <mutex>
42#include <memory>
43#include <vector>
44#include <utility>
45#include <algorithm>
46#include <unordered_map>
47#include <type_traits>
48
49#include "./base.h"
50#include "./logging.h"
51#include "./thread_local.h"
52
53namespace dmlc {
54
55// forward declare torch state
56class LuaState;
57
58namespace lua_stack {
59template<typename T>
60struct Handler;
61};
62
64class LuaRef {
65 public:
67 LuaRef() = default;
72 inline LuaRef(LuaRef&& other); // NOLINT(*)
77 inline LuaRef(const LuaRef& other); // NOLINT(*)
83 inline LuaRef& operator=(LuaRef&& other);
89 inline LuaRef& operator=(const LuaRef& other);
91 inline ~LuaRef();
96 inline void swap(LuaRef& other); // NOLINT(*)
103 template<typename T>
104 inline T Get() const;
116 template<typename T>
117 inline T* GetUDataPtr() const;
119 inline bool is_nil() const;
126 template<typename... Args>
127 inline LuaRef operator()(Args&& ...args) const;
134 inline LuaRef operator[](const std::string& key) const;
142 inline LuaRef operator[](size_t index) const;
150 template<typename T>
151 inline LuaRef& SetField(const std::string& key, const T& value); // NOLINT(*)
159 inline void SetByPopStack_(LuaState* s);
160
161 private:
162 // friend with luastate
163 friend struct lua_stack::Handler<LuaRef>;
164 friend class LuaState;
165 friend std::ostream &operator<<(std::ostream &os, const LuaRef &r);
167 LuaState* state_{nullptr};
169 int ref_;
170};
171
173class LuaState {
174 public:
176 enum Option {
177 kNoThreadProtect,
178 kThreadLocal,
179 kLocking,
180 };
182 inline ~LuaState();
189 inline LuaRef Eval(const char* lua_code);
196 inline LuaRef Eval(const std::string& lua_code) {
197 return this->Eval(lua_code.c_str());
198 }
206 template<typename T>
207 inline LuaRef Convert(const T& value);
213 inline LuaRef operator[](const std::string& key);
219 inline void SetGlobalField(const std::string& key, const LuaRef& value);
228 static inline LuaState* ThreadLocalState();
248 static inline LuaState* Create_(Option option);
249
258 template<typename F>
259 inline void PRun_(F f);
264 inline bool SameLuaState(lua_State *L) const {
265 return L_ == L;
266 }
267
268 protected:
269 struct StackReset;
270 friend class LuaRef;
271 friend struct ThreadLocalStore<LuaState>;
275 inline LuaState();
276
278 Option option_{kThreadLocal};
280 lua_State* L_;
282 std::mutex mutex_;
283};
284
285// implementations after this line
287
288#define LUA_CALL(x) \
289 if ((x)) { \
290 LOG(FATAL) << "Lua Call Error:" << lua_tostring(L, -1); \
291 }
292
304namespace lua_stack {
305inline int lua_abs_index(lua_State* L, int index) {
306 if (index > 0 || index <= LUA_REGISTRYINDEX) return index;
307 return lua_gettop(L) + index + 1;
308}
309
310template<typename T>
311struct Handler;
312
313template<typename T>
314struct NumberHandler {
315 static inline T Get(lua_State* L, int index, LuaState* s) {
316 CHECK_EQ(lua_type(L, index), LUA_TNUMBER)
317 << "Attempt to get number but type is \'"
318 << lua_typename(L, lua_type(L, index)) << '\'';
319 if (std::is_integral<T>::value) {
320 return static_cast<T>(lua_tointeger(L, index));
321 } else {
322 return static_cast<T>(lua_tonumber(L, index));
323 }
324 }
325 static inline void Push(lua_State* L, const T& v) {
326 if (std::is_integral<T>::value) {
327 lua_pushinteger(L, static_cast<lua_Integer>(v));
328 } else {
329 lua_pushnumber(L, static_cast<lua_Number>(v));
330 }
331 }
332};
333
334template<typename ContainerType>
335struct MapHandler {
336 using K = typename ContainerType::key_type;
337 using V = typename ContainerType::mapped_type;
338 static inline ContainerType Get(lua_State* L, int index, LuaState* s) {
339 ContainerType ret;
340 CHECK(lua_istable(L, index))
341 << "Expected a table but get "
342 << lua_typename(L, lua_type(L, index)) << '\'';
343 int tid = lua_abs_index(L, index);
344 lua_pushnil(L);
345 while (lua_next(L, -2)) {
346 ret[Handler<K>::Get(L, -2, s)] = Handler<V>::Pop(L, -1, s);
347 lua_pop(L, 1);
348 }
349 lua_settop(L, tid);
350 return ret;
351 }
352 static inline void Push(lua_State* L, const ContainerType& v) {
353 lua_createtable(L, v.size(), 0);
354 for (const auto& kv : v) {
355 Handler<K>::Push(L, kv.first);
356 Handler<V>::Push(L, kv.second);
357 lua_settable(L, -3);
358 }
359 }
360};
361
362struct UndefinedHandler {
363};
364
365template<typename T>
366struct Handler
367 : public std::conditional<std::is_arithmetic<T>::value,
368 NumberHandler<T>,
369 UndefinedHandler>::type {
370};
371
372template<>
373struct Handler<std::string> {
374 static inline std::string Get(lua_State* L, int index, LuaState* s) {
375 CHECK_EQ(lua_type(L, index), LUA_TSTRING);
376 return std::string(lua_tostring(L, index));
377 }
378 static inline void Push(lua_State* L, const std::string& v) {
379 lua_pushstring(L, v.c_str());
380 }
381};
382
383template<typename T>
384struct Handler<std::vector<T> > {
385 static inline std::vector<T> Get(lua_State* L, int index, LuaState* s) {
386 std::vector<T> ret;
387 CHECK(lua_istable(L, index))
388 << "Expected a table but get "
389 << lua_typename(L, lua_type(L, index)) << '\'';
390 int tid = lua_abs_index(L, index);
391 lua_pushnil(L);
392 while (lua_next(L, tid)) {
393 CHECK_EQ(Handler<size_t>::Get(L, -2, s), ret.size() + 1)
394 << "Target table is not an array";
395 ret.push_back(Handler<T>::Get(L, -1, s));
396 lua_pop(L, 1);
397 }
398 lua_settop(L, tid);
399 return ret;
400 }
401 static inline void Push(lua_State* L, const std::vector<T>& v) {
402 lua_createtable(L, v.size(), 0);
403 for (size_t i = 0; i < v.size(); ++i) {
404 Handler<T>::Push(L, v[i]);
405 lua_rawseti(L, -2, i + 1);
406 }
407 }
408};
409
410template<typename K, typename V>
411struct Handler<std::unordered_map<K, V> >
412 : public MapHandler<std::unordered_map<K, V> > {
413};
414
415template<>
416struct Handler<LuaRef> {
417 static inline LuaRef Get(lua_State* L, int index, LuaState* s) {
418 LuaRef ret;
419 lua_pushvalue(L, index);
420 ret.SetByPopStack_(s);
421 return ret;
422 }
423
424 static inline void Push(lua_State* L, const LuaRef& v) {
425 if (v.is_nil()) {
426 lua_pushnil(L);
427 } else {
428 CHECK(v.state_->SameLuaState(L))
429 << "Cannot pass LuaRef on a different LuaState's function";
430 lua_rawgeti(L, LUA_REGISTRYINDEX, v.ref_);
431 }
432 }
433};
434
435template<>
436struct Handler<std::nullptr_t> {
437 static inline LuaRef Get(lua_State* L, int index, LuaState* s) {
438 LOG(FATAL) << "not supported";
439 return LuaRef();
440 }
441 static inline void Push(lua_State* L, const std::nullptr_t& v) {
442 lua_pushnil(L);
443 }
444};
445
446// generic functor to call push the arguments.
447struct PushArg {
448 lua_State* L;
449 template<typename T>
450 inline void operator()(const T& v) const {
451 Handler<T>::Push(L, v);
452 }
453};
454
455} // namespace lua_stack
456
457inline LuaState::LuaState() {
458 L_ = luaL_newstate();
459 CHECK(L_ != nullptr)
460 << "Failed to create new lua state";
461 luaL_openlibs(L_);
462}
463
464inline LuaState::~LuaState() {
465 if (option_ != kThreadLocal && L_ != nullptr) {
466 // never close threadlocal, for save destruction.
467 lua_close(L_);
468 }
469}
470
471inline LuaState* LuaState::Create_(Option opt) {
472 LuaState* s = new LuaState();
473 s->option_ = opt;
474 CHECK_NE(opt, kThreadLocal)
475 << "use LuaState::ThreadLocalState() to get the thread local state";
476 return s;
477}
478
479inline void LuaRef::SetByPopStack_(LuaState* s) {
480 CHECK(state_ == nullptr);
481 lua_State* L = s->L_;
482 if (!lua_isnil(L, -1)) {
483 ref_ = lua_ref(L, LUA_REGISTRYINDEX);
484 state_ = s;
485 } else {
486 lua_pop(L, 1);
487 }
488}
489
490// RAII guard to reset stack
491struct LuaState::StackReset {
492 lua_State* L;
493 int top;
494 ~StackReset() {
495 lua_settop(L, top);
496 }
497};
498
499template<typename F>
500inline void LuaState::PRun_(F f) {
501 if (option_ != kLocking) {
502 StackReset reset{L_, lua_gettop(L_)};
503 if (option_ == kThreadLocal) {
504 CHECK_EQ(ThreadLocalState(), this)
505 << "Invoke lua from a different thread in ThreadLocal mode.";
506 }
507 f(L_);
508 CHECK_EQ(reset.top, lua_gettop(L_));
509 } else {
510 std::lock_guard<std::mutex> lock(mutex_);
511 StackReset reset{L_, lua_gettop(L_)};
512 f(L_);
513 CHECK_EQ(reset.top, lua_gettop(L_));
514 }
515}
516
517inline LuaState* LuaState::ThreadLocalState() {
519}
520
521inline LuaRef LuaState::Eval(const char* lua_code) {
522 LuaRef ret;
523 this->PRun_([this, lua_code, &ret](lua_State* L) {
524 luaL_loadstring(L, lua_code);
525 CHECK_EQ(lua_pcall(L, 0, 1, 0), 0)
526 << "Lua call error: " << lua_tostring(L, -1) << '\n'
527 << "---------\n"
528 << lua_code
529 << "\n----------";
530 ret.SetByPopStack_(this);
531 });
532 return ret;
533}
534
535template<typename T>
536inline LuaRef LuaState::Convert(const T& value) {
537 LuaRef ret;
538 this->PRun_([this, &value, &ret](lua_State* L) {
539 lua_stack::Handler<T>::Push(L, value);
540 ret.SetByPopStack_(this);
541 });
542 return ret;
543}
544
545inline LuaRef LuaState::operator[](const std::string& key) {
546 LuaRef ret;
547 this->PRun_([this, &key, &ret](lua_State* L) {
548 lua_getglobal(L, key.c_str());
549 ret.SetByPopStack_(this);
550 });
551 return ret;
552}
553
554inline void LuaState::SetGlobalField(
555 const std::string& key, const LuaRef& value) {
556 this->PRun_([this, &key, &value](lua_State* L) {
557 lua_rawgeti(L, LUA_REGISTRYINDEX, value.ref_);
558 lua_setglobal(L, key.c_str());
559 });
560}
561
562inline LuaRef::LuaRef(const LuaRef& other) {
563 if (other.state_ != nullptr) {
564 state_ = other.state_;
565 state_->PRun_([this, &other](lua_State* L) {
566 lua_rawgeti(L, LUA_REGISTRYINDEX, other.ref_);
567 ref_ = luaL_ref(L, LUA_REGISTRYINDEX);
568 });
569 }
570}
571
572inline LuaRef::LuaRef(LuaRef&& other) {
573 ref_ = other.ref_;
574 state_ = other.state_;
575 other.state_ = nullptr;
576}
577
578inline LuaRef& LuaRef::operator=(LuaRef&& other) {
579 LuaRef(std::move(other)).swap(*this);
580 return *this;
581}
582
583inline LuaRef& LuaRef::operator=(const LuaRef& other) {
584 LuaRef(other).swap(*this);
585 return *this;
586}
587
588inline void LuaRef::swap(LuaRef& other) { // NOLINT(*)
589 std::swap(state_, other.state_);
590 std::swap(ref_, other.ref_);
591}
592
593inline LuaRef::~LuaRef() {
594 if (state_ != nullptr) {
595 state_->PRun_([this](lua_State* L) {
596 luaL_unref(L, LUA_REGISTRYINDEX, ref_);
597 });
598 }
599}
600
601inline bool LuaRef::is_nil() const {
602 return state_ == nullptr;
603}
604
605std::ostream &operator<<(std::ostream &os, const LuaRef &r) {
606 if (!r.is_nil()) {
607 r.state_->PRun_([&os, &r](lua_State* L) {
608 lua_rawgeti(L, LUA_REGISTRYINDEX, r.ref_);
609 int type = lua_type(L, -1);
610 switch (type) {
611 case LUA_TSTRING:
612 os << "lua_string:'" << lua_tostring(L, -1) << "'"; break;
613 case LUA_TBOOLEAN:
614 os << "lua_bool:" << (lua_toboolean(L, -1) ? "true" : "false"); break;
615 case LUA_TNUMBER:
616 os << "lua_number:" << lua_tonumber(L, -1); break;
617 default:
618 os << "lua[ref=" << r.ref_ << ']' << lua_typename(L, type); break;
619 }
620 lua_pop(L, 1);
621 });
622 } else {
623 os << "lua_nil";
624 }
625 return os;
626}
627
628template<typename T>
629inline T LuaRef::Get() const {
630 CHECK(state_ != nullptr) << "Get:: LuaRef is nil";
631 T ret;
632 state_->PRun_([&ret, this](lua_State* L) {
633 lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
634 ret = lua_stack::Handler<T>::Get(L, -1, state_);
635 lua_pop(L, 1);
636 });
637 return ret;
638}
639
640template<typename T>
641inline T* LuaRef::GetUDataPtr() const {
642 CHECK(state_ != nullptr) << "Get:: LuaRef is nil";
643 T* ret;
644 state_->PRun_([&ret, this](lua_State* L) {
645 lua_rawgeti(L, LUA_REGISTRYINDEX, ref_);
646 ret = reinterpret_cast<T*>(lua_touserdata(L, -1));
647 lua_pop(L, 1);
648 });
649 return ret;
650}
651
652// helper function to dispatch varg foreach
653template<bool stop, std::size_t I, typename F, typename ...Args>
654struct for_each_dispatcher_ {
655 static inline void run(const std::tuple<Args...>& args, F f) {
656 f(std::get<I>(args));
657 for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f);
658 }
659};
660// helper function to run foreach
661template<std::size_t I, typename F, typename ...Args>
662struct for_each_dispatcher_<true, I, F, Args...> {
663 static inline void run(const std::tuple<Args...>& args, F f) {
664 }
665};
666
667// template function to iterate over tuples
668template<typename F, typename ...Args>
669inline void for_each(const std::tuple<Args...>& args, F f) {
670 for_each_dispatcher_<sizeof...(Args) == 0, 0, F, Args...>::run(args, f);
671}
672
673template<typename... Args>
674inline LuaRef LuaRef::operator()(Args&& ...args) const {
675 CHECK(state_ != nullptr) << "LuaRef is nil";
676 auto targ = std::make_tuple(std::forward<Args>(args)...);
677 size_t nargs = sizeof...(Args);
678 LuaRef ret;
679 state_->PRun_([this, nargs, &targ, &ret](lua_State* L) {
680 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
681 CHECK(lua_isfunction(L, -1))
682 << "Expect to invoke a function but type='"
683 << lua_typename(L, lua_type(L, -1)) << '\'';
684 for_each(targ, lua_stack::PushArg{L});
685 LUA_CALL(lua_pcall(L, nargs, 1, 0));
686 ret.SetByPopStack_(state_);
687 });
688 return ret;
689}
690
691template<typename T>
692inline LuaRef& LuaRef::SetField(const std::string& key, const T& value) { // NOLINT(*)
693 CHECK(state_ != nullptr) << "LuaRef is nil";
694 state_->PRun_([this, &key, &value](lua_State* L) {
695 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
696 CHECK(lua_istable(L, -1))
697 << "Expect a table but type='"
698 << lua_typename(L, lua_type(L, -1)) << '\'';
699 lua_stack::Handler<T>::Push(L, value);
700 lua_setfield(L, -2, key.c_str());
701 lua_pop(L, 1);
702 });
703 return *this;
704}
705
706inline LuaRef LuaRef::operator[](const std::string& key) const {
707 CHECK(state_ != nullptr) << "LuaRef is nil";
708 LuaRef ret;
709 state_->PRun_([this, &key, &ret](lua_State* L) {
710 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
711 CHECK(lua_istable(L, -1))
712 << "Expect a table but type='"
713 << lua_typename(L, lua_type(L, -1)) << '\'';
714 lua_getfield(L, -1, key.c_str());
715 ret.SetByPopStack_(state_);
716 lua_pop(L, 1);
717 });
718 return ret;
719}
720
721inline LuaRef LuaRef::operator[](size_t index) const {
722 CHECK(state_ != nullptr) << "LuaRef is nil";
723 LuaRef ret;
724 state_->PRun_([this, index, &ret](lua_State* L) {
725 lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_);
726 CHECK(lua_istable(L, -1))
727 << "Expect a table but type='"
728 << lua_typename(L, lua_type(L, -1)) << '\'';
729 lua_rawgeti(L, -1, index);
730 ret.SetByPopStack_(state_);
731 lua_pop(L, 1);
732 });
733 return ret;
734}
735
737} // namespace dmlc
738
739#endif // DMLC_LUA_H_
an reference to lua object
Definition lua.h:64
T Get() const
Get content out as type T.
T * GetUDataPtr() const
Get user data pointer from LuaRef.
void swap(LuaRef &other)
swap content with another ref
LuaRef(LuaRef &&other)
move constructor from another LuaRef
LuaRef(const LuaRef &other)
copy constructor
LuaRef & SetField(const std::string &key, const T &value)
Set field of lua table. The reference must be a table.
LuaRef operator[](size_t index) const
Get field from the lua array The reference must be a array.
~LuaRef()
destructor
void SetByPopStack_(LuaState *s)
Set LuaRef to the value on top of the stack. This state must be nil. This is API used by developer.
LuaRef & operator=(const LuaRef &other)
assign operator from other
LuaRef operator[](const std::string &key) const
Get field from the lua table. The reference must be a table.
LuaRef()=default
construct an nil ref
LuaRef & operator=(LuaRef &&other)
assign operator from other
bool is_nil() const
LuaRef operator()(Args &&...args) const
invoke the LuaRef as function
A Lua state.
Definition lua.h:173
Option
options to be provided in lua state
Definition lua.h:176
LuaRef Eval(const char *lua_code)
evaluate a piece of lua code, return the first result.
~LuaState()
destructor
LuaRef Convert(const T &value)
convert a C++ type to lua type
static LuaState * Create_(Option option)
static LuaState * ThreadLocalState()
LuaRef Eval(const std::string &lua_code)
evaluate a piece of lua code, return the first result.
Definition lua.h:196
bool SameLuaState(lua_State *L) const
Definition lua.h:264
Option option_
internal option, default to thread local
Definition lua.h:278
std::mutex mutex_
internal lock about the state
Definition lua.h:282
lua_State * L_
internal lua state
Definition lua.h:280
void SetGlobalField(const std::string &key, const LuaRef &value)
Set the value to the global table.
LuaRef operator[](const std::string &key)
get global field from the state
void PRun_(F f)
protected run f, this is used by API developers. always call this to access lua state f must not dest...
LuaState()
constructor
A threadlocal store to store threadlocal variables. Will return a thread local singleton of type T.
Definition thread_local.h:35
static T * Get()
Definition thread_local.h:38
defines console logging options for xgboost. Use to enforce unified print behavior.
C++11 header only interface to easily interact with Lua and Torch. This code is evolved from torch pl...
@ string
string value
namespace for dmlc
Definition array_view.h:12
std::ostream & operator<<(std::ostream &os, const optional< T > &t)
serialize an optional object to string.
Definition optional.h:296
Definition StdDeque.h:58
NLOHMANN_BASIC_JSON_TPL_DECLARATION void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL &j1, nlohmann::NLOHMANN_BASIC_JSON_TPL &j2) noexcept(//NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) is_nothrow_move_constructible< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value &&//NOLINT(misc-redundant-expression) is_nothrow_move_assignable< nlohmann::NLOHMANN_BASIC_JSON_TPL >::value)
exchanges the values of two JSON objects
Definition json.hpp:24418
Macros common to all headers.
Definition lua.h:60
Portable thread local storage.