mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Move Q matrix export to analysis
This commit is contained in:
parent
77cd0eeab8
commit
6fda866f94
2 changed files with 71 additions and 0 deletions
14
analysis/Project.toml
Normal file
14
analysis/Project.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[deps]
|
||||
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
|
||||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
|
||||
CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864"
|
||||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
|
||||
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
|
||||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
ReCo = "b25f7548-fcc9-4c91-bc24-841b54f4dd54"
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
57
analysis/Q_matrix_latex_export.jl
Normal file
57
analysis/Q_matrix_latex_export.jl
Normal file
|
@ -0,0 +1,57 @@
|
|||
using DataFrames: DataFrames
|
||||
using PrettyTables: pretty_table
|
||||
|
||||
using ReCo: ReCo
|
||||
|
||||
function latex_table(
|
||||
dataframe::DataFrames.DataFrame, filename::String; path::String = "exports/$filename"
|
||||
)
|
||||
open(path, "w") do f
|
||||
pretty_table(f, dataframe; backend = :latex, nosubheader = true, alignment = :c)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
export_q_matrix(env_helper::ReCo.RL.EnvHelper, filename_without_extension::String)
|
||||
|
||||
Generate a LaTeX table to the Q-matrix of `env_helper`.
|
||||
|
||||
The output is `ReCo.jl/exports/filename_without_extension.tex`. `env_helper` has to be an environment helper with the abstract type `EnvHelper`.
|
||||
|
||||
Return `nothing`.
|
||||
"""
|
||||
function export_q_matrix(env_helper::ReCo.RL.EnvHelper, filename_without_extension::String)
|
||||
table = copy(env_helper.shared.agent.policy.learner.approximator.table)
|
||||
|
||||
for col in 1:size(table)[2]
|
||||
table[:, col] ./= sum(table[:, col])
|
||||
end
|
||||
|
||||
table .= round.(table, digits = 2)
|
||||
|
||||
state_spaces_labels = env_helper.shared.env.shared.state_spaces_labels
|
||||
states = AbstractString[]
|
||||
for i in state_spaces_labels[1]
|
||||
for j in state_spaces_labels[2]
|
||||
push!(states, i * ";" * j)
|
||||
end
|
||||
end
|
||||
|
||||
action_spaces_labels = env_helper.shared.env.shared.action_spaces_labels
|
||||
actions = AbstractString[]
|
||||
|
||||
for i in action_spaces_labels[1]
|
||||
for j in action_spaces_labels[2]
|
||||
push!(actions, i * ";" * j)
|
||||
end
|
||||
end
|
||||
|
||||
df = DataFrames.DataFrame(table, states)
|
||||
DataFrames.insertcols!(df, 1, :Actions => actions)
|
||||
|
||||
latex_table(df, "$filename_without_extension.tex")
|
||||
|
||||
return nothing
|
||||
end
|
Loading…
Reference in a new issue