Skip to content

maitrix-org/llm-reasoners

Repository files navigation

logo

Home | Paper (COLM2024) | Blog | Discord | @Maitrix.org


LLM Reasoners is a library designed to enhance LLMs' ability to perform complex reasoning using advanced algorithms. It provides:

News

  • Feb. 6, 2025: Thrilled to introduce ReasonerAgent - A fully open source, ready-to-run agent that does research 🧐 in a web browser and answers your queries. Check out this thread, and explore the code here!
  • Jan. 31, 2025: LLM Reasoners has integrated SGLang. Enjoy 100x speed-up with a one-line change! New applications like PRM-guided search for inference-time scaling are also available. See more details in this post.
  • Dec. 20, 2024: We now supported planning algorithms (MCTS, DFS/BFS, Beam Search) in web environments with BrowserGym, check the README to try out!
(Show more news)
  • Nov. 13, 2024: We integrated DRPO, a tuning-free alignment method published at EMNLP 2024 (link).

  • Jul. 10, 2024: Our paper on LLM Reasoners is accepted to COLM 2024!

  • Jun. 24, 2024: PromptAgent is in LLM Reasoners! Let it help you write down a super detailed prompt for your task (here).

  • May. 14, 2024: Check out Eurus, a suit of LLMs optimized for reasoning. With LLM Reasoners, Eurus-RM can easily boost Llama-8B from 0.49 to 0.73 📈 on GSM8k (code).

  • May. 2, 2024: We have integrated our first reasoning method for scientific reasoning, StructChem! Check it out here.

  • Apr. 22, 2024: We integrated Llama-3, with additional useful APIs (e.g., customizing EOS tokens, calculating likelihood)

  • Apr. 8, 2024: Our new paper introducing LLM Reasoners is available!

  • Mar. 29, 2024: Grace Decoding has been incoporated!

  • Oct. 25, 2023: A video tutorial on the visualizer of LLM Reasoners are available.

  • Oct. 23, 2023: Reasoning-via-Planning is accepted to EMNLP 2023! Check our paper with updated results and discussion!

Introduction of the library

Library Structure

We abstract an LLM reasoning algorithm into three key components, reward function, world model, and search algorithm (see the formulation in our paper), corresponding to three classes in the library, SearchConfig, WorldModel and SearchAlgorithm respectively. Besides, there are LLM APIs to power other modules, Benchmark, and Visualization to evaluate or debug the reasoning algorithm (middle). To implement a reasoning algorithm for a certain domain (a Reasoner object), a user may inherit the SearchConfig and WorldModel class, and import a pre-implemented SearchAlgorithm. We also show a concrete example of solving Blocksworld with RAP using LLM Reasoners (bottom).

Quick Tour

Let's go through the code of reasoning over Blocksworld problems. Note that the code is simplified for demonstration (check here for a runnable notebook).

The first step is to define the world model: you will set up an initial state given a question in init_state, judge whether a state is terminal in is_terminal, and most importantly, define the world dynamics with step:

from typing import NamedTuple
import utils
from reasoners import WorldModel, LanguageModel
import copy

BWState = str
BWAction = str

class BlocksWorldModel(WorldModel[BWState, BWAction]):
    def __init__(self,
                 base_model: LanguageModel,
                 prompt: dict) -> None:
        super().__init__()
        self.base_model = base_model
        self.prompt = prompt

    def init_state(self) -> BWState:
        # extract the statement from a given problem
        # e.g., "the red block is clear, the blue block is clear..."
        return BWState(utils.extract_init_state(self.example)) 

    def step(self, state: BWState, action: BWAction) -> tuple[BWState, dict]:
        # call the LLM to predict the state transition
        state = copy.deepcopy(state)
        # load the prompt for the LLM to predict the next state
        # e.g. "... I have that <state>, if I <action>, then ..."
        world_update_prompt = self.prompt["update"].replace("<state>", state).replace("<action>", action)
        world_output = self.base_model.generate([world_update_prompt],
                                    eos_token_id="\n", hide_input=True, temperature=0).text[0].strip()
        new_state = utils.process_new_state(world_output)
        # till now, we have the new state after the action
        # the following part is to speed up the reward calculation

        # we want to check the portion of the satisfied subgoals, and use it as a part of the reward
        # since we have predicted the new state already, we can just check it here at convenience
        goal_reached = utils.goal_check(utils.extract_goals(self.example, new_state))
        # return the new state and the additional dictionary (to be passed to the reward function)
        return new_state, {"goal_reached": goal_reached}

    def is_terminal(self, state: BWState) -> bool:
        # define the condition the terminal state to stop the search
        # e.g., all the subgoals are met
        if utils.goal_check(utils.extract_goals(self.example), state.blocks_state) == 1:
            return True
        return False

Then, it's time to consider how to search for the optimal reasoning chain. It involves get_actions to get the action space given a state, and the most important reward as the guidance for reasoning. For Monte-Carlo Tree Search, we can additionally define a fast_reward to speed up the roll-out stage.

import utils
from world_model import BWState, BWAction
from reasoners import SearchConfig, LanguageModel
class BWConfig(SearchConfig):
    def __init__(self,
                 base_model: LanguageModel,
                 prompt: dict,
                 reward_alpha=0.5,
                 goal_reward_default=0.,
                 goal_reached_reward=100) -> None:
        super().__init__()
        self.base_model = base_model
        self.example = None
        self.prompt = prompt
        # some parameters to calculate the fast reward or reward (explained below)
        self.reward_alpha = reward_alpha
        self.goal_reward_default = goal_reward_default
        self.goal_reached_reward = goal_reached_reward

    def get_actions(self, state: BWState) -> list[BWAction]:
        # use a rule-based function to extract all legal actions
        return utils.generate_all_actions(state)

    def fast_reward(self, state: BWState, action: BWAction) -> tuple[float, dict]:
        # build an in-context learning prompt (similar to the one used in Chain-of-thoughts reasoning)
        inputs = self.prompt["icl"].replace("<init_state>", state)\
            .replace("<goals>", utils.extract_goals(self.example))
        # concatenate a candidate action after the prompt, and test its loglikelihood
        intuition = self.base_model.get_loglikelihood(inputs, [inputs + action])[0]
        # the reward is a combination of intuition and goal satisfaction
        # in fast_reward, we skip the calculation of goal satisfaction and use a default value
        fast_reward = intuition * self.reward_alpha + self.goal_reward_default * (1 - self.reward_alpha)
        # cache some information for the reward calculation later (will be passed to `reward` function)
        details = {'intuition': intuition}
        return fast_reward, details

    def reward(self, state: BWState, action: BWAction,
               intuition: float = None,
               goal_reached: tuple[bool, float] = None) -> float:
        # note that `intuition` (cached in `fast_reward`) and `goal_reached` (cached in `step`) are automatically passed as parameters to this reward function
        if goal_reached == 1:
            # if the goal state is reached, we will assign a large reward
            goal_reward = self.goal_reached_reward
        else:
            # otherwise assign the reward based on the portion of satisfied subgoals
            goal_reward = goal_reached
        # the reward is a combination of intuition and goal satisfaction
        reward = intuition * self.reward_alpha + goal_reward * (1 - self.reward_alpha)
        # return the reward and an additional dictionary (to be saved in the log for visualization later)
        return reward, {'intuition': intuition, 'goal_reached': goal_reached}

Now, we are ready to apply a reasoning algorithm to solve the problem:

from reasoners.algorithm import MCTS
from reasoners.lm import LLaMAModel
from world_model import BlocksWorldModel
from search_config import BWConfig

llama_model = LLaMAModel(llama_ckpts, llama_size, max_batch_size=1)
with open(prompt_path) as f:
    prompt = json.load(f)
world_model = BlocksWorldModel(base_model=base_model, prompt=prompt)
config = BWConfig(base_model=llama_model, prompt=prompt)
# save the history of every iteration for visualization
search_algo = MCTS(output_trace_in_each_iter=True)
reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo)
for i, example in enumerate(dataset):
    algo_output = reasoner(example)
    # save the MCTS results as pickle files
    with open(os.path.join(log_dir, 'algo_output', f'{resume + i + 1}.pkl'), 'wb') as f:
        pickle.dump(algo_output, f)

Finally, we can easily visualize the reasoning process:

import pickle
from reasoners.visualization import visualize
with open("logs/bw_MCTS/xxx/algo_output/1.pkl", 'rb') as f:
    mcts_result = pickle.load(f)

from reasoners.visualization.tree_snapshot import NodeData
from reasoners.algorithm.mcts import MCTSNode

# by default, a state will be presented along with the node, and the reward with saved dictionary in `SearchConfig.reward` will be presented along with the edge. 
# we can also define a helper function to customize what we want to see in the visualizer.
def blocksworld_node_data_factory(n: MCTSNode) -> NodeData:
    return NodeData({"block state": n.state.blocks_state if n.state else None,
                     "satisfied": n.fast_reward_details if n.fast_reward_details else "Not expanded"})
def blocksworld_edge_data_factory(n: MCTSNode) -> EdgeData:
    return EdgeData({"reward": n.reward, "intuition": n.fast_reward_details["intuition"]})
visualize(mcts_result, node_data_factory=blocksworld_node_data_factory,
                       edge_data_factory=blocksworld_edge_data_factory)

Then a URL of the visualized results will pop up. The figure will be interactive and look like the examples shown on our demo website.

Installation

Make sure to use Python 3.10 or later.

conda create -n reasoners python=3.10
conda activate reasoners

Install from pip

pip install llm-reasoners

Install from github

(Recommended if you want to run the examples in the github repo)

git clone https://github.com/Ber666/llm-reasoners --recursive
cd llm-reasoners
pip install -e .

Adding --recursive will help you clone exllama and LLM-Planning automatically. Note that some other optional modules may require other dependencies. Please refer to the error message for details.

Citation

This project is an extension of the following paper:

@inproceedings{hao2023reasoning,
  title={Reasoning with Language Model is Planning with World Model},
  author={Hao, Shibo and Gu, Yi and Ma, Haodi and Hong, Joshua and Wang, Zhen and Wang, Daisy and Hu, Zhiting},
  booktitle={Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing},
  pages={8154--8173},
  year={2023}
}
@article{hao2024llm,
  title={LLM Reasoners: New Evaluation, Library, and Analysis of Step-by-Step Reasoning with Large Language Models},
  author={Hao, Shibo and Gu, Yi and Luo, Haotian and Liu, Tianyang and Shao, Xiyan and Wang, Xinyuan and Xie, Shuhua and Ma, Haodi and Samavedhi, Adithya and Gao, Qiyue and others},
  journal={arXiv preprint arXiv:2404.05221},
  year={2024}
}