diff --git a/src/RL/RL.jl b/src/RL/RL.jl index bf7250a..2ad332a 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -298,7 +298,11 @@ function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) ), ) - return Agent(; policy=policy, trajectory=VectorSARTTrajectory()) + trajectory = VectorSARTTrajectory(; + state=Int64, action=Int64, reward=Float64, terminal=Bool + ) + + return Agent(; policy=policy, trajectory=trajectory) end function run_rl(;