"""
PPO Model Evaluation and Feasibility Analysis Script.
This module provides tools for loading a trained Stable Baselines3 (SB3) PPO model
and evaluating its performance in the HSAEnv. The core function, `analyze_actions`,
collects detailed time-series data on actions, joint positions, velocities, and
constraint violations across multiple evaluation episodes.
The results are saved as image files (plots) to diagnose control feasibility,
action smoothness, and adherence to physical limits.
"""
import gymnasium as gym
import numpy as np
import yaml
import os, re, glob, argparse
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from gymnasium.wrappers import TimeLimit
from hsa_gym.envs.hsa_constrained import HSAEnv
from numpy.typing import NDArray
[docs]
def find_matching_vecnormalize(checkpoint_dir: str, model_path: str) -> str | None:
"""
Find the VecNormalize statistics file that corresponds to the loaded model checkpoint.
The search attempts to find:
1. An exact match based on the model's step number.
2. A generic 'vec_normalize_final.pkl' file.
3. The latest available VecNormalize file in the directory.
:param checkpoint_dir: Directory containing checkpoints and VecNormalize stats.
:type checkpoint_dir: str
:param model_path: Path to the model checkpoint file being loaded.
:type model_path: str
:returns: Path to the matching VecNormalize file, or None if not found.
:rtype: str or None
"""
try:
model_steps = extract_step_number(model_path)
except:
model_steps = None
# Strategy 1: Look for exact match by step number
if model_steps:
exact_match = os.path.join(checkpoint_dir, f"vec_normalize_{model_steps}_steps.pkl")
if os.path.exists(exact_match):
print(f"[VecNormalize] Found exact match: {exact_match}")
return exact_match
# Strategy 2: Look for "final" version
final_path = os.path.join(checkpoint_dir, "vec_normalize_final.pkl")
if os.path.exists(final_path):
print(f"[VecNormalize] Found final version: {final_path}")
return final_path
# Strategy 3: Get the latest VecNormalize file
vecnorm_files = glob.glob(os.path.join(checkpoint_dir, "vec_normalize_*_steps.pkl"))
if vecnorm_files:
vecnorm_files.sort(key=lambda x: extract_step_number(x))
latest = vecnorm_files[-1]
print(f"[VecNormalize] Using latest available: {latest}")
return latest
print(f"[VecNormalize] WARNING: No VecNormalize stats found in {checkpoint_dir}")
return None
[docs]
def load_config(checkpoint_dir: str) -> dict:
"""
Load the environment configuration used for the training run from the archived YAML file.
:param checkpoint_dir: Directory where the `used_config.yaml` file is stored.
:type checkpoint_dir: str
:returns: A dictionary containing the environment configuration.
:rtype: dict
"""
config_path = os.path.join(checkpoint_dir, "used_config.yaml")
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
[docs]
def make_env(config: dict, render_mode: str = "human") -> gym.Wrapper:
"""
Instantiate and wrap a single HSAEnv environment based on configuration parameters.
The environment is wrapped with :py:class:`gymnasium.wrappers.TimeLimit`.
:param config: The configuration dictionary loaded from `used_config.yaml`.
:type config: dict
:param render_mode: The rendering mode for the environment ("human" or "rgb_array").
:type render_mode: str
:returns: The wrapped environment instance ready for evaluation.
:rtype: gymnasium.Wrapper
"""
env_config = config["env"]
env = HSAEnv(
render_mode=render_mode,
xml_file=env_config["xml_file"],
actuator_group=env_config["actuator_group"],
action_group=env_config["action_group"],
forward_reward_weight=env_config["forward_reward_weight"],
ctrl_cost_weight=env_config["ctrl_cost_weight"],
contact_cost_weight=env_config["contact_cost_weight"],
yvel_cost_weight=env_config["yvel_cost_weight"],
constraint_cost_weight=env_config["constraint_cost_weight"],
acc_cost_weight=env_config["acc_cost_weight"],
smooth_positions=env_config["smooth_positions"],
frame_skip=env_config["frame_skip"],
max_increment=env_config["max_increment"],
enable_terrain=env_config.get("enable_terrain", False),
terrain_type=env_config.get("terrain_type", "flat"),
early_termination_penalty=env_config.get("early_termination_penalty", 0.0),
alive_bonus=env_config.get("alive_bonus", 0.0),
goal_position=env_config.get("goal_position", None),
distance_reward_weight=env_config.get("distance_reward_weight", 0.0),
ensure_flat_spawn=env_config.get("ensure_flat_spawn", True),
)
env = TimeLimit(env, max_episode_steps=env_config["max_episode_steps"])
return env
[docs]
def analyze_actions(checkpoint_dir: str, model_path: str, num_episodes: int = 5) -> None:
"""
Core function to evaluate a trained model and analyze its kinematic and control outputs.
This function runs the model over several episodes and collects time-series data for:
* Action distribution and smoothness (change).
* Actuated joint positions and velocities.
* Paired joint constraint difference.
It then generates and saves several diagnostic plots and prints summary statistics.
:param checkpoint_dir: Directory containing the model and configuration files.
:type checkpoint_dir: str
:param model_path: Path to the specific PPO model checkpoint (`.zip`) to load.
:type model_path: str
:param num_episodes: Number of episodes to run for data collection.
:type num_episodes: int
:returns: None
:rtype: None
"""
print("Loading configuration and model...")
config = load_config(checkpoint_dir)
base_env = make_env(config, render_mode="human")
env = DummyVecEnv([lambda: base_env])
vecnorm_path = find_matching_vecnormalize(checkpoint_dir, model_path)
if vecnorm_path:
print(f"\n[VecNormalize] Loading normalization stats from:")
print(f" {vecnorm_path}")
env = VecNormalize.load(vecnorm_path, env)
# Ensure the environment is in evaluation mode
env.training = False
env.norm_reward = False
else:
# Initialize VecNormalize even if no file is found, but disable normalization
env = VecNormalize(env, training=False, norm_reward=False)
model = PPO.load(model_path, env=env)
# Storage for data
all_actions = []
all_action_changes = []
all_joint_positions = []
all_joint_velocities = []
all_constraint_diffs = []
actuator_names = ['1A', '2A', '3A', '4A', '1C', '2C', '3C', '4C']
qpos_indices = [7, 22, 10, 24, 20, 9, 21, 12] # Indices for qpos logging
qvel_indices = [6, 20, 9, 22, 18, 8, 19, 11] # Indices for qvel logging
print(f"\nRunning {num_episodes} episodes to collect data...")
for ep in range(num_episodes):
obs = env.reset()
done = False
episode_actions = []
episode_joint_pos = []
episode_joint_vel = []
episode_constraint_diffs = []
prev_action = None
step_count = 0
while not done:
action, _ = model.predict(obs, deterministic=False)
obs, reward, done_array, info = env.step(action)
done = done_array[0]
if done:
print(f" Episode terminated at step {step_count} with termination reasons: {info[0]['termination_reasons']}")
step_count += 1
# Store action
episode_actions.append(action[0].copy())
# Store action change
if prev_action is not None:
action_change = np.abs(action[0] - prev_action)
all_action_changes.append(action_change)
prev_action = action[0].copy()
unwrapped_env = env.envs[0].unwrapped
# Store joint positions and velocities if available
if hasattr(unwrapped_env, 'data'):
data = unwrapped_env.data
# Get joint positions for actuated joints
joint_pos = [data.qpos[idx] for idx in qpos_indices]
joint_vel = [data.qvel[idx] for idx in qvel_indices]
# Store data
episode_joint_pos.append(joint_pos)
episode_joint_vel.append(joint_vel)
# Calculate constraint differences (Absolute difference |A - C|)
# Order in joint_pos: [1A, 2A, 3A, 4A, 1C, 2C, 3C, 4C]
diffs = [
abs(joint_pos[0] - joint_pos[4]), # 1A - 1C
abs(joint_pos[1] - joint_pos[5]), # 2A - 2C
abs(joint_pos[2] - joint_pos[6]), # 3A - 3C
abs(joint_pos[3] - joint_pos[7]), # 4A - 4C
]
episode_constraint_diffs.append(diffs)
all_actions.extend(episode_actions)
all_joint_positions.extend(episode_joint_pos)
all_joint_velocities.extend(episode_joint_vel)
all_constraint_diffs.extend(episode_constraint_diffs)
print(f" Episode {ep+1}/{num_episodes} completed - {step_count} steps")
# Convert to numpy arrays
all_actions = np.array(all_actions)
all_action_changes = np.array(all_action_changes)
all_joint_positions = np.array(all_joint_positions)
all_joint_velocities = np.array(all_joint_velocities)
all_constraint_diffs = np.array(all_constraint_diffs)
print("\nGenerating plots...")
# Create figure with subplots
fig = plt.figure(figsize=(20, 12))
# ============================================================
# 1. Action Distribution (Histograms)
# ============================================================
for i in range(8):
ax = plt.subplot(4, 4, i+1)
ax.hist(all_actions[:, i], bins=50, alpha=0.7, edgecolor='black')
ax.set_title(f'Action {actuator_names[i]}')
ax.set_xlabel('Action Value')
ax.set_ylabel('Frequency')
ax.axvline(0, color='r', linestyle='--', alpha=0.5, label='Zero')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('action_distributions.png', dpi=300, bbox_inches='tight')
print(" Saved: action_distributions.png")
plt.close()
# ============================================================
# 2. Action Time Series
# ============================================================
fig, axes = plt.subplots(4, 2, figsize=(16, 12))
fig.suptitle('Action Time Series', fontsize=16)
for i in range(8):
row = i // 2
col = i % 2
axes[row, col].plot(all_actions[:, i], linewidth=0.5)
axes[row, col].set_title(f'{actuator_names[i]}')
axes[row, col].set_xlabel('Timestep')
axes[row, col].set_ylabel('Action')
axes[row, col].axhline(0, color='r', linestyle='--', alpha=0.3)
axes[row, col].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('action_timeseries.png', dpi=300, bbox_inches='tight')
print(" Saved: action_timeseries.png")
plt.close()
# ============================================================
# 3. Action Change Magnitudes
# ============================================================
fig, axes = plt.subplots(4, 2, figsize=(16, 12))
fig.suptitle('Action Change Magnitudes (Smoothness)', fontsize=16)
for i in range(8):
row = i // 2
col = i % 2
axes[row, col].plot(all_action_changes[:, i], linewidth=0.5, alpha=0.7)
axes[row, col].set_title(f'{actuator_names[i]} - Mean: {np.mean(all_action_changes[:, i]):.4f}')
axes[row, col].set_xlabel('Timestep')
axes[row, col].set_ylabel('|Action Change|')
axes[row, col].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('action_smoothness.png', dpi=300, bbox_inches='tight')
print(" Saved: action_smoothness.png")
plt.close()
# ============================================================
# 4. Joint Positions
# ============================================================
fig, axes = plt.subplots(4, 2, figsize=(16, 12))
fig.suptitle('Joint Positions', fontsize=16)
for i in range(8):
row = i // 2
col = i % 2
axes[row, col].plot(all_joint_positions[:, i], linewidth=0.5)
axes[row, col].set_title(f'{actuator_names[i]}')
axes[row, col].set_xlabel('Timestep')
axes[row, col].set_ylabel('Position (rad)')
axes[row, col].axhline(0, color='r', linestyle='--', alpha=0.3)
axes[row, col].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('joint_positions.png', dpi=300, bbox_inches='tight')
print(" Saved: joint_positions.png")
plt.close()
# ============================================================
# 5. Joint Velocities
# ============================================================
fig, axes = plt.subplots(4, 2, figsize=(16, 12))
fig.suptitle('Joint Velocities', fontsize=16)
for i in range(8):
row = i // 2
col = i % 2
axes[row, col].plot(all_joint_velocities[:, i], linewidth=0.5)
axes[row, col].set_title(f'{actuator_names[i]} - Max: {np.max(np.abs(all_joint_velocities[:, i])):.2f} rad/s')
axes[row, col].set_xlabel('Timestep')
axes[row, col].set_ylabel('Velocity (rad/s)')
axes[row, col].axhline(0, color='r', linestyle='--', alpha=0.3)
axes[row, col].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('joint_velocities.png', dpi=300, bbox_inches='tight')
print(" Saved: joint_velocities.png")
plt.close()
# ============================================================
# 6. Constraint Violations (|A - C| for each pair)
# ============================================================
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Constraint Monitoring: |A - C| per Pair', fontsize=16)
pair_names = ['Pair 1 (1A-1C)', 'Pair 2 (2A-2C)', 'Pair 3 (3A-3C)', 'Pair 4 (4A-4C)']
for i in range(4):
row = i // 2
col = i % 2
axes[row, col].plot(all_constraint_diffs[:, i], linewidth=0.5)
axes[row, col].axhline(np.pi, color='r', linestyle='--', linewidth=2, label='π limit')
axes[row, col].set_title(f'{pair_names[i]} - Max: {np.max(all_constraint_diffs[:, i]):.3f} rad')
axes[row, col].set_xlabel('Timestep')
axes[row, col].set_ylabel('|A - C| (rad)')
axes[row, col].legend()
axes[row, col].grid(True, alpha=0.3)
axes[row, col].set_ylim([0, np.pi + 0.5])
plt.tight_layout()
plt.savefig('constraint_violations.png', dpi=300, bbox_inches='tight')
print(" Saved: constraint_violations.png")
plt.close()
# 7. Summary Statistics
print("\n" + "="*60)
print("FEASIBILITY ANALYSIS SUMMARY")
print("="*60)
print("\nAction Statistics:")
for i in range(8):
print(f" {actuator_names[i]:3s}: mean={np.mean(all_actions[:, i]):6.3f}, "
f"std={np.std(all_actions[:, i]):6.3f}, "
f"min={np.min(all_actions[:, i]):6.3f}, "
f"max={np.max(all_actions[:, i]):6.3f}")
print("\nAction Smoothness (Mean Absolute Change):")
for i in range(8):
print(f" {actuator_names[i]:3s}: {np.mean(all_action_changes[:, i]):.4f} rad/step")
print("\nJoint Velocity Statistics (rad/s):")
for i in range(8):
max_vel = np.max(np.abs(all_joint_velocities[:, i]))
mean_vel = np.mean(np.abs(all_joint_velocities[:, i]))
print(f" {actuator_names[i]:3s}: max={max_vel:6.2f}, mean={mean_vel:6.2f}")
print("\n" + "="*60)
env.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate trained PPO models for HSA Robot Locomotion.")
parser.add_argument(
'--demo',
type=str,
choices=['corridor', 'flat'],
default='flat',
help='Select which demo model to evaluate: corridor or flat (default: flat)'
)
parser.add_argument(
'--episodes',
type=int,
default=2,
help='Number of episodes to run for evaluation (default: 2)'
)
args = parser.parse_args()
script_dir = os.path.dirname(os.path.abspath(__file__))
# # Demo 1 Corridor
if args.demo == 'corridor':
print("Running Corridor Demo Evaluation...")
checkpoint_dir = os.path.join(script_dir, "../models/ppo_curriculum_corridor")
model_path = os.path.join(checkpoint_dir, "model_29000000_steps.zip")
# Demo 2 Flat
elif args.demo == 'flat':
print("Running Flat Demo Evaluation...")
checkpoint_dir = os.path.join(script_dir, "../models/ppo_curriculum_flat_small")
model_path = os.path.join(checkpoint_dir, "ppo_curriculum_flat_small_final_100000000_steps.zip")
analyze_actions(checkpoint_dir, model_path, num_episodes=args.episodes)