简体   繁体   中英

unordered_map for custom class does not cause error when inserting the same key

I'm trying to figure out some points on using unordered_map for custom class. Below are the codes I use to take exercise where I define a simple class Line . I'm confused that why the insertion of Line2 in main() does not make the program outputs insert failed when the value of m for Line1 and Line2 are both 3 . Note since I only compare the first value (ie m ) in the operator== function in class Line , thus Line1 and Line2 in this code should have the same key. Shouldn't the insertion of an already existed key be invalid? Could some one explain why to me? Thanks!

#include<iostream>                                                                                                                                                                                                                                                                                                                                                                                                                                          
#include<unordered_map>                                                                                                                                                                                                                    

using namespace std;                                                                                                                                                                                                                       
class Line {                                                                                                                                                                                                                               
public:                                                                                                                                                                                                                                    
  float m;                                                                                                                                                                                                                                 
  float c;                                                                                                                                                                                                                                 

  Line() {m = 0; c = 0;}                                                                                                                                                                                                                   
  Line(float mInput, float cInput) {m = mInput; c = cInput;}                                                                                                                                                                               
  float getM() const {return m;}                                                                                                                                                                                                           
  float getC() const {return c;}                                                                                                                                                                                                           
  void setM(float mInput) {m = mInput;}                                                                                                                                                                                                    
  void setC(float cInput) {c = cInput;}                                                                                                                                                                                                    

  bool operator==(const Line &anotherLine) const                                                                                                                                                                                           
    {                                                                                                                                                                                                                                      
      return (m == anotherLine.m);                                                                                                                                                                                                         
    }                                                                                                                                                                                                                                      
};                                                                                                                                                                                                                                         

namespace std                                                                                                                                                                                                                              
{                                                                                                                                                                                                                                          
  template <>                                                                                                                                                                                                                              
  struct hash<Line>                                                                                                                                                                                                                        
  {                                                                                                                                                                                                                                        
    size_t operator()(const Line& k) const                                                                                                                                                                                                 
      {                                                                                                                                                                                                                                    
        // Compute individual hash values for two data members and combine them using XOR and bit shifting                                                                                                                                 
        return ((hash<float>()(k.getM()) ^ (hash<float>()(k.getC()) << 1)) >> 1);                                                                                                                                                          
      }                                                                                                                                                                                                                                    
  };                                                                                                                                                                                                                                       
}                                                                                                                                                                                                                                          

int main()                                                                                                                                                                                                                                 
{                                                                                                                                                                                                                                          
  unordered_map<Line, int> t;                                                                                                                                                                                                              

  Line line1 = Line(3.0,4.0);                                                                                                                                                                                                              
  Line line2 = Line(3.0,5.0);                                                                                                                                                                                                              

  t.insert({line1, 1});                                                                                                                                                                                                                                                                                                                                                                                                                                      
  auto x = t.insert({line2, 2});                                                                                                                                                                                                           
  if (x.second == false)                                                                                                                                                                                                                   
    cout << "insert failed" << endl;                                                                                                                                                                                                       

  for(unordered_map<Line, int>::const_iterator it = t.begin(); it != t.end(); it++)                                                                                                                                                        
  {                                                                                                                                                                                                                                        
    Line t = it->first;                                                                                                                                                                                                                    
    cout << t.m << " " << t.c << "\n" ;                                                                                                                                                                                                    
  }                                                                                                                                                                                                                                        

  return 1;                                                                                                                                                                                                                                
}    

Your hash and operator == must satisfy a consistency requirement that they currently violate. When two objects are equal according to == , their hash codes must be equal according to hash . In other words, while non-equal objects may have the same hash code, equal objects must have the same hash code:

size_t operator()(const Line& k) const  {
    return hash<float>()(k.getM());
}   

Since you compare only one component for equality, and ignore the other component, you need to change your hash function to use the same component that you use to decide the equality.

you are using both the values of "m" and "c" in you hash, so 2 "Line" instances would have the same key if both their "m" and "c" are equal, which is not the case in your example. So if you do this :

Line line1 = Line(3.0,4.0);                                                                                                                                                                                                              
Line line2 = Line(3.0,4.0);                                                                                                                                                                                                              

t.insert({line1, 1});                                                                                                                                                                                                                                                                                                                                                                                                                                      
auto x = t.insert({line2, 2});                                                                                                                                                                                                           
if (x.second == false)                                                                                                                                                                                                                   
  cout << "insert failed" << endl;  

you'll see that it will print "insert failed"

You can always use custom function to compare the keys upon insertion:

#include <iostream>
#include <unordered_map>

class Line {
private:
    float m;
    float c;
public:
    Line() { m = 0; c = 0; }
    Line(float mInput, float cInput) { m = mInput; c = cInput; }
    float getM() const { return m; }
    float getC() const { return c; }
};


struct hash
{
    size_t operator()(const Line& k) const 
    {
        return ((std::hash<float>()(k.getM()) ^ (std::hash<float>()(k.getC()) << 1)) >> 1);
    }
};

// custom key comparison
struct cmpKey
{
    bool operator() (Line const &l1, Line const &l2) const
    {
        return l1.getM() == l2.getM();
    }
};


int main()
{ 

    std::unordered_map<Line, int, hash, cmpKey> mymap; // with custom key comparisom

    Line line1 = Line(3.0, 4.0);
    Line line2 = Line(4.0, 5.0);
    Line line3 = Line(4.0, 4.0);

    auto x = mymap.insert({ line1, 1 });
    std::cout << std::boolalpha << "element inserted: " << x.second << std::endl;
    x = mymap.insert({ line2, 2 });
    std::cout << std::boolalpha << "element inserted: " << x.second << std::endl;
    x = mymap.insert({ line3, 3 });
    std::cout << std::boolalpha << "element inserted: " << x.second << std::endl;

    return 0;
}

Prints:

element inserted: true
element inserted: true
element inserted: false

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM