1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-30 17:03:30 +00:00

Added tests

This commit is contained in:
MoBit 2021-12-15 04:45:15 +01:00
parent 8ad67229a8
commit 7cb7f2d619
6 changed files with 79 additions and 9 deletions

View file

@ -15,7 +15,6 @@ GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5" Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
@ -29,4 +28,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat] [compat]
julia = "1.6" julia = ">=1.7"

View file

@ -27,14 +27,14 @@ function restrict_coordinate(value::Float64, half_box_len::Float64)
return value return value
end end
function restrict_coordinates!(v::SVector{2,Float64}, half_box_len::Float64) function restrict_coordinates(v::SVector{2,Float64}, half_box_len::Float64)
return SVector( return SVector(
restrict_coordinate(v[1], half_box_len), restrict_coordinate(v[2], half_box_len) restrict_coordinate(v[1], half_box_len), restrict_coordinate(v[2], half_box_len)
) )
end end
function restrict_coordinates!(p::Particle, half_box_len::Float64) function restrict_coordinates!(p::Particle, half_box_len::Float64)
p.tmp_c = restrict_coordinates!(p.tmp_c, half_box_len) p.tmp_c = restrict_coordinates(p.tmp_c, half_box_len)
return nothing return nothing
end end

View file

@ -1,6 +1,6 @@
module RL module RL
export run export run_rl
using ReinforcementLearning using ReinforcementLearning
using Flux: InvDecay using Flux: InvDecay
@ -12,8 +12,6 @@ using ProgressMeter: @showprogress
using ..ReCo using ..ReCo
import Base: run
const INITIAL_REWARD = 0.0 const INITIAL_REWARD = 0.0
struct DistanceState{L<:Bound} struct DistanceState{L<:Bound}
@ -369,7 +367,7 @@ function post_integration_hook(
return nothing return nothing
end end
function run(; function run_rl(;
goal_shape_ratio::Float64, goal_shape_ratio::Float64,
n_episodes::Int64=100, n_episodes::Int64=100,
episode_duration::Float64=50.0, episode_duration::Float64=50.0,

View file

@ -52,7 +52,7 @@ function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64)
S22 = 0.0 S22 = 0.0
for p in particles for p in particles
shifted_c = restrict_coordinates!(p.c - COM, half_box_len) shifted_c = restrict_coordinates(p.c - COM, half_box_len)
S11 += shifted_c[1]^2 S11 += shifted_c[1]^2
S12 += shifted_c[1] * shifted_c[2] S12 += shifted_c[1] * shifted_c[2]

3
test/Project.toml Normal file
View file

@ -0,0 +1,3 @@
[deps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

70
test/runtests.jl Normal file
View file

@ -0,0 +1,70 @@
using Test
using ReCo
using StaticArrays: SVector
@testset "Particle.jl" begin
half_box_len = 1.0
@testset "restrict_coordinates" begin
@test ReCo.restrict_coordinates(SVector(1.5, -0.8), half_box_len)
SVector(-0.5, -0.8)
end
@testset "are_overlapping" begin
overlapping_r2 = 1.0
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
SVector(-0.4, 0.0), SVector(0.4, 0.0), overlapping_r2, half_box_len
)
@test overlapping == true
@test r⃗₁₂ SVector(0.8, 0.0)
@test distance2 0.8^2
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
SVector(-0.6, 0.0), SVector(0.6, 0.0), overlapping_r2, half_box_len
)
@test overlapping == true
@test r⃗₁₂ SVector(-0.8, 0.0)
@test distance2 0.8^2
overlapping_r2 = 0.5^2
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
SVector(-0.3, 0.0), SVector(0.3, 0.0), overlapping_r2, half_box_len
)
@test overlapping == false
@test r⃗₁₂ SVector(0.6, 0.0)
@test distance2 0.6^2
end
end
@testset "shape.jl" begin
n_particles = 10
v₀ = 0.0
sim_consts = ReCo.gen_sim_consts(n_particles, v₀)
@testset "gen_sim_consts" begin
@test sim_consts.n_particles == 16
end
half_box_len = sim_consts.half_box_len
@testset "project_to_unit_circle" begin
@test ReCo.project_to_unit_circle(0.0, half_box_len) SVector(-1.0, 0.0)
@test ReCo.project_to_unit_circle(half_box_len, half_box_len)
ReCo.project_to_unit_circle(-half_box_len, half_box_len)
SVector(1.0, 0.0)
end
particles = ReCo.gen_particles(
sim_consts.grid_n, sim_consts.grid_box_width, half_box_len
)
@testset "center_of_mass" begin
@test ReCo.center_of_mass(particles, half_box_len) SVector(0.0, 0.0)
end
@testset "gyration_tensor" begin
@test ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) == 1.0
end
end