简体   繁体   中英

no method matching logpdf when sampling from uniform distribution

I am trying to use reinforcement learning in julia to teach a car that is constantly being accelerated backwards (but with a positive initial velocity) to apply brakes so that it gets as close to a target distance as possible before moving backwards.

To do this, I am making use of POMDPs.jl and crux.jl which has many solvers (I'm using DQN). I will list what I believe to be the relevant parts of the script first, and then more of it towards the end.

To define the MDP, I set the initial position, velocity, and force from the brakes as a uniform distribution over some values.

@with_kw struct SliderMDP <: MDP{Array{Float32}, Array{Float32}}
        x0 = Distributions.Uniform(0., 80.)# Distribution to sample initial position
        v0 = Distributions.Uniform(0., 25.) # Distribution to sample initial velocity
        d0 = Distributions.Uniform(0., 2.) # Distribution to sample brake force
        ...
end

My state holds the values of (position, velocity, brake force) , and the initial state is given as:

function POMDPs.initialstate(mdp::SliderMDP)
    ImplicitDistribution((rng) -> Float32.([rand(rng, mdp.x0), rand(rng, mdp.v0), rand(rng, mdp.d0)]))
end

Then, I set up my DQN solver using crux.jl and called a function to solve for the policy

solver_dqn = DQN(π=Q_network(), S=s, N=30000)
policy_dqn = solve(solver_dqn, mdp)

calling solve() gives me the error MethodError: no method matching logpdf(::Distributions.Categorical{Float64, Vector{Float64}}, ::Nothing) . I am quite sure that this comes from the initial state sampling, but I am not sure why or how to fix it. I have only been learning RL from various books and online lectures for a very short time, so any help regarding the error or my the model I set up (or anything else I'm oblivious to) would be appreciated.


More comprehensive code:

Packages:

using POMDPs
using POMDPModelTools
using POMDPPolicies
using POMDPSimulators

using Parameters
using Random

using Crux
using Flux

using Distributions

Rest of it:

@with_kw struct SliderMDP <: MDP{Array{Float32}, Array{Float32}}
    x0 = Distributions.Uniform(0., 80.)# Distribution to sample initial position
    v0 = Distributions.Uniform(0., 25.) # Distribution to sample initial velocity
    d0 = Distributions.Uniform(0., 2.) # Distribution to sample brake force
    
    m::Float64 = 1.
    tension::Float64 = 3.
    dmax::Float64 = 2.
    target::Float64 = 80.
    dt::Float64 = .05
    
    γ::Float32 = 1.
    actions::Vector{Float64} = [-.1, 0., .1]
end
    
function POMDPs.gen(env::SliderMDP, s, a, rng::AbstractRNG = Random.GLOBAL_RNG)
    x, ẋ, d = s

    if x >= env.target
        a = .1
    end
    if d+a >= env.dmax || d+a <= 0
        a = 0.
    end
    
    force = (d + env.tension) * -1
    ẍ = force/env.m
    
    # Simulation
    x_ = x + env.dt * ẋ
    ẋ_ = ẋ + env.dt * ẍ
    d_ = d + a

    sp = vcat(x_, ẋ_, d_)
    reward = abs(env.target - x) * -1
        
    return (sp=sp, r=reward)
end

    

function POMDPs.initialstate(mdp::SliderMDP)
    ImplicitDistribution((rng) -> Float32.([rand(rng, mdp.x0), rand(rng, mdp.v0), rand(rng, mdp.d0)]))
end
    
POMDPs.isterminal(mdp::SliderMDP, s) = s[2] <= 0
POMDPs.discount(mdp::SliderMDP) = mdp.γ

mdp = SliderMDP();
s = state_space(mdp); # Using Crux.jl

function Q_network()
    layer1 = Dense(3, 64, relu)
    layer2 = Dense(64, 64, relu)
    layer3 = Dense(64, length(3))
    return DiscreteNetwork(Chain(layer1, layer2, layer3), [-.1, 0, .1])
end

solver_dqn = DQN(π=Q_network(), S=s, N=30000) # Using Crux.jl
policy_dqn = solve(solver_dqn, mdp) # Error comes here

Short answer :

Change your output vector to Float32 ie Float32[-.1, 0, .1] .

Long answer :

Crux creates a Distribution over your network's output values, and at some point (policies.jl:298) samples a random value from it. It then converts this value to a Float32 . Later (utils.jl:15) it does a findfirst to find the index of this value in the original output array (stored as objs within the distribution), but because the original array is still Float64 , this fails and returns a nothing . Hence the error.

I believe this (converting the sampled value but not the objs array) to be a bug, and would encourage you to raise this as an issue on Github.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM