简体   繁体   English

如何用serde_json将“NaN”反序列化为“nan”?

[英]How to deserialize “NaN” as `nan` with serde_json?

I have datatypes which look like this: 我的数据类型如下所示:

#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Matrix {
    #[serde(rename = "numColumns")]
    pub num_cols: usize,
    #[serde(rename = "numRows")]
    pub num_rows: usize,
    pub data: Vec<f64>,
}

My JSON bodies look something like this: 我的JSON主体看起来像这样:

{
    "numRows": 2,
    "numColumns": 1,
    "data": [1.0, "NaN"]
}

This is the serialization provided by Jackson (from a Java server we use), and is valid JSON. 这是Jackson提供的序列化(来自我们使用的Java服务器),并且是有效的JSON。 Unfortunately if we call serde_json::from_str(&blob) we get an error: 不幸的是,如果我们调用serde_json::from_str(&blob)我们会收到错误:

Error("invalid type: string "NaN", expected f64", [snip]

I understand there are subtleties around floating point numbers and people get very opinionated about the way things ought to be. 我知道浮点数有微妙之处,人们对事情应该是如何看待自己。 I respect that. 我尊重。 Rust in particular likes to be very opinionated, and I like that. Rust特别喜欢非常自以为是,我喜欢这样。

However at the end of the day these JSON blobs are what I'm going to receive, and I need that "NaN" string to deserialize to some f64 value where is_nan() is true, and which serialized back to the string "NaN" , because the rest of the ecosystem uses Jackson and this is fine there. 然而,在一天结束时,这些JSON blob是我将要收到的,我需要将"NaN"字符串反序列化为某些f64值,其中is_nan()为true,并且序列化为字符串"NaN" ,因为生态系统的其余部分使用杰克逊,这在那里很好。

Can this be achieved in a reasonable way? 这可以以合理的方式实现吗?

Edit: the suggested linked questions talk about overriding the derived derializer, but they do not explain how to deserialize floats specifically. 编辑:建议的链接问题讨论覆盖派生的派生程序,但它们没有解释如何专门反序列化浮点数。

It actually seems like using a custom deserializer inside a Vec (or Map or etc.) is an open issue on serde and has been for a little over a year (as of time of writing): https://github.com/serde-rs/serde/issues/723 实际上,在Vec(或Map等)中使用自定义反序列化器似乎是一个关于serde的开放问题,并且已经持续了一年多(截至编写本文时): https//github.com/serde -rs / SERDE /问题/ 723

I believe the solution is to write a custom deserializer for f64 (which is fine), as well as everything which uses f64 as a subthing (eg Vec<f64> , HashMap<K, f64> , etc.). 我相信解决方案是为f64编写自定义反序列化器(这很好),以及使用f64作为子函数的所有内容(例如Vec<f64>HashMap<K, f64>等)。 Unfortunately it does not seem like these things are composable, as implementations of these methods look like 不幸的是,这些东西似乎不是可组合的,因为这些方法的实现看起来像

deserialize<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where D: Deserializer<'de> { /* snip */ }

and once you have a Deserializer you can only interact with it through visitors. 一旦你有一个Deserializer,你只能通过访问者与它进行交互。

Long story short, I eventually got it working, but it seems like a lot of code that shouldn't be necessary. 长话短说,我最终得到了它,但似乎很多代码不应该是必要的。 Posting it here in the hopes that either (a) someone knows how to clean this up, or (b) this is really how it should be done, and this answer will be useful to someone. 在这里发布,希望(a)有人知道如何清理它,或者(b)这是真的应该如何完成,这个答案对某人有用。 I've spent a whole day fervently reading docs and making trial and error guesses, so maybe this will be useful to someone else. 我花了一整天热情地阅读文档并进行反复试验,所以这可能会对其他人有用。 The functions (de)serialize_float(s) should be used with an appropriate #[serde( (de)serialize_with="etc." )] above the field name. 函数(de)serialize_float(s)应与字段名称上方的相应#[serde( (de)serialize_with="etc." )]

use serde::de::{self, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;

type Float = f64;

const NAN: Float = std::f64::NAN;

struct NiceFloat(Float);

impl Serialize for NiceFloat {
    #[inline]
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        serialize_float(&self.0, serializer)
    }
}

pub fn serialize_float<S>(x: &Float, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    if x.is_nan() {
        serializer.serialize_str("NaN")
    } else {
        serializer.serialize_f64(*x)
    }
}

pub fn serialize_floats<S>(floats: &[Float], serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    let mut seq = serializer.serialize_seq(Some(floats.len()))?;

    for f in floats {
        seq.serialize_element(&NiceFloat(*f))?;
    }

    seq.end()
}

struct FloatDeserializeVisitor;

impl<'de> Visitor<'de> for FloatDeserializeVisitor {
    type Value = Float;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("a float or the string \"NaN\"")
    }

    fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(v as Float)
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        if v == "NaN" {
            Ok(NAN)
        } else {
            Err(E::invalid_value(de::Unexpected::Str(v), &self))
        }
    }
}

struct NiceFloatDeserializeVisitor;

impl<'de> Visitor<'de> for NiceFloatDeserializeVisitor {
    type Value = NiceFloat;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("a float or the string \"NaN\"")
    }

    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(NiceFloat(v as Float))
    }

    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        Ok(NiceFloat(v as Float))
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: de::Error,
    {
        if v == "NaN" {
            Ok(NiceFloat(NAN))
        } else {
            Err(E::invalid_value(de::Unexpected::Str(v), &self))
        }
    }
}

pub fn deserialize_float<'de, D>(deserializer: D) -> Result<Float, D::Error>
where
    D: Deserializer<'de>,
{
    deserializer.deserialize_any(FloatDeserializeVisitor)
}

impl<'de> Deserialize<'de> for NiceFloat {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        let raw = deserialize_float(deserializer)?;
        Ok(NiceFloat(raw))
    }
}

pub struct VecDeserializeVisitor<T>(std::marker::PhantomData<T>);

impl<'de, T> Visitor<'de> for VecDeserializeVisitor<T>
where
    T: Deserialize<'de> + Sized,
{
    type Value = Vec<T>;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("A sequence of floats or \"NaN\" string values")
    }

    fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
    where
        S: SeqAccess<'de>,
    {
        let mut out = Vec::with_capacity(seq.size_hint().unwrap_or(0));

        while let Some(value) = seq.next_element()? {
            out.push(value);
        }

        Ok(out)
    }
}

pub fn deserialize_floats<'de, D>(deserializer: D) -> Result<Vec<Float>, D::Error>
where
    D: Deserializer<'de>,
{
    let visitor: VecDeserializeVisitor<NiceFloat> = VecDeserializeVisitor(std::marker::PhantomData);

    let seq: Vec<NiceFloat> = deserializer.deserialize_seq(visitor)?;

    let raw: Vec<Float> = seq.into_iter().map(|nf| nf.0).collect::<Vec<Float>>();

    Ok(raw)
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM