from dotenv import load_dotenv
load_dotenv()
import json
import statistics
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
from tensorlake.sandbox import SandboxClient
# ─── Reuse RolloutResult from the CartPole section ────────────────────────────
# total_reward = mean reward per game, X's perspective (+1 win, -1 loss, 0 draw)
# steps = total moves across all games in the batch
# trajectory = list of per-game outcomes
@dataclass
class RolloutResult:
seed: int
total_reward: float
steps: int
trajectory: List[dict] = field(default_factory=list)
# ─── Tic-tac-toe config ───────────────────────────────────────────────────────
# Extends the RolloutConfig pattern: swap env_name for policy_x / policy_o,
# add n_games (games per sandbox call = one rollout batch).
@dataclass
class TttConfig:
seed: int
policy_x: str # key into POLICIES
policy_o: str
n_games: int = 50
# ─── Policies as code strings ─────────────────────────────────────────────────
# Treat these like LLM-generated completions: they run inside the sandbox,
# never in the host process. A buggy policy crashes its sandbox, not the loop.
POLICIES: Dict[str, str] = {
"random": """
def choose_action(board, player, rng):
moves = [i for i, v in enumerate(board) if v is None]
return rng.choice(moves)
""",
"greedy": """
def choose_action(board, player, rng):
WINS = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
opponent = "O" if player == "X" else "X"
moves = [i for i, v in enumerate(board) if v is None]
# Take the win if available
for move in moves:
b = board[:]; b[move] = player
for a, c, d in WINS:
if b[a] and b[a] == b[c] == b[d]: return move
# Block the opponent's win
for move in moves:
b = board[:]; b[move] = opponent
for a, c, d in WINS:
if b[a] and b[a] == b[c] == b[d]: return move
return rng.choice(moves)
""",
}
# ─── Harness ──────────────────────────────────────────────────────────────────
# Runs n_games games inside a single sandbox and returns the batch return.
# Both policies execute in separate namespaces so they can't overwrite each
# other's globals — important when policies come from different sources.
_TTT_HARNESS = """
import json, random
WINS = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
ns_x, ns_o = {{}}, {{}}
exec({policy_x!r}, ns_x); exec({policy_o!r}, ns_o)
choose_x = ns_x["choose_action"]; choose_o = ns_o["choose_action"]
def winner(b):
for a, c, d in WINS:
if b[a] and b[a] == b[c] == b[d]: return b[a]
return None
rng = random.Random({seed})
games = []
for _ in range({n_games}):
board, moves_played = [None] * 9, 0
for turn in range(9):
player = "X" if turn % 2 == 0 else "O"
action = (choose_x if player == "X" else choose_o)(board[:], player, rng)
board[action] = player; moves_played += 1
w = winner(board)
if w:
games.append({{"outcome": w + " wins", "reward": 1 if w == "X" else -1, "moves": moves_played}})
break
else:
games.append({{"outcome": "draw", "reward": 0, "moves": moves_played}})
print(json.dumps({{
"total_reward": sum(g["reward"] for g in games) / len(games),
"steps": sum(g["moves"] for g in games),
"trajectory": games,
}}))
"""
# ─── Interactive move oracle ──────────────────────────────────────────────────
# For interactive play the sandbox stays open for the whole game session.
# Each opponent turn sends the current board and gets one action back.
# timeout_secs gives the human up to 5 minutes of total think time.
_MOVE_HARNESS = """
import random
ns = {{}}
exec({policy!r}, ns)
action = ns["choose_action"]({board!r}, {player!r}, random.Random({seed}))
print(action)
"""
WINS = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
def _winner(board: list):
for a, c, d in WINS:
if board[a] and board[a] == board[c] == board[d]:
return board[a]
return None
def _display(board: list) -> None:
row = lambda i: " | ".join(
str(i * 3 + j) if board[i * 3 + j] is None else board[i * 3 + j]
for j in range(3)
)
print(f" {row(0)}\n---+---+---\n {row(1)}\n---+---+---\n {row(2)}\n")
def play_against(human_side: str = "X", opponent_policy: str = "greedy") -> None:
"""
Play a game of tic-tac-toe against a policy running in a sandbox.
The sandbox opens once at the start of the game and stays live until the
game ends. Each opponent turn is a single box.run() call — the policy code
never executes in the host process.
human_side: "X" (you move first) or "O" (opponent moves first)
opponent_policy: any key in POLICIES
"""
assert human_side in ("X", "O"), "human_side must be 'X' or 'O'"
opponent_side = "O" if human_side == "X" else "X"
board = [None] * 9
print(f"\nYou are {human_side}. Opponent: {opponent_policy}.")
print("Empty squares show their position number (0–8).\n")
_display(board)
# Keep one sandbox alive for the whole game — no re-creation per move
with SandboxClient().create_and_connect(memory_mb=1024, timeout_secs=300) as box:
for turn in range(9):
player = "X" if turn % 2 == 0 else "O"
available = [i for i, v in enumerate(board) if v is None]
if player == human_side:
while True:
try:
move = int(input(f"Your move ({human_side}), choose from {available}: "))
if move in available:
break
print(f" Square {move} is taken. Choose from {available}.")
except ValueError:
print(f" Enter a number from {available}.")
else:
# The seed is the turn number — deterministic but varies per turn
harness = _MOVE_HARNESS.format(
policy=POLICIES[opponent_policy],
board=board,
player=player,
seed=turn,
)
ex = box.run("python3", ["-c", harness])
move = int((ex.stdout or "").strip())
print(f" {opponent_side} ({opponent_policy}) plays {move}")
board[move] = player
_display(board)
w = _winner(board)
if w:
print("You win!" if w == human_side else f"{opponent_policy} wins!")
return
print("Draw!")
# ─── Single batch rollout ─────────────────────────────────────────────────────
def run_ttt_batch(config: TttConfig) -> RolloutResult:
"""
Run one batch of n_games in a fresh sandbox and return a RolloutResult.
Follows the same signature as run_single_rollout from the CartPole section:
one config in, one RolloutResult out, one sandbox per call.
"""
harness = _TTT_HARNESS.format(
policy_x=POLICIES[config.policy_x],
policy_o=POLICIES[config.policy_o],
seed=config.seed,
n_games=config.n_games,
)
with SandboxClient().create_and_connect(memory_mb=1024) as box:
ex = box.run("python3", ["-c", harness])
data = json.loads((ex.stdout or "").strip())
return RolloutResult(
seed=config.seed,
total_reward=data["total_reward"],
steps=data["steps"],
trajectory=data["trajectory"],
)
# ─── Policy evaluation ────────────────────────────────────────────────────────
def evaluate_matchup(
policy_x: str,
policy_o: str,
seeds: List[int],
n_games: int = 50,
) -> Tuple[float, float]:
"""
Run one batch per seed in parallel; return (mean_return, std_return).
Follows the same parallel dispatch pattern as collect_parallel_rollouts:
one sandbox per seed, all running concurrently. More seeds = tighter
estimate of the true policy return.
"""
configs = [
TttConfig(seed=s, policy_x=policy_x, policy_o=policy_o, n_games=n_games)
for s in seeds
]
returns: List[float] = [0.0] * len(configs)
with ThreadPoolExecutor(max_workers=len(configs)) as pool:
futures = {pool.submit(run_ttt_batch, cfg): i for i, cfg in enumerate(configs)}
for future in as_completed(futures):
returns[futures[future]] = future.result().total_reward
return statistics.mean(returns), statistics.stdev(returns)
# ─── Q-learning ───────────────────────────────────────────────────────────────
# Uses str(s)+","+str(a) as Q-key to avoid f-string braces conflicting
# with .format() when the harness template is rendered on the host.
_QLEARN_HARNESS = """
import json, random
WINS = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
def greedy_move(board, rng):
moves = [i for i, v in enumerate(board) if v is None]
for move in moves:
b = board[:]; b[move] = "O"
for a, c, d in WINS:
if b[a] and b[a] == b[c] == b[d]: return move
for move in moves:
b = board[:]; b[move] = "X"
for a, c, d in WINS:
if b[a] and b[a] == b[c] == b[d]: return move
return rng.choice(moves)
def winner(b):
for a, c, d in WINS:
if b[a] and b[a] == b[c] == b[d]: return b[a]
return None
def skey(b): return tuple(0 if v is None else 1 if v == "X" else 2 for v in b)
def qkey(s, a): return str(s) + "," + str(a)
def qv(q, s, a): return q.get(qkey(s, a), 0.0)
q = json.loads({q_json!r})
rng = random.Random({seed})
alpha, gamma, epsilon = {alpha}, {gamma}, {epsilon}
ep_rewards = []
for _ in range({n_episodes}):
board = [None] * 9
ep_r = 0.0
while True:
moves = [i for i, v in enumerate(board) if v is None]
if not moves: ep_rewards.append(ep_r); break
s = skey(board)
a = rng.choice(moves) if rng.random() < epsilon else max(moves, key=lambda x: qv(q, s, x))
board[a] = "X"
w = winner(board)
if w or not any(v is None for v in board):
r = 1.0 if w == "X" else -1.0 if w == "O" else 0.0
q[qkey(s, a)] = qv(q, s, a) + alpha * (r - qv(q, s, a))
ep_r += r; ep_rewards.append(ep_r); break
board[greedy_move(board[:], rng)] = "O"
w = winner(board)
r = 1.0 if w == "X" else -1.0 if w == "O" else 0.0
s2 = skey(board)
moves2 = [i for i, v in enumerate(board) if v is None]
nq = max((qv(q, s2, x) for x in moves2), default=0.0) if moves2 else 0.0
q[qkey(s, a)] = qv(q, s, a) + alpha * (r + gamma * nq - qv(q, s, a))
ep_r += r
if w or not moves2: ep_rewards.append(ep_r); break
print(json.dumps({{"q_table": q, "mean_reward": sum(ep_rewards)/len(ep_rewards), "n_states": len(q)}}))
"""
@dataclass
class QConfig:
seed: int
q_table: dict = field(default_factory=dict)
epsilon: float = 0.3 # exploration rate — high early, can decay over iterations
alpha: float = 0.5 # learning rate
gamma: float = 0.9 # discount factor
n_episodes: int = 300
def run_qlearning_iter(config: QConfig) -> dict:
"""Run one training iteration in a sandbox; return updated Q-table + stats."""
harness = _QLEARN_HARNESS.format(
q_json=json.dumps(config.q_table),
seed=config.seed,
alpha=config.alpha,
gamma=config.gamma,
epsilon=config.epsilon,
n_episodes=config.n_episodes,
)
with SandboxClient().create_and_connect(memory_mb=1024) as box:
ex = box.run("python3", ["-c", harness])
return json.loads((ex.stdout or "").strip())
def train_q(n_iter: int = 8, episodes_per_iter: int = 300) -> dict:
"""
Train a Q-table over n_iter sequential sandbox calls.
Each call receives the Q-table from the previous iteration and returns
an updated one. Mean reward moving from negative to positive confirms
the policy is improving against the greedy opponent.
"""
q_table: dict = {}
print(f"{'Iter':>5} {'Mean reward':>13} {'Q-states':>10}")
print("-" * 34)
for i in range(n_iter):
result = run_qlearning_iter(QConfig(seed=i, q_table=q_table, n_episodes=episodes_per_iter))
q_table = result["q_table"]
print(f"{i+1:>5} {result['mean_reward']:>+13.3f} {result['n_states']:>10}")
return q_table
def q_policy_code(q_table: dict) -> str:
"""
Serialize the Q-table into a choose_action string compatible with POLICIES.
This lets the learned policy plug directly into evaluate_matchup and
play_against without any changes to those functions.
"""
q_json = json.dumps(q_table)
return (
"import json as _j\n"
"_Q = _j.loads(" + repr(q_json) + ")\n"
"def choose_action(board, player, rng):\n"
" def skey(b): return tuple(0 if v is None else 1 if v == 'X' else 2 for v in b)\n"
" def qkey(s, a): return str(s) + ',' + str(a)\n"
" moves = [i for i, v in enumerate(board) if v is None]\n"
" return max(moves, key=lambda a: _Q.get(qkey(skey(board), a), 0.0))\n"
)
# ─── Main ─────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
matchups = [
("random", "random"),
("greedy", "random"),
("random", "greedy"),
("greedy", "greedy"),
]
seeds = [0, 1, 2, 3] # one sandbox per seed per matchup = 16 sandboxes total
print("Evaluating all matchups (4 seeds × 50 games each, one sandbox per seed)...")
print(f"\n{'X policy':>10} {'O policy':>10} {'mean return':>13} {'std':>6}")
print("-" * 48)
eval_results = {}
for x, o in matchups:
mean, std = evaluate_matchup(x, o, seeds=seeds)
eval_results[(x, o)] = (mean, std)
print(f"{x:>10} {o:>10} {mean:>+13.3f} {std:>6.3f}")
print(
f"\n → Mean return is the expected reward per game from X's perspective\n"
f" (+1 win, −1 loss, 0 draw), averaged over {seeds} seeds × 50 games.\n"
f" greedy-vs-random ({eval_results[('greedy','random')][0]:+.3f}) shows how\n"
f" strongly a win/block heuristic dominates pure chance.\n"
f" greedy-vs-greedy ({eval_results[('greedy','greedy')][0]:+.3f} ≠ 0) reveals a\n"
f" fork vulnerability: X can reach positions that greedy-O cannot\n"
f" simultaneously block, which a stronger policy would eliminate.\n"
f" Low std (0.04–0.10) confirms 4 seeds × 50 games is enough to\n"
f" rank policies reliably — scale up seeds for tighter confidence intervals."
)
# ── Train and add the learned policy ─────────────────────────────────────
print("\nTraining Q-learner vs greedy opponent (8 iterations × 300 episodes)...")
q_table = train_q(n_iter=8, episodes_per_iter=300)
# Serialize the Q-table into a choose_action string — same interface as
# random and greedy, so evaluate_matchup works without any changes.
POLICIES["q_learned"] = q_policy_code(q_table)
print("\nEvaluating learned policy against baselines:")
print(f"\n{'Matchup':>28} {'mean return':>13} {'std':>6}")
print("-" * 54)
for x, o in [("q_learned", "greedy"), ("greedy", "q_learned"), ("q_learned", "random")]:
mean, std = evaluate_matchup(x, o, seeds=seeds)
print(f"{x+' vs '+o:>28} {mean:>+13.3f} {std:>6.3f}")
print(
"\n → q_learned was trained as X against greedy O.\n"
" It does not know how to play as O — greedy vs q_learned\n"
" exposes this: the policy is role-specialized, not general."
)
# ── Play a game ───────────────────────────────────────────────────────────
side = input("\nPlay a game? Choose your side [X/O] (or press Enter to skip): ").strip().upper()
if side in ("X", "O"):
available_policies = list(POLICIES.keys())
opp = input(f"Opponent policy {available_policies} (default: greedy): ").strip().lower()
if opp not in POLICIES:
opp = "greedy"
play_against(human_side=side, opponent_policy=opp)